Unverified Commit d6e1d28c authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Refactor DeepSeek attention dispatching (#6476)

parent 7c347259
......@@ -127,6 +127,9 @@ class AttnForwardMethod(IntEnum):
# This method can avoid OOM when prefix lengths are long.
MHA_CHUNKED_KV = auto()
# Use MLA but with fused RoPE
MLA_FUSED_ROPE = auto()
class DeepseekV2MLP(nn.Module):
def __init__(
......@@ -609,6 +612,18 @@ class DeepseekV2AttentionMLA(nn.Module):
def dispatch_attn_forward_method(
self, forward_batch: ForwardBatch
) -> AttnForwardMethod:
def _dispatch_mla_subtype():
if _is_hip:
if (
self.rocm_fused_decode_mla
and forward_batch.forward_mode.is_decode()
):
return AttnForwardMethod.MLA_FUSED_ROPE
else:
return AttnForwardMethod.MLA
else:
return AttnForwardMethod.MLA
if self.attention_backend == "flashinfer":
# Flashinfer MLA: Do not absorb when enabling ragged prefill
if (
......@@ -620,7 +635,7 @@ class DeepseekV2AttentionMLA(nn.Module):
):
return AttnForwardMethod.MHA
else:
return AttnForwardMethod.MLA
return _dispatch_mla_subtype()
elif self.attention_backend == "fa3":
# Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
if forward_batch.extend_prefix_lens_cpu is not None:
......@@ -637,7 +652,7 @@ class DeepseekV2AttentionMLA(nn.Module):
):
return AttnForwardMethod.MHA_CHUNKED_KV
else:
return AttnForwardMethod.MLA
return _dispatch_mla_subtype()
else:
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
if (
......@@ -648,7 +663,7 @@ class DeepseekV2AttentionMLA(nn.Module):
):
return AttnForwardMethod.MHA
else:
return AttnForwardMethod.MLA
return _dispatch_mla_subtype()
def forward(
self,
......@@ -671,23 +686,16 @@ class DeepseekV2AttentionMLA(nn.Module):
return self.forward_normal_chunked_kv(
positions, hidden_states, forward_batch
)
elif attn_forward_method == AttnForwardMethod.MLA:
return self.forward_absorb(
positions, hidden_states, forward_batch, zero_allocator
)
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
return self.forward_absorb_fused_mla_rope(
positions, hidden_states, forward_batch
)
else:
if _is_hip:
if (
self.rocm_fused_decode_mla
and forward_batch.forward_mode.is_decode()
):
return self.forward_absorb_fused_mla_rope(
positions, hidden_states, forward_batch
)
else:
return self.forward_absorb(
positions, hidden_states, forward_batch, zero_allocator
)
else:
return self.forward_absorb(
positions, hidden_states, forward_batch, zero_allocator
)
raise NotImplementedError
def forward_normal(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment