Unverified Commit 83087247 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[hotfix] missing `w13_weight_fp8` and `w2_weight_fp8` in UE8M0 requantization (#12259)

parent 334543ff
...@@ -131,23 +131,6 @@ class DeepEPMoE(FusedMoE): ...@@ -131,23 +131,6 @@ class DeepEPMoE(FusedMoE):
) )
# the last one is invalid rank_id # the last one is invalid rank_id
self.expert_mask[:-1] = 1 self.expert_mask[:-1] = 1
elif not _is_npu:
self.w13_weight_fp8 = (
self.w13_weight,
(
self.w13_weight_scale_inv
if self.use_block_quant or self.use_w4afp8
else self.w13_weight_scale
),
)
self.w2_weight_fp8 = (
self.w2_weight,
(
self.w2_weight_scale_inv
if self.use_block_quant or self.use_w4afp8
else self.w2_weight_scale
),
)
def forward( def forward(
self, self,
......
...@@ -227,15 +227,16 @@ class DeepGemmRunnerCore(MoeRunnerCore): ...@@ -227,15 +227,16 @@ class DeepGemmRunnerCore(MoeRunnerCore):
hidden_states_device = running_state["hidden_states_device"] hidden_states_device = running_state["hidden_states_device"]
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
b, s_mn, s_k = hidden_states_scale.shape
assert (
s_mn % 4 == 0 and s_k % 4 == 0
), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
# GroupGemm-0 # GroupGemm-0
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0: if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
hidden_states_scale = _cast_to_e8m0_with_rounding_up(hidden_states_scale) if hidden_states_scale.dtype != torch.int:
b, s_mn, s_k = hidden_states_scale.shape
assert (
s_mn % 4 == 0 and s_k % 4 == 0
), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
hidden_states_scale = _cast_to_e8m0_with_rounding_up(
hidden_states_scale
)
else: else:
hidden_states_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor( hidden_states_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
hidden_states_scale hidden_states_scale
......
...@@ -3289,8 +3289,8 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -3289,8 +3289,8 @@ class DeepseekV2ForCausalLM(nn.Module):
experts = layer.mlp.experts experts = layer.mlp.experts
if isinstance(experts, DeepEPMoE): if isinstance(experts, DeepEPMoE):
for w in [ for w in [
experts.w13_weight_fp8, (experts.w13_weight, experts.w13_weight_scale_inv),
experts.w2_weight_fp8, (experts.w2_weight, experts.w2_weight_scale_inv),
]: ]:
requant_weight_ue8m0_inplace(w[0], w[1], weight_block_size) requant_weight_ue8m0_inplace(w[0], w[1], weight_block_size)
else: else:
...@@ -3338,10 +3338,26 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -3338,10 +3338,26 @@ class DeepseekV2ForCausalLM(nn.Module):
) )
experts = layer.mlp.experts experts = layer.mlp.experts
w13_weight_fp8 = (
experts.w13_weight,
(
experts.w13_weight_scale_inv
if hasattr(experts, "w13_weight_scale_inv")
else experts.w13_weight_scale
),
)
w2_weight_fp8 = (
experts.w2_weight,
(
experts.w2_weight_scale_inv
if hasattr(experts, "w2_weight_scale_inv")
else experts.w2_weight_scale
),
)
if isinstance(experts, DeepEPMoE): if isinstance(experts, DeepEPMoE):
for w in [ for w in [
experts.w13_weight_fp8, w13_weight_fp8,
experts.w2_weight_fp8, w2_weight_fp8,
]: ]:
transform_scale_ue8m0_inplace(w[1], mn=w[0].shape[-2]) transform_scale_ue8m0_inplace(w[1], mn=w[0].shape[-2])
......
...@@ -821,8 +821,8 @@ class LongcatFlashForCausalLM(nn.Module): ...@@ -821,8 +821,8 @@ class LongcatFlashForCausalLM(nn.Module):
experts = layer.mlp.experts experts = layer.mlp.experts
if isinstance(experts, DeepEPMoE): if isinstance(experts, DeepEPMoE):
for w in [ for w in [
experts.w13_weight_fp8, (experts.w13_weight, experts.w13_weight_scale_inv),
experts.w2_weight_fp8, (experts.w2_weight, experts.w2_weight_scale_inv),
]: ]:
requant_weight_ue8m0_inplace(w[0], w[1], weight_block_size) requant_weight_ue8m0_inplace(w[0], w[1], weight_block_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment