Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
6ad6c8c9
Unverified
Commit
6ad6c8c9
authored
Aug 06, 2025
by
eigen
Committed by
GitHub
Aug 06, 2025
Browse files
feat: openai oss attention sink support with trtllm-gen backend #8825 (#8834)
Co-authored-by:
averyhuang
<
averyh@nvidia.com
>
parent
5b6acc14
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
37 additions
and
22 deletions
+37
-22
python/sglang/srt/layers/attention/trtllm_mha_backend.py
python/sglang/srt/layers/attention/trtllm_mha_backend.py
+19
-15
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+3
-3
python/sglang/srt/models/gpt_oss.py
python/sglang/srt/models/gpt_oss.py
+1
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+14
-3
No files found.
python/sglang/srt/layers/attention/trtllm_mha_backend.py
View file @
6ad6c8c9
from
__future__
import
annotations
"""
Support attention backend for TRTLLM MLA kernels from flashinfer.
Support attention backend for TRTLLM MHA kernels from flashinfer.
The kernel supports sm100 only, with sliding window and attention sink features.
"""
from
dataclasses
import
dataclass
...
...
@@ -57,11 +58,6 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
# MHA-specific dimensions
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
self
.
sliding_window_size
=
(
model_runner
.
sliding_window_size
if
model_runner
.
sliding_window_size
is
not
None
else
-
1
# -1 indicates full attention
)
self
.
hidden_size
=
config
.
hidden_size
# Runtime parameters
...
...
@@ -117,10 +113,10 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
metadata
=
TRTLLMMHAMetadata
()
# Get sequence information
metadata
.
cache_seqlens_int32
=
seq_lens
.
to
(
torch
.
int32
)
metadata
.
cache_seqlens_int32
=
seq_lens
[:
bs
]
.
to
(
torch
.
int32
)
# Precompute maximum sequence length
metadata
.
max_seq_len_k
=
se
q_lens
.
max
().
item
()
metadata
.
max_seq_len_k
=
se
lf
.
max_context_len
# Precompute page table
metadata
.
page_table
=
self
.
decode_cuda_graph_metadata
[
"page_table"
][:
bs
,
:]
...
...
@@ -149,7 +145,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
metadata
=
self
.
decode_cuda_graph_metadata
[
bs
]
max_len
=
seq_lens_cpu
.
max
().
item
()
max_seq_pages
=
(
max_len
+
self
.
page_size
-
1
)
//
self
.
page_size
metadata
.
max_seq_len_k
=
max
_len
metadata
.
max_seq_len_k
=
self
.
max_context
_len
metadata
.
cache_seqlens_int32
.
copy_
(
seq_lens
)
page_indices
=
self
.
req_to_token
[
...
...
@@ -217,6 +213,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
=
True
,
**
kwargs
,
)
->
torch
.
Tensor
:
"""Run forward for decode using TRTLLM MHA kernel."""
cache_loc
=
forward_batch
.
out_cache_loc
...
...
@@ -228,7 +225,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
k_cache
,
v_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
)
# shape conversion:
# [
b
s, page_size, num_kv_heads, head_dim] -> [
b
s, num_kv_heads, page_size, head_dim]
# [
num_page
s, page_size, num_kv_heads, head_dim] -> [
num_page
s, num_kv_heads, page_size, head_dim]
k_cache
=
k_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
permute
(
0
,
2
,
1
,
3
)
...
...
@@ -237,7 +234,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
).
permute
(
0
,
2
,
1
,
3
)
kv_cache
=
(
k_cache
,
v_cache
)
# TODO:
bmm1_scale and bmm2_scale might require modific
ation
# TODO:
add support for quantiz
ation
q_scale
=
1.0
k_scale
=
(
layer
.
k_scale_float
...
...
@@ -246,6 +243,8 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
)
bmm1_scale
=
q_scale
*
k_scale
*
layer
.
scaling
bmm2_scale
=
1.0
# sink: additional value per head in the denominator of the softmax.
attention_sink
=
kwargs
.
get
(
"sinks"
,
None
)
# Call TRT-LLM kernel
# raw_out: like q, [bs, acc_q_len, num_q_heads, head_dim] but with output dtype
...
...
@@ -258,8 +257,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
max_seq_len
=
self
.
forward_metadata
.
max_seq_len_k
,
bmm1_scale
=
bmm1_scale
,
bmm2_scale
=
bmm2_scale
,
window_left
=
self
.
sliding_window_size
,
window_left
=
layer
.
sliding_window_size
,
# TODO: add attention_sink operation or nvfp4 scale factor if needed
sinks
=
attention_sink
,
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
...
...
@@ -272,6 +272,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
**
kwargs
,
):
cache_loc
=
forward_batch
.
out_cache_loc
if
save_kv_cache
and
k
is
not
None
:
...
...
@@ -279,6 +280,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
)
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
# [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim]
k_cache
,
v_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
)
k_cache
=
k_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
,
layer
.
head_dim
...
...
@@ -288,8 +290,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
).
permute
(
0
,
2
,
1
,
3
)
kv_cache
=
(
k_cache
,
v_cache
)
# TODO: bmm1_scale and bmm2_scale might require modification
# TODO: Change once quantization is supported
# sink: additional value per head in the denominator of the softmax.
attention_sink
=
kwargs
.
get
(
"sinks"
,
None
)
# TODO: add support for quantization
q_scale
=
1.0
k_scale
=
(
layer
.
k_scale_float
...
...
@@ -312,8 +315,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
batch_size
=
forward_batch
.
batch_size
,
cum_seq_lens_q
=
self
.
forward_metadata
.
cu_seqlens_q
,
cum_seq_lens_kv
=
self
.
forward_metadata
.
cu_seqlens_k
,
window_left
=
self
.
sliding_window_size
,
window_left
=
layer
.
sliding_window_size
,
# TODO: add attention_sink operation or nvfp4 scale factor if needed
sinks
=
attention_sink
,
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
python/sglang/srt/model_executor/model_runner.py
View file @
6ad6c8c9
...
...
@@ -1443,13 +1443,13 @@ class ModelRunner:
)
return
CutlassMLABackend
(
self
)
elif
self
.
server_args
.
attention_
backend
==
"trtllm_mla"
:
elif
backend
_str
==
"trtllm_mla"
:
if
not
self
.
use_mla_backend
:
raise
ValueError
(
"trtllm_mla backend can only be used with MLA models."
)
from
sglang.srt.layers.attention.trtllm_mla_backend
import
TRTLLMMLABackend
return
TRTLLMMLABackend
(
self
)
elif
self
.
server_args
.
attention_
backend
==
"trtllm_mha"
:
elif
backend
_str
==
"trtllm_mha"
:
if
self
.
use_mla_backend
:
raise
ValueError
(
"trtllm_mha backend can only be used with non-MLA models."
...
...
@@ -1460,7 +1460,7 @@ class ModelRunner:
return
TRTLLMHAAttnBackend
(
self
)
elif
self
.
server_args
.
attention_
backend
==
"intel_amx"
:
elif
backend
_str
==
"intel_amx"
:
from
sglang.srt.layers.attention.intel_amx_backend
import
(
IntelAMXAttnBackend
,
)
...
...
python/sglang/srt/models/gpt_oss.py
View file @
6ad6c8c9
...
...
@@ -301,7 +301,7 @@ class GptOssAttention(nn.Module):
hidden_states
,
forward_batch
,
inner_state
=
intermediate_state
if
inner_state
is
None
:
return
hidden_states
attn_output
=
self
.
attn
(
*
inner_state
,
sinks
=
self
.
sinks
)
attn_output
=
self
.
attn
(
*
inner_state
,
sinks
=
self
.
sinks
.
to
(
torch
.
float32
)
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
python/sglang/srt/server_args.py
View file @
6ad6c8c9
...
...
@@ -445,7 +445,11 @@ class ServerArgs:
"trtllm_mla backend does not support speculative decoding yet."
)
if
self
.
attention_backend
==
"trtllm_mha"
:
if
(
self
.
attention_backend
==
"trtllm_mha"
or
self
.
decode_attention_backend
==
"trtllm_mha"
or
self
.
prefill_attention_backend
==
"trtllm_mha"
):
if
not
is_sm100_supported
():
raise
ValueError
(
"TRTLLM MHA backend is only supported on Blackwell GPUs (SM100). Please use a different backend."
...
...
@@ -459,11 +463,18 @@ class ServerArgs:
if
self
.
speculative_algorithm
is
not
None
:
raise
ValueError
(
"trtllm_m
l
a backend does not support speculative decoding yet."
"trtllm_m
h
a backend does not support speculative decoding yet."
)
model_arch
=
self
.
get_hf_config
().
architectures
[
0
]
if
model_arch
in
[
"GptOssForCausalLM"
]:
self
.
attention_backend
=
"triton"
if
self
.
attention_backend
is
None
:
# default is triton, but we could have trtllm_mha as an option
self
.
attention_backend
=
"triton"
assert
(
self
.
attention_backend
==
"trtllm_mha"
or
self
.
attention_backend
==
"triton"
)
# Check if FlashInfer MXFP4 MoE is enabled
from
sglang.srt.utils
import
get_bool_env_var
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment