Unverified Commit 468a8d72 authored by Xin Yang's avatar Xin Yang Committed by GitHub
Browse files

[Bugfix] Fix FusedMoEModularKernel for triton backend (#28913)


Signed-off-by: default avatarXin Yang <xyangx@amazon.com>
parent 4c23690f
...@@ -755,8 +755,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -755,8 +755,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.w13_weight = w13_weight self.w13_weight = w13_weight
self.w2_weight = w2_weight self.w2_weight = w2_weight
layer.w13_weight = Parameter(w13_weight.storage.data, requires_grad=False) del layer.w13_weight
layer.w2_weight = Parameter(w2_weight.storage.data, requires_grad=False) del layer.w2_weight
layer.w13_weight = w13_weight
layer.w2_weight = w2_weight
else: else:
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}") raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
...@@ -1065,8 +1067,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -1065,8 +1067,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
return triton_kernel_moe_forward( return triton_kernel_moe_forward(
hidden_states=x, hidden_states=x,
w1=self.w13_weight, w1=layer.w13_weight,
w2=self.w2_weight, w2=layer.w2_weight,
gating_output=router_logits, gating_output=router_logits,
topk=top_k, topk=top_k,
renormalize=renormalize, renormalize=renormalize,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment