Unverified Commit 918e3d4c authored by kk's avatar kk Committed by GitHub
Browse files

Fix accuracy drop of dsv3 run in dp enablement (#8677)


Co-authored-by: default avatarwunhuang <wunhuang@amd.com>
parent e9697374
...@@ -18,7 +18,10 @@ import triton.language as tl ...@@ -18,7 +18,10 @@ import triton.language as tl
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import (
get_attention_tp_size,
is_dp_attention_enabled,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -154,6 +157,8 @@ class AiterAttnBackend(AttentionBackend): ...@@ -154,6 +157,8 @@ class AiterAttnBackend(AttentionBackend):
(max_bs + 1,), dtype=torch.int32, device=model_runner.device (max_bs + 1,), dtype=torch.int32, device=model_runner.device
) )
self.enable_dp_attention = is_dp_attention_enabled()
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init auxiliary variables for triton attention backend.""" """Init auxiliary variables for triton attention backend."""
...@@ -302,19 +307,19 @@ class AiterAttnBackend(AttentionBackend): ...@@ -302,19 +307,19 @@ class AiterAttnBackend(AttentionBackend):
if self.use_mla: if self.use_mla:
self.mla_indices_updater_prefill.update( self.mla_indices_updater_prefill.update(
forward_batch.req_pool_indices, forward_batch.req_pool_indices,
forward_batch.extend_prefix_lens, forward_batch.seq_lens,
sum(forward_batch.extend_prefix_lens_cpu), forward_batch.seq_lens_sum,
forward_batch.extend_seq_lens, forward_batch.extend_seq_lens,
max(forward_batch.extend_seq_lens_cpu), forward_batch.extend_seq_lens.max().item(),
forward_batch.seq_lens_cpu.max().item(), forward_batch.seq_lens.max().item(),
spec_info=None, spec_info=None,
) )
self.mla_indices_updater_prefill.kv_indptr += (
self.mla_indices_updater_prefill.qo_indptr kv_indices = self.mla_indices_updater_prefill.kv_indices
)
self.forward_metadata = ForwardMetadata( self.forward_metadata = ForwardMetadata(
self.mla_indices_updater_prefill.kv_indptr, self.mla_indices_updater_prefill.kv_indptr,
self.mla_indices_updater_prefill.kv_indices, kv_indices,
self.mla_indices_updater_prefill.qo_indptr, self.mla_indices_updater_prefill.qo_indptr,
self.kv_last_page_len[:bs], self.kv_last_page_len[:bs],
self.mla_indices_updater_prefill.max_q_len, self.mla_indices_updater_prefill.max_q_len,
...@@ -614,66 +619,86 @@ class AiterAttnBackend(AttentionBackend): ...@@ -614,66 +619,86 @@ class AiterAttnBackend(AttentionBackend):
assert len(k.shape) == 3 assert len(k.shape) == 3
assert len(v.shape) == 3 assert len(v.shape) == 3
if kv_indices.shape[0] == 0: if forward_batch.forward_mode.is_extend():
o = flash_attn_varlen_func( if kv_indices.shape[0] == 0:
q, o = flash_attn_varlen_func(
k, q,
v, k,
qo_indptr, v,
qo_indptr, qo_indptr,
max_q_len, qo_indptr,
max_q_len, max_q_len,
softmax_scale=layer.scaling, max_q_len,
causal=True, softmax_scale=layer.scaling,
) causal=True,
return o )
elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim): return o
K_Buffer = torch.index_select(K_Buffer, 0, kv_indices) elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim):
kvc, k_pe = torch.split( K_Buffer = torch.index_select(K_Buffer, 0, kv_indices)
K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1 kvc, k_pe = torch.split(
) K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1
kvprefix = layer.kv_b_proj(kvc.contiguous())[0] )
kvprefix = layer.kv_b_proj(kvc.contiguous())[0]
kvprefix = kvprefix.view( kvprefix = kvprefix.view(
-1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim -1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim
) )
k_prefix, v_prefix = torch.split( k_prefix, v_prefix = torch.split(
kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1 kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1
) )
k_prefix = torch.cat( k_prefix = torch.cat(
[ [
k_prefix, k_prefix,
torch.broadcast_to( torch.broadcast_to(
k_pe, k_pe,
(k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]), (k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]),
), ),
], ],
dim=-1, dim=-1,
) )
assert ( assert (
forward_batch.extend_prefix_lens.shape forward_batch.extend_prefix_lens.shape
== forward_batch.extend_seq_lens.shape == forward_batch.extend_seq_lens.shape
) )
k_prefix = torch.split(k_prefix, forward_batch.extend_prefix_lens_cpu)
k_extend = torch.split(k, forward_batch.extend_seq_lens_cpu) k = k_prefix
assert len(k_prefix) == len(forward_batch.extend_prefix_lens_cpu) v = v_prefix
k = torch.cat([x for el in zip(k_prefix, k_extend) for x in el])
v_prefix = torch.split(v_prefix, forward_batch.extend_prefix_lens_cpu) o = flash_attn_varlen_func(
v_extend = torch.split(v, forward_batch.extend_seq_lens_cpu) q,
v = torch.cat([x for el in zip(v_prefix, v_extend) for x in el]) k,
v,
o = flash_attn_varlen_func( qo_indptr,
q, kv_indptr,
k, max_q_len,
v, max_kv_len,
qo_indptr, softmax_scale=layer.scaling,
kv_indptr, causal=True,
max_q_len, )
max_kv_len, return o
softmax_scale=layer.scaling,
causal=True, else:
) if layer.qk_head_dim != layer.v_head_dim:
return o o = q.new_empty(
(q.shape[0], layer.tp_q_head_num * layer.v_head_dim)
)
else:
o = torch.empty_like(q)
mla_prefill_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
K_Buffer.view(-1, 1, 1, layer.qk_head_dim),
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
qo_indptr,
kv_indptr,
kv_indices,
self.forward_metadata.kv_last_page_len,
self.forward_metadata.max_q_len,
layer.scaling,
layer.logit_cap,
)
K_Buffer = K_Buffer.view(-1, layer.tp_k_head_num, layer.qk_head_dim)
return o
elif forward_batch.forward_mode.is_target_verify(): elif forward_batch.forward_mode.is_target_verify():
o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim)) o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim))
mla_decode_fwd( mla_decode_fwd(
......
...@@ -1085,7 +1085,13 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1085,7 +1085,13 @@ class DeepseekV2AttentionMLA(nn.Module):
and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend() and not forward_batch.forward_mode.is_draft_extend()
): ):
return AttnForwardMethod.MHA if is_dp_attention_enabled():
if sum(forward_batch.extend_prefix_lens_cpu) == 0:
return AttnForwardMethod.MHA
else:
return AttnForwardMethod.MLA
else:
return AttnForwardMethod.MHA
else: else:
return AttnForwardMethod.MLA return AttnForwardMethod.MLA
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment