"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "b9a768b3ffa80c4c19d024f9f42d5917e7d8109e"
Unverified Commit 6d211429 authored by SaulLu's avatar SaulLu Committed by GitHub
Browse files

fix retribert's `test_torch_encode_plus_sent_to_model` (#17231)

parent ec7f8af1
...@@ -27,9 +27,9 @@ from transformers.models.bert.tokenization_bert import ( ...@@ -27,9 +27,9 @@ from transformers.models.bert.tokenization_bert import (
_is_punctuation, _is_punctuation,
_is_whitespace, _is_whitespace,
) )
from transformers.testing_utils import require_tokenizers, slow from transformers.testing_utils import require_tokenizers, require_torch, slow
from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english, merge_model_tokenizer_mappings
# Copied from transformers.tests.bert.test_modeling_bert.py with Bert->RetriBert # Copied from transformers.tests.bert.test_modeling_bert.py with Bert->RetriBert
...@@ -338,3 +338,47 @@ class RetriBertTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -338,3 +338,47 @@ class RetriBertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
] ]
self.assertListEqual(tokens_without_spe_char_p, expected_tokens) self.assertListEqual(tokens_without_spe_char_p, expected_tokens)
self.assertListEqual(tokens_without_spe_char_r, expected_tokens) self.assertListEqual(tokens_without_spe_char_r, expected_tokens)
# RetriBertModel doesn't define `get_input_embeddings` and it's forward method doesn't take only the output of the tokenizer as input
@require_torch
@slow
def test_torch_encode_plus_sent_to_model(self):
import torch
from transformers import MODEL_MAPPING, TOKENIZER_MAPPING
MODEL_TOKENIZER_MAPPING = merge_model_tokenizer_mappings(MODEL_MAPPING, TOKENIZER_MAPPING)
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
if tokenizer.__class__ not in MODEL_TOKENIZER_MAPPING:
return
config_class, model_class = MODEL_TOKENIZER_MAPPING[tokenizer.__class__]
config = config_class()
if config.is_encoder_decoder or config.pad_token_id is None:
return
model = model_class(config)
# The following test is different from the common's one
self.assertGreaterEqual(model.bert_query.get_input_embeddings().weight.shape[0], len(tokenizer))
# Build sequence
first_ten_tokens = list(tokenizer.get_vocab().keys())[:10]
sequence = " ".join(first_ten_tokens)
encoded_sequence = tokenizer.encode_plus(sequence, return_tensors="pt")
# Ensure that the BatchEncoding.to() method works.
encoded_sequence.to(model.device)
batch_encoded_sequence = tokenizer.batch_encode_plus([sequence, sequence], return_tensors="pt")
# This should not fail
with torch.no_grad(): # saves some time
# The following lines are different from the common's ones
model.embed_questions(**encoded_sequence)
model.embed_questions(**batch_encoded_sequence)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment