Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e36f865d
Commit
e36f865d
authored
Nov 07, 2025
by
linhai1
Browse files
Fix Bug.
parent
46da9556
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
83 additions
and
11 deletions
+83
-11
python/sglang/srt/layers/attention/dcu_mla_backend.py
python/sglang/srt/layers/attention/dcu_mla_backend.py
+37
-9
python/sglang/srt/layers/attention/flashattention_interface.py
...n/sglang/srt/layers/attention/flashattention_interface.py
+42
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-0
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+3
-1
No files found.
python/sglang/srt/layers/attention/dcu_mla_backend.py
View file @
e36f865d
...
@@ -387,18 +387,30 @@ class DCUMLABackend(AttentionBackend):
...
@@ -387,18 +387,30 @@ class DCUMLABackend(AttentionBackend):
layer
:
"RadixAttention"
,
layer
:
"RadixAttention"
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
=
True
,
save_kv_cache
:
bool
=
True
,
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
sinks
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
cache_loc
=
forward_batch
.
out_cache_loc
cache_loc
=
forward_batch
.
out_cache_loc
if
k
is
not
None
:
if
k
is
not
None
:
assert
v
is
not
None
assert
v
is
not
None
if
save_kv_cache
:
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
if
k_rope
is
None
:
layer
,
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
cache_loc
,
layer
,
k
,
cache_loc
,
v
,
k
,
)
v
,
)
else
:
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
layer
,
cache_loc
,
k
,
k_rope
,
)
bs
=
forward_batch
.
batch_size
bs
=
forward_batch
.
batch_size
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
...
@@ -432,7 +444,9 @@ class DCUMLABackend(AttentionBackend):
...
@@ -432,7 +444,9 @@ class DCUMLABackend(AttentionBackend):
layer
:
"RadixAttention"
,
layer
:
"RadixAttention"
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
=
True
,
save_kv_cache
:
bool
=
True
,
sinks
=
None
,
q_rope
=
None
,
k_rope
=
None
,
sinks
=
None
,
):
):
if
(
if
(
forward_batch
.
forward_mode
==
ForwardMode
.
EXTEND
forward_batch
.
forward_mode
==
ForwardMode
.
EXTEND
...
@@ -444,7 +458,7 @@ class DCUMLABackend(AttentionBackend):
...
@@ -444,7 +458,7 @@ class DCUMLABackend(AttentionBackend):
# q, k, v, layer, forward_batch, save_kv_cache, sinks
# q, k, v, layer, forward_batch, save_kv_cache, sinks
# )
# )
return
self
.
flashattn_backend
.
forward_extend
(
return
self
.
flashattn_backend
.
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
,
sinks
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
,
q_rope
,
k_rope
,
sinks
)
)
else
:
else
:
raise
RuntimeError
(
"skip prefill but use forward_extend"
)
raise
RuntimeError
(
"skip prefill but use forward_extend"
)
...
@@ -453,7 +467,21 @@ class DCUMLABackend(AttentionBackend):
...
@@ -453,7 +467,21 @@ class DCUMLABackend(AttentionBackend):
if
k
is
not
None
:
if
k
is
not
None
:
assert
v
is
not
None
assert
v
is
not
None
if
save_kv_cache
:
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
# forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
if
k_rope
is
None
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
)
else
:
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
layer
,
cache_loc
,
k
,
k_rope
,
)
bs
=
forward_batch
.
batch_size
bs
=
forward_batch
.
batch_size
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
...
...
python/sglang/srt/layers/attention/flashattention_interface.py
View file @
e36f865d
...
@@ -6,6 +6,8 @@ from typing import Optional, Union
...
@@ -6,6 +6,8 @@ from typing import Optional, Union
import
torch
import
torch
MAX_FLASH_ATTN_KERNEL_HEADDIM
=
256
def
flash_attn_with_kvcache
(
def
flash_attn_with_kvcache
(
q
,
q
,
k_cache
,
k_cache
,
...
@@ -40,7 +42,46 @@ def flash_attn_with_kvcache(
...
@@ -40,7 +42,46 @@ def flash_attn_with_kvcache(
sinks
=
None
,
sinks
=
None
,
ver
=
3
,
ver
=
3
,
):
):
return
flash_attn_with_kvcache_interface
(
if
cu_seqlens_q
is
not
None
and
q
.
shape
[
0
]
!=
cu_seqlens_q
.
shape
[
0
]
*
max_seqlen_q
:
v_cache
=
v_cache
.
view
(
-
1
,
v_cache
.
shape
[
-
2
],
v_cache
.
shape
[
-
1
])
if
v_cache
.
shape
[
-
1
]
>
MAX_FLASH_ATTN_KERNEL_HEADDIM
:
out_1
=
flash_attn_varlen_func_interface
(
q
=
q
,
# (total_q, num_heads, head_size_og)
k
=
k_cache
.
view
(
-
1
,
k_cache
.
shape
[
-
2
],
k_cache
.
shape
[
-
1
]),
# (total_k, num_heads_k, head_size_og)
v
=
v_cache
[:,
:,
:
MAX_FLASH_ATTN_KERNEL_HEADDIM
],
# (total_k, num_heads_k, head_size_og)
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k_new
if
cu_seqlens_k_new
is
not
None
else
None
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k
=
max_seqlen_q
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
)
out_2
=
flash_attn_varlen_func_interface
(
q
=
q
,
# (total_q, num_heads, head_size_og)
k
=
k_cache
.
view
(
-
1
,
k_cache
.
shape
[
-
2
],
k_cache
.
shape
[
-
1
]),
# (total_k, num_heads_k, head_size_og)
v
=
v_cache
[:,
:,
MAX_FLASH_ATTN_KERNEL_HEADDIM
:],
# (total_k, num_heads_k, head_size_og)
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k_new
if
cu_seqlens_k_new
is
not
None
else
None
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k
=
max_seqlen_q
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
)
return
torch
.
cat
([
out_1
,
out_2
],
dim
=-
1
)
else
:
return
flash_attn_varlen_func_interface
(
q
=
q
,
# (total_q, num_heads, head_size_og)
k
=
k_cache
.
view
(
-
1
,
k_cache
.
shape
[
-
2
],
k_cache
.
shape
[
-
1
]),
# (total_k, num_heads_k, head_size_og)
v
=
v_cache
.
view
(
-
1
,
v_cache
.
shape
[
-
2
],
v_cache
.
shape
[
-
1
]),
# (total_k, num_heads_k, head_size_og)
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k_new
if
cu_seqlens_k_new
is
not
None
else
None
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k
=
max_seqlen_q
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
)
else
:
return
flash_attn_with_kvcache_interface
(
q
=
q
.
contiguous
().
view
(
-
1
,
max_seqlen_q
,
q
.
shape
[
-
2
],
q
.
shape
[
-
1
]),
q
=
q
.
contiguous
().
view
(
-
1
,
max_seqlen_q
,
q
.
shape
[
-
2
],
q
.
shape
[
-
1
]),
k_cache
=
k_cache
,
k_cache
=
k_cache
,
v_cache
=
v_cache
,
v_cache
=
v_cache
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
e36f865d
...
@@ -178,6 +178,7 @@ CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS = [
...
@@ -178,6 +178,7 @@ CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS = [
"flashmla"
,
"flashmla"
,
"cutlass_mla"
,
"cutlass_mla"
,
"trtllm_mla"
,
"trtllm_mla"
,
"dcu_mla"
,
]
]
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
e36f865d
...
@@ -1662,7 +1662,9 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -1662,7 +1662,9 @@ class DeepseekV2AttentionMLA(nn.Module):
positions
,
positions
,
topk_indices
,
topk_indices
,
):
):
if
self
.
current_attention_backend
in
FORWARD_ABSORB_CORE_ATTENTION_BACKENDS
:
# if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
if
self
.
current_attention_backend
in
FORWARD_ABSORB_CORE_ATTENTION_BACKENDS
or
\
(
not
forward_batch
.
forward_mode
.
is_decode
()
and
self
.
current_attention_backend
==
'dcu_mla'
):
extra_args
=
{}
extra_args
=
{}
if
self
.
_fuse_rope_for_trtllm_mla
(
forward_batch
):
if
self
.
_fuse_rope_for_trtllm_mla
(
forward_batch
):
extra_args
=
{
extra_args
=
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment