Unverified Commit 6ad6c8c9 authored by eigen's avatar eigen Committed by GitHub
Browse files

feat: openai oss attention sink support with trtllm-gen backend #8825 (#8834)


Co-authored-by: default avataraveryhuang <averyh@nvidia.com>
parent 5b6acc14
from __future__ import annotations
"""
Support attention backend for TRTLLM MLA kernels from flashinfer.
Support attention backend for TRTLLM MHA kernels from flashinfer.
The kernel supports sm100 only, with sliding window and attention sink features.
"""
from dataclasses import dataclass
......@@ -57,11 +58,6 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
# MHA-specific dimensions
self.max_context_len = model_runner.model_config.context_len
self.sliding_window_size = (
model_runner.sliding_window_size
if model_runner.sliding_window_size is not None
else -1 # -1 indicates full attention
)
self.hidden_size = config.hidden_size
# Runtime parameters
......@@ -117,10 +113,10 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
metadata = TRTLLMMHAMetadata()
# Get sequence information
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
metadata.cache_seqlens_int32 = seq_lens[:bs].to(torch.int32)
# Precompute maximum sequence length
metadata.max_seq_len_k = seq_lens.max().item()
metadata.max_seq_len_k = self.max_context_len
# Precompute page table
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][:bs, :]
......@@ -149,7 +145,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
metadata = self.decode_cuda_graph_metadata[bs]
max_len = seq_lens_cpu.max().item()
max_seq_pages = (max_len + self.page_size - 1) // self.page_size
metadata.max_seq_len_k = max_len
metadata.max_seq_len_k = self.max_context_len
metadata.cache_seqlens_int32.copy_(seq_lens)
page_indices = self.req_to_token[
......@@ -217,6 +213,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
**kwargs,
) -> torch.Tensor:
"""Run forward for decode using TRTLLM MHA kernel."""
cache_loc = forward_batch.out_cache_loc
......@@ -228,7 +225,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
# shape conversion:
# [bs, page_size, num_kv_heads, head_dim] -> [bs, num_kv_heads, page_size, head_dim]
# [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim]
k_cache = k_cache.view(
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
).permute(0, 2, 1, 3)
......@@ -237,7 +234,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
).permute(0, 2, 1, 3)
kv_cache = (k_cache, v_cache)
# TODO: bmm1_scale and bmm2_scale might require modification
# TODO: add support for quantization
q_scale = 1.0
k_scale = (
layer.k_scale_float
......@@ -246,6 +243,8 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
)
bmm1_scale = q_scale * k_scale * layer.scaling
bmm2_scale = 1.0
# sink: additional value per head in the denominator of the softmax.
attention_sink = kwargs.get("sinks", None)
# Call TRT-LLM kernel
# raw_out: like q, [bs, acc_q_len, num_q_heads, head_dim] but with output dtype
......@@ -258,8 +257,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
max_seq_len=self.forward_metadata.max_seq_len_k,
bmm1_scale=bmm1_scale,
bmm2_scale=bmm2_scale,
window_left=self.sliding_window_size,
window_left=layer.sliding_window_size,
# TODO: add attention_sink operation or nvfp4 scale factor if needed
sinks=attention_sink,
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
......@@ -272,6 +272,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
**kwargs,
):
cache_loc = forward_batch.out_cache_loc
if save_kv_cache and k is not None:
......@@ -279,6 +280,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
# [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim]
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
k_cache = k_cache.view(
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
......@@ -288,8 +290,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
).permute(0, 2, 1, 3)
kv_cache = (k_cache, v_cache)
# TODO: bmm1_scale and bmm2_scale might require modification
# TODO: Change once quantization is supported
# sink: additional value per head in the denominator of the softmax.
attention_sink = kwargs.get("sinks", None)
# TODO: add support for quantization
q_scale = 1.0
k_scale = (
layer.k_scale_float
......@@ -312,8 +315,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
batch_size=forward_batch.batch_size,
cum_seq_lens_q=self.forward_metadata.cu_seqlens_q,
cum_seq_lens_kv=self.forward_metadata.cu_seqlens_k,
window_left=self.sliding_window_size,
window_left=layer.sliding_window_size,
# TODO: add attention_sink operation or nvfp4 scale factor if needed
sinks=attention_sink,
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
......@@ -1443,13 +1443,13 @@ class ModelRunner:
)
return CutlassMLABackend(self)
elif self.server_args.attention_backend == "trtllm_mla":
elif backend_str == "trtllm_mla":
if not self.use_mla_backend:
raise ValueError("trtllm_mla backend can only be used with MLA models.")
from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
return TRTLLMMLABackend(self)
elif self.server_args.attention_backend == "trtllm_mha":
elif backend_str == "trtllm_mha":
if self.use_mla_backend:
raise ValueError(
"trtllm_mha backend can only be used with non-MLA models."
......@@ -1460,7 +1460,7 @@ class ModelRunner:
return TRTLLMHAAttnBackend(self)
elif self.server_args.attention_backend == "intel_amx":
elif backend_str == "intel_amx":
from sglang.srt.layers.attention.intel_amx_backend import (
IntelAMXAttnBackend,
)
......
......@@ -301,7 +301,7 @@ class GptOssAttention(nn.Module):
hidden_states, forward_batch, inner_state = intermediate_state
if inner_state is None:
return hidden_states
attn_output = self.attn(*inner_state, sinks=self.sinks)
attn_output = self.attn(*inner_state, sinks=self.sinks.to(torch.float32))
output, _ = self.o_proj(attn_output)
return output
......
......@@ -445,7 +445,11 @@ class ServerArgs:
"trtllm_mla backend does not support speculative decoding yet."
)
if self.attention_backend == "trtllm_mha":
if (
self.attention_backend == "trtllm_mha"
or self.decode_attention_backend == "trtllm_mha"
or self.prefill_attention_backend == "trtllm_mha"
):
if not is_sm100_supported():
raise ValueError(
"TRTLLM MHA backend is only supported on Blackwell GPUs (SM100). Please use a different backend."
......@@ -459,11 +463,18 @@ class ServerArgs:
if self.speculative_algorithm is not None:
raise ValueError(
"trtllm_mla backend does not support speculative decoding yet."
"trtllm_mha backend does not support speculative decoding yet."
)
model_arch = self.get_hf_config().architectures[0]
if model_arch in ["GptOssForCausalLM"]:
self.attention_backend = "triton"
if self.attention_backend is None:
# default is triton, but we could have trtllm_mha as an option
self.attention_backend = "triton"
assert (
self.attention_backend == "trtllm_mha"
or self.attention_backend == "triton"
)
# Check if FlashInfer MXFP4 MoE is enabled
from sglang.srt.utils import get_bool_env_var
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment