Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8a6928e2
Unverified
Commit
8a6928e2
authored
Sep 12, 2022
by
Joao Gante
Committed by
GitHub
Sep 12, 2022
Browse files
TF: correct TFBart embeddings weights name when load_weight_prefix is passed (#18993)
parent
c126a239
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
4 deletions
+22
-4
src/transformers/models/bart/modeling_tf_bart.py
src/transformers/models/bart/modeling_tf_bart.py
+22
-4
No files found.
src/transformers/models/bart/modeling_tf_bart.py
View file @
8a6928e2
...
...
@@ -16,6 +16,7 @@
import
random
from
contextlib
import
nullcontext
from
typing
import
Optional
,
Tuple
,
Union
import
numpy
as
np
...
...
@@ -748,7 +749,15 @@ class TFBartEncoder(tf.keras.layers.Layer):
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
if
inputs_embeds
is
None
:
with
tf
.
name_scope
(
self
.
embed_tokens
.
name
+
"/"
):
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
# is used with a name ending in `/`, that name replaces the current name scope.
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
if
hasattr
(
self
.
embed_tokens
,
"load_weight_prefix"
):
context_manager
=
tf
.
name_scope
(
self
.
embed_tokens
.
load_weight_prefix
+
"/"
)
else
:
context_manager
=
nullcontext
()
with
context_manager
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
*
self
.
embed_scale
embed_pos
=
self
.
embed_positions
(
input_shape
)
...
...
@@ -936,7 +945,15 @@ class TFBartDecoder(tf.keras.layers.Layer):
positions
=
self
.
embed_positions
(
input_shape
,
position_ids
=
position_ids
)
if
inputs_embeds
is
None
:
with
tf
.
name_scope
(
self
.
embed_tokens
.
name
+
"/"
):
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
# is used with a name ending in `/`, that name replaces the current name scope.
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
if
hasattr
(
self
.
embed_tokens
,
"load_weight_prefix"
):
context_manager
=
tf
.
name_scope
(
self
.
embed_tokens
.
load_weight_prefix
+
"/"
)
else
:
context_manager
=
nullcontext
()
with
context_manager
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
*
self
.
embed_scale
hidden_states
=
inputs_embeds
...
...
@@ -1032,8 +1049,9 @@ class TFBartMainLayer(tf.keras.layers.Layer):
def
__init__
(
self
,
config
:
BartConfig
,
load_weight_prefix
=
None
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
config
=
config
load_weight_prefix
=
"model.shared"
if
load_weight_prefix
is
None
else
load_weight_prefix
self
.
shared
=
tf
.
keras
.
layers
.
Embedding
(
config
.
vocab_size
,
config
.
d_model
,
name
=
load_weight_prefix
)
self
.
shared
=
tf
.
keras
.
layers
.
Embedding
(
config
.
vocab_size
,
config
.
d_model
,
name
=
"model.shared"
)
# Additional attribute to specify the expected name scope of the layer (for loading/storing weights)
self
.
shared
.
load_weight_prefix
=
"model.shared"
if
load_weight_prefix
is
None
else
load_weight_prefix
self
.
encoder
=
TFBartEncoder
(
config
,
self
.
shared
,
name
=
"encoder"
)
self
.
decoder
=
TFBartDecoder
(
config
,
self
.
shared
,
name
=
"decoder"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment