Unverified Commit ded333fb authored by Rohan Potdar's avatar Rohan Potdar Committed by GitHub
Browse files

[ROCm][Bugfix]: Only save unpadded sizes for shared_experts in MoERunner to...


[ROCm][Bugfix]: Only save unpadded sizes for shared_experts in MoERunner to fix rmsnorm pad fusion (#34636)
Signed-off-by: default avatarRohan138 <rohanpotdar138@gmail.com>
parent 9d7577b2
...@@ -384,8 +384,11 @@ class DefaultMoERunner(MoERunner): ...@@ -384,8 +384,11 @@ class DefaultMoERunner(MoERunner):
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
# For latent MoE: save ORIGINAL hidden_states before transform # For latent MoE: save ORIGINAL hidden_states before transform
# (shared_experts need original dimension, routed experts use transformed) # (shared_experts need original dimension, routed experts use transformed)
if self.shared_experts is not None:
original_hidden_states = hidden_states original_hidden_states = hidden_states
original_hidden_dim = hidden_states.shape[-1] original_hidden_dim = hidden_states.shape[-1]
else:
original_hidden_states = None
# Apply transform for routed experts (e.g., latent projection for latent MoE) # Apply transform for routed experts (e.g., latent projection for latent MoE)
hidden_states = self.apply_routed_input_transform(hidden_states) hidden_states = self.apply_routed_input_transform(hidden_states)
...@@ -407,7 +410,7 @@ class DefaultMoERunner(MoERunner): ...@@ -407,7 +410,7 @@ class DefaultMoERunner(MoERunner):
self._encode_layer_name(), self._encode_layer_name(),
) )
if isinstance(fused_output, tuple): if self.shared_experts is not None:
orig_hidden_dims = [original_hidden_dim, transformed_hidden_dim] orig_hidden_dims = [original_hidden_dim, transformed_hidden_dim]
else: else:
orig_hidden_dims = [transformed_hidden_dim] orig_hidden_dims = [transformed_hidden_dim]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment