Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
40e3b2be
Unverified
Commit
40e3b2be
authored
Aug 05, 2025
by
eigen
Committed by
GitHub
Aug 05, 2025
Browse files
feat: add trtllm-gen mha from direct call (#8782)
Co-authored-by:
Baizhou Zhang
<
sobereddiezhang@gmail.com
>
parent
75df31b6
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
351 additions
and
0 deletions
+351
-0
python/sglang/srt/layers/attention/trtllm_mha_backend.py
python/sglang/srt/layers/attention/trtllm_mha_backend.py
+321
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+11
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+18
-0
No files found.
python/sglang/srt/layers/attention/trtllm_mha_backend.py
0 → 100644
View file @
40e3b2be
from
__future__
import
annotations
from
python.sglang.srt.layers.radix_attention
import
RadixAttention
"""
Support attention backend for TRTLLM MLA kernels from flashinfer.
"""
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
from
sglang.srt.layers.attention.flashinfer_backend
import
FlashInferAttnBackend
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.utils
import
is_flashinfer_available
if
is_flashinfer_available
():
import
flashinfer
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.spec_info
import
SpecInfo
# Constants
DEFAULT_WORKSPACE_SIZE_MB
=
128
# Memory workspace size in MB
@
dataclass
class
TRTLLMMHAMetadata
:
# Sequence lengths for the forward batch
cache_seqlens_int32
:
torch
.
Tensor
=
None
# Maximum sequence length for query
max_seq_len_q
:
int
=
1
# Maximum sequence length for key
max_seq_len_k
:
int
=
0
# Cumulative sequence lengths for `query
cu_seqlens_q
:
torch
.
Tensor
=
None
# Cumulative sequence lengths for key
cu_seqlens_k
:
torch
.
Tensor
=
None
# Page table, the index of KV Cache Tables/Blocks
page_table
:
torch
.
Tensor
=
None
class
TRTLLMHAAttnBackend
(
FlashInferAttnBackend
):
"""TRTLLM MHA attention kernel from flashinfer."""
def
__init__
(
self
,
model_runner
:
ModelRunner
,
skip_prefill
:
bool
=
False
,
kv_indptr_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
q_indptr_decode_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
):
super
().
__init__
(
model_runner
,
skip_prefill
,
kv_indptr_buf
,
q_indptr_decode_buf
)
config
=
model_runner
.
model_config
# MHA-specific dimensions
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
self
.
sliding_window_size
=
(
model_runner
.
sliding_window_size
if
model_runner
.
sliding_window_size
is
not
None
else
-
1
# -1 indicates full attention
)
self
.
hidden_size
=
config
.
hidden_size
# Runtime parameters
self
.
data_type
=
model_runner
.
kv_cache_dtype
self
.
q_data_type
=
model_runner
.
dtype
self
.
page_size
=
model_runner
.
page_size
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
device
=
model_runner
.
device
# Workspace allocation
self
.
workspace_size
=
DEFAULT_WORKSPACE_SIZE_MB
*
1024
*
1024
self
.
workspace_buffer
=
torch
.
empty
(
self
.
workspace_size
,
dtype
=
torch
.
int8
,
device
=
self
.
device
)
# CUDA graph state
self
.
decode_cuda_graph_metadata
=
{}
# Forward metadata
self
.
forward_metadata
:
Optional
[
TRTLLMMHAMetadata
]
=
None
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
,
kv_indices_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
):
"""Initialize CUDA graph state for TRTLLM MHA."""
self
.
decode_cuda_graph_metadata
=
{
"cache_seqlens"
:
torch
.
zeros
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"page_table"
:
torch
.
zeros
(
max_bs
,
(
self
.
max_context_len
+
self
.
page_size
-
1
)
//
self
.
page_size
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
),
"strided_indices"
:
torch
.
arange
(
0
,
self
.
max_context_len
,
self
.
page_size
,
device
=
self
.
device
),
}
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
num_tokens
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecInfo
],
):
"""Initialize metadata for CUDA graph capture."""
metadata
=
TRTLLMMHAMetadata
()
# Get sequence information
metadata
.
cache_seqlens_int32
=
seq_lens
.
to
(
torch
.
int32
)
# Precompute maximum sequence length
metadata
.
max_seq_len_k
=
seq_lens
.
max
().
item
()
# Precompute page table
metadata
.
page_table
=
self
.
decode_cuda_graph_metadata
[
"page_table"
][:
bs
,
:]
self
.
decode_cuda_graph_metadata
[
bs
]
=
metadata
self
.
forward_metadata
=
metadata
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecInfo
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
"""Replay CUDA graph with new inputs."""
seq_lens
=
seq_lens
[:
bs
]
seq_lens_cpu
=
seq_lens_cpu
[:
bs
]
req_pool_indices
=
req_pool_indices
[:
bs
]
device
=
seq_lens
.
device
metadata
=
None
# Normal Decode
metadata
=
self
.
decode_cuda_graph_metadata
[
bs
]
max_len
=
seq_lens_cpu
.
max
().
item
()
max_seq_pages
=
(
max_len
+
self
.
page_size
-
1
)
//
self
.
page_size
metadata
.
max_seq_len_k
=
max_len
metadata
.
cache_seqlens_int32
.
copy_
(
seq_lens
)
page_indices
=
self
.
req_to_token
[
req_pool_indices
[:,
None
],
self
.
decode_cuda_graph_metadata
[
"strided_indices"
][:
max_seq_pages
][
None
,
:],
]
metadata
.
page_table
[:,
:
max_seq_pages
].
copy_
(
page_indices
//
self
.
page_size
)
self
.
forward_metadata
=
metadata
def
get_cuda_graph_seq_len_fill_value
(
self
)
->
int
:
"""Get the fill value for sequence lengths in CUDA graph."""
return
1
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Initialize the metadata for a forward pass."""
metadata
=
TRTLLMMHAMetadata
()
seqlens_in_batch
=
forward_batch
.
seq_lens
batch_size
=
forward_batch
.
batch_size
device
=
seqlens_in_batch
.
device
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
# Normal Decode
metadata
.
cache_seqlens_int32
=
seqlens_in_batch
.
to
(
torch
.
int32
)
metadata
.
max_seq_len_k
=
forward_batch
.
seq_lens_cpu
.
max
().
item
()
metadata
.
page_table
=
forward_batch
.
req_to_token_pool
.
req_to_token
[
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
]
else
:
metadata
.
cache_seqlens_int32
=
seqlens_in_batch
.
to
(
torch
.
int32
)
metadata
.
max_seq_len_k
=
forward_batch
.
seq_lens_cpu
.
max
().
item
()
metadata
.
cu_seqlens_k
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
seqlens_in_batch
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
)
)
metadata
.
page_table
=
forward_batch
.
req_to_token_pool
.
req_to_token
[
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
]
if
any
(
forward_batch
.
extend_prefix_lens_cpu
):
extend_seq_lens
=
forward_batch
.
extend_seq_lens
metadata
.
max_seq_len_q
=
max
(
forward_batch
.
extend_seq_lens_cpu
)
metadata
.
cu_seqlens_q
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
extend_seq_lens
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
)
)
else
:
metadata
.
max_seq_len_q
=
metadata
.
max_seq_len_k
metadata
.
cu_seqlens_q
=
metadata
.
cu_seqlens_k
# Convert the page table to a strided format
if
self
.
page_size
>
1
:
self
.
strided_indices
=
torch
.
arange
(
0
,
metadata
.
page_table
.
shape
[
1
],
self
.
page_size
,
device
=
self
.
device
)
metadata
.
page_table
=
(
metadata
.
page_table
[:,
self
.
strided_indices
]
//
self
.
page_size
)
self
.
forward_metadata
=
metadata
def
forward_decode
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
=
True
,
)
->
torch
.
Tensor
:
"""Run forward for decode using TRTLLM MHA kernel."""
cache_loc
=
forward_batch
.
out_cache_loc
if
save_kv_cache
and
k
is
not
None
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
)
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
k_cache
,
v_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
)
# shape conversion:
# [bs, page_size, num_kv_heads, head_dim] -> [bs, num_kv_heads, page_size, head_dim]
k_cache
=
k_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
permute
(
0
,
2
,
1
,
3
)
v_cache
=
v_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
,
layer
.
head_dim
).
permute
(
0
,
2
,
1
,
3
)
kv_cache
=
(
k_cache
,
v_cache
)
# TODO: bmm1_scale and bmm2_scale might require modification
q_scale
=
1.0
k_scale
=
(
layer
.
k_scale_float
if
getattr
(
layer
,
"k_scale_float"
,
None
)
is
not
None
else
1.0
)
bmm1_scale
=
q_scale
*
k_scale
*
layer
.
scaling
bmm2_scale
=
1.0
# Call TRT-LLM kernel
# raw_out: like q, [bs, acc_q_len, num_q_heads, head_dim] but with output dtype
o
=
flashinfer
.
decode
.
trtllm_batch_decode_with_kv_cache
(
query
=
q
,
kv_cache
=
kv_cache
,
workspace_buffer
=
self
.
workspace_buffer
,
block_tables
=
self
.
forward_metadata
.
page_table
,
seq_lens
=
self
.
forward_metadata
.
cache_seqlens_int32
,
max_seq_len
=
self
.
forward_metadata
.
max_seq_len_k
,
bmm1_scale
=
bmm1_scale
,
bmm2_scale
=
bmm2_scale
,
window_left
=
self
.
sliding_window_size
,
# TODO: add attention_sink operation or nvfp4 scale factor if needed
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
def
forward_extend
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
):
cache_loc
=
forward_batch
.
out_cache_loc
if
save_kv_cache
and
k
is
not
None
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
)
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
k_cache
,
v_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
)
k_cache
=
k_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
permute
(
0
,
2
,
1
,
3
)
v_cache
=
v_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
,
layer
.
head_dim
).
permute
(
0
,
2
,
1
,
3
)
kv_cache
=
(
k_cache
,
v_cache
)
# TODO: bmm1_scale and bmm2_scale might require modification
# TODO: Change once quantization is supported
q_scale
=
1.0
k_scale
=
(
layer
.
k_scale_float
if
getattr
(
layer
,
"k_scale_float"
,
None
)
is
not
None
else
1.0
)
bmm1_scale
=
q_scale
*
k_scale
*
layer
.
scaling
bmm2_scale
=
1.0
o
=
flashinfer
.
prefill
.
trtllm_batch_context_with_kv_cache
(
query
=
q
,
kv_cache
=
kv_cache
,
workspace_buffer
=
self
.
workspace_buffer
,
block_tables
=
self
.
forward_metadata
.
page_table
,
seq_lens
=
self
.
forward_metadata
.
cache_seqlens_int32
,
max_q_len
=
self
.
forward_metadata
.
max_seq_len_q
,
max_kv_len
=
self
.
forward_metadata
.
max_seq_len_k
,
bmm1_scale
=
bmm1_scale
,
bmm2_scale
=
bmm2_scale
,
batch_size
=
forward_batch
.
batch_size
,
cum_seq_lens_q
=
self
.
forward_metadata
.
cu_seqlens_q
,
cum_seq_lens_kv
=
self
.
forward_metadata
.
cu_seqlens_k
,
window_left
=
self
.
sliding_window_size
,
# TODO: add attention_sink operation or nvfp4 scale factor if needed
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
python/sglang/srt/managers/schedule_batch.py
View file @
40e3b2be
...
...
@@ -1705,6 +1705,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
or
attention_backend_str
==
"flashmla"
or
attention_backend_str
==
"cutlass_mla"
or
attention_backend_str
==
"ascend"
or
attention_backend_str
==
"trtllm_mha"
or
global_server_args_dict
[
"enable_two_batch_overlap"
]
):
seq_lens_cpu
=
(
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
40e3b2be
...
...
@@ -1449,6 +1449,17 @@ class ModelRunner:
from
sglang.srt.layers.attention.trtllm_mla_backend
import
TRTLLMMLABackend
return
TRTLLMMLABackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"trtllm_mha"
:
if
self
.
use_mla_backend
:
raise
ValueError
(
"trtllm_mha backend can only be used with non-MLA models."
)
from
sglang.srt.layers.attention.trtllm_mha_backend
import
(
TRTLLMHAAttnBackend
,
)
return
TRTLLMHAAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"intel_amx"
:
from
sglang.srt.layers.attention.intel_amx_backend
import
(
IntelAMXAttnBackend
,
...
...
python/sglang/srt/server_args.py
View file @
40e3b2be
...
...
@@ -441,6 +441,23 @@ class ServerArgs:
"trtllm_mla backend does not support speculative decoding yet."
)
if
self
.
attention_backend
==
"trtllm_mha"
:
if
not
is_sm100_supported
():
raise
ValueError
(
"TRTLLM MHA backend is only supported on Blackwell GPUs (SM100). Please use a different backend."
)
if
self
.
page_size
not
in
[
16
,
32
,
64
]:
logger
.
warning
(
f
"TensorRT-LLM MHA only supports page_size of 16, 32 or 64, changing page_size from
{
self
.
page_size
}
to 64."
)
self
.
page_size
=
64
if
self
.
speculative_algorithm
is
not
None
:
raise
ValueError
(
"trtllm_mla backend does not support speculative decoding yet."
)
# Set page size
if
self
.
page_size
is
None
:
self
.
page_size
=
1
...
...
@@ -1275,6 +1292,7 @@ class ServerArgs:
"ascend"
,
"triton"
,
"trtllm_mla"
,
"trtllm_mha"
,
],
default
=
ServerArgs
.
attention_backend
,
help
=
"Choose the kernels for attention layers."
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment