Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6d211429
"examples/run_gpt2_interactive_conditional_samples.py" did not exist on "009ee86a192ff6d1f0b3d0fd81d497887e073afd"
Unverified
Commit
6d211429
authored
May 17, 2022
by
SaulLu
Committed by
GitHub
May 17, 2022
Browse files
fix retribert's `test_torch_encode_plus_sent_to_model` (#17231)
parent
ec7f8af1
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
46 additions
and
2 deletions
+46
-2
tests/models/retribert/test_tokenization_retribert.py
tests/models/retribert/test_tokenization_retribert.py
+46
-2
No files found.
tests/models/retribert/test_tokenization_retribert.py
View file @
6d211429
...
...
@@ -27,9 +27,9 @@ from transformers.models.bert.tokenization_bert import (
_is_punctuation
,
_is_whitespace
,
)
from
transformers.testing_utils
import
require_tokenizers
,
slow
from
transformers.testing_utils
import
require_tokenizers
,
require_torch
,
slow
from
...test_tokenization_common
import
TokenizerTesterMixin
,
filter_non_english
from
...test_tokenization_common
import
TokenizerTesterMixin
,
filter_non_english
,
merge_model_tokenizer_mappings
# Copied from transformers.tests.bert.test_modeling_bert.py with Bert->RetriBert
...
...
@@ -338,3 +338,47 @@ class RetriBertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
]
self
.
assertListEqual
(
tokens_without_spe_char_p
,
expected_tokens
)
self
.
assertListEqual
(
tokens_without_spe_char_r
,
expected_tokens
)
# RetriBertModel doesn't define `get_input_embeddings` and it's forward method doesn't take only the output of the tokenizer as input
@
require_torch
@
slow
def
test_torch_encode_plus_sent_to_model
(
self
):
import
torch
from
transformers
import
MODEL_MAPPING
,
TOKENIZER_MAPPING
MODEL_TOKENIZER_MAPPING
=
merge_model_tokenizer_mappings
(
MODEL_MAPPING
,
TOKENIZER_MAPPING
)
tokenizers
=
self
.
get_tokenizers
(
do_lower_case
=
False
)
for
tokenizer
in
tokenizers
:
with
self
.
subTest
(
f
"
{
tokenizer
.
__class__
.
__name__
}
"
):
if
tokenizer
.
__class__
not
in
MODEL_TOKENIZER_MAPPING
:
return
config_class
,
model_class
=
MODEL_TOKENIZER_MAPPING
[
tokenizer
.
__class__
]
config
=
config_class
()
if
config
.
is_encoder_decoder
or
config
.
pad_token_id
is
None
:
return
model
=
model_class
(
config
)
# The following test is different from the common's one
self
.
assertGreaterEqual
(
model
.
bert_query
.
get_input_embeddings
().
weight
.
shape
[
0
],
len
(
tokenizer
))
# Build sequence
first_ten_tokens
=
list
(
tokenizer
.
get_vocab
().
keys
())[:
10
]
sequence
=
" "
.
join
(
first_ten_tokens
)
encoded_sequence
=
tokenizer
.
encode_plus
(
sequence
,
return_tensors
=
"pt"
)
# Ensure that the BatchEncoding.to() method works.
encoded_sequence
.
to
(
model
.
device
)
batch_encoded_sequence
=
tokenizer
.
batch_encode_plus
([
sequence
,
sequence
],
return_tensors
=
"pt"
)
# This should not fail
with
torch
.
no_grad
():
# saves some time
# The following lines are different from the common's ones
model
.
embed_questions
(
**
encoded_sequence
)
model
.
embed_questions
(
**
batch_encoded_sequence
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment