Unverified Commit 679ca5d8 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Fix MoE for the Transformers modelling backend (#34436)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent f2c47886
...@@ -45,7 +45,6 @@ class TransformersFusedMoE(FusedMoE): ...@@ -45,7 +45,6 @@ class TransformersFusedMoE(FusedMoE):
# --8<-- [end:transformers_fused_moe] # --8<-- [end:transformers_fused_moe]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._topk_ids: torch.Tensor = None self._topk_ids: torch.Tensor = None
def custom_routing_function(hidden_states, gating_output, topk, renormalize): def custom_routing_function(hidden_states, gating_output, topk, renormalize):
...@@ -63,7 +62,8 @@ class TransformersFusedMoE(FusedMoE): ...@@ -63,7 +62,8 @@ class TransformersFusedMoE(FusedMoE):
(topk_ids,) = dist_group.all_gatherv([topk_ids], 0, sizes) (topk_ids,) = dist_group.all_gatherv([topk_ids], 0, sizes)
return topk_weights, topk_ids return topk_weights, topk_ids
self.custom_routing_function = custom_routing_function kwargs["custom_routing_function"] = custom_routing_function
super().__init__(*args, **kwargs)
def forward( def forward(
self, self,
...@@ -94,7 +94,7 @@ def transformers_moe_forward( ...@@ -94,7 +94,7 @@ def transformers_moe_forward(
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
self._topk_ids = topk_ids self._topk_ids = topk_ids
# Clone hidden_states because it will be mutated in-place in FusedMoE # Clone hidden_states because it will be mutated in-place in FusedMoE
return self.forward_impl(hidden_states.clone(), topk_weights) return self.runner.forward(hidden_states.clone(), topk_weights)
def transformers_moe_forward_fake( def transformers_moe_forward_fake(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment