Unverified Commit 012584ec authored by Jinyang Yuan's avatar Jinyang Yuan Committed by GitHub
Browse files

perf: Avoid unnecessary data type conversions for DeepSeek-V3 on Blackwell (#9834)


Signed-off-by: default avatarJinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
parent 90dfe3de
...@@ -655,6 +655,7 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -655,6 +655,7 @@ def _set_envs_and_config(server_args: ServerArgs):
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
os.environ["CUDA_MODULE_LOADING"] = "AUTO" os.environ["CUDA_MODULE_LOADING"] = "AUTO"
# flashinfer uses this environment variable for various kernels from MoE to quant kernels # flashinfer uses this environment variable for various kernels from MoE to quant kernels
if os.environ.get("TRTLLM_ENABLE_PDL", "1") != "0":
os.environ["TRTLLM_ENABLE_PDL"] = "1" os.environ["TRTLLM_ENABLE_PDL"] = "1"
# Can also be passed as argument # Can also be passed as argument
......
...@@ -67,7 +67,10 @@ from sglang.srt.layers.moe import ( ...@@ -67,7 +67,10 @@ from sglang.srt.layers.moe import (
should_use_flashinfer_cutlass_moe_fp4_allgather, should_use_flashinfer_cutlass_moe_fp4_allgather,
) )
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import (
FusedMoE,
_is_fp4_quantization_enabled,
)
from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
...@@ -299,7 +302,9 @@ class MoEGate(nn.Module): ...@@ -299,7 +302,9 @@ class MoEGate(nn.Module):
and _device_sm >= 90 and _device_sm >= 90
): ):
# router gemm output float32 # router gemm output float32
logits = dsv3_router_gemm(hidden_states, self.weight) logits = dsv3_router_gemm(
hidden_states, self.weight, out_dtype=torch.float32
)
elif _use_aiter_gfx95 and hidden_states.shape[0] <= 256: elif _use_aiter_gfx95 and hidden_states.shape[0] <= 256:
logits = aiter_dsv3_router_gemm( logits = aiter_dsv3_router_gemm(
hidden_states, self.weight, gemm_output_zero_allocator hidden_states, self.weight, gemm_output_zero_allocator
...@@ -364,6 +369,9 @@ class DeepseekV2MoE(nn.Module): ...@@ -364,6 +369,9 @@ class DeepseekV2MoE(nn.Module):
prefix=add_prefix("experts", prefix), prefix=add_prefix("experts", prefix),
) )
correction_bias = self.gate.e_score_correction_bias
if _is_fp4_quantization_enabled():
correction_bias = correction_bias.to(torch.bfloat16)
self.topk = TopK( self.topk = TopK(
top_k=config.num_experts_per_tok + self.num_fused_shared_experts, top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
renormalize=config.norm_topk_prob, renormalize=config.norm_topk_prob,
...@@ -371,7 +379,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -371,7 +379,7 @@ class DeepseekV2MoE(nn.Module):
num_expert_group=config.n_group, num_expert_group=config.n_group,
num_fused_shared_experts=self.num_fused_shared_experts, num_fused_shared_experts=self.num_fused_shared_experts,
topk_group=config.topk_group, topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias, correction_bias=correction_bias,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(), apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
force_topk=quant_config is None, force_topk=quant_config is None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment