"git@developer.sourcefind.cn:OpenDAS/nerfacc.git" did not exist on "29c59cab0228b91117b360cf5a5d9bf80ee38572"
Unverified Commit efedbe6c authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix global input scale incompatible with CuTe DSL moe (#10370)

parent 3a77c80b
......@@ -1187,6 +1187,21 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
if self.enable_flashinfer_cutlass_moe or self.enable_flashinfer_trtllm_moe:
w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
elif self.enable_flashinfer_cutedsl_moe:
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
w2_input_scale = layer.w2_input_scale
def _slice_scale(w):
assert w.shape == (layer.num_experts,)
assert layer.moe_ep_size * layer.num_local_experts == layer.num_experts
return w[
layer.moe_ep_rank
* layer.num_local_experts : (layer.moe_ep_rank + 1)
* layer.num_local_experts
]
w13_input_scale = _slice_scale(w13_input_scale)
w2_input_scale = _slice_scale(w2_input_scale)
else:
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
w2_input_scale = layer.w2_input_scale
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment