Unverified Commit e72ae80b authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Bugfix] Support 2D input shape in MoE layer (#6287)

parent 8a924d22
...@@ -88,12 +88,13 @@ class MixtralMoE(nn.Module): ...@@ -88,12 +88,13 @@ class MixtralMoE(nn.Module):
tp_size=tp_size) tp_size=tp_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape # NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size) hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states, router_logits) final_hidden_states = self.experts(hidden_states, router_logits)
return final_hidden_states.view(num_tokens, hidden_size) return final_hidden_states.view(orig_shape)
class MixtralAttention(nn.Module): class MixtralAttention(nn.Module):
......
...@@ -126,7 +126,9 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -126,7 +126,9 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
bias=False) bias=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape # NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
shared_output = None shared_output = None
if self.shared_expert is not None: if self.shared_expert is not None:
...@@ -145,7 +147,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -145,7 +147,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states) final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim) return final_hidden_states.view(orig_shape)
class Qwen2MoeAttention(nn.Module): class Qwen2MoeAttention(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment