Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
70c87138
Unverified
Commit
70c87138
authored
Jun 03, 2024
by
NielsRogge
Committed by
GitHub
Jun 03, 2024
Browse files
🚨
[Mistral and friends] Update MLP (#31057)
Update MLP
parent
d475f767
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
8 additions
and
12 deletions
+8
-12
src/transformers/models/jamba/modeling_jamba.py
src/transformers/models/jamba/modeling_jamba.py
+2
-3
src/transformers/models/mistral/modeling_mistral.py
src/transformers/models/mistral/modeling_mistral.py
+2
-3
src/transformers/models/qwen2/modeling_qwen2.py
src/transformers/models/qwen2/modeling_qwen2.py
+2
-3
src/transformers/models/stablelm/modeling_stablelm.py
src/transformers/models/stablelm/modeling_stablelm.py
+2
-3
No files found.
src/transformers/models/jamba/modeling_jamba.py
View file @
70c87138
...
...
@@ -1001,7 +1001,6 @@ class JambaMambaMixer(nn.Module):
class
JambaMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
self
.
intermediate_size
=
config
.
intermediate_size
self
.
gate_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
intermediate_size
,
bias
=
False
)
...
...
@@ -1009,8 +1008,8 @@ class JambaMLP(nn.Module):
self
.
down_proj
=
nn
.
Linear
(
self
.
intermediate_size
,
self
.
hidden_size
,
bias
=
False
)
self
.
act_fn
=
ACT2FN
[
config
.
hidden_act
]
def
forward
(
self
,
x
):
return
self
.
down_proj
(
self
.
act_fn
(
self
.
gate_proj
(
x
))
*
self
.
up_proj
(
x
))
def
forward
(
self
,
hidden_state
):
return
self
.
down_proj
(
self
.
act_fn
(
self
.
gate_proj
(
hidden_state
))
*
self
.
up_proj
(
hidden_state
))
# Adapted from transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock with Mistral->Jamba
...
...
src/transformers/models/mistral/modeling_mistral.py
View file @
70c87138
...
...
@@ -160,7 +160,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
class
MistralMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
self
.
intermediate_size
=
config
.
intermediate_size
self
.
gate_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
intermediate_size
,
bias
=
False
)
...
...
@@ -168,8 +167,8 @@ class MistralMLP(nn.Module):
self
.
down_proj
=
nn
.
Linear
(
self
.
intermediate_size
,
self
.
hidden_size
,
bias
=
False
)
self
.
act_fn
=
ACT2FN
[
config
.
hidden_act
]
def
forward
(
self
,
x
):
return
self
.
down_proj
(
self
.
act_fn
(
self
.
gate_proj
(
x
))
*
self
.
up_proj
(
x
))
def
forward
(
self
,
hidden_state
):
return
self
.
down_proj
(
self
.
act_fn
(
self
.
gate_proj
(
hidden_state
))
*
self
.
up_proj
(
hidden_state
))
# Copied from transformers.models.llama.modeling_llama.repeat_kv
...
...
src/transformers/models/qwen2/modeling_qwen2.py
View file @
70c87138
...
...
@@ -173,7 +173,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
class
Qwen2MLP
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
self
.
intermediate_size
=
config
.
intermediate_size
self
.
gate_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
intermediate_size
,
bias
=
False
)
...
...
@@ -181,8 +180,8 @@ class Qwen2MLP(nn.Module):
self
.
down_proj
=
nn
.
Linear
(
self
.
intermediate_size
,
self
.
hidden_size
,
bias
=
False
)
self
.
act_fn
=
ACT2FN
[
config
.
hidden_act
]
def
forward
(
self
,
x
):
return
self
.
down_proj
(
self
.
act_fn
(
self
.
gate_proj
(
x
))
*
self
.
up_proj
(
x
))
def
forward
(
self
,
hidden_state
):
return
self
.
down_proj
(
self
.
act_fn
(
self
.
gate_proj
(
hidden_state
))
*
self
.
up_proj
(
hidden_state
))
# Copied from transformers.models.llama.modeling_llama.repeat_kv
...
...
src/transformers/models/stablelm/modeling_stablelm.py
View file @
70c87138
...
...
@@ -197,7 +197,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
class
StableLmMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
self
.
intermediate_size
=
config
.
intermediate_size
self
.
gate_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
intermediate_size
,
bias
=
False
)
...
...
@@ -205,8 +204,8 @@ class StableLmMLP(nn.Module):
self
.
down_proj
=
nn
.
Linear
(
self
.
intermediate_size
,
self
.
hidden_size
,
bias
=
False
)
self
.
act_fn
=
ACT2FN
[
config
.
hidden_act
]
def
forward
(
self
,
x
):
return
self
.
down_proj
(
self
.
act_fn
(
self
.
gate_proj
(
x
))
*
self
.
up_proj
(
x
))
def
forward
(
self
,
hidden_state
):
return
self
.
down_proj
(
self
.
act_fn
(
self
.
gate_proj
(
hidden_state
))
*
self
.
up_proj
(
hidden_state
))
class
StableLmLayerNormPerHead
(
nn
.
Module
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment