Unverified Commit 83087247 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[hotfix] missing `w13_weight_fp8` and `w2_weight_fp8` in UE8M0 requantization (#12259)

parent 334543ff
......@@ -131,23 +131,6 @@ class DeepEPMoE(FusedMoE):
)
# the last one is invalid rank_id
self.expert_mask[:-1] = 1
elif not _is_npu:
self.w13_weight_fp8 = (
self.w13_weight,
(
self.w13_weight_scale_inv
if self.use_block_quant or self.use_w4afp8
else self.w13_weight_scale
),
)
self.w2_weight_fp8 = (
self.w2_weight,
(
self.w2_weight_scale_inv
if self.use_block_quant or self.use_w4afp8
else self.w2_weight_scale
),
)
def forward(
self,
......
......@@ -227,15 +227,16 @@ class DeepGemmRunnerCore(MoeRunnerCore):
hidden_states_device = running_state["hidden_states_device"]
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
b, s_mn, s_k = hidden_states_scale.shape
assert (
s_mn % 4 == 0 and s_k % 4 == 0
), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
# GroupGemm-0
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
hidden_states_scale = _cast_to_e8m0_with_rounding_up(hidden_states_scale)
if hidden_states_scale.dtype != torch.int:
b, s_mn, s_k = hidden_states_scale.shape
assert (
s_mn % 4 == 0 and s_k % 4 == 0
), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
hidden_states_scale = _cast_to_e8m0_with_rounding_up(
hidden_states_scale
)
else:
hidden_states_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
hidden_states_scale
......
......@@ -3289,8 +3289,8 @@ class DeepseekV2ForCausalLM(nn.Module):
experts = layer.mlp.experts
if isinstance(experts, DeepEPMoE):
for w in [
experts.w13_weight_fp8,
experts.w2_weight_fp8,
(experts.w13_weight, experts.w13_weight_scale_inv),
(experts.w2_weight, experts.w2_weight_scale_inv),
]:
requant_weight_ue8m0_inplace(w[0], w[1], weight_block_size)
else:
......@@ -3338,10 +3338,26 @@ class DeepseekV2ForCausalLM(nn.Module):
)
experts = layer.mlp.experts
w13_weight_fp8 = (
experts.w13_weight,
(
experts.w13_weight_scale_inv
if hasattr(experts, "w13_weight_scale_inv")
else experts.w13_weight_scale
),
)
w2_weight_fp8 = (
experts.w2_weight,
(
experts.w2_weight_scale_inv
if hasattr(experts, "w2_weight_scale_inv")
else experts.w2_weight_scale
),
)
if isinstance(experts, DeepEPMoE):
for w in [
experts.w13_weight_fp8,
experts.w2_weight_fp8,
w13_weight_fp8,
w2_weight_fp8,
]:
transform_scale_ue8m0_inplace(w[1], mn=w[0].shape[-2])
......
......@@ -821,8 +821,8 @@ class LongcatFlashForCausalLM(nn.Module):
experts = layer.mlp.experts
if isinstance(experts, DeepEPMoE):
for w in [
experts.w13_weight_fp8,
experts.w2_weight_fp8,
(experts.w13_weight, experts.w13_weight_scale_inv),
(experts.w2_weight, experts.w2_weight_scale_inv),
]:
requant_weight_ue8m0_inplace(w[0], w[1], weight_block_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment