"vscode:/vscode.git/clone" did not exist on "ef5136a745138896d080bf5bcac13377f7672b77"
Unverified Commit b65db028 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Tiny cleanup deepseek_v2.py (#11163)

parent 948278f1
......@@ -234,6 +234,13 @@ class FusedMoE(torch.nn.Module):
self.quant_method.create_moe_runner(self, self.moe_runner_config)
self.dispatcher = StandardDispatcher()
self.should_fuse_routed_scaling_factor_in_topk = isinstance(
self.quant_method, ModelOptNvFp4FusedMoEMethod
) or (
isinstance(self.quant_method, Fp8MoEMethod)
and self.quant_method.use_cutlass_fused_experts_fp8
)
def _load_per_tensor_weight_scale(
self,
shard_id: str,
......@@ -936,12 +943,6 @@ class FusedMoE(torch.nn.Module):
for shard_id in ["w1", "w2", "w3"]
]
def should_fuse_routed_scaling_factor_in_topk(self):
return isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod) or (
isinstance(self.quant_method, Fp8MoEMethod)
and self.quant_method.use_cutlass_fused_experts_fp8
)
class FlashInferFusedMoE(FusedMoE):
def __init__(self, *args, **kwargs):
......
......@@ -166,16 +166,15 @@ if _is_cuda:
elif _is_cpu and _is_cpu_amx_available:
pass
elif _is_hip:
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
decode_attention_fwd_grouped_rope,
)
from sglang.srt.layers.quantization.awq_triton import (
awq_dequantize_triton as awq_dequantize,
)
else:
pass
if _is_hip:
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
decode_attention_fwd_grouped_rope,
)
_is_flashinfer_available = is_flashinfer_available()
_is_sm100_supported = is_cuda() and is_sm100_supported()
......@@ -229,7 +228,7 @@ def _dispatch_mla_subtype(attn, forward_batch):
return AttnForwardMethod.MLA
class BackendRegistry:
class AttentionBackendRegistry:
_handlers = {}
@classmethod
......@@ -241,7 +240,7 @@ class BackendRegistry:
return cls._handlers.get(backend_name, cls._handlers.get("triton"))
def handle_ascend(attn, forward_batch):
def handle_attention_ascend(attn, forward_batch):
if (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
......@@ -268,7 +267,7 @@ def _is_extend_without_speculative(forward_batch):
)
def _handle_backend(attn, forward_batch, backend_name):
def _handle_attention_backend(attn, forward_batch, backend_name):
sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
disable_ragged = (
backend_name in ["flashinfer", "flashmla"]
......@@ -290,28 +289,28 @@ def _handle_backend(attn, forward_batch, backend_name):
return _dispatch_mla_subtype(attn, forward_batch)
def handle_flashinfer(attn, forward_batch):
return _handle_backend(attn, forward_batch, "flashinfer")
def handle_attention_flashinfer(attn, forward_batch):
return _handle_attention_backend(attn, forward_batch, "flashinfer")
def handle_fa3(attn, forward_batch):
return _handle_backend(attn, forward_batch, "fa3")
def handle_attention_fa3(attn, forward_batch):
return _handle_attention_backend(attn, forward_batch, "fa3")
def handle_flashmla(attn, forward_batch):
return _handle_backend(attn, forward_batch, "flashmla")
def handle_attention_flashmla(attn, forward_batch):
return _handle_attention_backend(attn, forward_batch, "flashmla")
def handle_cutlass_mla(attn, forward_batch):
return _handle_backend(attn, forward_batch, "cutlass_mla")
def handle_attention_cutlass_mla(attn, forward_batch):
return _handle_attention_backend(attn, forward_batch, "cutlass_mla")
def handle_fa4(attn, forward_batch):
def handle_attention_fa4(attn, forward_batch):
# TODO(cicirori): use FA4 MHA for DeepSeekV3 for now
return AttnForwardMethod.MHA_CHUNKED_KV
def handle_trtllm_mla(attn, forward_batch):
def handle_attention_trtllm_mla(attn, forward_batch):
sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
if _is_extend_without_speculative(forward_batch) and (
not attn.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0
......@@ -321,7 +320,7 @@ def handle_trtllm_mla(attn, forward_batch):
return _dispatch_mla_subtype(attn, forward_batch)
def handle_aiter(attn, forward_batch):
def handle_attention_aiter(attn, forward_batch):
if _is_extend_without_speculative(forward_batch):
if is_dp_attention_enabled():
if sum(forward_batch.extend_prefix_lens_cpu) == 0:
......@@ -334,7 +333,7 @@ def handle_aiter(attn, forward_batch):
return AttnForwardMethod.MLA
def handle_triton(attn, forward_batch):
def handle_attention_triton(attn, forward_batch):
if (
_is_extend_without_speculative(forward_batch)
and sum(forward_batch.extend_prefix_lens_cpu) == 0
......@@ -541,7 +540,7 @@ class DeepseekV2MoE(nn.Module):
correction_bias=self.gate.e_score_correction_bias,
quant_config=quant_config,
routed_scaling_factor=self.routed_scaling_factor,
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk,
# Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized
# and requires the output format to be standard. We use quant_config to determine the output format.
output_format=TopKOutputFormat.STANDARD if quant_config is None else None,
......@@ -838,13 +837,13 @@ class DeepseekV2MoE(nn.Module):
if shared_output is not None:
x = shared_output
if self.experts.should_fuse_routed_scaling_factor_in_topk():
if self.experts.should_fuse_routed_scaling_factor_in_topk:
x.add_(final_hidden_states)
else:
x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
final_hidden_states = x
else:
if not self.experts.should_fuse_routed_scaling_factor_in_topk():
if not self.experts.should_fuse_routed_scaling_factor_in_topk:
final_hidden_states *= self.routed_scaling_factor
return final_hidden_states
......@@ -1217,7 +1216,7 @@ class DeepseekV2AttentionMLA(nn.Module):
attention_backend = global_server_args_dict["prefill_attention_backend"]
self.current_attention_backend = attention_backend
handler = BackendRegistry.get_handler(attention_backend)
handler = AttentionBackendRegistry.get_handler(attention_backend)
return handler(self, forward_batch)
def op_prepare(self, state):
......@@ -3092,15 +3091,15 @@ class DeepseekV2ForCausalLM(nn.Module):
)
BackendRegistry.register("ascend", handle_ascend)
BackendRegistry.register("flashinfer", handle_flashinfer)
BackendRegistry.register("fa3", handle_fa3)
BackendRegistry.register("flashmla", handle_flashmla)
BackendRegistry.register("cutlass_mla", handle_cutlass_mla)
BackendRegistry.register("fa4", handle_fa4)
BackendRegistry.register("trtllm_mla", handle_trtllm_mla)
BackendRegistry.register("aiter", handle_aiter)
BackendRegistry.register("triton", handle_triton)
AttentionBackendRegistry.register("ascend", handle_attention_ascend)
AttentionBackendRegistry.register("flashinfer", handle_attention_flashinfer)
AttentionBackendRegistry.register("fa3", handle_attention_fa3)
AttentionBackendRegistry.register("flashmla", handle_attention_flashmla)
AttentionBackendRegistry.register("cutlass_mla", handle_attention_cutlass_mla)
AttentionBackendRegistry.register("fa4", handle_attention_fa4)
AttentionBackendRegistry.register("trtllm_mla", handle_attention_trtllm_mla)
AttentionBackendRegistry.register("aiter", handle_attention_aiter)
AttentionBackendRegistry.register("triton", handle_attention_triton)
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment