Unverified Commit 07dd7c2f authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[cleanup] test_tokenization_common.py (#4390)

parent 8f1d0471
...@@ -198,11 +198,12 @@ Follow these steps to start contributing: ...@@ -198,11 +198,12 @@ Follow these steps to start contributing:
are useful to avoid duplicated work, and to differentiate it from PRs ready are useful to avoid duplicated work, and to differentiate it from PRs ready
to be merged; to be merged;
4. Make sure existing tests pass; 4. Make sure existing tests pass;
5. Add high-coverage tests. No quality test, no merge. 5. Add high-coverage tests. No quality testing = no merge.
- If you are adding a new model, make sure that you use `ModelTester.all_model_classes = (MyModel, MyModelWithLMHead,...)`, which triggers the common tests. - If you are adding a new model, make sure that you use `ModelTester.all_model_classes = (MyModel, MyModelWithLMHead,...)`, which triggers the common tests.
- If you are adding new `@slow` tests, make sure they pass using `RUN_SLOW=1 python -m pytest tests/test_my_new_model.py`. - If you are adding new `@slow` tests, make sure they pass using `RUN_SLOW=1 python -m pytest tests/test_my_new_model.py`.
- If you are adding a new tokenizer, write tests, and make sure `RUN_SLOW=1 python -m pytest tests/test_tokenization_{your_model_name}.py` passes.
CircleCI does not run them. CircleCI does not run them.
6. All public methods must have informative docstrings; 6. All public methods must have informative docstrings that work nicely with sphinx. See `modeling_ctrl.py` for an example.
### Tests ### Tests
......
...@@ -199,7 +199,7 @@ class RobertaTokenizer(GPT2Tokenizer): ...@@ -199,7 +199,7 @@ class RobertaTokenizer(GPT2Tokenizer):
if token_ids_1 is not None: if token_ids_1 is not None:
raise ValueError( raise ValueError(
"You should not supply a second sequence if the provided sequence of " "You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model." "ids is already formatted with special tokens for the model."
) )
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
......
...@@ -771,26 +771,26 @@ class PreTrainedTokenizer(SpecialTokensMixin): ...@@ -771,26 +771,26 @@ class PreTrainedTokenizer(SpecialTokensMixin):
raise NotImplementedError raise NotImplementedError
@property @property
def is_fast(self): def is_fast(self) -> bool:
return False return False
@property @property
def max_len(self): def max_len(self) -> int:
""" Kept here for backward compatibility. """ Kept here for backward compatibility.
Now renamed to `model_max_length` to avoid ambiguity. Now renamed to `model_max_length` to avoid ambiguity.
""" """
return self.model_max_length return self.model_max_length
@property @property
def max_len_single_sentence(self): def max_len_single_sentence(self) -> int:
return self.model_max_length - self.num_special_tokens_to_add(pair=False) return self.model_max_length - self.num_special_tokens_to_add(pair=False)
@property @property
def max_len_sentences_pair(self): def max_len_sentences_pair(self) -> int:
return self.model_max_length - self.num_special_tokens_to_add(pair=True) return self.model_max_length - self.num_special_tokens_to_add(pair=True)
@max_len_single_sentence.setter @max_len_single_sentence.setter
def max_len_single_sentence(self, value): def max_len_single_sentence(self, value) -> int:
""" For backward compatibility, allow to try to setup 'max_len_single_sentence' """ """ For backward compatibility, allow to try to setup 'max_len_single_sentence' """
if value == self.model_max_length - self.num_special_tokens_to_add(pair=False): if value == self.model_max_length - self.num_special_tokens_to_add(pair=False):
logger.warning( logger.warning(
...@@ -802,7 +802,7 @@ class PreTrainedTokenizer(SpecialTokensMixin): ...@@ -802,7 +802,7 @@ class PreTrainedTokenizer(SpecialTokensMixin):
) )
@max_len_sentences_pair.setter @max_len_sentences_pair.setter
def max_len_sentences_pair(self, value): def max_len_sentences_pair(self, value) -> int:
""" For backward compatibility, allow to try to setup 'max_len_sentences_pair' """ """ For backward compatibility, allow to try to setup 'max_len_sentences_pair' """
if value == self.model_max_length - self.num_special_tokens_to_add(pair=True): if value == self.model_max_length - self.num_special_tokens_to_add(pair=True):
logger.warning( logger.warning(
...@@ -1118,7 +1118,7 @@ class PreTrainedTokenizer(SpecialTokensMixin): ...@@ -1118,7 +1118,7 @@ class PreTrainedTokenizer(SpecialTokensMixin):
return vocab_files + (special_tokens_map_file, added_tokens_file) return vocab_files + (special_tokens_map_file, added_tokens_file)
def save_vocabulary(self, save_directory): def save_vocabulary(self, save_directory) -> Tuple[str]:
""" Save the tokenizer vocabulary to a directory. This method does *NOT* save added tokens """ Save the tokenizer vocabulary to a directory. This method does *NOT* save added tokens
and special token mappings. and special token mappings.
...@@ -1128,7 +1128,7 @@ class PreTrainedTokenizer(SpecialTokensMixin): ...@@ -1128,7 +1128,7 @@ class PreTrainedTokenizer(SpecialTokensMixin):
""" """
raise NotImplementedError raise NotImplementedError
def add_tokens(self, new_tokens): def add_tokens(self, new_tokens: Union[str, List[str]]) -> int:
""" """
Add a list of new tokens to the tokenizer class. If the new tokens are not in the Add a list of new tokens to the tokenizer class. If the new tokens are not in the
vocabulary, they are added to it with indices starting from length of the current vocabulary. vocabulary, they are added to it with indices starting from length of the current vocabulary.
...@@ -1156,7 +1156,7 @@ class PreTrainedTokenizer(SpecialTokensMixin): ...@@ -1156,7 +1156,7 @@ class PreTrainedTokenizer(SpecialTokensMixin):
if not isinstance(new_tokens, list): if not isinstance(new_tokens, list):
new_tokens = [new_tokens] new_tokens = [new_tokens]
to_add_tokens = [] tokens_to_add = []
for token in new_tokens: for token in new_tokens:
assert isinstance(token, str) assert isinstance(token, str)
if self.init_kwargs.get("do_lower_case", False) and token not in self.all_special_tokens: if self.init_kwargs.get("do_lower_case", False) and token not in self.all_special_tokens:
...@@ -1164,18 +1164,18 @@ class PreTrainedTokenizer(SpecialTokensMixin): ...@@ -1164,18 +1164,18 @@ class PreTrainedTokenizer(SpecialTokensMixin):
if ( if (
token != self.unk_token token != self.unk_token
and self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token) and self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token)
and token not in to_add_tokens and token not in tokens_to_add
): ):
to_add_tokens.append(token) tokens_to_add.append(token)
logger.info("Adding %s to the vocabulary", token) logger.info("Adding %s to the vocabulary", token)
added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(to_add_tokens)) added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add))
added_tok_decoder = {v: k for k, v in added_tok_encoder.items()} added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
self.added_tokens_encoder.update(added_tok_encoder) self.added_tokens_encoder.update(added_tok_encoder)
self.unique_added_tokens_encoder = set(self.added_tokens_encoder.keys()).union(set(self.all_special_tokens)) self.unique_added_tokens_encoder = set(self.added_tokens_encoder.keys()).union(set(self.all_special_tokens))
self.added_tokens_decoder.update(added_tok_decoder) self.added_tokens_decoder.update(added_tok_decoder)
return len(to_add_tokens) return len(tokens_to_add)
def num_special_tokens_to_add(self, pair=False): def num_special_tokens_to_add(self, pair=False):
""" """
...@@ -2080,10 +2080,7 @@ class PreTrainedTokenizer(SpecialTokensMixin): ...@@ -2080,10 +2080,7 @@ class PreTrainedTokenizer(SpecialTokensMixin):
def build_inputs_with_special_tokens(self, token_ids_0: List, token_ids_1: Optional[List] = None) -> List: def build_inputs_with_special_tokens(self, token_ids_0: List, token_ids_1: Optional[List] = None) -> List:
""" """
Build model inputs from a sequence or a pair of sequence for sequence classification tasks Build model inputs from a sequence or a pair of sequence for sequence classification tasks
by concatenating and adding special tokens. by concatenating and adding special tokens. This implementation does not add special tokens.
A RoBERTa sequence has the following format:
single sequence: <s> X </s>
pair of sequences: <s> A </s></s> B </s>
""" """
if token_ids_1 is None: if token_ids_1 is None:
return token_ids_0 return token_ids_0
......
...@@ -36,9 +36,6 @@ class AlbertTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -36,9 +36,6 @@ class AlbertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer = AlbertTokenizer(SAMPLE_VOCAB) tokenizer = AlbertTokenizer(SAMPLE_VOCAB)
tokenizer.save_pretrained(self.tmpdirname) tokenizer.save_pretrained(self.tmpdirname)
def get_tokenizer(self, **kwargs):
return AlbertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = "this is a test" input_text = "this is a test"
output_text = "this is a test" output_text = "this is a test"
......
...@@ -59,9 +59,6 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -59,9 +59,6 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
def get_tokenizer(self, **kwargs):
return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_rust_tokenizer(self, **kwargs): def get_rust_tokenizer(self, **kwargs):
return BertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs) return BertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
......
...@@ -60,9 +60,6 @@ class BertJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -60,9 +60,6 @@ class BertJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
def get_tokenizer(self, **kwargs):
return BertJapaneseTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = "こんにちは、世界。 \nこんばんは、世界。" input_text = "こんにちは、世界。 \nこんばんは、世界。"
output_text = "こんにちは 、 世界 。 こんばんは 、 世界 。" output_text = "こんにちは 、 世界 。 こんばんは 、 世界 。"
......
...@@ -22,12 +22,12 @@ from collections import OrderedDict ...@@ -22,12 +22,12 @@ from collections import OrderedDict
from typing import TYPE_CHECKING, Dict, Tuple, Union from typing import TYPE_CHECKING, Dict, Tuple, Union
from tests.utils import require_tf, require_torch from tests.utils import require_tf, require_torch
from transformers import PreTrainedTokenizer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import ( from transformers import (
PretrainedConfig, PretrainedConfig,
PreTrainedTokenizer,
PreTrainedTokenizerFast, PreTrainedTokenizerFast,
PreTrainedModel, PreTrainedModel,
TFPreTrainedModel, TFPreTrainedModel,
...@@ -67,19 +67,24 @@ class TokenizerTesterMixin: ...@@ -67,19 +67,24 @@ class TokenizerTesterMixin:
def tearDown(self): def tearDown(self):
shutil.rmtree(self.tmpdirname) shutil.rmtree(self.tmpdirname)
def get_tokenizer(self, **kwargs): def get_tokenizer(self, **kwargs) -> PreTrainedTokenizer:
raise NotImplementedError return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
def get_rust_tokenizer(self, **kwargs): def get_rust_tokenizer(self, **kwargs):
raise NotImplementedError raise NotImplementedError
def get_input_output_texts(self): def get_input_output_texts(self) -> Tuple[str, str]:
raise NotImplementedError """Feel free to overwrite"""
# TODO: @property
return (
"This is a test",
"This is a test",
)
@staticmethod @staticmethod
def convert_batch_encode_plus_format_to_encode_plus(batch_encode_plus_sequences): def convert_batch_encode_plus_format_to_encode_plus(batch_encode_plus_sequences):
# Switch from batch_encode_plus format: {'input_ids': [[...], [...]], ...} # Switch from batch_encode_plus format: {'input_ids': [[...], [...]], ...}
# to the concatenated encode_plus format: [{'input_ids': [...], ...}, {'input_ids': [...], ...}] # to the list of examples/ encode_plus format: [{'input_ids': [...], ...}, {'input_ids': [...], ...}]
return [ return [
{value: batch_encode_plus_sequences[value][i] for value in batch_encode_plus_sequences.keys()} {value: batch_encode_plus_sequences[value][i] for value in batch_encode_plus_sequences.keys()}
for i in range(len(batch_encode_plus_sequences["input_ids"])) for i in range(len(batch_encode_plus_sequences["input_ids"]))
...@@ -114,13 +119,13 @@ class TokenizerTesterMixin: ...@@ -114,13 +119,13 @@ class TokenizerTesterMixin:
# Now let's start the test # Now let's start the test
tokenizer = self.get_tokenizer(max_len=42) tokenizer = self.get_tokenizer(max_len=42)
sample_text = "He is very happy, UNwant\u00E9d,running"
before_tokens = tokenizer.encode("He is very happy, UNwant\u00E9d,running", add_special_tokens=False) before_tokens = tokenizer.encode(sample_text, add_special_tokens=False)
tokenizer.save_pretrained(self.tmpdirname) tokenizer.save_pretrained(self.tmpdirname)
tokenizer = self.tokenizer_class.from_pretrained(self.tmpdirname) tokenizer = self.tokenizer_class.from_pretrained(self.tmpdirname)
after_tokens = tokenizer.encode("He is very happy, UNwant\u00E9d,running", add_special_tokens=False) after_tokens = tokenizer.encode(sample_text, add_special_tokens=False)
self.assertListEqual(before_tokens, after_tokens) self.assertListEqual(before_tokens, after_tokens)
self.assertEqual(tokenizer.max_len, 42) self.assertEqual(tokenizer.max_len, 42)
...@@ -128,6 +133,7 @@ class TokenizerTesterMixin: ...@@ -128,6 +133,7 @@ class TokenizerTesterMixin:
self.assertEqual(tokenizer.max_len, 43) self.assertEqual(tokenizer.max_len, 43)
def test_pickle_tokenizer(self): def test_pickle_tokenizer(self):
"""Google pickle __getstate__ __setstate__ if you are struggling with this."""
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
self.assertIsNotNone(tokenizer) self.assertIsNotNone(tokenizer)
...@@ -253,7 +259,7 @@ class TokenizerTesterMixin: ...@@ -253,7 +259,7 @@ class TokenizerTesterMixin:
decoded = tokenizer.decode(encoded, skip_special_tokens=True) decoded = tokenizer.decode(encoded, skip_special_tokens=True)
assert special_token not in decoded assert special_token not in decoded
def test_required_methods_tokenizer(self): def test_internal_consistency(self):
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
input_text, output_text = self.get_input_output_texts() input_text, output_text = self.get_input_output_texts()
...@@ -263,13 +269,12 @@ class TokenizerTesterMixin: ...@@ -263,13 +269,12 @@ class TokenizerTesterMixin:
self.assertListEqual(ids, ids_2) self.assertListEqual(ids, ids_2)
tokens_2 = tokenizer.convert_ids_to_tokens(ids) tokens_2 = tokenizer.convert_ids_to_tokens(ids)
self.assertNotEqual(len(tokens_2), 0)
text_2 = tokenizer.decode(ids) text_2 = tokenizer.decode(ids)
self.assertIsInstance(text_2, str)
self.assertEqual(text_2, output_text) self.assertEqual(text_2, output_text)
self.assertNotEqual(len(tokens_2), 0)
self.assertIsInstance(text_2, str)
def test_encode_decode_with_spaces(self): def test_encode_decode_with_spaces(self):
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
...@@ -429,10 +434,7 @@ class TokenizerTesterMixin: ...@@ -429,10 +434,7 @@ class TokenizerTesterMixin:
def test_special_tokens_mask(self): def test_special_tokens_mask(self):
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
sequence_0 = "Encode this." sequence_0 = "Encode this."
sequence_1 = "This one too please."
# Testing single inputs # Testing single inputs
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False) encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False)
encoded_sequence_dict = tokenizer.encode_plus( encoded_sequence_dict = tokenizer.encode_plus(
...@@ -442,13 +444,13 @@ class TokenizerTesterMixin: ...@@ -442,13 +444,13 @@ class TokenizerTesterMixin:
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"] special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special)) self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
filtered_sequence = [ filtered_sequence = [x for i, x in enumerate(encoded_sequence_w_special) if not special_tokens_mask[i]]
(x if not special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)
]
filtered_sequence = [x for x in filtered_sequence if x is not None]
self.assertEqual(encoded_sequence, filtered_sequence) self.assertEqual(encoded_sequence, filtered_sequence)
# Testing inputs pairs def test_special_tokens_mask_input_pairs(self):
tokenizer = self.get_tokenizer()
sequence_0 = "Encode this."
sequence_1 = "This one too please."
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False) encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False)
encoded_sequence += tokenizer.encode(sequence_1, add_special_tokens=False) encoded_sequence += tokenizer.encode(sequence_1, add_special_tokens=False)
encoded_sequence_dict = tokenizer.encode_plus( encoded_sequence_dict = tokenizer.encode_plus(
...@@ -464,7 +466,9 @@ class TokenizerTesterMixin: ...@@ -464,7 +466,9 @@ class TokenizerTesterMixin:
filtered_sequence = [x for x in filtered_sequence if x is not None] filtered_sequence = [x for x in filtered_sequence if x is not None]
self.assertEqual(encoded_sequence, filtered_sequence) self.assertEqual(encoded_sequence, filtered_sequence)
# Testing with already existing special tokens def test_special_tokens_mask_already_has_special_tokens(self):
tokenizer = self.get_tokenizer()
sequence_0 = "Encode this."
if tokenizer.cls_token_id == tokenizer.unk_token_id and tokenizer.cls_token_id == tokenizer.unk_token_id: if tokenizer.cls_token_id == tokenizer.unk_token_id and tokenizer.cls_token_id == tokenizer.unk_token_id:
tokenizer.add_special_tokens({"cls_token": "</s>", "sep_token": "<s>"}) tokenizer.add_special_tokens({"cls_token": "</s>", "sep_token": "<s>"})
encoded_sequence_dict = tokenizer.encode_plus( encoded_sequence_dict = tokenizer.encode_plus(
...@@ -514,13 +518,12 @@ class TokenizerTesterMixin: ...@@ -514,13 +518,12 @@ class TokenizerTesterMixin:
tokenizer.padding_side = "right" tokenizer.padding_side = "right"
padded_sequence_right = tokenizer.encode(sequence, pad_to_max_length=True) padded_sequence_right = tokenizer.encode(sequence, pad_to_max_length=True)
padded_sequence_right_length = len(padded_sequence_right) padded_sequence_right_length = len(padded_sequence_right)
assert sequence_length == padded_sequence_right_length
assert encoded_sequence == padded_sequence_right
tokenizer.padding_side = "left" tokenizer.padding_side = "left"
padded_sequence_left = tokenizer.encode(sequence, pad_to_max_length=True) padded_sequence_left = tokenizer.encode(sequence, pad_to_max_length=True)
padded_sequence_left_length = len(padded_sequence_left) padded_sequence_left_length = len(padded_sequence_left)
assert sequence_length == padded_sequence_right_length
assert encoded_sequence == padded_sequence_right
assert sequence_length == padded_sequence_left_length assert sequence_length == padded_sequence_left_length
assert encoded_sequence == padded_sequence_left assert encoded_sequence == padded_sequence_left
...@@ -617,6 +620,9 @@ class TokenizerTesterMixin: ...@@ -617,6 +620,9 @@ class TokenizerTesterMixin:
self.assertIsInstance(vocab, dict) self.assertIsInstance(vocab, dict)
self.assertEqual(len(vocab), len(tokenizer)) self.assertEqual(len(vocab), len(tokenizer))
def test_conversion_reversible(self):
tokenizer = self.get_tokenizer()
vocab = tokenizer.get_vocab()
for word, ind in vocab.items(): for word, ind in vocab.items():
self.assertEqual(tokenizer.convert_tokens_to_ids(word), ind) self.assertEqual(tokenizer.convert_tokens_to_ids(word), ind)
self.assertEqual(tokenizer.convert_ids_to_tokens(ind), word) self.assertEqual(tokenizer.convert_ids_to_tokens(ind), word)
...@@ -746,6 +752,7 @@ class TokenizerTesterMixin: ...@@ -746,6 +752,7 @@ class TokenizerTesterMixin:
@require_torch @require_torch
def test_torch_encode_plus_sent_to_model(self): def test_torch_encode_plus_sent_to_model(self):
import torch
from transformers import MODEL_MAPPING, TOKENIZER_MAPPING from transformers import MODEL_MAPPING, TOKENIZER_MAPPING
MODEL_TOKENIZER_MAPPING = merge_model_tokenizer_mappings(MODEL_MAPPING, TOKENIZER_MAPPING) MODEL_TOKENIZER_MAPPING = merge_model_tokenizer_mappings(MODEL_MAPPING, TOKENIZER_MAPPING)
...@@ -773,6 +780,8 @@ class TokenizerTesterMixin: ...@@ -773,6 +780,8 @@ class TokenizerTesterMixin:
encoded_sequence = tokenizer.encode_plus(sequence, return_tensors="pt") encoded_sequence = tokenizer.encode_plus(sequence, return_tensors="pt")
batch_encoded_sequence = tokenizer.batch_encode_plus([sequence, sequence], return_tensors="pt") batch_encoded_sequence = tokenizer.batch_encode_plus([sequence, sequence], return_tensors="pt")
# This should not fail # This should not fail
with torch.no_grad(): # saves some time
model(**encoded_sequence) model(**encoded_sequence)
model(**batch_encoded_sequence) model(**batch_encoded_sequence)
......
...@@ -24,9 +24,6 @@ class DistilBertTokenizationTest(BertTokenizationTest): ...@@ -24,9 +24,6 @@ class DistilBertTokenizationTest(BertTokenizationTest):
tokenizer_class = DistilBertTokenizer tokenizer_class = DistilBertTokenizer
def get_tokenizer(self, **kwargs):
return DistilBertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_rust_tokenizer(self, **kwargs): def get_rust_tokenizer(self, **kwargs):
return DistilBertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs) return DistilBertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
......
...@@ -64,13 +64,8 @@ class OpenAIGPTTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -64,13 +64,8 @@ class OpenAIGPTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
with open(self.merges_file, "w") as fp: with open(self.merges_file, "w") as fp:
fp.write("\n".join(merges)) fp.write("\n".join(merges))
def get_tokenizer(self, **kwargs):
return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = "lower newer" return "lower newer", "lower newer"
output_text = "lower newer"
return input_text, output_text
def test_full_tokenizer(self): def test_full_tokenizer(self):
tokenizer = OpenAIGPTTokenizer(self.vocab_file, self.merges_file) tokenizer = OpenAIGPTTokenizer(self.vocab_file, self.merges_file)
......
...@@ -37,14 +37,6 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -37,14 +37,6 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer = T5Tokenizer(SAMPLE_VOCAB) tokenizer = T5Tokenizer(SAMPLE_VOCAB)
tokenizer.save_pretrained(self.tmpdirname) tokenizer.save_pretrained(self.tmpdirname)
def get_tokenizer(self, **kwargs):
return T5Tokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self):
input_text = "This is a test"
output_text = "This is a test"
return input_text, output_text
def test_full_tokenizer(self): def test_full_tokenizer(self):
tokenizer = T5Tokenizer(SAMPLE_VOCAB) tokenizer = T5Tokenizer(SAMPLE_VOCAB)
......
...@@ -65,9 +65,6 @@ class XLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -65,9 +65,6 @@ class XLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
with open(self.merges_file, "w") as fp: with open(self.merges_file, "w") as fp:
fp.write("\n".join(merges)) fp.write("\n".join(merges))
def get_tokenizer(self, **kwargs):
return XLMTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = "lower newer" input_text = "lower newer"
output_text = "lower newer" output_text = "lower newer"
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import os import os
import unittest import unittest
from transformers.file_utils import cached_property
from transformers.tokenization_xlm_roberta import SPIECE_UNDERLINE, XLMRobertaTokenizer from transformers.tokenization_xlm_roberta import SPIECE_UNDERLINE, XLMRobertaTokenizer
from .test_tokenization_common import TokenizerTesterMixin from .test_tokenization_common import TokenizerTesterMixin
...@@ -37,14 +38,6 @@ class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -37,14 +38,6 @@ class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer = XLMRobertaTokenizer(SAMPLE_VOCAB, keep_accents=True) tokenizer = XLMRobertaTokenizer(SAMPLE_VOCAB, keep_accents=True)
tokenizer.save_pretrained(self.tmpdirname) tokenizer.save_pretrained(self.tmpdirname)
def get_tokenizer(self, **kwargs):
return XLMRobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self):
input_text = "This is a test"
output_text = "This is a test"
return input_text, output_text
def test_full_tokenizer(self): def test_full_tokenizer(self):
tokenizer = XLMRobertaTokenizer(SAMPLE_VOCAB, keep_accents=True) tokenizer = XLMRobertaTokenizer(SAMPLE_VOCAB, keep_accents=True)
...@@ -121,22 +114,22 @@ class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -121,22 +114,22 @@ class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
], ],
) )
@cached_property
def big_tokenizer(self):
return XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
@slow @slow
def test_tokenization_base_easy_symbols(self): def test_tokenization_base_easy_symbols(self):
tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
symbols = "Hello World!" symbols = "Hello World!"
original_tokenizer_encodings = [0, 35378, 6661, 38, 2] original_tokenizer_encodings = [0, 35378, 6661, 38, 2]
# xlmr = torch.hub.load('pytorch/fairseq', 'xlmr.base') # xlmr.large has same tokenizer # xlmr = torch.hub.load('pytorch/fairseq', 'xlmr.base') # xlmr.large has same tokenizer
# xlmr.eval() # xlmr.eval()
# xlmr.encode(symbols) # xlmr.encode(symbols)
self.assertListEqual(original_tokenizer_encodings, tokenizer.encode(symbols)) self.assertListEqual(original_tokenizer_encodings, self.big_tokenizer.encode(symbols))
@slow @slow
def test_tokenization_base_hard_symbols(self): def test_tokenization_base_hard_symbols(self):
tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
symbols = 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will add words that should not exsist and be tokenized to <unk>, such as saoneuhaoesuth' symbols = 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will add words that should not exsist and be tokenized to <unk>, such as saoneuhaoesuth'
original_tokenizer_encodings = [ original_tokenizer_encodings = [
0, 0,
...@@ -209,4 +202,4 @@ class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -209,4 +202,4 @@ class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
# xlmr.eval() # xlmr.eval()
# xlmr.encode(symbols) # xlmr.encode(symbols)
self.assertListEqual(original_tokenizer_encodings, tokenizer.encode(symbols)) self.assertListEqual(original_tokenizer_encodings, self.big_tokenizer.encode(symbols))
...@@ -37,14 +37,6 @@ class XLNetTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -37,14 +37,6 @@ class XLNetTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
tokenizer.save_pretrained(self.tmpdirname) tokenizer.save_pretrained(self.tmpdirname)
def get_tokenizer(self, **kwargs):
return XLNetTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self):
input_text = "This is a test"
output_text = "This is a test"
return input_text, output_text
def test_full_tokenizer(self): def test_full_tokenizer(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment