Unverified Commit 9aeacb58 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Adding Fast tokenizers for SentencePiece based tokenizers - Breaking: remove...


Adding Fast tokenizers for SentencePiece based tokenizers - Breaking: remove Transfo-XL fast tokenizer (#7141)

* [WIP] SP tokenizers

* fixing tests for T5

* WIP tokenizers

* serialization

* update T5

* WIP T5 tokenization

* slow to fast conversion script

* Refactoring to move tokenzier implementations inside transformers

* Adding gpt - refactoring - quality

* WIP adding several tokenizers to the fast world

* WIP Roberta - moving implementations

* update to dev4 switch file loading to in-memory loading

* Updating and fixing

* advancing on the tokenizers - updating do_lower_case

* style and quality

* moving forward with tokenizers conversion and tests

* MBart, T5

* dumping the fast version of transformer XL

* Adding to autotokenizers + style/quality

* update init and space_between_special_tokens

* style and quality

* bump up tokenizers version

* add protobuf

* fix pickle Bert JP with Mecab

* fix newly added tokenizers

* style and quality

* fix bert japanese

* fix funnel

* limite tokenizer warning to one occurence

* clean up file

* fix new tokenizers

* fast tokenizers deep tests

* WIP adding all the special fast tests on the new fast tokenizers

* quick fix

* adding more fast tokenizers in the fast tests

* all tokenizers in fast version tested

* Adding BertGenerationFast

* bump up setup.py for CI

* remove BertGenerationFast (too early)

* bump up tokenizers version

* Clean old docstrings

* Typo

* Update following Lysandre comments
Co-authored-by: default avatarSylvain Gugger <sylvain.gugger@gmail.com>
parent 4d04120c
# coding=utf-8
# Copyright 2018 Google T5 Authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
from transformers.testing_utils import _torch_available
from transformers.tokenization_camembert import CamembertTokenizer, CamembertTokenizerFast
from .test_tokenization_common import TokenizerTesterMixin
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
FRAMEWORK = "pt" if _torch_available else "tf"
class CamembertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = CamembertTokenizer
rust_tokenizer_class = CamembertTokenizerFast
test_rust_tokenizer = True
def setUp(self):
super().setUp()
# We have a SentencePiece fixture for testing
tokenizer = CamembertTokenizer(SAMPLE_VOCAB)
tokenizer.save_pretrained(self.tmpdirname)
def test_rust_and_python_full_tokenizers(self):
if not self.test_rust_tokenizer:
return
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer()
sequence = "I was born in 92000, and this is falsé."
tokens = tokenizer.tokenize(sequence)
rust_tokens = rust_tokenizer.tokenize(sequence)
self.assertListEqual(tokens, rust_tokens)
ids = tokenizer.encode(sequence, add_special_tokens=False)
rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
self.assertListEqual(ids, rust_ids)
rust_tokenizer = self.get_rust_tokenizer()
ids = tokenizer.encode(sequence)
rust_ids = rust_tokenizer.encode(sequence)
self.assertListEqual(ids, rust_ids)
...@@ -56,7 +56,9 @@ def merge_model_tokenizer_mappings( ...@@ -56,7 +56,9 @@ def merge_model_tokenizer_mappings(
class TokenizerTesterMixin: class TokenizerTesterMixin:
tokenizer_class = None tokenizer_class = None
rust_tokenizer_class = None
test_rust_tokenizer = False test_rust_tokenizer = False
space_between_special_tokens = False
def setUp(self): def setUp(self):
self.tmpdirname = tempfile.mkdtemp() self.tmpdirname = tempfile.mkdtemp()
...@@ -68,12 +70,15 @@ class TokenizerTesterMixin: ...@@ -68,12 +70,15 @@ class TokenizerTesterMixin:
input_txt = self.get_clean_sequence(tokenizer)[0] input_txt = self.get_clean_sequence(tokenizer)[0]
return input_txt, input_txt return input_txt, input_txt
def get_clean_sequence(self, tokenizer, with_prefix_space=False, max_length=20) -> Tuple[str, list]: def get_clean_sequence(self, tokenizer, with_prefix_space=False, max_length=20, min_length=5) -> Tuple[str, list]:
toks = [(i, tokenizer.decode([i], clean_up_tokenization_spaces=False)) for i in range(len(tokenizer))] toks = [(i, tokenizer.decode([i], clean_up_tokenization_spaces=False)) for i in range(len(tokenizer))]
toks = list(filter(lambda t: re.match(r"^[ a-zA-Z]+$", t[1]), toks)) toks = list(filter(lambda t: re.match(r"^[ a-zA-Z]+$", t[1]), toks))
toks = list(filter(lambda t: [t[0]] == tokenizer.encode(t[1], add_special_tokens=False), toks)) toks = list(filter(lambda t: [t[0]] == tokenizer.encode(t[1], add_special_tokens=False), toks))
if max_length is not None and len(toks) > max_length: if max_length is not None and len(toks) > max_length:
toks = toks[:max_length] toks = toks[:max_length]
if min_length is not None and len(toks) < min_length and len(toks) > 0:
while len(toks) < min_length:
toks = toks + toks
# toks_str = [t[1] for t in toks] # toks_str = [t[1] for t in toks]
toks_ids = [t[0] for t in toks] toks_ids = [t[0] for t in toks]
...@@ -99,7 +104,7 @@ class TokenizerTesterMixin: ...@@ -99,7 +104,7 @@ class TokenizerTesterMixin:
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs) return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
def get_rust_tokenizer(self, **kwargs) -> PreTrainedTokenizerFast: def get_rust_tokenizer(self, **kwargs) -> PreTrainedTokenizerFast:
raise NotImplementedError return self.rust_tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
# def get_input_output_texts(self) -> Tuple[str, str]: # def get_input_output_texts(self) -> Tuple[str, str]:
# """Feel free to overwrite""" # """Feel free to overwrite"""
...@@ -118,6 +123,29 @@ class TokenizerTesterMixin: ...@@ -118,6 +123,29 @@ class TokenizerTesterMixin:
for i in range(len(batch_encode_plus_sequences["input_ids"])) for i in range(len(batch_encode_plus_sequences["input_ids"]))
] ]
def test_rust_and_python_full_tokenizers(self):
if not self.test_rust_tokenizer:
return
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer()
sequence, _ = self.get_input_output_texts(tokenizer)
# We don't have an exact equivalence on `tokenize()` between Rust and Slow
# Slow tokenizer only split tokens, Rust tokenizers will replace with <unk>
# tokens = tokenizer.tokenize(sequence)
# rust_tokens = rust_tokenizer.tokenize(sequence)
# self.assertListEqual(tokens, rust_tokens)
ids = tokenizer.encode(sequence, add_special_tokens=False)
rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
self.assertListEqual(ids, rust_ids)
ids = tokenizer.encode(sequence, add_special_tokens=True)
rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=True)
self.assertListEqual(ids, rust_ids)
def test_tokenizers_common_properties(self): def test_tokenizers_common_properties(self):
tokenizers = self.get_tokenizers() tokenizers = self.get_tokenizers()
for tokenizer in tokenizers: for tokenizer in tokenizers:
...@@ -241,6 +269,9 @@ class TokenizerTesterMixin: ...@@ -241,6 +269,9 @@ class TokenizerTesterMixin:
tokenizers = self.get_tokenizers(fast=False, do_lower_case=True) tokenizers = self.get_tokenizers(fast=False, do_lower_case=True)
for tokenizer in tokenizers: for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"): with self.subTest(f"{tokenizer.__class__.__name__}"):
if not hasattr(tokenizer, "do_lower_case") or not tokenizer.do_lower_case:
continue
special_token = tokenizer.all_special_tokens[0] special_token = tokenizer.all_special_tokens[0]
text = special_token + " aaaaa bbbbbb low cccccccccdddddddd l " + special_token text = special_token + " aaaaa bbbbbb low cccccccccdddddddd l " + special_token
...@@ -272,6 +303,9 @@ class TokenizerTesterMixin: ...@@ -272,6 +303,9 @@ class TokenizerTesterMixin:
tokenizers = self.get_tokenizers(fast=False, do_lower_case=False) tokenizers = self.get_tokenizers(fast=False, do_lower_case=False)
for tokenizer in tokenizers: for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"): with self.subTest(f"{tokenizer.__class__.__name__}"):
if hasattr(tokenizer, "do_lower_case") and tokenizer.do_lower_case:
continue
special_token = tokenizer.all_special_tokens[0] special_token = tokenizer.all_special_tokens[0]
text = special_token + " aaaaa bbbbbb low cccccccccdddddddd l " + special_token text = special_token + " aaaaa bbbbbb low cccccccccdddddddd l " + special_token
...@@ -282,7 +316,7 @@ class TokenizerTesterMixin: ...@@ -282,7 +316,7 @@ class TokenizerTesterMixin:
toks0 = tokenizer.tokenize(text) # toks before adding new_toks toks0 = tokenizer.tokenize(text) # toks before adding new_toks
added = tokenizer.add_tokens(new_toks) added = tokenizer.add_tokens(new_toks)
self.assertEqual(added, 4) self.assertIn(added, [2, 4])
toks = tokenizer.tokenize(text) toks = tokenizer.tokenize(text)
toks2 = tokenizer.tokenize(text2) toks2 = tokenizer.tokenize(text2)
...@@ -390,12 +424,17 @@ class TokenizerTesterMixin: ...@@ -390,12 +424,17 @@ class TokenizerTesterMixin:
for tokenizer in tokenizers: for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"): with self.subTest(f"{tokenizer.__class__.__name__}"):
new_toks = ["[ABC]", "[DEF]"] # TODO(thom) add this one back when Rust toks are ready: , "GHI IHG"] # new_toks = ["[ABC]", "[DEF]"] # TODO(thom) add this one back when Rust toks are ready: , "GHI IHG"]
new_toks = [AddedToken("[ABC]", normalized=False), AddedToken("[DEF]", normalized=False)]
tokenizer.add_tokens(new_toks) tokenizer.add_tokens(new_toks)
input = "[ABC] [DEF] [ABC] [DEF]" # TODO(thom) add back cf above: "[ABC] [DEF] [ABC] GHI IHG [DEF]" input = "[ABC][DEF][ABC][DEF]" # TODO(thom) add back cf above: "[ABC] [DEF] [ABC] GHI IHG [DEF]"
if self.space_between_special_tokens:
output = "[ABC] [DEF] [ABC] [DEF]"
else:
output = input
encoded = tokenizer.encode(input, add_special_tokens=False) encoded = tokenizer.encode(input, add_special_tokens=False)
decoded = tokenizer.decode(encoded) decoded = tokenizer.decode(encoded, spaces_between_special_tokens=self.space_between_special_tokens)
self.assertEqual(decoded, input) self.assertIn(decoded, [output, output.lower()])
def test_pretrained_model_lists(self): def test_pretrained_model_lists(self):
weights_list = list(self.tokenizer_class.max_model_input_sizes.keys()) weights_list = list(self.tokenizer_class.max_model_input_sizes.keys())
...@@ -447,7 +486,7 @@ class TokenizerTesterMixin: ...@@ -447,7 +486,7 @@ class TokenizerTesterMixin:
sequence = tokenizer.encode(seq_0, add_special_tokens=False) sequence = tokenizer.encode(seq_0, add_special_tokens=False)
total_length = len(sequence) total_length = len(sequence)
assert total_length > 1, "Issue with the testing sequence, please update it it's too short" assert total_length > 4, "Issue with the testing sequence, please update it it's too short"
# Test with max model input length # Test with max model input length
model_max_length = tokenizer.model_max_length model_max_length = tokenizer.model_max_length
...@@ -546,6 +585,7 @@ class TokenizerTesterMixin: ...@@ -546,6 +585,7 @@ class TokenizerTesterMixin:
model_max_length = tokenizer.model_max_length model_max_length = tokenizer.model_max_length
self.assertEqual(model_max_length, 100) self.assertEqual(model_max_length, 100)
seq_2 = seq_0 * model_max_length seq_2 = seq_0 * model_max_length
assert len(seq_2) > model_max_length
sequence1 = tokenizer(seq_1, add_special_tokens=False) sequence1 = tokenizer(seq_1, add_special_tokens=False)
total_length1 = len(sequence1["input_ids"]) total_length1 = len(sequence1["input_ids"])
...@@ -559,9 +599,9 @@ class TokenizerTesterMixin: ...@@ -559,9 +599,9 @@ class TokenizerTesterMixin:
[False, True, "longest"] if tokenizer.pad_token and tokenizer.pad_token_id >= 0 else [False] [False, True, "longest"] if tokenizer.pad_token and tokenizer.pad_token_id >= 0 else [False]
) )
for padding_state in padding_strategies: for padding_state in padding_strategies:
with self.subTest(f"Padding: {padding_state}"): with self.subTest(f"{tokenizer.__class__.__name__} Padding: {padding_state}"):
for truncation_state in [True, "longest_first", "only_first"]: for truncation_state in [True, "longest_first", "only_first"]:
with self.subTest(f"Truncation: {truncation_state}"): with self.subTest(f"{tokenizer.__class__.__name__} Truncation: {truncation_state}"):
output = tokenizer(seq_2, seq_1, padding=padding_state, truncation=truncation_state) output = tokenizer(seq_2, seq_1, padding=padding_state, truncation=truncation_state)
self.assertEqual(len(output["input_ids"]), model_max_length) self.assertEqual(len(output["input_ids"]), model_max_length)
...@@ -748,34 +788,47 @@ class TokenizerTesterMixin: ...@@ -748,34 +788,47 @@ class TokenizerTesterMixin:
# # This is not supported with the Rust tokenizers # # This is not supported with the Rust tokenizers
# # self.assertEqual(tokenizer.encode(input_ids, add_special_tokens=True), formatted_input) # # self.assertEqual(tokenizer.encode(input_ids, add_special_tokens=True), formatted_input)
def test_swap_special_token(self): # def test_swap_special_token(self):
tokenizers = self.get_tokenizers(do_lower_case=False) # tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers: # for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"): # with self.subTest(f"{tokenizer.__class__.__name__}"):
mask = "<mask>" # # Our mask token
sequence = "Encode this sequence" # mask = "<mask>"
sequence_masked_0 = "Encode <mask> sequence" # # We take a single word in the middle of the vocabulary
sequence_masked_1 = "<mask> this sequence" # all_tokens = sorted(tokenizer.get_vocab().keys())
# word = tokenizer.decode(tokenizer.encode(all_tokens[len(all_tokens)//2], add_special_tokens=False)[:1])
# Add tokens so that masked token isn't split
tokenizer.add_tokens(sequence.split()) # sequence_0 = "Encode " + word + " sequence"
tokenizer.add_special_tokens({"mask_token": mask}) # sequence_masked_0 = "Encode " + mask + " sequence"
mask_ind = tokenizer.convert_tokens_to_ids(mask)
encoded = tokenizer.encode(sequence, add_special_tokens=False) # sequence_1 = word + " this sequence"
# sequence_masked_1 = mask + " this sequence"
# Test first masked sequence
encoded_masked = tokenizer.encode(sequence_masked_0, add_special_tokens=False) # # Add tokens so that masked token isn't split
mask_loc = encoded_masked.index(mask_ind) # # tokens = [AddedToken(t, lstrip=True, normalized=False) for t in sequence.split()]
encoded_masked[mask_loc] = encoded[mask_loc] # # tokenizer.add_tokens(tokens)
# tokenizer.add_special_tokens(
self.assertEqual(encoded_masked, encoded) # {"mask_token": AddedToken(mask, normalized=False)}
# ) # Eat left space on Byte-level BPE tokenizers
# Test second masked sequence # mask_ind = tokenizer.convert_tokens_to_ids(mask)
encoded_masked = tokenizer.encode(sequence_masked_1, add_special_tokens=False)
mask_loc = encoded_masked.index(mask_ind) # # Test first masked sequence
encoded_masked[mask_loc] = encoded[mask_loc] # encoded_0 = tokenizer.encode(sequence_0, add_special_tokens=False)
# encoded_masked = tokenizer.encode(sequence_masked_0, add_special_tokens=False)
self.assertEqual(encoded_masked, encoded) # assert len(encoded_masked) == len(encoded_0)
# mask_loc = encoded_masked.index(mask_ind)
# encoded_masked[mask_loc] = encoded_0[mask_loc]
# self.assertEqual(encoded_masked, encoded_0)
# # Test second masked sequence
# encoded_1 = tokenizer.encode(sequence_1, add_special_tokens=False)
# encoded_masked = tokenizer.encode(sequence_masked_1, add_special_tokens=False)
# assert len(encoded_masked) == len(encoded_1)
# mask_loc = encoded_masked.index(mask_ind)
# encoded_masked[mask_loc] = encoded_1[mask_loc]
# self.assertEqual(encoded_masked, encoded_1)
def test_special_tokens_mask(self): def test_special_tokens_mask(self):
tokenizers = self.get_tokenizers(do_lower_case=False) tokenizers = self.get_tokenizers(do_lower_case=False)
...@@ -919,10 +972,10 @@ class TokenizerTesterMixin: ...@@ -919,10 +972,10 @@ class TokenizerTesterMixin:
def test_padding_to_multiple_of(self): def test_padding_to_multiple_of(self):
tokenizers = self.get_tokenizers() tokenizers = self.get_tokenizers()
for tokenizer in tokenizers: for tokenizer in tokenizers:
if tokenizer.pad_token is None: with self.subTest(f"{tokenizer.__class__.__name__}"):
self.skipTest("No padding token.") if tokenizer.pad_token is None:
else: self.skipTest("No padding token.")
with self.subTest(f"{tokenizer.__class__.__name__}"): else:
empty_tokens = tokenizer("", padding=True, pad_to_multiple_of=8) empty_tokens = tokenizer("", padding=True, pad_to_multiple_of=8)
normal_tokens = tokenizer("This is a sample input", padding=True, pad_to_multiple_of=8) normal_tokens = tokenizer("This is a sample input", padding=True, pad_to_multiple_of=8)
for key, value in empty_tokens.items(): for key, value in empty_tokens.items():
...@@ -1063,14 +1116,15 @@ class TokenizerTesterMixin: ...@@ -1063,14 +1116,15 @@ class TokenizerTesterMixin:
tokenizers = self.get_tokenizers(do_lower_case=False) tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers: for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"): with self.subTest(f"{tokenizer.__class__.__name__}"):
vocab = tokenizer.get_vocab() vocab_dict = tokenizer.get_vocab()
self.assertIsInstance(vocab_dict, dict)
self.assertGreaterEqual(len(tokenizer), len(vocab_dict))
self.assertIsInstance(vocab, dict) vocab = [tokenizer.convert_ids_to_tokens(i) for i in range(len(tokenizer))]
self.assertEqual(len(vocab), len(tokenizer)) self.assertEqual(len(vocab), len(tokenizer))
tokenizer.add_tokens(["asdfasdfasdfasdf"]) tokenizer.add_tokens(["asdfasdfasdfasdf"])
vocab = tokenizer.get_vocab() vocab = [tokenizer.convert_ids_to_tokens(i) for i in range(len(tokenizer))]
self.assertIsInstance(vocab, dict)
self.assertEqual(len(vocab), len(tokenizer)) self.assertEqual(len(vocab), len(tokenizer))
def test_conversion_reversible(self): def test_conversion_reversible(self):
...@@ -1079,6 +1133,8 @@ class TokenizerTesterMixin: ...@@ -1079,6 +1133,8 @@ class TokenizerTesterMixin:
with self.subTest(f"{tokenizer.__class__.__name__}"): with self.subTest(f"{tokenizer.__class__.__name__}"):
vocab = tokenizer.get_vocab() vocab = tokenizer.get_vocab()
for word, ind in vocab.items(): for word, ind in vocab.items():
if word == tokenizer.unk_token:
continue
self.assertEqual(tokenizer.convert_tokens_to_ids(word), ind) self.assertEqual(tokenizer.convert_tokens_to_ids(word), ind)
self.assertEqual(tokenizer.convert_ids_to_tokens(ind), word) self.assertEqual(tokenizer.convert_ids_to_tokens(ind), word)
...@@ -1173,12 +1229,13 @@ class TokenizerTesterMixin: ...@@ -1173,12 +1229,13 @@ class TokenizerTesterMixin:
def test_added_token_serializable(self): def test_added_token_serializable(self):
tokenizers = self.get_tokenizers(do_lower_case=False) tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers: for tokenizer in tokenizers:
new_token = AddedToken("new_token", lstrip=True) with self.subTest(f"{tokenizer.__class__.__name__}"):
tokenizer.add_special_tokens({"additional_special_tokens": [new_token]}) new_token = AddedToken("new_token", lstrip=True)
tokenizer.add_special_tokens({"additional_special_tokens": [new_token]})
with tempfile.TemporaryDirectory() as tmp_dir_name: with tempfile.TemporaryDirectory() as tmp_dir_name:
tokenizer.save_pretrained(tmp_dir_name) tokenizer.save_pretrained(tmp_dir_name)
tokenizer.from_pretrained(tmp_dir_name) tokenizer.from_pretrained(tmp_dir_name)
def test_batch_encode_plus_padding(self): def test_batch_encode_plus_padding(self):
# Test that padded sequences are equivalent between batch_encode_plus and encode_plus # Test that padded sequences are equivalent between batch_encode_plus and encode_plus
...@@ -1243,6 +1300,9 @@ class TokenizerTesterMixin: ...@@ -1243,6 +1300,9 @@ class TokenizerTesterMixin:
for tokenizer in tokenizers: for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"): with self.subTest(f"{tokenizer.__class__.__name__}"):
if hasattr(tokenizer, "add_prefix_space") and not tokenizer.add_prefix_space:
continue
# Prepare a sequence from our tokenizer vocabulary # Prepare a sequence from our tokenizer vocabulary
sequence, ids = self.get_clean_sequence(tokenizer, with_prefix_space=True, max_length=20) sequence, ids = self.get_clean_sequence(tokenizer, with_prefix_space=True, max_length=20)
# sequence = " " + sequence # To be sure the byte-level tokenizers are feeling good # sequence = " " + sequence # To be sure the byte-level tokenizers are feeling good
...@@ -1345,12 +1405,14 @@ class TokenizerTesterMixin: ...@@ -1345,12 +1405,14 @@ class TokenizerTesterMixin:
def test_prepare_for_model(self): def test_prepare_for_model(self):
tokenizers = self.get_tokenizers(do_lower_case=False) tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers: for tokenizer in tokenizers:
string_sequence = "Testing the prepare_for_model method." with self.subTest(f"{tokenizer.__class__.__name__}"):
ids = tokenizer.encode(string_sequence, add_special_tokens=False) string_sequence = "Testing the prepare_for_model method."
input_dict = tokenizer.encode_plus(string_sequence) ids = tokenizer.encode(string_sequence, add_special_tokens=False)
prepared_input_dict = tokenizer.prepare_for_model(ids) prepared_input_dict = tokenizer.prepare_for_model(ids, add_special_tokens=True)
input_dict = tokenizer.encode_plus(string_sequence, add_special_tokens=True)
self.assertEqual(input_dict, prepared_input_dict) self.assertEqual(input_dict, prepared_input_dict)
def test_batch_encode_plus_overflowing_tokens(self): def test_batch_encode_plus_overflowing_tokens(self):
tokenizers = self.get_tokenizers(do_lower_case=False) tokenizers = self.get_tokenizers(do_lower_case=False)
......
...@@ -25,6 +25,7 @@ from .test_tokenization_common import TokenizerTesterMixin ...@@ -25,6 +25,7 @@ from .test_tokenization_common import TokenizerTesterMixin
class CTRLTokenizationTest(TokenizerTesterMixin, unittest.TestCase): class CTRLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = CTRLTokenizer tokenizer_class = CTRLTokenizer
test_rust_tokenizer = False
def setUp(self): def setUp(self):
super().setUp() super().setUp()
......
...@@ -23,9 +23,8 @@ from .test_tokenization_bert import BertTokenizationTest ...@@ -23,9 +23,8 @@ from .test_tokenization_bert import BertTokenizationTest
class DistilBertTokenizationTest(BertTokenizationTest): class DistilBertTokenizationTest(BertTokenizationTest):
tokenizer_class = DistilBertTokenizer tokenizer_class = DistilBertTokenizer
rust_tokenizer_class = DistilBertTokenizerFast
def get_rust_tokenizer(self, **kwargs): test_rust_tokenizer = True
return DistilBertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
@slow @slow
def test_sequence_builders(self): def test_sequence_builders(self):
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment