Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
6ad6c8c9
Unverified
Commit
6ad6c8c9
authored
Aug 06, 2025
by
eigen
Committed by
GitHub
Aug 06, 2025
Browse files
feat: openai oss attention sink support with trtllm-gen backend #8825 (#8834)
Co-authored-by:
averyhuang
<
averyh@nvidia.com
>
parent
5b6acc14
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
37 additions
and
22 deletions
+37
-22
python/sglang/srt/layers/attention/trtllm_mha_backend.py
python/sglang/srt/layers/attention/trtllm_mha_backend.py
+19
-15
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+3
-3
python/sglang/srt/models/gpt_oss.py
python/sglang/srt/models/gpt_oss.py
+1
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+14
-3
No files found.
python/sglang/srt/layers/attention/trtllm_mha_backend.py
View file @
6ad6c8c9
from
__future__
import
annotations
from
__future__
import
annotations
"""
"""
Support attention backend for TRTLLM MLA kernels from flashinfer.
Support attention backend for TRTLLM MHA kernels from flashinfer.
The kernel supports sm100 only, with sliding window and attention sink features.
"""
"""
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
...
@@ -57,11 +58,6 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
...
@@ -57,11 +58,6 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
# MHA-specific dimensions
# MHA-specific dimensions
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
self
.
sliding_window_size
=
(
model_runner
.
sliding_window_size
if
model_runner
.
sliding_window_size
is
not
None
else
-
1
# -1 indicates full attention
)
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
# Runtime parameters
# Runtime parameters
...
@@ -117,10 +113,10 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
...
@@ -117,10 +113,10 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
metadata
=
TRTLLMMHAMetadata
()
metadata
=
TRTLLMMHAMetadata
()
# Get sequence information
# Get sequence information
metadata
.
cache_seqlens_int32
=
seq_lens
.
to
(
torch
.
int32
)
metadata
.
cache_seqlens_int32
=
seq_lens
[:
bs
]
.
to
(
torch
.
int32
)
# Precompute maximum sequence length
# Precompute maximum sequence length
metadata
.
max_seq_len_k
=
se
q_lens
.
max
().
item
()
metadata
.
max_seq_len_k
=
se
lf
.
max_context_len
# Precompute page table
# Precompute page table
metadata
.
page_table
=
self
.
decode_cuda_graph_metadata
[
"page_table"
][:
bs
,
:]
metadata
.
page_table
=
self
.
decode_cuda_graph_metadata
[
"page_table"
][:
bs
,
:]
...
@@ -149,7 +145,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
...
@@ -149,7 +145,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
metadata
=
self
.
decode_cuda_graph_metadata
[
bs
]
metadata
=
self
.
decode_cuda_graph_metadata
[
bs
]
max_len
=
seq_lens_cpu
.
max
().
item
()
max_len
=
seq_lens_cpu
.
max
().
item
()
max_seq_pages
=
(
max_len
+
self
.
page_size
-
1
)
//
self
.
page_size
max_seq_pages
=
(
max_len
+
self
.
page_size
-
1
)
//
self
.
page_size
metadata
.
max_seq_len_k
=
max
_len
metadata
.
max_seq_len_k
=
self
.
max_context
_len
metadata
.
cache_seqlens_int32
.
copy_
(
seq_lens
)
metadata
.
cache_seqlens_int32
.
copy_
(
seq_lens
)
page_indices
=
self
.
req_to_token
[
page_indices
=
self
.
req_to_token
[
...
@@ -217,6 +213,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
...
@@ -217,6 +213,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
layer
:
RadixAttention
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
=
True
,
save_kv_cache
:
bool
=
True
,
**
kwargs
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Run forward for decode using TRTLLM MHA kernel."""
"""Run forward for decode using TRTLLM MHA kernel."""
cache_loc
=
forward_batch
.
out_cache_loc
cache_loc
=
forward_batch
.
out_cache_loc
...
@@ -228,7 +225,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
...
@@ -228,7 +225,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
k_cache
,
v_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
)
k_cache
,
v_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
)
# shape conversion:
# shape conversion:
# [
b
s, page_size, num_kv_heads, head_dim] -> [
b
s, num_kv_heads, page_size, head_dim]
# [
num_page
s, page_size, num_kv_heads, head_dim] -> [
num_page
s, num_kv_heads, page_size, head_dim]
k_cache
=
k_cache
.
view
(
k_cache
=
k_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
,
layer
.
head_dim
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
permute
(
0
,
2
,
1
,
3
)
).
permute
(
0
,
2
,
1
,
3
)
...
@@ -237,7 +234,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
...
@@ -237,7 +234,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
).
permute
(
0
,
2
,
1
,
3
)
).
permute
(
0
,
2
,
1
,
3
)
kv_cache
=
(
k_cache
,
v_cache
)
kv_cache
=
(
k_cache
,
v_cache
)
# TODO:
bmm1_scale and bmm2_scale might require modific
ation
# TODO:
add support for quantiz
ation
q_scale
=
1.0
q_scale
=
1.0
k_scale
=
(
k_scale
=
(
layer
.
k_scale_float
layer
.
k_scale_float
...
@@ -246,6 +243,8 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
...
@@ -246,6 +243,8 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
)
)
bmm1_scale
=
q_scale
*
k_scale
*
layer
.
scaling
bmm1_scale
=
q_scale
*
k_scale
*
layer
.
scaling
bmm2_scale
=
1.0
bmm2_scale
=
1.0
# sink: additional value per head in the denominator of the softmax.
attention_sink
=
kwargs
.
get
(
"sinks"
,
None
)
# Call TRT-LLM kernel
# Call TRT-LLM kernel
# raw_out: like q, [bs, acc_q_len, num_q_heads, head_dim] but with output dtype
# raw_out: like q, [bs, acc_q_len, num_q_heads, head_dim] but with output dtype
...
@@ -258,8 +257,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
...
@@ -258,8 +257,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
max_seq_len
=
self
.
forward_metadata
.
max_seq_len_k
,
max_seq_len
=
self
.
forward_metadata
.
max_seq_len_k
,
bmm1_scale
=
bmm1_scale
,
bmm1_scale
=
bmm1_scale
,
bmm2_scale
=
bmm2_scale
,
bmm2_scale
=
bmm2_scale
,
window_left
=
self
.
sliding_window_size
,
window_left
=
layer
.
sliding_window_size
,
# TODO: add attention_sink operation or nvfp4 scale factor if needed
# TODO: add attention_sink operation or nvfp4 scale factor if needed
sinks
=
attention_sink
,
)
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
...
@@ -272,6 +272,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
...
@@ -272,6 +272,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
layer
:
RadixAttention
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
save_kv_cache
=
True
,
**
kwargs
,
):
):
cache_loc
=
forward_batch
.
out_cache_loc
cache_loc
=
forward_batch
.
out_cache_loc
if
save_kv_cache
and
k
is
not
None
:
if
save_kv_cache
and
k
is
not
None
:
...
@@ -279,6 +280,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
...
@@ -279,6 +280,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
)
)
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
# [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim]
k_cache
,
v_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
)
k_cache
,
v_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
)
k_cache
=
k_cache
.
view
(
k_cache
=
k_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
,
layer
.
head_dim
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
,
layer
.
head_dim
...
@@ -288,8 +290,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
...
@@ -288,8 +290,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
).
permute
(
0
,
2
,
1
,
3
)
).
permute
(
0
,
2
,
1
,
3
)
kv_cache
=
(
k_cache
,
v_cache
)
kv_cache
=
(
k_cache
,
v_cache
)
# TODO: bmm1_scale and bmm2_scale might require modification
# sink: additional value per head in the denominator of the softmax.
# TODO: Change once quantization is supported
attention_sink
=
kwargs
.
get
(
"sinks"
,
None
)
# TODO: add support for quantization
q_scale
=
1.0
q_scale
=
1.0
k_scale
=
(
k_scale
=
(
layer
.
k_scale_float
layer
.
k_scale_float
...
@@ -312,8 +315,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
...
@@ -312,8 +315,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
batch_size
=
forward_batch
.
batch_size
,
batch_size
=
forward_batch
.
batch_size
,
cum_seq_lens_q
=
self
.
forward_metadata
.
cu_seqlens_q
,
cum_seq_lens_q
=
self
.
forward_metadata
.
cu_seqlens_q
,
cum_seq_lens_kv
=
self
.
forward_metadata
.
cu_seqlens_k
,
cum_seq_lens_kv
=
self
.
forward_metadata
.
cu_seqlens_k
,
window_left
=
self
.
sliding_window_size
,
window_left
=
layer
.
sliding_window_size
,
# TODO: add attention_sink operation or nvfp4 scale factor if needed
# TODO: add attention_sink operation or nvfp4 scale factor if needed
sinks
=
attention_sink
,
)
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
python/sglang/srt/model_executor/model_runner.py
View file @
6ad6c8c9
...
@@ -1443,13 +1443,13 @@ class ModelRunner:
...
@@ -1443,13 +1443,13 @@ class ModelRunner:
)
)
return
CutlassMLABackend
(
self
)
return
CutlassMLABackend
(
self
)
elif
self
.
server_args
.
attention_
backend
==
"trtllm_mla"
:
elif
backend
_str
==
"trtllm_mla"
:
if
not
self
.
use_mla_backend
:
if
not
self
.
use_mla_backend
:
raise
ValueError
(
"trtllm_mla backend can only be used with MLA models."
)
raise
ValueError
(
"trtllm_mla backend can only be used with MLA models."
)
from
sglang.srt.layers.attention.trtllm_mla_backend
import
TRTLLMMLABackend
from
sglang.srt.layers.attention.trtllm_mla_backend
import
TRTLLMMLABackend
return
TRTLLMMLABackend
(
self
)
return
TRTLLMMLABackend
(
self
)
elif
self
.
server_args
.
attention_
backend
==
"trtllm_mha"
:
elif
backend
_str
==
"trtllm_mha"
:
if
self
.
use_mla_backend
:
if
self
.
use_mla_backend
:
raise
ValueError
(
raise
ValueError
(
"trtllm_mha backend can only be used with non-MLA models."
"trtllm_mha backend can only be used with non-MLA models."
...
@@ -1460,7 +1460,7 @@ class ModelRunner:
...
@@ -1460,7 +1460,7 @@ class ModelRunner:
return
TRTLLMHAAttnBackend
(
self
)
return
TRTLLMHAAttnBackend
(
self
)
elif
self
.
server_args
.
attention_
backend
==
"intel_amx"
:
elif
backend
_str
==
"intel_amx"
:
from
sglang.srt.layers.attention.intel_amx_backend
import
(
from
sglang.srt.layers.attention.intel_amx_backend
import
(
IntelAMXAttnBackend
,
IntelAMXAttnBackend
,
)
)
...
...
python/sglang/srt/models/gpt_oss.py
View file @
6ad6c8c9
...
@@ -301,7 +301,7 @@ class GptOssAttention(nn.Module):
...
@@ -301,7 +301,7 @@ class GptOssAttention(nn.Module):
hidden_states
,
forward_batch
,
inner_state
=
intermediate_state
hidden_states
,
forward_batch
,
inner_state
=
intermediate_state
if
inner_state
is
None
:
if
inner_state
is
None
:
return
hidden_states
return
hidden_states
attn_output
=
self
.
attn
(
*
inner_state
,
sinks
=
self
.
sinks
)
attn_output
=
self
.
attn
(
*
inner_state
,
sinks
=
self
.
sinks
.
to
(
torch
.
float32
)
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
...
python/sglang/srt/server_args.py
View file @
6ad6c8c9
...
@@ -445,7 +445,11 @@ class ServerArgs:
...
@@ -445,7 +445,11 @@ class ServerArgs:
"trtllm_mla backend does not support speculative decoding yet."
"trtllm_mla backend does not support speculative decoding yet."
)
)
if
self
.
attention_backend
==
"trtllm_mha"
:
if
(
self
.
attention_backend
==
"trtllm_mha"
or
self
.
decode_attention_backend
==
"trtllm_mha"
or
self
.
prefill_attention_backend
==
"trtllm_mha"
):
if
not
is_sm100_supported
():
if
not
is_sm100_supported
():
raise
ValueError
(
raise
ValueError
(
"TRTLLM MHA backend is only supported on Blackwell GPUs (SM100). Please use a different backend."
"TRTLLM MHA backend is only supported on Blackwell GPUs (SM100). Please use a different backend."
...
@@ -459,11 +463,18 @@ class ServerArgs:
...
@@ -459,11 +463,18 @@ class ServerArgs:
if
self
.
speculative_algorithm
is
not
None
:
if
self
.
speculative_algorithm
is
not
None
:
raise
ValueError
(
raise
ValueError
(
"trtllm_m
l
a backend does not support speculative decoding yet."
"trtllm_m
h
a backend does not support speculative decoding yet."
)
)
model_arch
=
self
.
get_hf_config
().
architectures
[
0
]
model_arch
=
self
.
get_hf_config
().
architectures
[
0
]
if
model_arch
in
[
"GptOssForCausalLM"
]:
if
model_arch
in
[
"GptOssForCausalLM"
]:
self
.
attention_backend
=
"triton"
if
self
.
attention_backend
is
None
:
# default is triton, but we could have trtllm_mha as an option
self
.
attention_backend
=
"triton"
assert
(
self
.
attention_backend
==
"trtllm_mha"
or
self
.
attention_backend
==
"triton"
)
# Check if FlashInfer MXFP4 MoE is enabled
# Check if FlashInfer MXFP4 MoE is enabled
from
sglang.srt.utils
import
get_bool_env_var
from
sglang.srt.utils
import
get_bool_env_var
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment