Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2f40c728
Unverified
Commit
2f40c728
authored
Feb 11, 2022
by
Joao Gante
Committed by
GitHub
Feb 11, 2022
Browse files
TF MT5 embeddings resize (#15567)
* Fix TF MT5 vocab resize * more assertive testing
parent
8c03df10
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
37 additions
and
1 deletion
+37
-1
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+5
-0
tests/test_modeling_tf_mt5.py
tests/test_modeling_tf_mt5.py
+18
-1
tests/test_modeling_tf_t5.py
tests/test_modeling_tf_t5.py
+14
-0
No files found.
src/transformers/modeling_tf_utils.py
View file @
2f40c728
...
@@ -1135,6 +1135,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
...
@@ -1135,6 +1135,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
return
model_embeds
return
model_embeds
def
_get_word_embedding_weight
(
model
,
embedding_layer
):
def
_get_word_embedding_weight
(
model
,
embedding_layer
):
# If the variable holds the weights themselves, return them
if
isinstance
(
embedding_layer
,
tf
.
Tensor
):
return
embedding_layer
# Otherwise, try to get them from the layer's attributes
embeds
=
getattr
(
embedding_layer
,
"weight"
,
None
)
embeds
=
getattr
(
embedding_layer
,
"weight"
,
None
)
if
embeds
is
not
None
:
if
embeds
is
not
None
:
return
embeds
return
embeds
...
...
tests/test_modeling_tf_mt5.py
View file @
2f40c728
...
@@ -22,7 +22,24 @@ from transformers.testing_utils import require_sentencepiece, require_tf, requir
...
@@ -22,7 +22,24 @@ from transformers.testing_utils import require_sentencepiece, require_tf, requir
if
is_tf_available
():
if
is_tf_available
():
import
tensorflow
as
tf
import
tensorflow
as
tf
from
transformers
import
AutoTokenizer
,
TFAutoModelForSeq2SeqLM
from
transformers
import
AutoTokenizer
,
T5Tokenizer
,
TFAutoModelForSeq2SeqLM
,
TFMT5ForConditionalGeneration
@
require_tf
class
TFMT5ModelTest
(
unittest
.
TestCase
):
# no mixin with common tests -> most cases are already covered in the TF T5
@
slow
def
test_resize_embeddings
(
self
):
model
=
TFMT5ForConditionalGeneration
.
from_pretrained
(
"google/mt5-small"
)
original_vocab_size
=
model
.
get_input_embeddings
().
weight
.
shape
[
0
]
# the vocab size is defined in the model config
self
.
assertEqual
(
original_vocab_size
,
model
.
config
.
vocab_size
)
tokenizer
=
T5Tokenizer
.
from_pretrained
(
"google/mt5-small"
)
tokenizer
.
add_special_tokens
({
"bos_token"
:
""
,
"eos_token"
:
""
})
model
.
_resize_token_embeddings
(
len
(
tokenizer
))
# the vocab size is now resized to the length of the tokenizer, which is different from the original size
self
.
assertEqual
(
model
.
get_input_embeddings
().
weight
.
shape
[
0
],
len
(
tokenizer
))
self
.
assertNotEqual
(
model
.
get_input_embeddings
().
weight
.
shape
[
0
],
original_vocab_size
)
@
require_tf
@
require_tf
...
...
tests/test_modeling_tf_t5.py
View file @
2f40c728
...
@@ -314,6 +314,20 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -314,6 +314,20 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO: Fix head-masking according to PyTorch T5 model
# TODO: Fix head-masking according to PyTorch T5 model
pass
pass
@
slow
def
test_resize_embeddings
(
self
):
model
=
TFT5ForConditionalGeneration
.
from_pretrained
(
"t5-small"
)
original_vocab_size
=
model
.
get_input_embeddings
().
weight
.
shape
[
0
]
# the vocab size is defined in the model config
self
.
assertEqual
(
original_vocab_size
,
model
.
config
.
vocab_size
)
tokenizer
=
T5Tokenizer
.
from_pretrained
(
"t5-small"
)
tokenizer
.
add_special_tokens
({
"bos_token"
:
""
,
"eos_token"
:
""
})
model
.
_resize_token_embeddings
(
len
(
tokenizer
))
# the vocab size is now resized to the length of the tokenizer, which is different from the original size
self
.
assertEqual
(
model
.
get_input_embeddings
().
weight
.
shape
[
0
],
len
(
tokenizer
))
self
.
assertNotEqual
(
model
.
get_input_embeddings
().
weight
.
shape
[
0
],
original_vocab_size
)
class
TFT5EncoderOnlyModelTester
:
class
TFT5EncoderOnlyModelTester
:
def
__init__
(
def
__init__
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment