Unverified Commit b65db028 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Tiny cleanup deepseek_v2.py (#11163)

parent 948278f1
...@@ -234,6 +234,13 @@ class FusedMoE(torch.nn.Module): ...@@ -234,6 +234,13 @@ class FusedMoE(torch.nn.Module):
self.quant_method.create_moe_runner(self, self.moe_runner_config) self.quant_method.create_moe_runner(self, self.moe_runner_config)
self.dispatcher = StandardDispatcher() self.dispatcher = StandardDispatcher()
self.should_fuse_routed_scaling_factor_in_topk = isinstance(
self.quant_method, ModelOptNvFp4FusedMoEMethod
) or (
isinstance(self.quant_method, Fp8MoEMethod)
and self.quant_method.use_cutlass_fused_experts_fp8
)
def _load_per_tensor_weight_scale( def _load_per_tensor_weight_scale(
self, self,
shard_id: str, shard_id: str,
...@@ -936,12 +943,6 @@ class FusedMoE(torch.nn.Module): ...@@ -936,12 +943,6 @@ class FusedMoE(torch.nn.Module):
for shard_id in ["w1", "w2", "w3"] for shard_id in ["w1", "w2", "w3"]
] ]
def should_fuse_routed_scaling_factor_in_topk(self):
return isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod) or (
isinstance(self.quant_method, Fp8MoEMethod)
and self.quant_method.use_cutlass_fused_experts_fp8
)
class FlashInferFusedMoE(FusedMoE): class FlashInferFusedMoE(FusedMoE):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
......
...@@ -166,16 +166,15 @@ if _is_cuda: ...@@ -166,16 +166,15 @@ if _is_cuda:
elif _is_cpu and _is_cpu_amx_available: elif _is_cpu and _is_cpu_amx_available:
pass pass
elif _is_hip: elif _is_hip:
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
decode_attention_fwd_grouped_rope,
)
from sglang.srt.layers.quantization.awq_triton import ( from sglang.srt.layers.quantization.awq_triton import (
awq_dequantize_triton as awq_dequantize, awq_dequantize_triton as awq_dequantize,
) )
else: else:
pass pass
if _is_hip:
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
decode_attention_fwd_grouped_rope,
)
_is_flashinfer_available = is_flashinfer_available() _is_flashinfer_available = is_flashinfer_available()
_is_sm100_supported = is_cuda() and is_sm100_supported() _is_sm100_supported = is_cuda() and is_sm100_supported()
...@@ -229,7 +228,7 @@ def _dispatch_mla_subtype(attn, forward_batch): ...@@ -229,7 +228,7 @@ def _dispatch_mla_subtype(attn, forward_batch):
return AttnForwardMethod.MLA return AttnForwardMethod.MLA
class BackendRegistry: class AttentionBackendRegistry:
_handlers = {} _handlers = {}
@classmethod @classmethod
...@@ -241,7 +240,7 @@ class BackendRegistry: ...@@ -241,7 +240,7 @@ class BackendRegistry:
return cls._handlers.get(backend_name, cls._handlers.get("triton")) return cls._handlers.get(backend_name, cls._handlers.get("triton"))
def handle_ascend(attn, forward_batch): def handle_attention_ascend(attn, forward_batch):
if ( if (
forward_batch.forward_mode.is_extend() forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_target_verify()
...@@ -268,7 +267,7 @@ def _is_extend_without_speculative(forward_batch): ...@@ -268,7 +267,7 @@ def _is_extend_without_speculative(forward_batch):
) )
def _handle_backend(attn, forward_batch, backend_name): def _handle_attention_backend(attn, forward_batch, backend_name):
sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch) sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
disable_ragged = ( disable_ragged = (
backend_name in ["flashinfer", "flashmla"] backend_name in ["flashinfer", "flashmla"]
...@@ -290,28 +289,28 @@ def _handle_backend(attn, forward_batch, backend_name): ...@@ -290,28 +289,28 @@ def _handle_backend(attn, forward_batch, backend_name):
return _dispatch_mla_subtype(attn, forward_batch) return _dispatch_mla_subtype(attn, forward_batch)
def handle_flashinfer(attn, forward_batch): def handle_attention_flashinfer(attn, forward_batch):
return _handle_backend(attn, forward_batch, "flashinfer") return _handle_attention_backend(attn, forward_batch, "flashinfer")
def handle_fa3(attn, forward_batch): def handle_attention_fa3(attn, forward_batch):
return _handle_backend(attn, forward_batch, "fa3") return _handle_attention_backend(attn, forward_batch, "fa3")
def handle_flashmla(attn, forward_batch): def handle_attention_flashmla(attn, forward_batch):
return _handle_backend(attn, forward_batch, "flashmla") return _handle_attention_backend(attn, forward_batch, "flashmla")
def handle_cutlass_mla(attn, forward_batch): def handle_attention_cutlass_mla(attn, forward_batch):
return _handle_backend(attn, forward_batch, "cutlass_mla") return _handle_attention_backend(attn, forward_batch, "cutlass_mla")
def handle_fa4(attn, forward_batch): def handle_attention_fa4(attn, forward_batch):
# TODO(cicirori): use FA4 MHA for DeepSeekV3 for now # TODO(cicirori): use FA4 MHA for DeepSeekV3 for now
return AttnForwardMethod.MHA_CHUNKED_KV return AttnForwardMethod.MHA_CHUNKED_KV
def handle_trtllm_mla(attn, forward_batch): def handle_attention_trtllm_mla(attn, forward_batch):
sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch) sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
if _is_extend_without_speculative(forward_batch) and ( if _is_extend_without_speculative(forward_batch) and (
not attn.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0 not attn.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0
...@@ -321,7 +320,7 @@ def handle_trtllm_mla(attn, forward_batch): ...@@ -321,7 +320,7 @@ def handle_trtllm_mla(attn, forward_batch):
return _dispatch_mla_subtype(attn, forward_batch) return _dispatch_mla_subtype(attn, forward_batch)
def handle_aiter(attn, forward_batch): def handle_attention_aiter(attn, forward_batch):
if _is_extend_without_speculative(forward_batch): if _is_extend_without_speculative(forward_batch):
if is_dp_attention_enabled(): if is_dp_attention_enabled():
if sum(forward_batch.extend_prefix_lens_cpu) == 0: if sum(forward_batch.extend_prefix_lens_cpu) == 0:
...@@ -334,7 +333,7 @@ def handle_aiter(attn, forward_batch): ...@@ -334,7 +333,7 @@ def handle_aiter(attn, forward_batch):
return AttnForwardMethod.MLA return AttnForwardMethod.MLA
def handle_triton(attn, forward_batch): def handle_attention_triton(attn, forward_batch):
if ( if (
_is_extend_without_speculative(forward_batch) _is_extend_without_speculative(forward_batch)
and sum(forward_batch.extend_prefix_lens_cpu) == 0 and sum(forward_batch.extend_prefix_lens_cpu) == 0
...@@ -541,7 +540,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -541,7 +540,7 @@ class DeepseekV2MoE(nn.Module):
correction_bias=self.gate.e_score_correction_bias, correction_bias=self.gate.e_score_correction_bias,
quant_config=quant_config, quant_config=quant_config,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(), apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk,
# Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized # Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized
# and requires the output format to be standard. We use quant_config to determine the output format. # and requires the output format to be standard. We use quant_config to determine the output format.
output_format=TopKOutputFormat.STANDARD if quant_config is None else None, output_format=TopKOutputFormat.STANDARD if quant_config is None else None,
...@@ -838,13 +837,13 @@ class DeepseekV2MoE(nn.Module): ...@@ -838,13 +837,13 @@ class DeepseekV2MoE(nn.Module):
if shared_output is not None: if shared_output is not None:
x = shared_output x = shared_output
if self.experts.should_fuse_routed_scaling_factor_in_topk(): if self.experts.should_fuse_routed_scaling_factor_in_topk:
x.add_(final_hidden_states) x.add_(final_hidden_states)
else: else:
x.add_(final_hidden_states, alpha=self.routed_scaling_factor) x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
final_hidden_states = x final_hidden_states = x
else: else:
if not self.experts.should_fuse_routed_scaling_factor_in_topk(): if not self.experts.should_fuse_routed_scaling_factor_in_topk:
final_hidden_states *= self.routed_scaling_factor final_hidden_states *= self.routed_scaling_factor
return final_hidden_states return final_hidden_states
...@@ -1217,7 +1216,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1217,7 +1216,7 @@ class DeepseekV2AttentionMLA(nn.Module):
attention_backend = global_server_args_dict["prefill_attention_backend"] attention_backend = global_server_args_dict["prefill_attention_backend"]
self.current_attention_backend = attention_backend self.current_attention_backend = attention_backend
handler = BackendRegistry.get_handler(attention_backend) handler = AttentionBackendRegistry.get_handler(attention_backend)
return handler(self, forward_batch) return handler(self, forward_batch)
def op_prepare(self, state): def op_prepare(self, state):
...@@ -3092,15 +3091,15 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -3092,15 +3091,15 @@ class DeepseekV2ForCausalLM(nn.Module):
) )
BackendRegistry.register("ascend", handle_ascend) AttentionBackendRegistry.register("ascend", handle_attention_ascend)
BackendRegistry.register("flashinfer", handle_flashinfer) AttentionBackendRegistry.register("flashinfer", handle_attention_flashinfer)
BackendRegistry.register("fa3", handle_fa3) AttentionBackendRegistry.register("fa3", handle_attention_fa3)
BackendRegistry.register("flashmla", handle_flashmla) AttentionBackendRegistry.register("flashmla", handle_attention_flashmla)
BackendRegistry.register("cutlass_mla", handle_cutlass_mla) AttentionBackendRegistry.register("cutlass_mla", handle_attention_cutlass_mla)
BackendRegistry.register("fa4", handle_fa4) AttentionBackendRegistry.register("fa4", handle_attention_fa4)
BackendRegistry.register("trtllm_mla", handle_trtllm_mla) AttentionBackendRegistry.register("trtllm_mla", handle_attention_trtllm_mla)
BackendRegistry.register("aiter", handle_aiter) AttentionBackendRegistry.register("aiter", handle_attention_aiter)
BackendRegistry.register("triton", handle_triton) AttentionBackendRegistry.register("triton", handle_attention_triton)
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment