Unverified Commit 54d0b1c2 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`Mixtral`] Change mistral op order (#27955)

up
parent 4850aaba
......@@ -663,10 +663,10 @@ class MixtralBLockSparseTop2MLP(nn.Module):
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states, routing_weights):
def forward(self, hidden_states):
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
current_hidden_states = self.w2(current_hidden_states)
return routing_weights * current_hidden_states
return current_hidden_states
MISTRAL_ATTENTION_CLASSES = {
......@@ -736,7 +736,7 @@ class MixtralSparseMoeBlock(nn.Module):
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None])
current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment