"vscode:/vscode.git/clone" did not exist on "c47576ca6e699c6f8eaa8dfc4959e2e85dec0c72"
Unverified Commit 70c87138 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

🚨 [Mistral and friends] Update MLP (#31057)

Update MLP
parent d475f767
...@@ -1001,7 +1001,6 @@ class JambaMambaMixer(nn.Module): ...@@ -1001,7 +1001,6 @@ class JambaMambaMixer(nn.Module):
class JambaMLP(nn.Module): class JambaMLP(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
...@@ -1009,8 +1008,8 @@ class JambaMLP(nn.Module): ...@@ -1009,8 +1008,8 @@ class JambaMLP(nn.Module):
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act] self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x): def forward(self, hidden_state):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
# Adapted from transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock with Mistral->Jamba # Adapted from transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock with Mistral->Jamba
......
...@@ -160,7 +160,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): ...@@ -160,7 +160,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
class MistralMLP(nn.Module): class MistralMLP(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
...@@ -168,8 +167,8 @@ class MistralMLP(nn.Module): ...@@ -168,8 +167,8 @@ class MistralMLP(nn.Module):
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act] self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x): def forward(self, hidden_state):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
# Copied from transformers.models.llama.modeling_llama.repeat_kv # Copied from transformers.models.llama.modeling_llama.repeat_kv
......
...@@ -173,7 +173,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): ...@@ -173,7 +173,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
class Qwen2MLP(nn.Module): class Qwen2MLP(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
...@@ -181,8 +180,8 @@ class Qwen2MLP(nn.Module): ...@@ -181,8 +180,8 @@ class Qwen2MLP(nn.Module):
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act] self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x): def forward(self, hidden_state):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
# Copied from transformers.models.llama.modeling_llama.repeat_kv # Copied from transformers.models.llama.modeling_llama.repeat_kv
......
...@@ -197,7 +197,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): ...@@ -197,7 +197,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
class StableLmMLP(nn.Module): class StableLmMLP(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
...@@ -205,8 +204,8 @@ class StableLmMLP(nn.Module): ...@@ -205,8 +204,8 @@ class StableLmMLP(nn.Module):
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act] self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x): def forward(self, hidden_state):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
class StableLmLayerNormPerHead(nn.Module): class StableLmLayerNormPerHead(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment