Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
00cbadb8
Unverified
Commit
00cbadb8
authored
Sep 10, 2022
by
Joao Gante
Committed by
GitHub
Sep 10, 2022
Browse files
RFC: Replace custom TF embeddings by Keras embeddings (#18939)
parent
855dcae8
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
141 additions
and
142 deletions
+141
-142
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+104
-3
src/transformers/models/bart/modeling_tf_bart.py
src/transformers/models/bart/modeling_tf_bart.py
+22
-52
src/transformers/models/mbart/modeling_tf_mbart.py
src/transformers/models/mbart/modeling_tf_mbart.py
+2
-1
tests/models/bart/test_modeling_tf_bart.py
tests/models/bart/test_modeling_tf_bart.py
+2
-65
tests/test_modeling_tf_common.py
tests/test_modeling_tf_common.py
+11
-21
No files found.
src/transformers/modeling_tf_utils.py
View file @
00cbadb8
...
@@ -887,6 +887,12 @@ def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False,
...
@@ -887,6 +887,12 @@ def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False,
# If not, make the value to None
# If not, make the value to None
saved_weight_value
=
saved_weights
.
get
(
symbolic_weight_name
,
None
)
saved_weight_value
=
saved_weights
.
get
(
symbolic_weight_name
,
None
)
# Retrocompatibility patch: some embeddings are stored with the weights name (e.g. Bart's
# `model.shared/embeddings:0` are stored as `model.shared/weights:0`)
if
saved_weight_value
is
None
and
symbolic_weight_name
.
endswith
(
"embeddings:0"
):
symbolic_weight_name
=
symbolic_weight_name
[:
-
12
]
+
"weight:0"
saved_weight_value
=
saved_weights
.
get
(
symbolic_weight_name
,
None
)
# Add the updated name to the final list for computing missing/unexpected values
# Add the updated name to the final list for computing missing/unexpected values
symbolic_weights_names
.
add
(
symbolic_weight_name
)
symbolic_weights_names
.
add
(
symbolic_weight_name
)
...
@@ -1700,7 +1706,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
...
@@ -1700,7 +1706,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
"""
"""
return
None
return
None
def
resize_token_embeddings
(
self
,
new_num_tokens
=
None
)
->
tf
.
Variable
:
def
resize_token_embeddings
(
self
,
new_num_tokens
:
Optional
[
int
]
=
None
)
->
Union
[
tf
.
keras
.
layers
.
Embedding
,
tf
.
Variable
]:
"""
"""
Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
...
@@ -1710,11 +1718,17 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
...
@@ -1710,11 +1718,17 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
new_num_tokens (`int`, *optional*):
new_num_tokens (`int`, *optional*):
The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
returns a pointer to the input tokens
`tf.Variable` module of the model
without doing anything.
returns a pointer to the input tokens without doing anything.
Return:
Return:
`tf.Variable`: Pointer to the input tokens
Embeddings Module
of the model.
`tf.Variable`
or `tf.keras.layers.Embedding`
: Pointer to the input tokens of the model.
"""
"""
# TODO (joao): flagged for replacement (by `_v2_resized_token_embeddings`) due to embeddings refactor
# Run the new code path if the model has a keras embeddings layer
if
isinstance
(
self
.
get_input_embeddings
(),
tf
.
keras
.
layers
.
Embedding
):
return
self
.
_v2_resized_token_embeddings
(
new_num_tokens
)
if
new_num_tokens
is
None
or
new_num_tokens
==
self
.
config
.
vocab_size
:
if
new_num_tokens
is
None
or
new_num_tokens
==
self
.
config
.
vocab_size
:
return
self
.
_get_word_embedding_weight
(
self
.
get_input_embeddings
())
return
self
.
_get_word_embedding_weight
(
self
.
get_input_embeddings
())
...
@@ -1725,7 +1739,32 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
...
@@ -1725,7 +1739,32 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
return
model_embeds
return
model_embeds
def
_v2_resized_token_embeddings
(
self
,
new_num_tokens
:
Optional
[
int
]
=
None
)
->
tf
.
keras
.
layers
.
Embedding
:
"""
Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
Arguments:
new_num_tokens (`int`, *optional*):
The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
returns a pointer to the input tokens without doing anything.
Return:
`tf.keras.layers.Embedding`: Pointer to the input tokens of the model.
"""
if
new_num_tokens
is
None
or
new_num_tokens
==
self
.
config
.
vocab_size
:
return
self
.
get_input_embeddings
()
model_embeds
=
self
.
_v2_resize_token_embeddings
(
new_num_tokens
)
# Update base model and current model config
self
.
config
.
vocab_size
=
new_num_tokens
return
model_embeds
def
_get_word_embedding_weight
(
model
,
embedding_layer
):
def
_get_word_embedding_weight
(
model
,
embedding_layer
):
# TODO (joao): flagged for delection due to embeddings refactor
# If the variable holds the weights themselves, return them
# If the variable holds the weights themselves, return them
if
isinstance
(
embedding_layer
,
tf
.
Tensor
):
if
isinstance
(
embedding_layer
,
tf
.
Tensor
):
return
embedding_layer
return
embedding_layer
...
@@ -1755,6 +1794,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
...
@@ -1755,6 +1794,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
return
None
return
None
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
# TODO (joao): flagged for replacement (by `_v2_resize_token_embeddings`) due to embeddings refactor
old_embeddings
=
self
.
_get_word_embedding_weight
(
self
.
get_input_embeddings
())
old_embeddings
=
self
.
_get_word_embedding_weight
(
self
.
get_input_embeddings
())
new_embeddings
=
self
.
_get_resized_embeddings
(
old_embeddings
,
new_num_tokens
)
new_embeddings
=
self
.
_get_resized_embeddings
(
old_embeddings
,
new_num_tokens
)
...
@@ -1776,6 +1816,27 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
...
@@ -1776,6 +1816,27 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
return
self
.
get_input_embeddings
()
return
self
.
get_input_embeddings
()
def
_v2_resize_token_embeddings
(
self
,
new_num_tokens
):
old_embeddings
=
self
.
get_input_embeddings
()
new_embeddings
=
self
.
_v2_get_resized_embeddings
(
old_embeddings
,
new_num_tokens
)
self
.
set_input_embeddings
(
new_embeddings
)
# If word embeddings are not tied, make sure that lm head bias is resized as well
if
self
.
get_bias
()
is
not
None
:
old_lm_head_bias
=
self
.
get_bias
()
new_lm_head_bias
=
self
.
_get_resized_lm_head_bias
(
old_lm_head_bias
,
new_num_tokens
)
self
.
set_bias
(
new_lm_head_bias
)
# If word embeddings are not tied, make sure that lm head decoder is resized as well.
tied_weights
=
self
.
get_input_embeddings
()
==
self
.
get_output_embeddings
()
if
self
.
get_output_embeddings
()
is
not
None
and
not
tied_weights
:
old_lm_head_decoder
=
self
.
_get_word_embedding_weight
(
self
.
get_output_embeddings
())
# TODO (joao): this one probably needs a v2 version with other models
new_lm_head_decoder
=
self
.
_get_resized_lm_head_decoder
(
old_lm_head_decoder
,
new_num_tokens
)
self
.
set_output_embeddings
(
new_lm_head_decoder
)
return
self
.
get_input_embeddings
()
def
_get_resized_lm_head_bias
(
self
,
old_lm_head_bias
,
new_num_tokens
):
def
_get_resized_lm_head_bias
(
self
,
old_lm_head_bias
,
new_num_tokens
):
"""
"""
Build a resized bias from the old ones. Increasing the size will add newly initialized vectors at the end.
Build a resized bias from the old ones. Increasing the size will add newly initialized vectors at the end.
...
@@ -1885,6 +1946,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
...
@@ -1885,6 +1946,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
`tf.Variable`: Pointer to the resized Embedding Module or the old Embedding Module if `new_num_tokens` is
`tf.Variable`: Pointer to the resized Embedding Module or the old Embedding Module if `new_num_tokens` is
`None`
`None`
"""
"""
# TODO (joao): flagged for replacement (by `_v2_get_resized_embeddings`) due to embeddings refactor
old_embedding_dim
=
shape_list
(
old_embeddings
)[
1
]
old_embedding_dim
=
shape_list
(
old_embeddings
)[
1
]
init_range
=
getattr
(
self
.
config
,
"initializer_range"
,
0.02
)
init_range
=
getattr
(
self
.
config
,
"initializer_range"
,
0.02
)
embeddings_mask
,
current_embeddings
=
init_copy_embeddings
(
old_embeddings
,
new_num_tokens
)
embeddings_mask
,
current_embeddings
=
init_copy_embeddings
(
old_embeddings
,
new_num_tokens
)
...
@@ -1900,6 +1962,42 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
...
@@ -1900,6 +1962,42 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
return
new_embeddings
return
new_embeddings
def
_v2_get_resized_embeddings
(
self
,
old_embeddings
:
tf
.
keras
.
layers
.
Embedding
,
new_num_tokens
:
int
)
->
tf
.
keras
.
layers
.
Embedding
:
"""
Build a resized Embedding layer from a provided Embedding layer. Increasing the size will add newly initialized
vectors at the end. Reducing the size will remove vectors from the end.
Args:
old_embeddings (`tf.keras.layers.Embedding`):
Old embeddings to be resized.
new_num_tokens (`int`, *optional*):
New number of tokens in the embedding matrix.
Return:
`tf.keras.layers.Embedding`: Resized Embedding layer.
"""
# Get a new (initialized) embeddings layer
init_range
=
getattr
(
self
.
config
,
"initializer_range"
,
0.02
)
new_embeddings
=
tf
.
keras
.
layers
.
Embedding
(
input_dim
=
new_num_tokens
,
output_dim
=
old_embeddings
.
output_dim
,
embeddings_initializer
=
get_initializer
(
init_range
),
name
=
old_embeddings
.
embeddings
.
name
[:
-
13
],
# exact same scoped name except "/embeddings:0"
)
new_embeddings
(
tf
.
constant
([[
0
]]))
# Copy the old embeddings to the new embeddings
if
old_embeddings
.
input_dim
>=
new_num_tokens
:
init_embeddings
=
old_embeddings
.
embeddings
[:
new_num_tokens
]
else
:
init_embeddings
=
tf
.
concat
(
[
old_embeddings
.
embeddings
,
new_embeddings
.
embeddings
[
old_embeddings
.
input_dim
:]],
axis
=
0
)
new_embeddings
.
embeddings
.
assign
(
init_embeddings
)
return
new_embeddings
def
prune_heads
(
self
,
heads_to_prune
):
def
prune_heads
(
self
,
heads_to_prune
):
"""
"""
Prunes heads of the base model.
Prunes heads of the base model.
...
@@ -2632,6 +2730,7 @@ class TFSharedEmbeddings(tf.keras.layers.Layer):
...
@@ -2632,6 +2730,7 @@ class TFSharedEmbeddings(tf.keras.layers.Layer):
kwargs:
kwargs:
Additional keyword arguments passed along to the `__init__` of `tf.keras.layers.Layer`.
Additional keyword arguments passed along to the `__init__` of `tf.keras.layers.Layer`.
"""
"""
# TODO (joao): flagged for delection due to embeddings refactor
def
__init__
(
self
,
vocab_size
:
int
,
hidden_size
:
int
,
initializer_range
:
Optional
[
float
]
=
None
,
**
kwargs
):
def
__init__
(
self
,
vocab_size
:
int
,
hidden_size
:
int
,
initializer_range
:
Optional
[
float
]
=
None
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
...
@@ -2848,6 +2947,8 @@ class TFWrappedEmbeddings:
...
@@ -2848,6 +2947,8 @@ class TFWrappedEmbeddings:
saving/storing the correct weights
saving/storing the correct weights
"""
"""
# TODO (joao): flagged for delection due to embeddings refactor
def
__init__
(
self
,
layer
,
abs_scope_name
=
None
):
def
__init__
(
self
,
layer
,
abs_scope_name
=
None
):
self
.
_layer
=
layer
self
.
_layer
=
layer
self
.
_abs_scope_name
=
abs_scope_name
self
.
_abs_scope_name
=
abs_scope_name
...
...
src/transformers/models/bart/modeling_tf_bart.py
View file @
00cbadb8
...
@@ -35,8 +35,6 @@ from ...modeling_tf_utils import (
...
@@ -35,8 +35,6 @@ from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss
,
TFCausalLanguageModelingLoss
,
TFModelInputType
,
TFModelInputType
,
TFPreTrainedModel
,
TFPreTrainedModel
,
TFSharedEmbeddings
,
TFWrappedEmbeddings
,
keras_serializable
,
keras_serializable
,
unpack_inputs
,
unpack_inputs
,
)
)
...
@@ -113,7 +111,7 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
...
@@ -113,7 +111,7 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
return
(
one_cst
-
expanded_mask
)
*
LARGE_NEGATIVE
return
(
one_cst
-
expanded_mask
)
*
LARGE_NEGATIVE
class
TFBartLearnedPositionalEmbedding
(
TFShared
Embedding
s
):
class
TFBartLearnedPositionalEmbedding
(
tf
.
keras
.
layers
.
Embedding
):
"""
"""
This module learns positional embeddings up to a fixed maximum size.
This module learns positional embeddings up to a fixed maximum size.
"""
"""
...
@@ -136,7 +134,8 @@ class TFBartLearnedPositionalEmbedding(TFSharedEmbeddings):
...
@@ -136,7 +134,8 @@ class TFBartLearnedPositionalEmbedding(TFSharedEmbeddings):
position_ids
=
tf
.
range
(
seq_len
,
delta
=
1
,
name
=
"range"
)
position_ids
=
tf
.
range
(
seq_len
,
delta
=
1
,
name
=
"range"
)
position_ids
+=
past_key_values_length
position_ids
+=
past_key_values_length
return
super
().
call
(
position_ids
+
self
.
offset
)
offset_dtype
=
position_ids
.
dtype
if
isinstance
(
position_ids
,
tf
.
Tensor
)
else
tf
.
int32
return
super
().
call
(
position_ids
+
tf
.
constant
(
self
.
offset
,
dtype
=
offset_dtype
))
class
TFBartAttention
(
tf
.
keras
.
layers
.
Layer
):
class
TFBartAttention
(
tf
.
keras
.
layers
.
Layer
):
...
@@ -667,7 +666,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
...
@@ -667,7 +666,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
config: BartConfig
config: BartConfig
"""
"""
def
__init__
(
self
,
config
:
BartConfig
,
embed_tokens
:
Optional
[
TFShared
Embedding
s
]
=
None
,
**
kwargs
):
def
__init__
(
self
,
config
:
BartConfig
,
embed_tokens
:
Optional
[
tf
.
keras
.
layers
.
Embedding
]
=
None
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
config
=
config
self
.
config
=
config
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
dropout
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
dropout
)
...
@@ -685,12 +684,6 @@ class TFBartEncoder(tf.keras.layers.Layer):
...
@@ -685,12 +684,6 @@ class TFBartEncoder(tf.keras.layers.Layer):
self
.
layers
=
[
TFBartEncoderLayer
(
config
,
name
=
f
"layers.
{
i
}
"
)
for
i
in
range
(
config
.
encoder_layers
)]
self
.
layers
=
[
TFBartEncoderLayer
(
config
,
name
=
f
"layers.
{
i
}
"
)
for
i
in
range
(
config
.
encoder_layers
)]
self
.
layernorm_embedding
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
1e-5
,
name
=
"layernorm_embedding"
)
self
.
layernorm_embedding
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
1e-5
,
name
=
"layernorm_embedding"
)
def
get_embed_tokens
(
self
):
return
self
.
embed_tokens
def
set_embed_tokens
(
self
,
embed_tokens
):
self
.
embed_tokens
=
embed_tokens
@
unpack_inputs
@
unpack_inputs
def
call
(
def
call
(
self
,
self
,
...
@@ -750,6 +743,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
...
@@ -750,6 +743,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
if
inputs_embeds
is
None
:
if
inputs_embeds
is
None
:
with
tf
.
name_scope
(
self
.
embed_tokens
.
name
+
"/"
):
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
*
self
.
embed_scale
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
*
self
.
embed_scale
embed_pos
=
self
.
embed_positions
(
input_shape
)
embed_pos
=
self
.
embed_positions
(
input_shape
)
...
@@ -820,7 +814,7 @@ class TFBartDecoder(tf.keras.layers.Layer):
...
@@ -820,7 +814,7 @@ class TFBartDecoder(tf.keras.layers.Layer):
embed_tokens: output embedding
embed_tokens: output embedding
"""
"""
def
__init__
(
self
,
config
:
BartConfig
,
embed_tokens
:
Optional
[
TFShared
Embedding
s
]
=
None
,
**
kwargs
):
def
__init__
(
self
,
config
:
BartConfig
,
embed_tokens
:
Optional
[
tf
.
keras
.
layers
.
Embedding
]
=
None
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
config
=
config
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
...
@@ -837,12 +831,6 @@ class TFBartDecoder(tf.keras.layers.Layer):
...
@@ -837,12 +831,6 @@ class TFBartDecoder(tf.keras.layers.Layer):
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
dropout
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
dropout
)
def
get_embed_tokens
(
self
):
return
self
.
embed_tokens
def
set_embed_tokens
(
self
,
embed_tokens
):
self
.
embed_tokens
=
embed_tokens
@
unpack_inputs
@
unpack_inputs
def
call
(
def
call
(
self
,
self
,
...
@@ -943,6 +931,7 @@ class TFBartDecoder(tf.keras.layers.Layer):
...
@@ -943,6 +931,7 @@ class TFBartDecoder(tf.keras.layers.Layer):
positions
=
self
.
embed_positions
(
input_shape
,
position_ids
=
position_ids
)
positions
=
self
.
embed_positions
(
input_shape
,
position_ids
=
position_ids
)
if
inputs_embeds
is
None
:
if
inputs_embeds
is
None
:
with
tf
.
name_scope
(
self
.
embed_tokens
.
name
+
"/"
):
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
*
self
.
embed_scale
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
*
self
.
embed_scale
hidden_states
=
inputs_embeds
hidden_states
=
inputs_embeds
...
@@ -1038,36 +1027,19 @@ class TFBartMainLayer(tf.keras.layers.Layer):
...
@@ -1038,36 +1027,19 @@ class TFBartMainLayer(tf.keras.layers.Layer):
def
__init__
(
self
,
config
:
BartConfig
,
load_weight_prefix
=
None
,
**
kwargs
):
def
__init__
(
self
,
config
:
BartConfig
,
load_weight_prefix
=
None
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
config
=
config
self
.
config
=
config
self
.
shared
=
TFSharedEmbeddings
(
config
.
vocab_size
,
config
.
d_model
,
config
.
pad_token_id
,
name
=
"model.shared"
)
load_weight_prefix
=
"model.shared"
if
load_weight_prefix
is
None
else
load_weight_prefix
self
.
shared
=
tf
.
keras
.
layers
.
Embedding
(
config
.
vocab_size
,
config
.
d_model
,
name
=
load_weight_prefix
)
# set tf scope correctly
self
.
encoder
=
TFBartEncoder
(
config
,
self
.
shared
,
name
=
"encoder"
)
if
load_weight_prefix
is
None
:
self
.
decoder
=
TFBartDecoder
(
config
,
self
.
shared
,
name
=
"decoder"
)
load_weight_prefix
=
"model.shared"
with
tf
.
compat
.
v1
.
variable_scope
(
load_weight_prefix
)
as
shared_abs_scope_name
:
pass
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens
=
TFWrappedEmbeddings
(
self
.
shared
,
abs_scope_name
=
shared_abs_scope_name
)
embed_tokens
.
vocab_size
=
self
.
shared
.
vocab_size
embed_tokens
.
hidden_size
=
self
.
shared
.
hidden_size
self
.
encoder
=
TFBartEncoder
(
config
,
embed_tokens
,
name
=
"encoder"
)
self
.
decoder
=
TFBartDecoder
(
config
,
embed_tokens
,
name
=
"decoder"
)
def
get_input_embeddings
(
self
):
def
get_input_embeddings
(
self
):
return
self
.
shared
return
self
.
shared
def
set_input_embeddings
(
self
,
new_embeddings
):
def
set_input_embeddings
(
self
,
new_embeddings
):
self
.
shared
.
weight
=
new_embeddings
self
.
shared
=
new_embeddings
self
.
shared
.
vocab_size
=
self
.
shared
.
weight
.
shape
[
0
]
self
.
encoder
.
embed_tokens
=
self
.
shared
# retrieve correct absolute scope for embed token wrapper
self
.
decoder
.
embed_tokens
=
self
.
shared
with
tf
.
compat
.
v1
.
variable_scope
(
"model.shared"
)
as
shared_abs_scope_name
:
pass
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens
=
TFWrappedEmbeddings
(
self
.
shared
,
abs_scope_name
=
shared_abs_scope_name
)
self
.
encoder
.
set_embed_tokens
(
embed_tokens
)
self
.
decoder
.
set_embed_tokens
(
embed_tokens
)
@
unpack_inputs
@
unpack_inputs
def
call
(
def
call
(
...
@@ -1273,11 +1245,7 @@ class BiasLayer(tf.keras.layers.Layer):
...
@@ -1273,11 +1245,7 @@ class BiasLayer(tf.keras.layers.Layer):
BART_START_DOCSTRING
,
BART_START_DOCSTRING
,
)
)
class
TFBartForConditionalGeneration
(
TFBartPretrainedModel
,
TFCausalLanguageModelingLoss
):
class
TFBartForConditionalGeneration
(
TFBartPretrainedModel
,
TFCausalLanguageModelingLoss
):
_keys_to_ignore_on_load_unexpected
=
[
_keys_to_ignore_on_load_missing
=
[
r
"final_logits_bias"
]
r
"model.encoder.embed_tokens.weight"
,
r
"model.decoder.embed_tokens.weight"
,
]
_requires_load_weight_prefix
=
True
_requires_load_weight_prefix
=
True
def
__init__
(
self
,
config
,
load_weight_prefix
=
None
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
config
,
load_weight_prefix
=
None
,
*
inputs
,
**
kwargs
):
...
@@ -1303,10 +1271,10 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
...
@@ -1303,10 +1271,10 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
self
.
set_input_embeddings
(
value
)
self
.
set_input_embeddings
(
value
)
def
get_bias
(
self
):
def
get_bias
(
self
):
return
{
"final_logits_bias"
:
self
.
final_logits_
bias
}
return
{
"final_logits_bias"
:
self
.
bias_layer
.
bias
}
def
set_bias
(
self
,
value
):
def
set_bias
(
self
,
value
):
self
.
final_logits_
bias
=
value
[
"final_logits_bias"
]
self
.
bias_layer
.
bias
=
value
[
"final_logits_bias"
]
@
add_start_docstrings_to_model_forward
(
BART_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_model_forward
(
BART_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
TFSeq2SeqLMOutput
,
config_class
=
_CONFIG_FOR_DOC
)
@
replace_return_docstrings
(
output_type
=
TFSeq2SeqLMOutput
,
config_class
=
_CONFIG_FOR_DOC
)
...
@@ -1374,7 +1342,9 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
...
@@ -1374,7 +1342,9 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
return_dict
=
return_dict
,
return_dict
=
return_dict
,
training
=
training
,
training
=
training
,
)
)
lm_logits
=
self
.
model
.
shared
(
outputs
[
0
],
mode
=
"linear"
)
# TODO (joao): the line below is for models with tied embeddings. The previous TFBart had tied embeddings.
# The PT Bart does not have tied embeddings. Untie the weights while keeping loading retrocompatibility.
lm_logits
=
tf
.
matmul
(
outputs
[
0
],
self
.
model
.
shared
.
weights
,
transpose_b
=
True
)
lm_logits
=
self
.
bias_layer
(
lm_logits
)
lm_logits
=
self
.
bias_layer
(
lm_logits
)
masked_lm_loss
=
None
if
labels
is
None
else
self
.
hf_compute_loss
(
labels
,
lm_logits
)
masked_lm_loss
=
None
if
labels
is
None
else
self
.
hf_compute_loss
(
labels
,
lm_logits
)
...
...
src/transformers/models/mbart/modeling_tf_mbart.py
View file @
00cbadb8
...
@@ -137,7 +137,8 @@ class TFMBartLearnedPositionalEmbedding(TFSharedEmbeddings):
...
@@ -137,7 +137,8 @@ class TFMBartLearnedPositionalEmbedding(TFSharedEmbeddings):
position_ids
=
tf
.
range
(
seq_len
,
delta
=
1
,
name
=
"range"
)
position_ids
=
tf
.
range
(
seq_len
,
delta
=
1
,
name
=
"range"
)
position_ids
+=
past_key_values_length
position_ids
+=
past_key_values_length
return
super
().
call
(
position_ids
+
self
.
offset
)
offset_dtype
=
position_ids
.
dtype
if
isinstance
(
position_ids
,
tf
.
Tensor
)
else
tf
.
int32
return
super
().
call
(
position_ids
+
tf
.
constant
(
self
.
offset
,
dtype
=
offset_dtype
))
# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->MBart
# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->MBart
...
...
tests/models/bart/test_modeling_tf_bart.py
View file @
00cbadb8
...
@@ -230,69 +230,6 @@ class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
...
@@ -230,69 +230,6 @@ class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
name
=
model
.
get_bias
()
name
=
model
.
get_bias
()
assert
name
is
None
assert
name
is
None
def
test_resize_token_embeddings
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
def
_get_word_embedding_weight
(
model
,
embedding_layer
):
if
hasattr
(
embedding_layer
,
"weight"
):
return
embedding_layer
.
weight
else
:
# Here we build the word embeddings weights if not exists.
# And then we retry to get the attribute once built.
model
(
model
.
dummy_inputs
)
if
hasattr
(
embedding_layer
,
"weight"
):
return
embedding_layer
.
weight
else
:
return
None
for
model_class
in
self
.
all_model_classes
:
for
size
in
[
config
.
vocab_size
-
10
,
config
.
vocab_size
+
10
,
None
]:
# build the embeddings
model
=
model_class
(
config
=
config
)
old_input_embeddings
=
_get_word_embedding_weight
(
model
,
model
.
get_input_embeddings
())
old_output_embeddings
=
_get_word_embedding_weight
(
model
,
model
.
get_output_embeddings
())
old_final_logits_bias
=
model
.
get_bias
()
# reshape the embeddings
model
.
resize_token_embeddings
(
size
)
new_input_embeddings
=
_get_word_embedding_weight
(
model
,
model
.
get_input_embeddings
())
new_output_embeddings
=
_get_word_embedding_weight
(
model
,
model
.
get_output_embeddings
())
new_final_logits_bias
=
model
.
get_bias
()
# check that the resized embeddings size matches the desired size.
assert_size
=
size
if
size
is
not
None
else
config
.
vocab_size
self
.
assertEqual
(
new_input_embeddings
.
shape
[
0
],
assert_size
)
# check that weights remain the same after resizing
models_equal
=
True
for
p1
,
p2
in
zip
(
old_input_embeddings
.
value
(),
new_input_embeddings
.
value
()):
if
tf
.
math
.
reduce_sum
(
tf
.
math
.
abs
(
p1
-
p2
))
>
0
:
models_equal
=
False
self
.
assertTrue
(
models_equal
)
if
old_output_embeddings
is
not
None
and
new_output_embeddings
is
not
None
:
self
.
assertEqual
(
new_output_embeddings
.
shape
[
0
],
assert_size
)
models_equal
=
True
for
p1
,
p2
in
zip
(
old_output_embeddings
.
value
(),
new_output_embeddings
.
value
()):
if
tf
.
math
.
reduce_sum
(
tf
.
math
.
abs
(
p1
-
p2
))
>
0
:
models_equal
=
False
self
.
assertTrue
(
models_equal
)
if
old_final_logits_bias
is
not
None
and
new_final_logits_bias
is
not
None
:
old_final_logits_bias
=
old_final_logits_bias
[
"final_logits_bias"
]
new_final_logits_bias
=
new_final_logits_bias
[
"final_logits_bias"
]
self
.
assertEqual
(
new_final_logits_bias
.
shape
[
0
],
1
)
self
.
assertEqual
(
new_final_logits_bias
.
shape
[
1
],
assert_size
)
models_equal
=
True
for
old
,
new
in
zip
(
old_final_logits_bias
.
value
(),
new_final_logits_bias
.
value
()):
for
p1
,
p2
in
zip
(
old
,
new
):
if
tf
.
math
.
reduce_sum
(
tf
.
math
.
abs
(
p1
-
p2
))
>
0
:
models_equal
=
False
self
.
assertTrue
(
models_equal
)
@
tooslow
@
tooslow
def
test_saved_model_creation
(
self
):
def
test_saved_model_creation
(
self
):
pass
pass
...
@@ -635,7 +572,7 @@ class FasterTFBartModelIntegrationTests(unittest.TestCase):
...
@@ -635,7 +572,7 @@ class FasterTFBartModelIntegrationTests(unittest.TestCase):
def
test_xsum_1_1_generation
(
self
):
def
test_xsum_1_1_generation
(
self
):
model
=
self
.
xsum_1_1_model
model
=
self
.
xsum_1_1_model
assert
model
.
model
.
decoder
.
embed_tokens
.
_layer
==
model
.
model
.
shared
assert
model
.
model
.
decoder
.
embed_tokens
==
model
.
model
.
shared
ARTICLE
=
(
ARTICLE
=
(
"The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
"The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
" Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"
" Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"
...
@@ -685,7 +622,7 @@ class FasterTFBartModelIntegrationTests(unittest.TestCase):
...
@@ -685,7 +622,7 @@ class FasterTFBartModelIntegrationTests(unittest.TestCase):
def
test_xsum_1_1_xla_generation
(
self
):
def
test_xsum_1_1_xla_generation
(
self
):
# same test as above, but with `no_repeat_ngram_size=0` (not compatible with XLA) and XLA comparison enabled
# same test as above, but with `no_repeat_ngram_size=0` (not compatible with XLA) and XLA comparison enabled
model
=
self
.
xsum_1_1_model
model
=
self
.
xsum_1_1_model
assert
model
.
model
.
decoder
.
embed_tokens
.
_layer
==
model
.
model
.
shared
assert
model
.
model
.
decoder
.
embed_tokens
==
model
.
model
.
shared
ARTICLE
=
(
ARTICLE
=
(
"The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
"The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
" Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"
" Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"
...
...
tests/test_modeling_tf_common.py
View file @
00cbadb8
...
@@ -1144,30 +1144,20 @@ class TFModelTesterMixin:
...
@@ -1144,30 +1144,20 @@ class TFModelTesterMixin:
self
.
assert_outputs_same
(
output_for_dict_input
,
output_for_kw_input
)
self
.
assert_outputs_same
(
output_for_dict_input
,
output_for_kw_input
)
def
test_resize_token_embeddings
(
self
):
def
test_resize_token_embeddings
(
self
):
# TODO (joao): after the embeddings refactor is complete, rework this test so as to rely exclusively on
# tf.keras.layers.Embedding
if
not
self
.
test_resize_embeddings
:
if
not
self
.
test_resize_embeddings
:
return
return
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
def
_get_word_embedding_weight
(
model
,
embedding_layer
):
def
_get_word_embedding_weight
(
model
,
embedding_layer
):
embeds
=
getattr
(
embedding_layer
,
"weight"
,
None
)
if
isinstance
(
embedding_layer
,
tf
.
keras
.
layers
.
Embedding
):
if
embeds
is
not
None
:
# builds the embeddings layer
return
embeds
embeds
=
getattr
(
embedding_layer
,
"decoder"
,
None
)
if
embeds
is
not
None
:
return
embeds
model
(
model
.
dummy_inputs
)
model
(
model
.
dummy_inputs
)
return
embedding_layer
.
embeddings
embeds
=
getattr
(
embedding_layer
,
"weight"
,
None
)
else
:
if
embeds
is
not
None
:
return
model
.
_get_word_embedding_weight
(
embedding_layer
)
return
embeds
embeds
=
getattr
(
embedding_layer
,
"decoder"
,
None
)
if
embeds
is
not
None
:
return
embeds
return
None
for
model_class
in
self
.
all_model_classes
:
for
model_class
in
self
.
all_model_classes
:
for
size
in
[
config
.
vocab_size
-
10
,
config
.
vocab_size
+
10
,
None
]:
for
size
in
[
config
.
vocab_size
-
10
,
config
.
vocab_size
+
10
,
None
]:
...
@@ -1195,10 +1185,10 @@ class TFModelTesterMixin:
...
@@ -1195,10 +1185,10 @@ class TFModelTesterMixin:
if
old_bias
is
not
None
and
new_bias
is
not
None
:
if
old_bias
is
not
None
and
new_bias
is
not
None
:
for
old_weight
,
new_weight
in
zip
(
old_bias
.
values
(),
new_bias
.
values
()):
for
old_weight
,
new_weight
in
zip
(
old_bias
.
values
(),
new_bias
.
values
()):
self
.
assertEqual
(
new_weight
.
shape
[
0
],
assert_size
)
self
.
assertEqual
(
new_weight
.
shape
[
-
1
],
assert_size
)
models_equal
=
True
models_equal
=
True
for
p1
,
p2
in
zip
(
old_weight
.
value
(),
new_weight
.
value
(
)):
for
p1
,
p2
in
zip
(
tf
.
squeeze
(
old_weight
),
tf
.
squeeze
(
new_weight
)):
if
tf
.
math
.
reduce_sum
(
tf
.
math
.
abs
(
p1
-
p2
))
>
0
:
if
tf
.
math
.
reduce_sum
(
tf
.
math
.
abs
(
p1
-
p2
))
>
0
:
models_equal
=
False
models_equal
=
False
self
.
assertTrue
(
models_equal
)
self
.
assertTrue
(
models_equal
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment