Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
cd0b5891
Commit
cd0b5891
authored
Nov 10, 2025
by
linhai1
Browse files
refer to flashmla to add decode backend.
parent
d629db06
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
184 additions
and
37 deletions
+184
-37
python/sglang/srt/layers/attention/dcu_mla_backend.py
python/sglang/srt/layers/attention/dcu_mla_backend.py
+155
-33
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+11
-0
python/sglang/srt/speculative/draft_utils.py
python/sglang/srt/speculative/draft_utils.py
+18
-4
No files found.
python/sglang/srt/layers/attention/dcu_mla_backend.py
View file @
cd0b5891
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
from
__future__
import
annotations
from
__future__
import
annotations
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Callable
,
Optional
,
Tuple
,
Union
import
torch
import
torch
import
triton
import
triton
...
@@ -47,6 +47,16 @@ class VllmMLADecodeMetadata:
...
@@ -47,6 +47,16 @@ class VllmMLADecodeMetadata:
num_splits
:
Optional
[
torch
.
Tensor
]
=
None
num_splits
:
Optional
[
torch
.
Tensor
]
=
None
block_kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
block_kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
def
__init__
(
self
,
flashmla_metadata
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
,
num_splits
:
Optional
[
torch
.
Tensor
]
=
None
,
block_kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
):
self
.
flashmla_metadata
=
flashmla_metadata
self
.
num_splits
=
num_splits
self
.
block_kv_indices
=
block_kv_indices
class
DCUMLABackend
(
AttentionBackend
):
class
DCUMLABackend
(
AttentionBackend
):
def
__init__
(
def
__init__
(
...
@@ -92,22 +102,25 @@ class DCUMLABackend(AttentionBackend):
...
@@ -92,22 +102,25 @@ class DCUMLABackend(AttentionBackend):
skip_prefill
=
False
,
skip_prefill
=
False
,
)
)
def
_build_decode_metadata
(
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
self
,
forward_batch
:
ForwardBatch
,
seq_lens
:
torch
.
Tensor
)
->
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
]:
bs
=
forward_batch
.
batch_size
bs
=
forward_batch
.
batch_size
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens
.
max
().
item
(),
PAGE_SIZE
)
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
max_seqlen_pad
=
triton
.
cdiv
(
forward_batch
.
seq_lens_cpu
.
max
().
item
(),
PAGE_SIZE
)
block_kv_indices
=
torch
.
full
(
block_kv_indices
=
torch
.
full
(
(
bs
,
max_seqlen_pad
),
-
1
,
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
(
bs
,
max_seqlen_pad
),
-
1
,
dtype
=
torch
.
int32
,
device
=
forward_batch
.
seq_lens
.
device
)
)
create_flashmla_kv_indices_triton
[(
bs
,)](
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
seq_lens
,
forward_batch
.
seq_lens
,
None
,
None
,
block_kv_indices
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
...
@@ -115,26 +128,44 @@ class DCUMLABackend(AttentionBackend):
...
@@ -115,26 +128,44 @@ class DCUMLABackend(AttentionBackend):
)
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
self
.
num_q_heads
,
1
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
)
self
.
num_q_heads
,
return
(
mla_metadata
,
num_splits
),
num_splits
,
block_kv_indices
1
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
(
mla_metadata
,
num_splits
),
num_splits_t
,
block_kv_indices
=
(
self
.
_build_decode_metadata
(
forward_batch
,
forward_batch
.
seq_lens
)
)
)
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
mla_metadata
,
num_splits_t
,
block_kv_indices
mla_metadata
,
num_splits
,
block_kv_indices
)
)
elif
forward_batch
.
forward_mode
.
is_target_verify
():
elif
forward_batch
.
forward_mode
.
is_target_verify
():
seq_lens_cpu
=
forward_batch
.
seq_lens_cpu
+
self
.
num_draft_tokens
seq_lens
=
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
seq_lens
=
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
(
mla_metadata
,
num_splits
),
num_splits_t
,
block_kv_indices
=
(
self
.
_build_decode_metadata
(
forward_batch
,
seq_lens
)
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens_cpu
.
max
().
item
(),
PAGE_SIZE
)
block_kv_indices
=
torch
.
full
(
(
bs
,
max_seqlen_pad
),
-
1
,
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
,
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
self
.
num_draft_tokens
*
self
.
num_q_heads
,
1
,
)
)
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
mla_metadata
,
num_splits_t
,
block_kv_indices
mla_metadata
,
num_splits
,
block_kv_indices
)
)
else
:
else
:
if
not
self
.
skip_prefill
:
if
not
self
.
skip_prefill
:
...
@@ -450,4 +481,95 @@ class DCUMLABackend(AttentionBackend):
...
@@ -450,4 +481,95 @@ class DCUMLABackend(AttentionBackend):
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
class
DCUMLAMultiStepDraftBackend
:
"""
Wrap multiple flashmla attention backends as one for multiple consecutive
draft decoding steps.
"""
def
__init__
(
self
,
model_runner
:
ModelRunner
,
topk
:
int
,
speculative_num_steps
:
int
,
):
if
topk
>
1
:
raise
ValueError
(
"Currently FlashMLA only supports topk=1 for speculative decoding"
)
self
.
topk
=
topk
self
.
speculative_num_steps
=
speculative_num_steps
max_bs
=
model_runner
.
req_to_token_pool
.
size
*
self
.
topk
self
.
kv_indptr
=
torch
.
zeros
(
(
self
.
speculative_num_steps
,
max_bs
+
1
,
),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
,
)
self
.
attn_backends
=
[]
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
.
append
(
DCUMLABackend
(
model_runner
,
skip_prefill
=
True
,
kv_indptr_buf
=
self
.
kv_indptr
[
i
],
kv_last_page_len_buf
=
None
,
)
)
def
common_template
(
self
,
forward_batch
:
ForwardBatch
,
call_fn
:
Callable
,
):
assert
forward_batch
.
spec_info
is
not
None
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
call_fn
(
i
,
forward_batch
)
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
call_fn
(
i
,
forward_batch
):
assert
forward_batch
.
spec_info
is
not
None
self
.
attn_backends
[
i
].
init_forward_metadata
(
forward_batch
)
self
.
common_template
(
forward_batch
,
call_fn
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
max_bs
,
max_num_tokens
,
block_kv_indices
=
None
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
):
def
call_fn
(
i
,
forward_batch
):
self
.
attn_backends
[
i
].
init_forward_metadata_capture_cuda_graph
(
forward_batch
.
batch_size
,
forward_batch
.
batch_size
*
self
.
topk
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
encoder_lens
=
None
,
forward_mode
=
ForwardMode
.
DECODE
,
spec_info
=
forward_batch
.
spec_info
,
)
self
.
common_template
(
forward_batch
,
call_fn
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
,
bs
:
int
):
def
call_fn
(
i
,
forward_batch
):
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
bs
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
seq_lens_sum
=-
1
,
encoder_lens
=
None
,
forward_mode
=
ForwardMode
.
DECODE
,
spec_info
=
forward_batch
.
spec_info
,
seq_lens_cpu
=
forward_batch
.
seq_lens_cpu
,
)
self
.
common_template
(
forward_batch
,
call_fn
)
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
cd0b5891
...
@@ -871,9 +871,16 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -871,9 +871,16 @@ class FlashAttentionBackend(AttentionBackend):
return_softmax_lse
=
forward_batch
.
mha_return_lse
,
return_softmax_lse
=
forward_batch
.
mha_return_lse
,
**
kwargs
,
**
kwargs
,
)
)
# if layer.layer_id == 0:
# print('### mha output, q, k, v', output.shape, q.shape, k.shape, v.shape)
#torch.Size([136, 16, 128]) torch.Size([136, 16, 192]) torch.Size([136, 16, 192]) torch.Size([136, 16, 128])
#torch.Size([7, 16, 128]) torch.Size([7, 16, 192]) torch.Size([7, 16, 192]) torch.Size([7, 16, 128])
#torch.Size([40, 16, 128]) torch.Size([40, 16, 192]) torch.Size([40, 16, 192]) torch.Size([40, 16, 128])
if
forward_batch
.
mha_return_lse
:
if
forward_batch
.
mha_return_lse
:
output
,
lse
,
*
rest
=
output
output
,
lse
,
*
rest
=
output
lse
=
torch
.
transpose
(
lse
,
0
,
1
).
contiguous
()
lse
=
torch
.
transpose
(
lse
,
0
,
1
).
contiguous
()
# if layer.layer_id == 0:
# print('###output, lse, q, k, v', output.shape, lse.shape, q.shape, k.shape, v.shape)
return
output
,
lse
return
output
,
lse
return
output
return
output
else
:
else
:
...
@@ -921,6 +928,10 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -921,6 +928,10 @@ class FlashAttentionBackend(AttentionBackend):
return_softmax_lse
=
use_cascade_attn
,
return_softmax_lse
=
use_cascade_attn
,
num_splits
=
self
.
num_splits
,
num_splits
=
self
.
num_splits
,
)
)
# if layer.layer_id == 0:
# print('### mla output, q, k, v', result.shape, q_rope.shape, k_rope_cache.shape, c_kv_cache.shape)
#torch.Size([8, 16, 512]) torch.Size([8, 16, 64]) torch.Size([3318, 64, 1, 64]) torch.Size([3318, 64, 1, 512])
#torch.Size([286, 16, 512]) torch.Size([286, 16, 64]) torch.Size([3322, 64, 1, 64]) torch.Size([3322, 64, 1, 512])
if
use_cascade_attn
:
if
use_cascade_attn
:
o
,
softmax_lse
,
*
rest
=
result
o
,
softmax_lse
,
*
rest
=
result
o_expand
,
softmax_lse_expand
,
*
rest_expand
=
(
o_expand
,
softmax_lse_expand
,
*
rest_expand
=
(
...
...
python/sglang/srt/speculative/draft_utils.py
View file @
cd0b5891
...
@@ -27,10 +27,7 @@ class DraftBackendFactory:
...
@@ -27,10 +27,7 @@ class DraftBackendFactory:
backend_type
=
self
.
server_args
.
attention_backend
backend_type
=
self
.
server_args
.
attention_backend
if
backend_type
not
in
backend_map
:
if
backend_type
not
in
backend_map
:
if
backend_type
!=
"dcu_mla"
:
raise
ValueError
(
error_template
.
format
(
backend_type
=
backend_type
))
raise
ValueError
(
error_template
.
format
(
backend_type
=
backend_type
))
else
:
return
backend_map
[
"fa3"
]()
return
backend_map
[
backend_type
]()
return
backend_map
[
backend_type
]()
...
@@ -49,6 +46,7 @@ class DraftBackendFactory:
...
@@ -49,6 +46,7 @@ class DraftBackendFactory:
else
self
.
_create_triton_decode_backend
else
self
.
_create_triton_decode_backend
),
),
"flashmla"
:
self
.
_create_flashmla_decode_backend
,
"flashmla"
:
self
.
_create_flashmla_decode_backend
,
"dcu_mla"
:
self
.
_create_dcumla_decode_backend
,
"trtllm_mha"
:
self
.
_create_trtllm_mha_decode_backend
,
"trtllm_mha"
:
self
.
_create_trtllm_mha_decode_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_decode_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_decode_backend
,
"nsa"
:
self
.
_create_nsa_decode_backend
,
"nsa"
:
self
.
_create_nsa_decode_backend
,
...
@@ -72,6 +70,7 @@ class DraftBackendFactory:
...
@@ -72,6 +70,7 @@ class DraftBackendFactory:
else
self
.
_create_triton_prefill_backend
else
self
.
_create_triton_prefill_backend
),
),
"flashmla"
:
self
.
_create_flashmla_prefill_backend
,
"flashmla"
:
self
.
_create_flashmla_prefill_backend
,
"dcu_mla"
:
self
.
_create_dcumla_prefill_backend
,
"trtllm_mha"
:
self
.
_create_trtllm_mha_prefill_backend
,
"trtllm_mha"
:
self
.
_create_trtllm_mha_prefill_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_prefill_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_prefill_backend
,
"nsa"
:
self
.
_create_nsa_prefill_backend
,
"nsa"
:
self
.
_create_nsa_prefill_backend
,
...
@@ -153,6 +152,15 @@ class DraftBackendFactory:
...
@@ -153,6 +152,15 @@ class DraftBackendFactory:
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
)
def
_create_dcumla_decode_backend
(
self
):
from
sglang.srt.layers.attention.dcu_mla_backend
import
(
DCUMLAMultiStepDraftBackend
,
)
return
DCUMLAMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_trtllm_mha_decode_backend
(
self
):
def
_create_trtllm_mha_decode_backend
(
self
):
from
sglang.srt.layers.attention.trtllm_mha_backend
import
(
from
sglang.srt.layers.attention.trtllm_mha_backend
import
(
TRTLLMHAAttnMultiStepDraftBackend
,
TRTLLMHAAttnMultiStepDraftBackend
,
...
@@ -227,3 +235,9 @@ class DraftBackendFactory:
...
@@ -227,3 +235,9 @@ class DraftBackendFactory:
"flashmla prefill backend is not yet supported for draft extend."
"flashmla prefill backend is not yet supported for draft extend."
)
)
return
None
return
None
def
_create_dcumla_prefill_backend
(
self
):
logger
.
warning
(
"flashmla prefill backend is not yet supported for draft extend."
)
return
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment