Unverified Commit d6e1d28c authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Refactor DeepSeek attention dispatching (#6476)

parent 7c347259
...@@ -127,6 +127,9 @@ class AttnForwardMethod(IntEnum): ...@@ -127,6 +127,9 @@ class AttnForwardMethod(IntEnum):
# This method can avoid OOM when prefix lengths are long. # This method can avoid OOM when prefix lengths are long.
MHA_CHUNKED_KV = auto() MHA_CHUNKED_KV = auto()
# Use MLA but with fused RoPE
MLA_FUSED_ROPE = auto()
class DeepseekV2MLP(nn.Module): class DeepseekV2MLP(nn.Module):
def __init__( def __init__(
...@@ -609,6 +612,18 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -609,6 +612,18 @@ class DeepseekV2AttentionMLA(nn.Module):
def dispatch_attn_forward_method( def dispatch_attn_forward_method(
self, forward_batch: ForwardBatch self, forward_batch: ForwardBatch
) -> AttnForwardMethod: ) -> AttnForwardMethod:
def _dispatch_mla_subtype():
if _is_hip:
if (
self.rocm_fused_decode_mla
and forward_batch.forward_mode.is_decode()
):
return AttnForwardMethod.MLA_FUSED_ROPE
else:
return AttnForwardMethod.MLA
else:
return AttnForwardMethod.MLA
if self.attention_backend == "flashinfer": if self.attention_backend == "flashinfer":
# Flashinfer MLA: Do not absorb when enabling ragged prefill # Flashinfer MLA: Do not absorb when enabling ragged prefill
if ( if (
...@@ -620,7 +635,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -620,7 +635,7 @@ class DeepseekV2AttentionMLA(nn.Module):
): ):
return AttnForwardMethod.MHA return AttnForwardMethod.MHA
else: else:
return AttnForwardMethod.MLA return _dispatch_mla_subtype()
elif self.attention_backend == "fa3": elif self.attention_backend == "fa3":
# Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences. # Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
if forward_batch.extend_prefix_lens_cpu is not None: if forward_batch.extend_prefix_lens_cpu is not None:
...@@ -637,7 +652,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -637,7 +652,7 @@ class DeepseekV2AttentionMLA(nn.Module):
): ):
return AttnForwardMethod.MHA_CHUNKED_KV return AttnForwardMethod.MHA_CHUNKED_KV
else: else:
return AttnForwardMethod.MLA return _dispatch_mla_subtype()
else: else:
# Triton: Use normal computation for prefill and use weight absorption for extend/decode # Triton: Use normal computation for prefill and use weight absorption for extend/decode
if ( if (
...@@ -648,7 +663,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -648,7 +663,7 @@ class DeepseekV2AttentionMLA(nn.Module):
): ):
return AttnForwardMethod.MHA return AttnForwardMethod.MHA
else: else:
return AttnForwardMethod.MLA return _dispatch_mla_subtype()
def forward( def forward(
self, self,
...@@ -671,23 +686,16 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -671,23 +686,16 @@ class DeepseekV2AttentionMLA(nn.Module):
return self.forward_normal_chunked_kv( return self.forward_normal_chunked_kv(
positions, hidden_states, forward_batch positions, hidden_states, forward_batch
) )
elif attn_forward_method == AttnForwardMethod.MLA:
return self.forward_absorb(
positions, hidden_states, forward_batch, zero_allocator
)
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
return self.forward_absorb_fused_mla_rope(
positions, hidden_states, forward_batch
)
else: else:
if _is_hip: raise NotImplementedError
if (
self.rocm_fused_decode_mla
and forward_batch.forward_mode.is_decode()
):
return self.forward_absorb_fused_mla_rope(
positions, hidden_states, forward_batch
)
else:
return self.forward_absorb(
positions, hidden_states, forward_batch, zero_allocator
)
else:
return self.forward_absorb(
positions, hidden_states, forward_batch, zero_allocator
)
def forward_normal( def forward_normal(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment