Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
918e3d4c
Unverified
Commit
918e3d4c
authored
Sep 05, 2025
by
kk
Committed by
GitHub
Sep 04, 2025
Browse files
Fix accuracy drop of dsv3 run in dp enablement (#8677)
Co-authored-by:
wunhuang
<
wunhuang@amd.com
>
parent
e9697374
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
100 additions
and
69 deletions
+100
-69
python/sglang/srt/layers/attention/aiter_backend.py
python/sglang/srt/layers/attention/aiter_backend.py
+93
-68
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+7
-1
No files found.
python/sglang/srt/layers/attention/aiter_backend.py
View file @
918e3d4c
...
@@ -18,7 +18,10 @@ import triton.language as tl
...
@@ -18,7 +18,10 @@ import triton.language as tl
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.dp_attention
import
(
get_attention_tp_size
,
is_dp_attention_enabled
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -154,6 +157,8 @@ class AiterAttnBackend(AttentionBackend):
...
@@ -154,6 +157,8 @@ class AiterAttnBackend(AttentionBackend):
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
)
self
.
enable_dp_attention
=
is_dp_attention_enabled
()
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Init auxiliary variables for triton attention backend."""
"""Init auxiliary variables for triton attention backend."""
...
@@ -302,19 +307,19 @@ class AiterAttnBackend(AttentionBackend):
...
@@ -302,19 +307,19 @@ class AiterAttnBackend(AttentionBackend):
if
self
.
use_mla
:
if
self
.
use_mla
:
self
.
mla_indices_updater_prefill
.
update
(
self
.
mla_indices_updater_prefill
.
update
(
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
forward_batch
.
extend_prefix
_lens
,
forward_batch
.
seq
_lens
,
sum
(
forward_batch
.
extend_prefix
_lens_
cpu
)
,
forward_batch
.
seq
_lens_
sum
,
forward_batch
.
extend_seq_lens
,
forward_batch
.
extend_seq_lens
,
max
(
forward_batch
.
extend_seq_lens
_cpu
),
forward_batch
.
extend_seq_lens
.
max
().
item
(
),
forward_batch
.
seq_lens
_cpu
.
max
().
item
(),
forward_batch
.
seq_lens
.
max
().
item
(),
spec_info
=
None
,
spec_info
=
None
,
)
)
self
.
mla_indices_updater_prefill
.
kv_indptr
+=
(
self
.
mla_indices_updater_prefill
.
qo
_ind
ptr
kv_indices
=
self
.
mla_indices_updater_prefill
.
kv
_ind
ices
)
self
.
forward_metadata
=
ForwardMetadata
(
self
.
forward_metadata
=
ForwardMetadata
(
self
.
mla_indices_updater_prefill
.
kv_indptr
,
self
.
mla_indices_updater_prefill
.
kv_indptr
,
self
.
mla_indices_updater_prefill
.
kv_indices
,
kv_indices
,
self
.
mla_indices_updater_prefill
.
qo_indptr
,
self
.
mla_indices_updater_prefill
.
qo_indptr
,
self
.
kv_last_page_len
[:
bs
],
self
.
kv_last_page_len
[:
bs
],
self
.
mla_indices_updater_prefill
.
max_q_len
,
self
.
mla_indices_updater_prefill
.
max_q_len
,
...
@@ -614,66 +619,86 @@ class AiterAttnBackend(AttentionBackend):
...
@@ -614,66 +619,86 @@ class AiterAttnBackend(AttentionBackend):
assert
len
(
k
.
shape
)
==
3
assert
len
(
k
.
shape
)
==
3
assert
len
(
v
.
shape
)
==
3
assert
len
(
v
.
shape
)
==
3
if
kv_indices
.
shape
[
0
]
==
0
:
if
forward_batch
.
forward_mode
.
is_extend
():
o
=
flash_attn_varlen_func
(
if
kv_indices
.
shape
[
0
]
==
0
:
q
,
o
=
flash_attn_varlen_func
(
k
,
q
,
v
,
k
,
qo_indptr
,
v
,
qo_indptr
,
qo_indptr
,
max_q_len
,
qo_indptr
,
max_q_len
,
max_q_len
,
softmax_scale
=
layer
.
scaling
,
max_q_len
,
causal
=
True
,
softmax_scale
=
layer
.
scaling
,
)
causal
=
True
,
return
o
)
elif
layer
.
qk_head_dim
!=
(
kv_lora_rank
+
qk_rope_head_dim
):
return
o
K_Buffer
=
torch
.
index_select
(
K_Buffer
,
0
,
kv_indices
)
elif
layer
.
qk_head_dim
!=
(
kv_lora_rank
+
qk_rope_head_dim
):
kvc
,
k_pe
=
torch
.
split
(
K_Buffer
=
torch
.
index_select
(
K_Buffer
,
0
,
kv_indices
)
K_Buffer
,
[
kv_lora_rank
,
qk_rope_head_dim
],
dim
=-
1
kvc
,
k_pe
=
torch
.
split
(
)
K_Buffer
,
[
kv_lora_rank
,
qk_rope_head_dim
],
dim
=-
1
kvprefix
=
layer
.
kv_b_proj
(
kvc
.
contiguous
())[
0
]
)
kvprefix
=
layer
.
kv_b_proj
(
kvc
.
contiguous
())[
0
]
kvprefix
=
kvprefix
.
view
(
kvprefix
=
kvprefix
.
view
(
-
1
,
layer
.
tp_k_head_num
,
qk_nope_head_dim
+
layer
.
v_head_dim
-
1
,
layer
.
tp_k_head_num
,
qk_nope_head_dim
+
layer
.
v_head_dim
)
)
k_prefix
,
v_prefix
=
torch
.
split
(
k_prefix
,
v_prefix
=
torch
.
split
(
kvprefix
,
[
qk_nope_head_dim
,
layer
.
v_head_dim
],
dim
=-
1
kvprefix
,
[
qk_nope_head_dim
,
layer
.
v_head_dim
],
dim
=-
1
)
)
k_prefix
=
torch
.
cat
(
k_prefix
=
torch
.
cat
(
[
[
k_prefix
,
k_prefix
,
torch
.
broadcast_to
(
torch
.
broadcast_to
(
k_pe
,
k_pe
,
(
k_pe
.
shape
[
0
],
layer
.
tp_k_head_num
,
k_pe
.
shape
[
2
]),
(
k_pe
.
shape
[
0
],
layer
.
tp_k_head_num
,
k_pe
.
shape
[
2
]),
),
),
],
],
dim
=-
1
,
dim
=-
1
,
)
)
assert
(
assert
(
forward_batch
.
extend_prefix_lens
.
shape
forward_batch
.
extend_prefix_lens
.
shape
==
forward_batch
.
extend_seq_lens
.
shape
==
forward_batch
.
extend_seq_lens
.
shape
)
)
k_prefix
=
torch
.
split
(
k_prefix
,
forward_batch
.
extend_prefix_lens_cpu
)
k_extend
=
torch
.
split
(
k
,
forward_batch
.
extend_seq_lens_cpu
)
k
=
k_prefix
assert
len
(
k_prefix
)
==
len
(
forward_batch
.
extend_prefix_lens_cpu
)
v
=
v_prefix
k
=
torch
.
cat
([
x
for
el
in
zip
(
k_prefix
,
k_extend
)
for
x
in
el
])
v_prefix
=
torch
.
split
(
v_prefix
,
forward_batch
.
extend_prefix_lens_cpu
)
o
=
flash_attn_varlen_func
(
v_extend
=
torch
.
split
(
v
,
forward_batch
.
extend_seq_lens_cpu
)
q
,
v
=
torch
.
cat
([
x
for
el
in
zip
(
v_prefix
,
v_extend
)
for
x
in
el
])
k
,
v
,
o
=
flash_attn_varlen_func
(
qo_indptr
,
q
,
kv_indptr
,
k
,
max_q_len
,
v
,
max_kv_len
,
qo_indptr
,
softmax_scale
=
layer
.
scaling
,
kv_indptr
,
causal
=
True
,
max_q_len
,
)
max_kv_len
,
return
o
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
else
:
)
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
return
o
o
=
q
.
new_empty
(
(
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
)
else
:
o
=
torch
.
empty_like
(
q
)
mla_prefill_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
K_Buffer
.
view
(
-
1
,
1
,
1
,
layer
.
qk_head_dim
),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
qo_indptr
,
kv_indptr
,
kv_indices
,
self
.
forward_metadata
.
kv_last_page_len
,
self
.
forward_metadata
.
max_q_len
,
layer
.
scaling
,
layer
.
logit_cap
,
)
K_Buffer
=
K_Buffer
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
qk_head_dim
)
return
o
elif
forward_batch
.
forward_mode
.
is_target_verify
():
elif
forward_batch
.
forward_mode
.
is_target_verify
():
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
,
layer
.
v_head_dim
))
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
,
layer
.
v_head_dim
))
mla_decode_fwd
(
mla_decode_fwd
(
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
918e3d4c
...
@@ -1085,7 +1085,13 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -1085,7 +1085,13 @@ class DeepseekV2AttentionMLA(nn.Module):
and
not
forward_batch
.
forward_mode
.
is_target_verify
()
and
not
forward_batch
.
forward_mode
.
is_target_verify
()
and
not
forward_batch
.
forward_mode
.
is_draft_extend
()
and
not
forward_batch
.
forward_mode
.
is_draft_extend
()
):
):
return
AttnForwardMethod
.
MHA
if
is_dp_attention_enabled
():
if
sum
(
forward_batch
.
extend_prefix_lens_cpu
)
==
0
:
return
AttnForwardMethod
.
MHA
else
:
return
AttnForwardMethod
.
MLA
else
:
return
AttnForwardMethod
.
MHA
else
:
else
:
return
AttnForwardMethod
.
MLA
return
AttnForwardMethod
.
MLA
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment