Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
64574ef8
"tests/vscode:/vscode.git/clone" did not exist on "6620eda357132bcd034c8b5c239fa4527e150c35"
Unverified
Commit
64574ef8
authored
Aug 21, 2025
by
pranavm-nvidia
Committed by
GitHub
Aug 21, 2025
Browse files
Enables speculative decoding for the trtllm_mla attention backend (#9238)
parent
18da2c96
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
60 additions
and
21 deletions
+60
-21
python/sglang/srt/layers/attention/trtllm_mla_backend.py
python/sglang/srt/layers/attention/trtllm_mla_backend.py
+39
-16
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+0
-5
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+21
-0
No files found.
python/sglang/srt/layers/attention/trtllm_mla_backend.py
View file @
64574ef8
...
@@ -11,7 +11,10 @@ from typing import TYPE_CHECKING, Optional, Union
...
@@ -11,7 +11,10 @@ from typing import TYPE_CHECKING, Optional, Union
import
torch
import
torch
import
triton
import
triton
from
sglang.srt.layers.attention.flashinfer_mla_backend
import
FlashInferMLAAttnBackend
from
sglang.srt.layers.attention.flashinfer_mla_backend
import
(
FlashInferMLAAttnBackend
,
FlashInferMLAMultiStepDraftBackend
,
)
from
sglang.srt.layers.attention.utils
import
(
from
sglang.srt.layers.attention.utils
import
(
TRITON_PAD_NUM_PAGE_PER_BLOCK
,
TRITON_PAD_NUM_PAGE_PER_BLOCK
,
create_flashmla_kv_indices_triton
,
create_flashmla_kv_indices_triton
,
...
@@ -96,7 +99,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -96,7 +99,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
# CUDA graph state
# CUDA graph state
self
.
decode_cuda_graph_metadata
=
{}
self
.
decode_cuda_graph_metadata
=
{}
self
.
cuda_graph_kv_indices
=
None
self
.
decode_
cuda_graph_kv_indices
=
None
self
.
forward_metadata
:
Union
[
TRTLLMMLADecodeMetadata
,
None
]
=
None
self
.
forward_metadata
:
Union
[
TRTLLMMLADecodeMetadata
,
None
]
=
None
def
_calc_padded_blocks
(
self
,
max_seq_len
:
int
)
->
int
:
def
_calc_padded_blocks
(
self
,
max_seq_len
:
int
)
->
int
:
...
@@ -167,15 +170,18 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -167,15 +170,18 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
kv_indices_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_indices_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
"""Initialize CUDA graph state for TRTLLM MLA."""
"""Initialize CUDA graph state for TRTLLM MLA."""
max_blocks_per_seq
=
self
.
_calc_padded_blocks
(
self
.
max_context_len
)
max_blocks_per_seq
=
self
.
_calc_padded_blocks
(
self
.
max_context_len
)
self
.
cuda_graph_kv_indices
=
torch
.
full
(
self
.
decode_
cuda_graph_kv_indices
=
torch
.
full
(
(
max_bs
,
max_blocks_per_seq
),
-
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
(
max_bs
,
max_blocks_per_seq
),
-
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
)
self
.
cuda_graph_workspace
=
torch
.
empty
(
self
.
decode_
cuda_graph_workspace
=
torch
.
empty
(
self
.
workspace_size
,
dtype
=
torch
.
int8
,
device
=
self
.
device
self
.
workspace_size
,
dtype
=
torch
.
int8
,
device
=
self
.
device
)
)
super
().
init_cuda_graph_state
(
max_bs
,
max_num_tokens
,
kv_indices_buf
)
def
init_forward_metadata_capture_cuda_graph
(
def
init_forward_metadata_capture_cuda_graph
(
self
,
self
,
bs
:
int
,
bs
:
int
,
...
@@ -187,8 +193,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -187,8 +193,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
SpecInfo
],
):
):
"""Initialize metadata for CUDA graph capture."""
"""Initialize metadata for CUDA graph capture."""
# Delegate to parent for non-decode modes or when speculative execution is used.
if
not
(
forward_mode
.
is_decode_or_idle
()
and
spec_info
is
None
):
# Delegate to parent for non-decode modes.
if
not
forward_mode
.
is_decode_or_idle
():
return
super
().
init_forward_metadata_capture_cuda_graph
(
return
super
().
init_forward_metadata_capture_cuda_graph
(
bs
,
bs
,
num_tokens
,
num_tokens
,
...
@@ -199,9 +206,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -199,9 +206,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
spec_info
,
spec_info
,
)
)
# Custom fast-path for decode/idle
without speculative execution
.
# Custom fast-path for decode/idle.
max_seqlen_pad
=
self
.
_calc_padded_blocks
(
seq_lens
.
max
().
item
())
max_seqlen_pad
=
self
.
_calc_padded_blocks
(
seq_lens
.
max
().
item
())
block_kv_indices
=
self
.
cuda_graph_kv_indices
[:
bs
,
:
max_seqlen_pad
]
block_kv_indices
=
self
.
decode_
cuda_graph_kv_indices
[:
bs
,
:
max_seqlen_pad
]
create_flashmla_kv_indices_triton
[(
bs
,)](
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
self
.
req_to_token
,
...
@@ -215,7 +222,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -215,7 +222,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
PAGED_SIZE
=
self
.
page_size
,
PAGED_SIZE
=
self
.
page_size
,
)
)
metadata
=
TRTLLMMLADecodeMetadata
(
self
.
cuda_graph_workspace
,
block_kv_indices
)
metadata
=
TRTLLMMLADecodeMetadata
(
self
.
decode_cuda_graph_workspace
,
block_kv_indices
)
self
.
decode_cuda_graph_metadata
[
bs
]
=
metadata
self
.
decode_cuda_graph_metadata
[
bs
]
=
metadata
self
.
forward_metadata
=
metadata
self
.
forward_metadata
=
metadata
...
@@ -231,8 +240,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -231,8 +240,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
):
"""Replay CUDA graph with new inputs."""
"""Replay CUDA graph with new inputs."""
# Delegate to parent for non-decode modes
or when speculative execution is used
.
# Delegate to parent for non-decode modes.
if
not
(
forward_mode
.
is_decode_or_idle
()
and
spec_info
is
None
)
:
if
not
forward_mode
.
is_decode_or_idle
():
return
super
().
init_forward_metadata_replay_cuda_graph
(
return
super
().
init_forward_metadata_replay_cuda_graph
(
bs
,
bs
,
req_pool_indices
,
req_pool_indices
,
...
@@ -265,11 +274,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -265,11 +274,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Initialize the metadata for a forward pass."""
"""Initialize the metadata for a forward pass."""
# Delegate to parent for non-decode modes or when speculative execution is used.
# Delegate to parent for non-decode modes.
if
not
(
if
not
forward_batch
.
forward_mode
.
is_decode_or_idle
():
forward_batch
.
forward_mode
.
is_decode_or_idle
()
and
forward_batch
.
spec_info
is
None
):
return
super
().
init_forward_metadata
(
forward_batch
)
return
super
().
init_forward_metadata
(
forward_batch
)
bs
=
forward_batch
.
batch_size
bs
=
forward_batch
.
batch_size
...
@@ -474,3 +480,20 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -474,3 +480,20 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
output
=
raw_out_v
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
output
=
raw_out_v
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
return
output
return
output
class
TRTLLMMLAMultiStepDraftBackend
(
FlashInferMLAMultiStepDraftBackend
):
"""Multi-step draft backend for TRT-LLM MLA used by EAGLE."""
def
__init__
(
self
,
model_runner
:
"ModelRunner"
,
topk
:
int
,
speculative_num_steps
:
int
):
super
().
__init__
(
model_runner
,
topk
,
speculative_num_steps
)
for
i
in
range
(
self
.
speculative_num_steps
):
self
.
attn_backends
[
i
]
=
TRTLLMMLABackend
(
model_runner
,
skip_prefill
=
True
,
kv_indptr_buf
=
self
.
kv_indptr
[
i
],
q_indptr_decode_buf
=
self
.
q_indptr_decode
,
)
python/sglang/srt/server_args.py
View file @
64574ef8
...
@@ -479,11 +479,6 @@ class ServerArgs:
...
@@ -479,11 +479,6 @@ class ServerArgs:
)
)
self
.
page_size
=
64
self
.
page_size
=
64
if
self
.
speculative_algorithm
is
not
None
:
raise
ValueError
(
"trtllm_mla backend does not support speculative decoding yet."
)
if
self
.
kv_cache_dtype
not
in
[
"fp8_e4m3"
,
"auto"
]:
if
self
.
kv_cache_dtype
not
in
[
"fp8_e4m3"
,
"auto"
]:
raise
ValueError
(
raise
ValueError
(
"TensorRT-LLM MLA backend only supports kv-cache-dtype of fp8_e4m3 or auto."
"TensorRT-LLM MLA backend only supports kv-cache-dtype of fp8_e4m3 or auto."
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
64574ef8
...
@@ -266,6 +266,27 @@ class EAGLEWorker(TpModelWorker):
...
@@ -266,6 +266,27 @@ class EAGLEWorker(TpModelWorker):
self
.
topk
,
self
.
topk
,
self
.
speculative_num_steps
,
self
.
speculative_num_steps
,
)
)
elif
self
.
server_args
.
attention_backend
==
"trtllm_mla"
:
if
not
global_server_args_dict
[
"use_mla_backend"
]:
raise
ValueError
(
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
)
from
sglang.srt.layers.attention.trtllm_mla_backend
import
(
TRTLLMMLABackend
,
TRTLLMMLAMultiStepDraftBackend
,
)
self
.
draft_attn_backend
=
TRTLLMMLAMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
,
)
self
.
draft_extend_attn_backend
=
TRTLLMMLABackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
,
)
self
.
has_prefill_wrapper_verify
=
True
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"EAGLE is not supported in attention backend
{
self
.
server_args
.
attention_backend
}
"
f
"EAGLE is not supported in attention backend
{
self
.
server_args
.
attention_backend
}
"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment