Unverified Commit fa4b7055 authored by Hemanth Acharya's avatar Hemanth Acharya Committed by GitHub
Browse files

[ROCm] Cast score correction bias tensor during model construction for DeepSeek/Kimi-K2 (#39999)


Signed-off-by: default avatarHemanth Acharya <heachary@amd.com>
parent 447c372a
...@@ -1782,6 +1782,8 @@ class rocm_aiter_ops: ...@@ -1782,6 +1782,8 @@ class rocm_aiter_ops:
need_renorm: bool, need_renorm: bool,
routed_scaling_factor: float = 1.0, routed_scaling_factor: float = 1.0,
) -> None: ) -> None:
if correction_bias.dtype != gating_output.dtype:
correction_bias = correction_bias.to(gating_output.dtype)
torch.ops.vllm.rocm_aiter_biased_grouped_topk( torch.ops.vllm.rocm_aiter_biased_grouped_topk(
gating_output, gating_output,
correction_bias, correction_bias,
......
...@@ -152,7 +152,7 @@ def rocm_aiter_grouped_topk( ...@@ -152,7 +152,7 @@ def rocm_aiter_grouped_topk(
if e_score_correction_bias is not None: if e_score_correction_bias is not None:
rocm_aiter_ops.biased_grouped_topk( rocm_aiter_ops.biased_grouped_topk(
gating_output, gating_output,
e_score_correction_bias.to(gating_output.dtype), e_score_correction_bias,
topk_weights, topk_weights,
topk_ids, topk_ids,
num_expert_group, num_expert_group,
......
...@@ -136,7 +136,7 @@ def fused_topk_bias( ...@@ -136,7 +136,7 @@ def fused_topk_bias(
) )
rocm_aiter_ops.biased_grouped_topk( rocm_aiter_ops.biased_grouped_topk(
gating_output, gating_output,
e_score_correction_bias.to(gating_output.dtype), e_score_correction_bias,
topk_weights, topk_weights,
topk_ids, topk_ids,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
......
...@@ -349,6 +349,21 @@ class DeepseekV2MoE(nn.Module): ...@@ -349,6 +349,21 @@ class DeepseekV2MoE(nn.Module):
else torch.bfloat16 else torch.bfloat16
) )
# Pre-cast the bias to match the gate output dtype so the
# conversion is not repeated on every forward pass. All
# downstream references (FusedMoE, router) share the same
# nn.Parameter object, so mutating .data propagates everywhere.
# Weight loading uses copy_(), which handles the dtype conversion.
# Only needed on ROCm where the aiter biased_grouped_topk kernel
# requires the bias dtype to match the gating output dtype.
if (
self.is_rocm_aiter_moe_enabled
and self.gate.e_score_correction_bias is not None
):
self.gate.e_score_correction_bias.data = (
self.gate.e_score_correction_bias.data.to(self.gate.out_dtype)
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment