Unverified Commit 918e3d4c authored by kk's avatar kk Committed by GitHub
Browse files

Fix accuracy drop of dsv3 run in dp enablement (#8677)


Co-authored-by: default avatarwunhuang <wunhuang@amd.com>
parent e9697374
......@@ -18,7 +18,10 @@ import triton.language as tl
from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.dp_attention import (
get_attention_tp_size,
is_dp_attention_enabled,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
if TYPE_CHECKING:
......@@ -154,6 +157,8 @@ class AiterAttnBackend(AttentionBackend):
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
self.enable_dp_attention = is_dp_attention_enabled()
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init auxiliary variables for triton attention backend."""
......@@ -302,19 +307,19 @@ class AiterAttnBackend(AttentionBackend):
if self.use_mla:
self.mla_indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.extend_prefix_lens,
sum(forward_batch.extend_prefix_lens_cpu),
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
forward_batch.extend_seq_lens,
max(forward_batch.extend_seq_lens_cpu),
forward_batch.seq_lens_cpu.max().item(),
forward_batch.extend_seq_lens.max().item(),
forward_batch.seq_lens.max().item(),
spec_info=None,
)
self.mla_indices_updater_prefill.kv_indptr += (
self.mla_indices_updater_prefill.qo_indptr
)
kv_indices = self.mla_indices_updater_prefill.kv_indices
self.forward_metadata = ForwardMetadata(
self.mla_indices_updater_prefill.kv_indptr,
self.mla_indices_updater_prefill.kv_indices,
kv_indices,
self.mla_indices_updater_prefill.qo_indptr,
self.kv_last_page_len[:bs],
self.mla_indices_updater_prefill.max_q_len,
......@@ -614,6 +619,7 @@ class AiterAttnBackend(AttentionBackend):
assert len(k.shape) == 3
assert len(v.shape) == 3
if forward_batch.forward_mode.is_extend():
if kv_indices.shape[0] == 0:
o = flash_attn_varlen_func(
q,
......@@ -654,13 +660,9 @@ class AiterAttnBackend(AttentionBackend):
forward_batch.extend_prefix_lens.shape
== forward_batch.extend_seq_lens.shape
)
k_prefix = torch.split(k_prefix, forward_batch.extend_prefix_lens_cpu)
k_extend = torch.split(k, forward_batch.extend_seq_lens_cpu)
assert len(k_prefix) == len(forward_batch.extend_prefix_lens_cpu)
k = torch.cat([x for el in zip(k_prefix, k_extend) for x in el])
v_prefix = torch.split(v_prefix, forward_batch.extend_prefix_lens_cpu)
v_extend = torch.split(v, forward_batch.extend_seq_lens_cpu)
v = torch.cat([x for el in zip(v_prefix, v_extend) for x in el])
k = k_prefix
v = v_prefix
o = flash_attn_varlen_func(
q,
......@@ -674,6 +676,29 @@ class AiterAttnBackend(AttentionBackend):
causal=True,
)
return o
else:
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty(
(q.shape[0], layer.tp_q_head_num * layer.v_head_dim)
)
else:
o = torch.empty_like(q)
mla_prefill_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
K_Buffer.view(-1, 1, 1, layer.qk_head_dim),
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
qo_indptr,
kv_indptr,
kv_indices,
self.forward_metadata.kv_last_page_len,
self.forward_metadata.max_q_len,
layer.scaling,
layer.logit_cap,
)
K_Buffer = K_Buffer.view(-1, layer.tp_k_head_num, layer.qk_head_dim)
return o
elif forward_batch.forward_mode.is_target_verify():
o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim))
mla_decode_fwd(
......
......@@ -1085,6 +1085,12 @@ class DeepseekV2AttentionMLA(nn.Module):
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
):
if is_dp_attention_enabled():
if sum(forward_batch.extend_prefix_lens_cpu) == 0:
return AttnForwardMethod.MHA
else:
return AttnForwardMethod.MLA
else:
return AttnForwardMethod.MHA
else:
return AttnForwardMethod.MLA
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment