Unverified Commit b17e67df authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

[Auto Sync] Update deepseek_v2.py (20250920) (#10683)


Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 8ecef73f
......@@ -177,6 +177,20 @@ _is_sm100_supported = is_cuda() and is_sm100_supported()
logger = logging.getLogger(__name__)
FORWARD_ABSORB_CORE_ATTENTION_BACKENDS = [
"fa3",
"flashinfer",
"cutlass_mla",
"trtllm_mla",
"ascend",
]
def add_forward_absorb_core_attention_backend(backend_name):
if backend_name not in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.append(backend_name)
logger.info(f"Added {backend_name} to FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.")
class AttnForwardMethod(IntEnum):
# Use multi-head attention
......@@ -196,6 +210,134 @@ class AttnForwardMethod(IntEnum):
MLA_FUSED_ROPE_CPU = auto()
def _dispatch_mla_subtype(attn, forward_batch):
if _is_hip:
if attn.rocm_fused_decode_mla and forward_batch.forward_mode.is_decode():
return AttnForwardMethod.MLA_FUSED_ROPE
else:
return AttnForwardMethod.MLA
else:
if hasattr(attn, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(attn):
return AttnForwardMethod.MLA_FUSED_ROPE_CPU
else:
return AttnForwardMethod.MLA
class BackendRegistry:
_handlers = {}
@classmethod
def register(cls, backend_name, handler_func):
cls._handlers[backend_name] = handler_func
@classmethod
def get_handler(cls, backend_name):
return cls._handlers.get(backend_name, cls._handlers.get("triton"))
def handle_ascend(attn, forward_batch):
if (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
):
return AttnForwardMethod.MHA
else:
return AttnForwardMethod.MLA
def _get_sum_extend_prefix_lens(forward_batch):
return (
sum(forward_batch.extend_prefix_lens_cpu)
if forward_batch.extend_prefix_lens_cpu is not None
else 0
)
def _is_extend_without_speculative(forward_batch):
return (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
)
def _handle_backend(attn, forward_batch, backend_name):
sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
disable_ragged = (
backend_name in ["flashinfer", "flashmla"]
) and attn.flashinfer_mla_disable_ragged
if (
not disable_ragged
and _is_extend_without_speculative(forward_batch)
and (
(
sum_extend_prefix_lens >= attn.chunked_prefix_cache_threshold
and not attn.disable_chunked_prefix_cache
)
or sum_extend_prefix_lens == 0
)
):
return AttnForwardMethod.MHA_CHUNKED_KV
else:
return _dispatch_mla_subtype(attn, forward_batch)
def handle_flashinfer(attn, forward_batch):
return _handle_backend(attn, forward_batch, "flashinfer")
def handle_fa3(attn, forward_batch):
return _handle_backend(attn, forward_batch, "fa3")
def handle_flashmla(attn, forward_batch):
return _handle_backend(attn, forward_batch, "flashmla")
def handle_cutlass_mla(attn, forward_batch):
return _handle_backend(attn, forward_batch, "cutlass_mla")
def handle_fa4(attn, forward_batch):
# TODO(cicirori): use FA4 MHA for DeepSeekV3 for now
return AttnForwardMethod.MHA_CHUNKED_KV
def handle_trtllm_mla(attn, forward_batch):
sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
if _is_extend_without_speculative(forward_batch) and (
not attn.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0
):
return AttnForwardMethod.MHA_CHUNKED_KV
else:
return _dispatch_mla_subtype(attn, forward_batch)
def handle_aiter(attn, forward_batch):
if _is_extend_without_speculative(forward_batch):
if is_dp_attention_enabled():
if sum(forward_batch.extend_prefix_lens_cpu) == 0:
return AttnForwardMethod.MHA
else:
return AttnForwardMethod.MLA
else:
return AttnForwardMethod.MHA
else:
return AttnForwardMethod.MLA
def handle_triton(attn, forward_batch):
if (
_is_extend_without_speculative(forward_batch)
and sum(forward_batch.extend_prefix_lens_cpu) == 0
):
return AttnForwardMethod.MHA
else:
return _dispatch_mla_subtype(attn, forward_batch)
class DeepseekV2MLP(nn.Module):
def __init__(
self,
......@@ -1039,23 +1181,6 @@ class DeepseekV2AttentionMLA(nn.Module):
def dispatch_attn_forward_method(
self, forward_batch: ForwardBatch
) -> AttnForwardMethod:
def _dispatch_mla_subtype():
if _is_hip:
if (
self.rocm_fused_decode_mla
and forward_batch.forward_mode.is_decode()
):
return AttnForwardMethod.MLA_FUSED_ROPE
else:
return AttnForwardMethod.MLA
else:
if hasattr(self, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(
self
):
return AttnForwardMethod.MLA_FUSED_ROPE_CPU
else:
return AttnForwardMethod.MLA
# Determine attention backend used by current forward batch
if forward_batch.forward_mode.is_decode_or_idle():
attention_backend = global_server_args_dict["decode_attention_backend"]
......@@ -1072,94 +1197,8 @@ class DeepseekV2AttentionMLA(nn.Module):
attention_backend = global_server_args_dict["prefill_attention_backend"]
self.current_attention_backend = attention_backend
if attention_backend == "ascend":
if (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
):
return AttnForwardMethod.MHA
else:
return AttnForwardMethod.MLA
elif (
attention_backend == "flashinfer"
or attention_backend == "fa3"
or attention_backend == "flashmla"
or attention_backend == "cutlass_mla"
):
# Use MHA with chunked KV cache when prefilling on long sequences.
sum_extend_prefix_lens = (
sum(forward_batch.extend_prefix_lens_cpu)
if forward_batch.extend_prefix_lens_cpu is not None
else 0
)
# Flashinfer MLA: Do not absorb when enabling ragged prefill
disable_ragged = (
attention_backend == "flashinfer" or attention_backend == "flashmla"
) and self.flashinfer_mla_disable_ragged
if (
not disable_ragged
and forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and (
(
sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
and not self.disable_chunked_prefix_cache
)
or sum_extend_prefix_lens == 0
)
):
return AttnForwardMethod.MHA_CHUNKED_KV
else:
return _dispatch_mla_subtype()
elif attention_backend == "fa4":
# TODO(cicirori): use FA4 MHA for DeepSeekV3 for now
return AttnForwardMethod.MHA_CHUNKED_KV
elif attention_backend == "trtllm_mla":
sum_extend_prefix_lens = (
sum(forward_batch.extend_prefix_lens_cpu)
if forward_batch.extend_prefix_lens_cpu is not None
else 0
)
if (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and (
not self.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0
)
):
return AttnForwardMethod.MHA_CHUNKED_KV
else:
return _dispatch_mla_subtype()
elif attention_backend == "aiter":
if (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
):
if is_dp_attention_enabled():
if sum(forward_batch.extend_prefix_lens_cpu) == 0:
return AttnForwardMethod.MHA
else:
return AttnForwardMethod.MLA
else:
return AttnForwardMethod.MHA
else:
return AttnForwardMethod.MLA
else:
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
if (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and sum(forward_batch.extend_prefix_lens_cpu) == 0
):
return AttnForwardMethod.MHA
else:
return _dispatch_mla_subtype()
handler = BackendRegistry.get_handler(attention_backend)
return handler(self, forward_batch)
def op_prepare(self, state):
state.attn_intermediate_state = self.forward_prepare(
......@@ -1456,13 +1495,7 @@ class DeepseekV2AttentionMLA(nn.Module):
def forward_absorb_core(
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
):
if (
self.current_attention_backend == "fa3"
or self.current_attention_backend == "flashinfer"
or self.current_attention_backend == "cutlass_mla"
or self.current_attention_backend == "trtllm_mla"
or self.current_attention_backend == "ascend"
):
if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
extra_args = {}
if self._fuse_rope_for_trtllm_mla(forward_batch):
extra_args = {
......@@ -3016,6 +3049,17 @@ class DeepseekV2ForCausalLM(nn.Module):
)
BackendRegistry.register("ascend", handle_ascend)
BackendRegistry.register("flashinfer", handle_flashinfer)
BackendRegistry.register("fa3", handle_fa3)
BackendRegistry.register("flashmla", handle_flashmla)
BackendRegistry.register("cutlass_mla", handle_cutlass_mla)
BackendRegistry.register("fa4", handle_fa4)
BackendRegistry.register("trtllm_mla", handle_trtllm_mla)
BackendRegistry.register("aiter", handle_aiter)
BackendRegistry.register("triton", handle_triton)
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment