Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d629db06
Commit
d629db06
authored
Nov 08, 2025
by
linhai1
Browse files
add draft_extend support for dcu_mla.
parent
4d106b5f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
22 additions
and
79 deletions
+22
-79
python/sglang/srt/layers/attention/dcu_mla_backend.py
python/sglang/srt/layers/attention/dcu_mla_backend.py
+13
-73
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+4
-2
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+1
-3
python/sglang/srt/speculative/draft_utils.py
python/sglang/srt/speculative/draft_utils.py
+4
-1
No files found.
python/sglang/srt/layers/attention/dcu_mla_backend.py
View file @
d629db06
...
@@ -86,14 +86,6 @@ class DCUMLABackend(AttentionBackend):
...
@@ -86,14 +86,6 @@ class DCUMLABackend(AttentionBackend):
self
.
skip_prefill
=
skip_prefill
self
.
skip_prefill
=
skip_prefill
if
not
skip_prefill
:
if
not
skip_prefill
:
# 先用triton backend,后面考虑替换
# from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
# self.triton_backend = TritonAttnBackend(
# model_runner,
# skip_prefill=False,
# kv_indptr_buf=kv_indptr_buf,
# )
# prefill改用flash attn
from
sglang.srt.layers.attention.flashattention_backend
import
FlashAttentionBackend
from
sglang.srt.layers.attention.flashattention_backend
import
FlashAttentionBackend
self
.
flashattn_backend
=
FlashAttentionBackend
(
self
.
flashattn_backend
=
FlashAttentionBackend
(
model_runner
,
model_runner
,
...
@@ -109,7 +101,6 @@ class DCUMLABackend(AttentionBackend):
...
@@ -109,7 +101,6 @@ class DCUMLABackend(AttentionBackend):
bs
=
forward_batch
.
batch_size
bs
=
forward_batch
.
batch_size
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens
.
max
().
item
(),
PAGE_SIZE
)
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens
.
max
().
item
(),
PAGE_SIZE
)
# 参考vllm官方博客分页
block_kv_indices
=
torch
.
full
(
block_kv_indices
=
torch
.
full
(
(
bs
,
max_seqlen_pad
),
-
1
,
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
(
bs
,
max_seqlen_pad
),
-
1
,
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
)
)
...
@@ -131,7 +122,6 @@ class DCUMLABackend(AttentionBackend):
...
@@ -131,7 +122,6 @@ class DCUMLABackend(AttentionBackend):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
# decode用flashmla
(
mla_metadata
,
num_splits
),
num_splits_t
,
block_kv_indices
=
(
(
mla_metadata
,
num_splits
),
num_splits_t
,
block_kv_indices
=
(
self
.
_build_decode_metadata
(
forward_batch
,
forward_batch
.
seq_lens
)
self
.
_build_decode_metadata
(
forward_batch
,
forward_batch
.
seq_lens
)
)
)
...
@@ -147,9 +137,7 @@ class DCUMLABackend(AttentionBackend):
...
@@ -147,9 +137,7 @@ class DCUMLABackend(AttentionBackend):
mla_metadata
,
num_splits_t
,
block_kv_indices
mla_metadata
,
num_splits_t
,
block_kv_indices
)
)
else
:
else
:
# prefill/extend用triton backend -> 改用flash attn
if
not
self
.
skip_prefill
:
if
not
self
.
skip_prefill
:
# self.triton_backend.init_forward_metadata(forward_batch)
self
.
flashattn_backend
.
init_forward_metadata
(
forward_batch
)
self
.
flashattn_backend
.
init_forward_metadata
(
forward_batch
)
def
init_cuda_graph_state
(
def
init_cuda_graph_state
(
...
@@ -241,15 +229,6 @@ class DCUMLABackend(AttentionBackend):
...
@@ -241,15 +229,6 @@ class DCUMLABackend(AttentionBackend):
)
)
else
:
else
:
if
not
self
.
skip_prefill
:
if
not
self
.
skip_prefill
:
# self.triton_backend.init_forward_metadata_capture_cuda_graph(
# bs,
# num_tokens,
# req_pool_indices,
# seq_lens,
# encoder_lens,
# forward_mode,
# spec_info,
# )
self
.
flashattn_backend
.
init_forward_metadata_capture_cuda_graph
(
self
.
flashattn_backend
.
init_forward_metadata_capture_cuda_graph
(
bs
,
bs
,
num_tokens
,
num_tokens
,
...
@@ -321,16 +300,6 @@ class DCUMLABackend(AttentionBackend):
...
@@ -321,16 +300,6 @@ class DCUMLABackend(AttentionBackend):
]
]
else
:
else
:
if
not
self
.
skip_prefill
:
if
not
self
.
skip_prefill
:
# self.triton_backend.init_forward_metadata_replay_cuda_graph(
# bs,
# req_pool_indices,
# seq_lens,
# seq_lens_sum,
# encoder_lens,
# forward_mode,
# spec_info,
# seq_lens_cpu,
# )
self
.
flashattn_backend
.
init_forward_metadata_replay_cuda_graph
(
self
.
flashattn_backend
.
init_forward_metadata_replay_cuda_graph
(
bs
,
bs
,
req_pool_indices
,
req_pool_indices
,
...
@@ -387,30 +356,18 @@ class DCUMLABackend(AttentionBackend):
...
@@ -387,30 +356,18 @@ class DCUMLABackend(AttentionBackend):
layer
:
"RadixAttention"
,
layer
:
"RadixAttention"
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
=
True
,
save_kv_cache
:
bool
=
True
,
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
sinks
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
cache_loc
=
forward_batch
.
out_cache_loc
cache_loc
=
forward_batch
.
out_cache_loc
if
k
is
not
None
:
if
k
is
not
None
:
assert
v
is
not
None
assert
v
is
not
None
if
save_kv_cache
:
if
save_kv_cache
:
if
k_rope
is
None
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
layer
,
cache_loc
,
cache_loc
,
k
,
k
,
v
,
v
,
)
)
else
:
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
layer
,
cache_loc
,
k
,
k_rope
,
)
bs
=
forward_batch
.
batch_size
bs
=
forward_batch
.
batch_size
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
...
@@ -444,22 +401,14 @@ class DCUMLABackend(AttentionBackend):
...
@@ -444,22 +401,14 @@ class DCUMLABackend(AttentionBackend):
layer
:
"RadixAttention"
,
layer
:
"RadixAttention"
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
=
True
,
save_kv_cache
:
bool
=
True
,
q_rope
=
None
,
k_rope
=
None
,
sinks
=
None
,
):
):
if
(
if
(
forward_batch
.
forward_mode
==
ForwardMode
.
EXTEND
forward_batch
.
forward_mode
==
ForwardMode
.
EXTEND
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
):
):
# flash_attn不支持fp8,fp8无法正常执行extend
if
not
self
.
skip_prefill
:
if
not
self
.
skip_prefill
:
# return self.triton_backend.forward_extend(
# q, k, v, layer, forward_batch, save_kv_cache, sinks
# )
return
self
.
flashattn_backend
.
forward_extend
(
return
self
.
flashattn_backend
.
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
,
q_rope
,
k_rope
,
sinks
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
,
)
)
else
:
else
:
raise
RuntimeError
(
"skip prefill but use forward_extend"
)
raise
RuntimeError
(
"skip prefill but use forward_extend"
)
...
@@ -468,21 +417,12 @@ class DCUMLABackend(AttentionBackend):
...
@@ -468,21 +417,12 @@ class DCUMLABackend(AttentionBackend):
if
k
is
not
None
:
if
k
is
not
None
:
assert
v
is
not
None
assert
v
is
not
None
if
save_kv_cache
:
if
save_kv_cache
:
# forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
if
k_rope
is
None
:
layer
,
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
cache_loc
,
layer
,
k
,
cache_loc
,
v
,
k
,
)
v
,
)
else
:
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
layer
,
cache_loc
,
k
,
k_rope
,
)
bs
=
forward_batch
.
batch_size
bs
=
forward_batch
.
batch_size
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
...
...
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
d629db06
...
@@ -668,9 +668,11 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -668,9 +668,11 @@ class FlashAttentionBackend(AttentionBackend):
if
not
layer
.
is_cross_attention
if
not
layer
.
is_cross_attention
else
forward_batch
.
encoder_out_cache_loc
else
forward_batch
.
encoder_out_cache_loc
)
)
if
not
self
.
use_mla
:
# if not self.use_mla:
if
k_rope
is
None
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
# layer, cache_loc, k, v, layer.k_scale, layer.v_scale
layer
,
cache_loc
,
k
,
v
)
)
else
:
else
:
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
d629db06
...
@@ -1662,9 +1662,7 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -1662,9 +1662,7 @@ class DeepseekV2AttentionMLA(nn.Module):
positions
,
positions
,
topk_indices
,
topk_indices
,
):
):
# if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
if
self
.
current_attention_backend
in
FORWARD_ABSORB_CORE_ATTENTION_BACKENDS
:
if
self
.
current_attention_backend
in
FORWARD_ABSORB_CORE_ATTENTION_BACKENDS
or
\
(
not
forward_batch
.
forward_mode
.
is_decode
()
and
self
.
current_attention_backend
==
'dcu_mla'
):
extra_args
=
{}
extra_args
=
{}
if
self
.
_fuse_rope_for_trtllm_mla
(
forward_batch
):
if
self
.
_fuse_rope_for_trtllm_mla
(
forward_batch
):
extra_args
=
{
extra_args
=
{
...
...
python/sglang/srt/speculative/draft_utils.py
View file @
d629db06
...
@@ -27,7 +27,10 @@ class DraftBackendFactory:
...
@@ -27,7 +27,10 @@ class DraftBackendFactory:
backend_type
=
self
.
server_args
.
attention_backend
backend_type
=
self
.
server_args
.
attention_backend
if
backend_type
not
in
backend_map
:
if
backend_type
not
in
backend_map
:
raise
ValueError
(
error_template
.
format
(
backend_type
=
backend_type
))
if
backend_type
!=
"dcu_mla"
:
raise
ValueError
(
error_template
.
format
(
backend_type
=
backend_type
))
else
:
return
backend_map
[
"fa3"
]()
return
backend_map
[
backend_type
]()
return
backend_map
[
backend_type
]()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment