Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3246cea1
Commit
3246cea1
authored
Nov 11, 2025
by
lizhigong
Browse files
Merge branch 'v0.5.4_dev_linhai' into 'v0.5.4_dev'
V0.5.4 dev linhai See merge request OpenDAS/sglang!17
parents
93eb92f8
59b01a00
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
168 additions
and
52 deletions
+168
-52
python/sglang/srt/layers/attention/dcu_mla_backend.py
python/sglang/srt/layers/attention/dcu_mla_backend.py
+140
-43
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+10
-8
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+1
-1
python/sglang/srt/speculative/draft_utils.py
python/sglang/srt/speculative/draft_utils.py
+17
-0
No files found.
python/sglang/srt/layers/attention/dcu_mla_backend.py
View file @
3246cea1
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
from
__future__
import
annotations
from
__future__
import
annotations
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Callable
,
Optional
,
Tuple
,
Union
import
torch
import
torch
import
triton
import
triton
...
@@ -47,6 +47,16 @@ class VllmMLADecodeMetadata:
...
@@ -47,6 +47,16 @@ class VllmMLADecodeMetadata:
num_splits
:
Optional
[
torch
.
Tensor
]
=
None
num_splits
:
Optional
[
torch
.
Tensor
]
=
None
block_kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
block_kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
def
__init__
(
self
,
flashmla_metadata
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
,
num_splits
:
Optional
[
torch
.
Tensor
]
=
None
,
block_kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
):
self
.
flashmla_metadata
=
flashmla_metadata
self
.
num_splits
=
num_splits
self
.
block_kv_indices
=
block_kv_indices
class
DCUMLABackend
(
AttentionBackend
):
class
DCUMLABackend
(
AttentionBackend
):
def
__init__
(
def
__init__
(
...
@@ -86,14 +96,6 @@ class DCUMLABackend(AttentionBackend):
...
@@ -86,14 +96,6 @@ class DCUMLABackend(AttentionBackend):
self
.
skip_prefill
=
skip_prefill
self
.
skip_prefill
=
skip_prefill
if
not
skip_prefill
:
if
not
skip_prefill
:
# 先用triton backend,后面考虑替换
# from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
# self.triton_backend = TritonAttnBackend(
# model_runner,
# skip_prefill=False,
# kv_indptr_buf=kv_indptr_buf,
# )
# prefill改用flash attn
from
sglang.srt.layers.attention.flashattention_backend
import
FlashAttentionBackend
from
sglang.srt.layers.attention.flashattention_backend
import
FlashAttentionBackend
self
.
flashattn_backend
=
FlashAttentionBackend
(
self
.
flashattn_backend
=
FlashAttentionBackend
(
model_runner
,
model_runner
,
...
@@ -147,9 +149,7 @@ class DCUMLABackend(AttentionBackend):
...
@@ -147,9 +149,7 @@ class DCUMLABackend(AttentionBackend):
mla_metadata
,
num_splits_t
,
block_kv_indices
mla_metadata
,
num_splits_t
,
block_kv_indices
)
)
else
:
else
:
# prefill/extend用triton backend -> 改用flash attn
if
not
self
.
skip_prefill
:
if
not
self
.
skip_prefill
:
# self.triton_backend.init_forward_metadata(forward_batch)
self
.
flashattn_backend
.
init_forward_metadata
(
forward_batch
)
self
.
flashattn_backend
.
init_forward_metadata
(
forward_batch
)
def
init_cuda_graph_state
(
def
init_cuda_graph_state
(
...
@@ -241,15 +241,6 @@ class DCUMLABackend(AttentionBackend):
...
@@ -241,15 +241,6 @@ class DCUMLABackend(AttentionBackend):
)
)
else
:
else
:
if
not
self
.
skip_prefill
:
if
not
self
.
skip_prefill
:
# self.triton_backend.init_forward_metadata_capture_cuda_graph(
# bs,
# num_tokens,
# req_pool_indices,
# seq_lens,
# encoder_lens,
# forward_mode,
# spec_info,
# )
self
.
flashattn_backend
.
init_forward_metadata_capture_cuda_graph
(
self
.
flashattn_backend
.
init_forward_metadata_capture_cuda_graph
(
bs
,
bs
,
num_tokens
,
num_tokens
,
...
@@ -321,16 +312,6 @@ class DCUMLABackend(AttentionBackend):
...
@@ -321,16 +312,6 @@ class DCUMLABackend(AttentionBackend):
]
]
else
:
else
:
if
not
self
.
skip_prefill
:
if
not
self
.
skip_prefill
:
# self.triton_backend.init_forward_metadata_replay_cuda_graph(
# bs,
# req_pool_indices,
# seq_lens,
# seq_lens_sum,
# encoder_lens,
# forward_mode,
# spec_info,
# seq_lens_cpu,
# )
self
.
flashattn_backend
.
init_forward_metadata_replay_cuda_graph
(
self
.
flashattn_backend
.
init_forward_metadata_replay_cuda_graph
(
bs
,
bs
,
req_pool_indices
,
req_pool_indices
,
...
@@ -363,7 +344,7 @@ class DCUMLABackend(AttentionBackend):
...
@@ -363,7 +344,7 @@ class DCUMLABackend(AttentionBackend):
def
_call_fp8_decode
(
self
,
reshape_q
:
torch
.
Tensor
,
k_cache_reshaped
:
torch
.
Tensor
,
def
_call_fp8_decode
(
self
,
reshape_q
:
torch
.
Tensor
,
k_cache_reshaped
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
cache_seqlens
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
cache_seqlens
:
torch
.
Tensor
,
scaling
:
float
):
scaling
:
float
,
k_scale
=
None
,
kv_cache_dtype
=
None
):
assert
_has_flash_mla
,
"FP8 KV cache 需要flash_mla包"
assert
_has_flash_mla
,
"FP8 KV cache 需要flash_mla包"
o
,
_
=
flash_mla_with_kvcache_quantization
(
o
,
_
=
flash_mla_with_kvcache_quantization
(
q
=
reshape_q
,
q
=
reshape_q
,
...
@@ -375,7 +356,8 @@ class DCUMLABackend(AttentionBackend):
...
@@ -375,7 +356,8 @@ class DCUMLABackend(AttentionBackend):
num_splits
=
self
.
forward_metadata
.
num_splits
,
num_splits
=
self
.
forward_metadata
.
num_splits
,
softmax_scale
=
scaling
,
softmax_scale
=
scaling
,
causal
=
True
,
causal
=
True
,
is_fp8_kvcache
=
True
,
k_scale
=
k_scale
,
kv_cache_dtype
=
kv_cache_dtype
,
)
)
return
o
return
o
...
@@ -412,14 +394,29 @@ class DCUMLABackend(AttentionBackend):
...
@@ -412,14 +394,29 @@ class DCUMLABackend(AttentionBackend):
getattr
(
torch
,
"float8_e5m2"
,
None
),
getattr
(
torch
,
"float8_e5m2"
,
None
),
getattr
(
torch
,
"float8_e5m2fnuz"
,
None
),
getattr
(
torch
,
"float8_e5m2fnuz"
,
None
),
):
):
if
k_cache_reshaped
.
dtype
==
torch
.
float8_e4m3fnuz
or
\
k_cache_reshaped
.
dtype
==
torch
.
float8_e4m3fn
:
kv_cache_dtype
=
"fp8_e4m3"
elif
k_cache_reshaped
.
dtype
==
torch
.
float8_e5m2fnuz
or
\
k_cache_reshaped
.
dtype
==
torch
.
float8_e5m2
:
kv_cache_dtype
=
"fp8_e5m2"
k_scale
=
layer
.
k_scale
if
layer
.
k_scale
is
not
None
else
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
reshape_q
.
device
)
o
=
self
.
_call_fp8_decode
(
o
=
self
.
_call_fp8_decode
(
reshape_q
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
reshape_q
,
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
layer
.
scaling
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
layer
.
scaling
,
k_scale
.
to
(
torch
.
float32
),
kv_cache_dtype
=
kv_cache_dtype
,
)
)
else
:
else
:
o
=
self
.
_call_decode
(
o
=
self
.
_call_decode
(
reshape_q
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
reshape_q
,
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
layer
.
scaling
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
layer
.
scaling
,
)
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
...
@@ -432,7 +429,6 @@ class DCUMLABackend(AttentionBackend):
...
@@ -432,7 +429,6 @@ class DCUMLABackend(AttentionBackend):
layer
:
"RadixAttention"
,
layer
:
"RadixAttention"
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
=
True
,
save_kv_cache
:
bool
=
True
,
# For multi-head latent attention
# For multi-head latent attention
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -445,11 +441,7 @@ class DCUMLABackend(AttentionBackend):
...
@@ -445,11 +441,7 @@ class DCUMLABackend(AttentionBackend):
forward_batch
.
forward_mode
==
ForwardMode
.
EXTEND
forward_batch
.
forward_mode
==
ForwardMode
.
EXTEND
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
)
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
)
):
):
# flash_attn不支持fp8,fp8无法正常执行extend
if
not
self
.
skip_prefill
:
if
not
self
.
skip_prefill
:
# return self.triton_backend.forward_extend(
# q, k, v, layer, forward_batch, save_kv_cache, sinks
# )
return
self
.
flashattn_backend
.
forward_extend
(
return
self
.
flashattn_backend
.
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
,
q_rope
,
k_rope
,
sinks
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
,
q_rope
,
k_rope
,
sinks
)
)
...
@@ -474,14 +466,27 @@ class DCUMLABackend(AttentionBackend):
...
@@ -474,14 +466,27 @@ class DCUMLABackend(AttentionBackend):
getattr
(
torch
,
"float8_e5m2"
,
None
),
getattr
(
torch
,
"float8_e5m2"
,
None
),
getattr
(
torch
,
"float8_e5m2fnuz"
,
None
),
getattr
(
torch
,
"float8_e5m2fnuz"
,
None
),
):
):
if
k_cache_reshaped
.
dtype
==
torch
.
float8_e4m3fnuz
or
\
k_cache_reshaped
.
dtype
==
torch
.
float8_e4m3fn
:
kv_cache_dtype
=
"fp8_e4m3"
elif
k_cache_reshaped
.
dtype
==
torch
.
float8_e5m2fnuz
or
\
k_cache_reshaped
.
dtype
==
torch
.
float8_e5m2
:
kv_cache_dtype
=
"fp8_e5m2"
k_scale
=
layer
.
k_scale
if
layer
.
k_scale
is
not
None
else
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
reshape_q
.
device
)
o
=
self
.
_call_fp8_decode
(
o
=
self
.
_call_fp8_decode
(
reshape_q
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
reshape_q
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
(
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
).
to
(
torch
.
int32
),
(
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
).
to
(
torch
.
int32
),
layer
.
scaling
,
layer
.
scaling
,
k_scale
.
to
(
torch
.
float32
),
kv_cache_dtype
=
kv_cache_dtype
,
)
)
else
:
else
:
o
=
self
.
_call_decode
(
o
=
self
.
_call_decode
(
reshape_q
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
reshape_q
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
(
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
).
to
(
torch
.
int32
),
(
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
).
to
(
torch
.
int32
),
layer
.
scaling
,
layer
.
scaling
,
)
)
...
@@ -489,3 +494,95 @@ class DCUMLABackend(AttentionBackend):
...
@@ -489,3 +494,95 @@ class DCUMLABackend(AttentionBackend):
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
class
DCUMLAMultiStepDraftBackend
:
"""
Wrap multiple flashmla attention backends as one for multiple consecutive
draft decoding steps.
"""
def
__init__
(
self
,
model_runner
:
ModelRunner
,
topk
:
int
,
speculative_num_steps
:
int
,
):
if
topk
>
1
:
raise
ValueError
(
"Currently FlashMLA only supports topk=1 for speculative decoding"
)
self
.
topk
=
topk
self
.
speculative_num_steps
=
speculative_num_steps
max_bs
=
model_runner
.
req_to_token_pool
.
size
*
self
.
topk
self
.
kv_indptr
=
torch
.
zeros
(
(
self
.
speculative_num_steps
,
max_bs
+
1
,
),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
,
)
self
.
attn_backends
=
[]
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
.
append
(
DCUMLABackend
(
model_runner
,
skip_prefill
=
True
,
kv_indptr_buf
=
self
.
kv_indptr
[
i
],
kv_last_page_len_buf
=
None
,
)
)
def
common_template
(
self
,
forward_batch
:
ForwardBatch
,
call_fn
:
Callable
,
):
assert
forward_batch
.
spec_info
is
not
None
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
call_fn
(
i
,
forward_batch
)
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
call_fn
(
i
,
forward_batch
):
assert
forward_batch
.
spec_info
is
not
None
self
.
attn_backends
[
i
].
init_forward_metadata
(
forward_batch
)
self
.
common_template
(
forward_batch
,
call_fn
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
max_bs
,
max_num_tokens
,
block_kv_indices
=
None
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
):
def
call_fn
(
i
,
forward_batch
):
self
.
attn_backends
[
i
].
init_forward_metadata_capture_cuda_graph
(
forward_batch
.
batch_size
,
forward_batch
.
batch_size
*
self
.
topk
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
encoder_lens
=
None
,
forward_mode
=
ForwardMode
.
DECODE
,
spec_info
=
forward_batch
.
spec_info
,
)
self
.
common_template
(
forward_batch
,
call_fn
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
,
bs
:
int
):
def
call_fn
(
i
,
forward_batch
):
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
bs
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
seq_lens_sum
=-
1
,
encoder_lens
=
None
,
forward_mode
=
ForwardMode
.
DECODE
,
spec_info
=
forward_batch
.
spec_info
,
seq_lens_cpu
=
forward_batch
.
seq_lens_cpu
,
)
self
.
common_template
(
forward_batch
,
call_fn
)
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
3246cea1
...
@@ -695,6 +695,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -695,6 +695,7 @@ class FlashAttentionBackend(AttentionBackend):
# has corresponding quantization method so that layer.k_scale is not None,
# has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
# 4) fa_impl_ver != 4 since fa4 does not currently support fp8 queries and keys.
# 4) fa_impl_ver != 4 since fa4 does not currently support fp8 queries and keys.
data_dtype
=
q
.
dtype
if
(
if
(
self
.
kv_cache_dtype_str
!=
"auto"
self
.
kv_cache_dtype_str
!=
"auto"
and
layer
.
head_dim
<=
256
and
layer
.
head_dim
<=
256
...
@@ -828,7 +829,9 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -828,7 +829,9 @@ class FlashAttentionBackend(AttentionBackend):
forward_batch
.
attn_attend_prefix_cache
is
not
None
forward_batch
.
attn_attend_prefix_cache
is
not
None
and
not
forward_batch
.
forward_mode
.
is_target_verify
()
and
not
forward_batch
.
forward_mode
.
is_target_verify
()
and
not
forward_batch
.
forward_mode
.
is_draft_extend
()
and
not
forward_batch
.
forward_mode
.
is_draft_extend
()
):
):
k_descale
=
k_descale
if
layer
.
k_scale
is
not
None
else
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
q
.
device
)
v_descale
=
v_descale
if
layer
.
k_scale
is
not
None
else
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
q
.
device
)
# Do multi-head attention with chunked prefix cache
# Do multi-head attention with chunked prefix cache
if
forward_batch
.
attn_attend_prefix_cache
:
if
forward_batch
.
attn_attend_prefix_cache
:
assert
not
get_global_server_args
().
disable_chunked_prefix_cache
assert
not
get_global_server_args
().
disable_chunked_prefix_cache
...
@@ -842,9 +845,9 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -842,9 +845,9 @@ class FlashAttentionBackend(AttentionBackend):
assert
forward_batch
.
mha_return_lse
assert
forward_batch
.
mha_return_lse
output
=
flash_attn_varlen_func
(
output
=
flash_attn_varlen_func
(
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
.
to
(
data_dtype
)
,
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
.
to
(
q
.
dtype
),
k
=
(
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
*
k_descale
).
to
(
data_
dtype
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
)
.
to
(
q
.
dtype
),
v
=
(
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
)
*
v_descale
).
to
(
data_
dtype
),
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k
=
forward_batch
.
prefix_chunk_cu_seq_lens
[
chunk_idx
],
cu_seqlens_k
=
forward_batch
.
prefix_chunk_cu_seq_lens
[
chunk_idx
],
max_seqlen_q
=
metadata
.
max_seq_len_q
,
max_seqlen_q
=
metadata
.
max_seq_len_q
,
...
@@ -855,11 +858,10 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -855,11 +858,10 @@ class FlashAttentionBackend(AttentionBackend):
**
kwargs
,
**
kwargs
,
)
)
else
:
else
:
# MHA for extend part of sequence without attending prefix kv cache
output
=
flash_attn_varlen_func
(
output
=
flash_attn_varlen_func
(
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
.
to
(
data_dtype
)
,
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
.
to
(
q
.
dtype
),
k
=
(
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
*
k_descale
).
to
(
data_
dtype
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
)
.
to
(
q
.
dtype
),
v
=
(
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
)
*
v_descale
).
to
(
data_
dtype
),
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k
=
metadata
.
cu_seqlens_q
,
max_seqlen_q
=
metadata
.
max_seq_len_q
,
max_seqlen_q
=
metadata
.
max_seq_len_q
,
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
3246cea1
...
@@ -2296,7 +2296,7 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -2296,7 +2296,7 @@ class DeepseekV2AttentionMLA(nn.Module):
# Fetch latent cache from memory pool with precomputed chunked kv indices
# Fetch latent cache from memory pool with precomputed chunked kv indices
latent_cache_buf
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
latent_cache_buf
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
self
.
attn_mha
.
layer_id
self
.
attn_mha
.
layer_id
)
)
.
to
(
q
.
dtype
)
latent_cache
=
(
latent_cache
=
(
latent_cache_buf
[
forward_batch
.
prefix_chunk_kv_indices
[
i
]]
latent_cache_buf
[
forward_batch
.
prefix_chunk_kv_indices
[
i
]]
.
contiguous
()
.
contiguous
()
...
...
python/sglang/srt/speculative/draft_utils.py
View file @
3246cea1
...
@@ -46,6 +46,7 @@ class DraftBackendFactory:
...
@@ -46,6 +46,7 @@ class DraftBackendFactory:
else
self
.
_create_triton_decode_backend
else
self
.
_create_triton_decode_backend
),
),
"flashmla"
:
self
.
_create_flashmla_decode_backend
,
"flashmla"
:
self
.
_create_flashmla_decode_backend
,
"dcu_mla"
:
self
.
_create_dcumla_decode_backend
,
"trtllm_mha"
:
self
.
_create_trtllm_mha_decode_backend
,
"trtllm_mha"
:
self
.
_create_trtllm_mha_decode_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_decode_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_decode_backend
,
"nsa"
:
self
.
_create_nsa_decode_backend
,
"nsa"
:
self
.
_create_nsa_decode_backend
,
...
@@ -69,6 +70,7 @@ class DraftBackendFactory:
...
@@ -69,6 +70,7 @@ class DraftBackendFactory:
else
self
.
_create_triton_prefill_backend
else
self
.
_create_triton_prefill_backend
),
),
"flashmla"
:
self
.
_create_flashmla_prefill_backend
,
"flashmla"
:
self
.
_create_flashmla_prefill_backend
,
"dcu_mla"
:
self
.
_create_dcumla_prefill_backend
,
"trtllm_mha"
:
self
.
_create_trtllm_mha_prefill_backend
,
"trtllm_mha"
:
self
.
_create_trtllm_mha_prefill_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_prefill_backend
,
"trtllm_mla"
:
self
.
_create_trtllm_mla_prefill_backend
,
"nsa"
:
self
.
_create_nsa_prefill_backend
,
"nsa"
:
self
.
_create_nsa_prefill_backend
,
...
@@ -149,6 +151,15 @@ class DraftBackendFactory:
...
@@ -149,6 +151,15 @@ class DraftBackendFactory:
return
FlashMLAMultiStepDraftBackend
(
return
FlashMLAMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
)
def
_create_dcumla_decode_backend
(
self
):
from
sglang.srt.layers.attention.dcu_mla_backend
import
(
DCUMLAMultiStepDraftBackend
,
)
return
DCUMLAMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
)
def
_create_trtllm_mha_decode_backend
(
self
):
def
_create_trtllm_mha_decode_backend
(
self
):
from
sglang.srt.layers.attention.trtllm_mha_backend
import
(
from
sglang.srt.layers.attention.trtllm_mha_backend
import
(
...
@@ -224,3 +235,9 @@ class DraftBackendFactory:
...
@@ -224,3 +235,9 @@ class DraftBackendFactory:
"flashmla prefill backend is not yet supported for draft extend."
"flashmla prefill backend is not yet supported for draft extend."
)
)
return
None
return
None
def
_create_dcumla_prefill_backend
(
self
):
logger
.
warning
(
"flashmla prefill backend is not yet supported for draft extend."
)
return
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment