Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6d211429
Unverified
Commit
6d211429
authored
May 17, 2022
by
SaulLu
Committed by
GitHub
May 17, 2022
Browse files
fix retribert's `test_torch_encode_plus_sent_to_model` (#17231)
parent
ec7f8af1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
46 additions
and
2 deletions
+46
-2
tests/models/retribert/test_tokenization_retribert.py
tests/models/retribert/test_tokenization_retribert.py
+46
-2
No files found.
tests/models/retribert/test_tokenization_retribert.py
View file @
6d211429
...
@@ -27,9 +27,9 @@ from transformers.models.bert.tokenization_bert import (
...
@@ -27,9 +27,9 @@ from transformers.models.bert.tokenization_bert import (
_is_punctuation
,
_is_punctuation
,
_is_whitespace
,
_is_whitespace
,
)
)
from
transformers.testing_utils
import
require_tokenizers
,
slow
from
transformers.testing_utils
import
require_tokenizers
,
require_torch
,
slow
from
...test_tokenization_common
import
TokenizerTesterMixin
,
filter_non_english
from
...test_tokenization_common
import
TokenizerTesterMixin
,
filter_non_english
,
merge_model_tokenizer_mappings
# Copied from transformers.tests.bert.test_modeling_bert.py with Bert->RetriBert
# Copied from transformers.tests.bert.test_modeling_bert.py with Bert->RetriBert
...
@@ -338,3 +338,47 @@ class RetriBertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
...
@@ -338,3 +338,47 @@ class RetriBertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
]
]
self
.
assertListEqual
(
tokens_without_spe_char_p
,
expected_tokens
)
self
.
assertListEqual
(
tokens_without_spe_char_p
,
expected_tokens
)
self
.
assertListEqual
(
tokens_without_spe_char_r
,
expected_tokens
)
self
.
assertListEqual
(
tokens_without_spe_char_r
,
expected_tokens
)
# RetriBertModel doesn't define `get_input_embeddings` and it's forward method doesn't take only the output of the tokenizer as input
@
require_torch
@
slow
def
test_torch_encode_plus_sent_to_model
(
self
):
import
torch
from
transformers
import
MODEL_MAPPING
,
TOKENIZER_MAPPING
MODEL_TOKENIZER_MAPPING
=
merge_model_tokenizer_mappings
(
MODEL_MAPPING
,
TOKENIZER_MAPPING
)
tokenizers
=
self
.
get_tokenizers
(
do_lower_case
=
False
)
for
tokenizer
in
tokenizers
:
with
self
.
subTest
(
f
"
{
tokenizer
.
__class__
.
__name__
}
"
):
if
tokenizer
.
__class__
not
in
MODEL_TOKENIZER_MAPPING
:
return
config_class
,
model_class
=
MODEL_TOKENIZER_MAPPING
[
tokenizer
.
__class__
]
config
=
config_class
()
if
config
.
is_encoder_decoder
or
config
.
pad_token_id
is
None
:
return
model
=
model_class
(
config
)
# The following test is different from the common's one
self
.
assertGreaterEqual
(
model
.
bert_query
.
get_input_embeddings
().
weight
.
shape
[
0
],
len
(
tokenizer
))
# Build sequence
first_ten_tokens
=
list
(
tokenizer
.
get_vocab
().
keys
())[:
10
]
sequence
=
" "
.
join
(
first_ten_tokens
)
encoded_sequence
=
tokenizer
.
encode_plus
(
sequence
,
return_tensors
=
"pt"
)
# Ensure that the BatchEncoding.to() method works.
encoded_sequence
.
to
(
model
.
device
)
batch_encoded_sequence
=
tokenizer
.
batch_encode_plus
([
sequence
,
sequence
],
return_tensors
=
"pt"
)
# This should not fail
with
torch
.
no_grad
():
# saves some time
# The following lines are different from the common's ones
model
.
embed_questions
(
**
encoded_sequence
)
model
.
embed_questions
(
**
batch_encoded_sequence
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment