Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
425e1a04
Unverified
Commit
425e1a04
authored
May 03, 2024
by
Mayank Mishra
Committed by
GitHub
May 03, 2024
Browse files
add mlp bias for llama models (#30031)
* add bias * fix quality
parent
a0e77a1f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
8 additions
and
5 deletions
+8
-5
src/transformers/models/cohere/modeling_cohere.py
src/transformers/models/cohere/modeling_cohere.py
+0
-1
src/transformers/models/llama/configuration_llama.py
src/transformers/models/llama/configuration_llama.py
+5
-1
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llama/modeling_llama.py
+3
-3
No files found.
src/transformers/models/cohere/modeling_cohere.py
View file @
425e1a04
...
@@ -161,7 +161,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
...
@@ -161,7 +161,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
return
q_embed
.
to
(
dtype
=
dtype
),
k_embed
.
to
(
dtype
=
dtype
)
return
q_embed
.
to
(
dtype
=
dtype
),
k_embed
.
to
(
dtype
=
dtype
)
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
class
CohereMLP
(
nn
.
Module
):
class
CohereMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
()
super
().
__init__
()
...
...
src/transformers/models/llama/configuration_llama.py
View file @
425e1a04
...
@@ -94,10 +94,12 @@ class LlamaConfig(PretrainedConfig):
...
@@ -94,10 +94,12 @@ class LlamaConfig(PretrainedConfig):
these scaling strategies behave:
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.
experimental feature, subject to breaking API changes in future versions.
attention_bias (`bool`,
defaults to `False`,
*optional*, defaults to `False`):
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
The dropout ratio for the attention probabilities.
mlp_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
```python
```python
>>> from transformers import LlamaModel, LlamaConfig
>>> from transformers import LlamaModel, LlamaConfig
...
@@ -137,6 +139,7 @@ class LlamaConfig(PretrainedConfig):
...
@@ -137,6 +139,7 @@ class LlamaConfig(PretrainedConfig):
rope_scaling
=
None
,
rope_scaling
=
None
,
attention_bias
=
False
,
attention_bias
=
False
,
attention_dropout
=
0.0
,
attention_dropout
=
0.0
,
mlp_bias
=
False
,
**
kwargs
,
**
kwargs
,
):
):
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
...
@@ -161,6 +164,7 @@ class LlamaConfig(PretrainedConfig):
...
@@ -161,6 +164,7 @@ class LlamaConfig(PretrainedConfig):
self
.
_rope_scaling_validation
()
self
.
_rope_scaling_validation
()
self
.
attention_bias
=
attention_bias
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
self
.
attention_dropout
=
attention_dropout
self
.
mlp_bias
=
mlp_bias
super
().
__init__
(
super
().
__init__
(
pad_token_id
=
pad_token_id
,
pad_token_id
=
pad_token_id
,
...
...
src/transformers/models/llama/modeling_llama.py
View file @
425e1a04
...
@@ -214,9 +214,9 @@ class LlamaMLP(nn.Module):
...
@@ -214,9 +214,9 @@ class LlamaMLP(nn.Module):
self
.
config
=
config
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
self
.
intermediate_size
=
config
.
intermediate_size
self
.
intermediate_size
=
config
.
intermediate_size
self
.
gate_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
intermediate_size
,
bias
=
False
)
self
.
gate_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
intermediate_size
,
bias
=
config
.
mlp_bias
)
self
.
up_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
intermediate_size
,
bias
=
False
)
self
.
up_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
intermediate_size
,
bias
=
config
.
mlp_bias
)
self
.
down_proj
=
nn
.
Linear
(
self
.
intermediate_size
,
self
.
hidden_size
,
bias
=
False
)
self
.
down_proj
=
nn
.
Linear
(
self
.
intermediate_size
,
self
.
hidden_size
,
bias
=
config
.
mlp_bias
)
self
.
act_fn
=
ACT2FN
[
config
.
hidden_act
]
self
.
act_fn
=
ACT2FN
[
config
.
hidden_act
]
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment