Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b75255cd
Unverified
Commit
b75255cd
authored
Nov 30, 2022
by
Younes Belkada
Committed by
GitHub
Nov 30, 2022
Browse files
[OPT/Galactica] Load large `galactica` models (#20390)
* fix `opt` bias * revert unneeded assignment
parent
293991d4
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
5 deletions
+19
-5
src/transformers/models/opt/configuration_opt.py
src/transformers/models/opt/configuration_opt.py
+9
-0
src/transformers/models/opt/modeling_opt.py
src/transformers/models/opt/modeling_opt.py
+10
-5
No files found.
src/transformers/models/opt/configuration_opt.py
View file @
b75255cd
...
...
@@ -74,6 +74,10 @@ class OPTConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models).
enable_bias (`bool`, *optional*, defaults to `True`):
Whether or not if the linear layers in the attention blocks should use the bias term.
layer_norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether or not if the layer norms should have learnable parameters.
Example:
...
...
@@ -112,6 +116,8 @@ class OPTConfig(PretrainedConfig):
pad_token_id
=
1
,
bos_token_id
=
2
,
eos_token_id
=
2
,
enable_bias
=
True
,
layer_norm_elementwise_affine
=
True
,
**
kwargs
):
super
().
__init__
(
...
...
@@ -134,6 +140,9 @@ class OPTConfig(PretrainedConfig):
self
.
layerdrop
=
layerdrop
self
.
use_cache
=
use_cache
self
.
do_layer_norm_before
=
do_layer_norm_before
# We keep these variables at `True` for backward compatibility.
self
.
enable_bias
=
enable_bias
self
.
layer_norm_elementwise_affine
=
layer_norm_elementwise_affine
# Note that the only purpose of `_remove_final_layer_norm` is to keep backward compatibility
# with checkpoints that have been fine-tuned before transformers v4.20.1
...
...
src/transformers/models/opt/modeling_opt.py
View file @
b75255cd
...
...
@@ -279,15 +279,18 @@ class OPTDecoderLayer(nn.Module):
num_heads
=
config
.
num_attention_heads
,
dropout
=
config
.
attention_dropout
,
is_decoder
=
True
,
bias
=
config
.
enable_bias
,
)
self
.
do_layer_norm_before
=
config
.
do_layer_norm_before
self
.
dropout
=
config
.
dropout
self
.
activation_fn
=
ACT2FN
[
config
.
activation_function
]
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
)
self
.
fc1
=
nn
.
Linear
(
self
.
embed_dim
,
config
.
ffn_dim
)
self
.
fc2
=
nn
.
Linear
(
config
.
ffn_dim
,
self
.
embed_dim
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
)
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
,
elementwise_affine
=
config
.
layer_norm_elementwise_affine
)
self
.
fc1
=
nn
.
Linear
(
self
.
embed_dim
,
config
.
ffn_dim
,
bias
=
config
.
enable_bias
)
self
.
fc2
=
nn
.
Linear
(
config
.
ffn_dim
,
self
.
embed_dim
,
bias
=
config
.
enable_bias
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
,
elementwise_affine
=
config
.
layer_norm_elementwise_affine
)
def
forward
(
self
,
...
...
@@ -507,7 +510,9 @@ class OPTDecoder(OPTPreTrainedModel):
# with checkpoints that have been fine-tuned before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164
if
config
.
do_layer_norm_before
and
not
config
.
_remove_final_layer_norm
:
self
.
final_layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
elementwise_affine
=
config
.
layer_norm_elementwise_affine
)
else
:
self
.
final_layer_norm
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment