Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9ed80b00
Unverified
Commit
9ed80b00
authored
Oct 11, 2022
by
Joao Gante
Committed by
GitHub
Oct 11, 2022
Browse files
TF: TFBart embedding initialization (#19460)
* correct embedding init
parent
b651efe5
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
3 deletions
+19
-3
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+13
-2
src/transformers/models/bart/modeling_tf_bart.py
src/transformers/models/bart/modeling_tf_bart.py
+6
-1
No files found.
src/transformers/modeling_tf_utils.py
View file @
9ed80b00
...
@@ -2059,12 +2059,23 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
...
@@ -2059,12 +2059,23 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
Return:
Return:
`tf.keras.layers.Embedding`: Resized Embedding layer.
`tf.keras.layers.Embedding`: Resized Embedding layer.
"""
"""
# Get the initialization range for the embeddings
init_range
=
0.02
# default value
potential_initialization_variable_names
=
[
"initializer_range"
,
# most common
"initializer_factor"
,
# e.g. T5
"init_std"
,
# e.g BART
]
for
var_name
in
potential_initialization_variable_names
:
if
hasattr
(
self
.
config
,
var_name
):
init_range
=
getattr
(
self
.
config
,
var_name
)
# Get a new (initialized) embeddings layer
# Get a new (initialized) embeddings layer
init_range
=
getattr
(
self
.
config
,
"initializer_range"
,
0.02
)
new_embeddings
=
tf
.
keras
.
layers
.
Embedding
(
new_embeddings
=
tf
.
keras
.
layers
.
Embedding
(
input_dim
=
new_num_tokens
,
input_dim
=
new_num_tokens
,
output_dim
=
old_embeddings
.
output_dim
,
output_dim
=
old_embeddings
.
output_dim
,
embeddings_initializer
=
get_
initializer
(
init_range
),
embeddings_initializer
=
tf
.
keras
.
initializer
s
.
TruncatedNormal
(
stddev
=
init_range
),
name
=
old_embeddings
.
embeddings
.
name
[:
-
13
],
# exact same scoped name except "/embeddings:0"
name
=
old_embeddings
.
embeddings
.
name
[:
-
13
],
# exact same scoped name except "/embeddings:0"
)
)
new_embeddings
(
tf
.
constant
([[
0
]]))
new_embeddings
(
tf
.
constant
([[
0
]]))
...
...
src/transformers/models/bart/modeling_tf_bart.py
View file @
9ed80b00
...
@@ -1053,7 +1053,12 @@ class TFBartMainLayer(tf.keras.layers.Layer):
...
@@ -1053,7 +1053,12 @@ class TFBartMainLayer(tf.keras.layers.Layer):
def
__init__
(
self
,
config
:
BartConfig
,
load_weight_prefix
=
None
,
**
kwargs
):
def
__init__
(
self
,
config
:
BartConfig
,
load_weight_prefix
=
None
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
config
=
config
self
.
config
=
config
self
.
shared
=
tf
.
keras
.
layers
.
Embedding
(
config
.
vocab_size
,
config
.
d_model
,
name
=
"model.shared"
)
self
.
shared
=
tf
.
keras
.
layers
.
Embedding
(
input_dim
=
config
.
vocab_size
,
output_dim
=
config
.
d_model
,
embeddings_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
self
.
config
.
init_std
),
name
=
"model.shared"
,
)
# Additional attribute to specify the expected name scope of the layer (for loading/storing weights)
# Additional attribute to specify the expected name scope of the layer (for loading/storing weights)
self
.
shared
.
load_weight_prefix
=
"model.shared"
if
load_weight_prefix
is
None
else
load_weight_prefix
self
.
shared
.
load_weight_prefix
=
"model.shared"
if
load_weight_prefix
is
None
else
load_weight_prefix
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment