Unverified Commit 6ad6c8c9 authored by eigen's avatar eigen Committed by GitHub
Browse files

feat: openai oss attention sink support with trtllm-gen backend #8825 (#8834)


Co-authored-by: default avataraveryhuang <averyh@nvidia.com>
parent 5b6acc14
from __future__ import annotations from __future__ import annotations
""" """
Support attention backend for TRTLLM MLA kernels from flashinfer. Support attention backend for TRTLLM MHA kernels from flashinfer.
The kernel supports sm100 only, with sliding window and attention sink features.
""" """
from dataclasses import dataclass from dataclasses import dataclass
...@@ -57,11 +58,6 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -57,11 +58,6 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
# MHA-specific dimensions # MHA-specific dimensions
self.max_context_len = model_runner.model_config.context_len self.max_context_len = model_runner.model_config.context_len
self.sliding_window_size = (
model_runner.sliding_window_size
if model_runner.sliding_window_size is not None
else -1 # -1 indicates full attention
)
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
# Runtime parameters # Runtime parameters
...@@ -117,10 +113,10 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -117,10 +113,10 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
metadata = TRTLLMMHAMetadata() metadata = TRTLLMMHAMetadata()
# Get sequence information # Get sequence information
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32) metadata.cache_seqlens_int32 = seq_lens[:bs].to(torch.int32)
# Precompute maximum sequence length # Precompute maximum sequence length
metadata.max_seq_len_k = seq_lens.max().item() metadata.max_seq_len_k = self.max_context_len
# Precompute page table # Precompute page table
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][:bs, :] metadata.page_table = self.decode_cuda_graph_metadata["page_table"][:bs, :]
...@@ -149,7 +145,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -149,7 +145,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
metadata = self.decode_cuda_graph_metadata[bs] metadata = self.decode_cuda_graph_metadata[bs]
max_len = seq_lens_cpu.max().item() max_len = seq_lens_cpu.max().item()
max_seq_pages = (max_len + self.page_size - 1) // self.page_size max_seq_pages = (max_len + self.page_size - 1) // self.page_size
metadata.max_seq_len_k = max_len metadata.max_seq_len_k = self.max_context_len
metadata.cache_seqlens_int32.copy_(seq_lens) metadata.cache_seqlens_int32.copy_(seq_lens)
page_indices = self.req_to_token[ page_indices = self.req_to_token[
...@@ -217,6 +213,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -217,6 +213,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
layer: RadixAttention, layer: RadixAttention,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache: bool = True, save_kv_cache: bool = True,
**kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
"""Run forward for decode using TRTLLM MHA kernel.""" """Run forward for decode using TRTLLM MHA kernel."""
cache_loc = forward_batch.out_cache_loc cache_loc = forward_batch.out_cache_loc
...@@ -228,7 +225,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -228,7 +225,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
# shape conversion: # shape conversion:
# [bs, page_size, num_kv_heads, head_dim] -> [bs, num_kv_heads, page_size, head_dim] # [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim]
k_cache = k_cache.view( k_cache = k_cache.view(
-1, self.page_size, layer.tp_k_head_num, layer.head_dim -1, self.page_size, layer.tp_k_head_num, layer.head_dim
).permute(0, 2, 1, 3) ).permute(0, 2, 1, 3)
...@@ -237,7 +234,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -237,7 +234,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
).permute(0, 2, 1, 3) ).permute(0, 2, 1, 3)
kv_cache = (k_cache, v_cache) kv_cache = (k_cache, v_cache)
# TODO: bmm1_scale and bmm2_scale might require modification # TODO: add support for quantization
q_scale = 1.0 q_scale = 1.0
k_scale = ( k_scale = (
layer.k_scale_float layer.k_scale_float
...@@ -246,6 +243,8 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -246,6 +243,8 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
) )
bmm1_scale = q_scale * k_scale * layer.scaling bmm1_scale = q_scale * k_scale * layer.scaling
bmm2_scale = 1.0 bmm2_scale = 1.0
# sink: additional value per head in the denominator of the softmax.
attention_sink = kwargs.get("sinks", None)
# Call TRT-LLM kernel # Call TRT-LLM kernel
# raw_out: like q, [bs, acc_q_len, num_q_heads, head_dim] but with output dtype # raw_out: like q, [bs, acc_q_len, num_q_heads, head_dim] but with output dtype
...@@ -258,8 +257,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -258,8 +257,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
max_seq_len=self.forward_metadata.max_seq_len_k, max_seq_len=self.forward_metadata.max_seq_len_k,
bmm1_scale=bmm1_scale, bmm1_scale=bmm1_scale,
bmm2_scale=bmm2_scale, bmm2_scale=bmm2_scale,
window_left=self.sliding_window_size, window_left=layer.sliding_window_size,
# TODO: add attention_sink operation or nvfp4 scale factor if needed # TODO: add attention_sink operation or nvfp4 scale factor if needed
sinks=attention_sink,
) )
return o.view(-1, layer.tp_q_head_num * layer.head_dim) return o.view(-1, layer.tp_q_head_num * layer.head_dim)
...@@ -272,6 +272,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -272,6 +272,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
layer: RadixAttention, layer: RadixAttention,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache=True, save_kv_cache=True,
**kwargs,
): ):
cache_loc = forward_batch.out_cache_loc cache_loc = forward_batch.out_cache_loc
if save_kv_cache and k is not None: if save_kv_cache and k is not None:
...@@ -279,6 +280,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -279,6 +280,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
layer, cache_loc, k, v, layer.k_scale, layer.v_scale layer, cache_loc, k, v, layer.k_scale, layer.v_scale
) )
q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
# [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim]
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
k_cache = k_cache.view( k_cache = k_cache.view(
-1, self.page_size, layer.tp_k_head_num, layer.head_dim -1, self.page_size, layer.tp_k_head_num, layer.head_dim
...@@ -288,8 +290,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -288,8 +290,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
).permute(0, 2, 1, 3) ).permute(0, 2, 1, 3)
kv_cache = (k_cache, v_cache) kv_cache = (k_cache, v_cache)
# TODO: bmm1_scale and bmm2_scale might require modification # sink: additional value per head in the denominator of the softmax.
# TODO: Change once quantization is supported attention_sink = kwargs.get("sinks", None)
# TODO: add support for quantization
q_scale = 1.0 q_scale = 1.0
k_scale = ( k_scale = (
layer.k_scale_float layer.k_scale_float
...@@ -312,8 +315,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ...@@ -312,8 +315,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
batch_size=forward_batch.batch_size, batch_size=forward_batch.batch_size,
cum_seq_lens_q=self.forward_metadata.cu_seqlens_q, cum_seq_lens_q=self.forward_metadata.cu_seqlens_q,
cum_seq_lens_kv=self.forward_metadata.cu_seqlens_k, cum_seq_lens_kv=self.forward_metadata.cu_seqlens_k,
window_left=self.sliding_window_size, window_left=layer.sliding_window_size,
# TODO: add attention_sink operation or nvfp4 scale factor if needed # TODO: add attention_sink operation or nvfp4 scale factor if needed
sinks=attention_sink,
) )
return o.view(-1, layer.tp_q_head_num * layer.head_dim) return o.view(-1, layer.tp_q_head_num * layer.head_dim)
...@@ -1443,13 +1443,13 @@ class ModelRunner: ...@@ -1443,13 +1443,13 @@ class ModelRunner:
) )
return CutlassMLABackend(self) return CutlassMLABackend(self)
elif self.server_args.attention_backend == "trtllm_mla": elif backend_str == "trtllm_mla":
if not self.use_mla_backend: if not self.use_mla_backend:
raise ValueError("trtllm_mla backend can only be used with MLA models.") raise ValueError("trtllm_mla backend can only be used with MLA models.")
from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
return TRTLLMMLABackend(self) return TRTLLMMLABackend(self)
elif self.server_args.attention_backend == "trtllm_mha": elif backend_str == "trtllm_mha":
if self.use_mla_backend: if self.use_mla_backend:
raise ValueError( raise ValueError(
"trtllm_mha backend can only be used with non-MLA models." "trtllm_mha backend can only be used with non-MLA models."
...@@ -1460,7 +1460,7 @@ class ModelRunner: ...@@ -1460,7 +1460,7 @@ class ModelRunner:
return TRTLLMHAAttnBackend(self) return TRTLLMHAAttnBackend(self)
elif self.server_args.attention_backend == "intel_amx": elif backend_str == "intel_amx":
from sglang.srt.layers.attention.intel_amx_backend import ( from sglang.srt.layers.attention.intel_amx_backend import (
IntelAMXAttnBackend, IntelAMXAttnBackend,
) )
......
...@@ -301,7 +301,7 @@ class GptOssAttention(nn.Module): ...@@ -301,7 +301,7 @@ class GptOssAttention(nn.Module):
hidden_states, forward_batch, inner_state = intermediate_state hidden_states, forward_batch, inner_state = intermediate_state
if inner_state is None: if inner_state is None:
return hidden_states return hidden_states
attn_output = self.attn(*inner_state, sinks=self.sinks) attn_output = self.attn(*inner_state, sinks=self.sinks.to(torch.float32))
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
......
...@@ -445,7 +445,11 @@ class ServerArgs: ...@@ -445,7 +445,11 @@ class ServerArgs:
"trtllm_mla backend does not support speculative decoding yet." "trtllm_mla backend does not support speculative decoding yet."
) )
if self.attention_backend == "trtllm_mha": if (
self.attention_backend == "trtllm_mha"
or self.decode_attention_backend == "trtllm_mha"
or self.prefill_attention_backend == "trtllm_mha"
):
if not is_sm100_supported(): if not is_sm100_supported():
raise ValueError( raise ValueError(
"TRTLLM MHA backend is only supported on Blackwell GPUs (SM100). Please use a different backend." "TRTLLM MHA backend is only supported on Blackwell GPUs (SM100). Please use a different backend."
...@@ -459,11 +463,18 @@ class ServerArgs: ...@@ -459,11 +463,18 @@ class ServerArgs:
if self.speculative_algorithm is not None: if self.speculative_algorithm is not None:
raise ValueError( raise ValueError(
"trtllm_mla backend does not support speculative decoding yet." "trtllm_mha backend does not support speculative decoding yet."
) )
model_arch = self.get_hf_config().architectures[0] model_arch = self.get_hf_config().architectures[0]
if model_arch in ["GptOssForCausalLM"]: if model_arch in ["GptOssForCausalLM"]:
self.attention_backend = "triton" if self.attention_backend is None:
# default is triton, but we could have trtllm_mha as an option
self.attention_backend = "triton"
assert (
self.attention_backend == "trtllm_mha"
or self.attention_backend == "triton"
)
# Check if FlashInfer MXFP4 MoE is enabled # Check if FlashInfer MXFP4 MoE is enabled
from sglang.srt.utils import get_bool_env_var from sglang.srt.utils import get_bool_env_var
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment