Unverified Commit 6435b9f9 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Updating the TensorFlow models to work as expected with tokenizers v3.0.0 (#3684)

* Updating modeling tf files; adding tests

* Merge `encode_plus` and `batch_encode_plus`
parent 500aa123
...@@ -24,6 +24,7 @@ from .configuration_albert import AlbertConfig ...@@ -24,6 +24,7 @@ from .configuration_albert import AlbertConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_bert import ACT2FN, TFBertSelfAttention from .modeling_tf_bert import ACT2FN, TFBertSelfAttention
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -526,7 +527,7 @@ class TFAlbertMainLayer(tf.keras.layers.Layer): ...@@ -526,7 +527,7 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
head_mask = inputs[4] if len(inputs) > 4 else head_mask head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
assert len(inputs) <= 6, "Too many inputs." assert len(inputs) <= 6, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids) token_type_ids = inputs.get("token_type_ids", token_type_ids)
......
...@@ -24,6 +24,7 @@ import tensorflow as tf ...@@ -24,6 +24,7 @@ import tensorflow as tf
from .configuration_bert import BertConfig from .configuration_bert import BertConfig
from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -514,7 +515,7 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -514,7 +515,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
head_mask = inputs[4] if len(inputs) > 4 else head_mask head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
assert len(inputs) <= 6, "Too many inputs." assert len(inputs) <= 6, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids) token_type_ids = inputs.get("token_type_ids", token_type_ids)
......
...@@ -24,6 +24,7 @@ import tensorflow as tf ...@@ -24,6 +24,7 @@ import tensorflow as tf
from .configuration_ctrl import CTRLConfig from .configuration_ctrl import CTRLConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, keras_serializable, shape_list from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, keras_serializable, shape_list
from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -230,7 +231,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -230,7 +231,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
head_mask = inputs[5] if len(inputs) > 5 else head_mask head_mask = inputs[5] if len(inputs) > 5 else head_mask
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
assert len(inputs) <= 7, "Too many inputs." assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
past = inputs.get("past", past) past = inputs.get("past", past)
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
......
...@@ -25,6 +25,7 @@ import tensorflow as tf ...@@ -25,6 +25,7 @@ import tensorflow as tf
from .configuration_distilbert import DistilBertConfig from .configuration_distilbert import DistilBertConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, get_initializer, shape_list from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, get_initializer, shape_list
from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -421,7 +422,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer): ...@@ -421,7 +422,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
head_mask = inputs[2] if len(inputs) > 2 else head_mask head_mask = inputs[2] if len(inputs) > 2 else head_mask
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
assert len(inputs) <= 4, "Too many inputs." assert len(inputs) <= 4, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
head_mask = inputs.get("head_mask", head_mask) head_mask = inputs.get("head_mask", head_mask)
......
...@@ -7,6 +7,7 @@ from transformers import ElectraConfig ...@@ -7,6 +7,7 @@ from transformers import ElectraConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_bert import ACT2FN, TFBertEncoder, TFBertPreTrainedModel from .modeling_tf_bert import ACT2FN, TFBertEncoder, TFBertPreTrainedModel
from .modeling_tf_utils import get_initializer, shape_list from .modeling_tf_utils import get_initializer, shape_list
from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -237,7 +238,7 @@ class TFElectraMainLayer(TFElectraPreTrainedModel): ...@@ -237,7 +238,7 @@ class TFElectraMainLayer(TFElectraPreTrainedModel):
head_mask = inputs[4] if len(inputs) > 4 else head_mask head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
assert len(inputs) <= 6, "Too many inputs." assert len(inputs) <= 6, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids) token_type_ids = inputs.get("token_type_ids", token_type_ids)
......
...@@ -30,6 +30,7 @@ from .modeling_tf_xlm import ( ...@@ -30,6 +30,7 @@ from .modeling_tf_xlm import (
get_masks, get_masks,
shape_list, shape_list,
) )
from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -141,7 +142,7 @@ class TFFlaubertMainLayer(TFXLMMainLayer): ...@@ -141,7 +142,7 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
head_mask = inputs[7] if len(inputs) > 7 else head_mask head_mask = inputs[7] if len(inputs) > 7 else head_mask
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
assert len(inputs) <= 9, "Too many inputs." assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
langs = inputs.get("langs", langs) langs = inputs.get("langs", langs)
......
...@@ -32,6 +32,7 @@ from .modeling_tf_utils import ( ...@@ -32,6 +32,7 @@ from .modeling_tf_utils import (
keras_serializable, keras_serializable,
shape_list, shape_list,
) )
from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -255,7 +256,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -255,7 +256,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
head_mask = inputs[5] if len(inputs) > 5 else head_mask head_mask = inputs[5] if len(inputs) > 5 else head_mask
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
assert len(inputs) <= 7, "Too many inputs." assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
past = inputs.get("past", past) past = inputs.get("past", past)
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
......
...@@ -31,6 +31,7 @@ from .modeling_tf_utils import ( ...@@ -31,6 +31,7 @@ from .modeling_tf_utils import (
get_initializer, get_initializer,
shape_list, shape_list,
) )
from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -248,7 +249,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): ...@@ -248,7 +249,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
head_mask = inputs[4] if len(inputs) > 4 else head_mask head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
assert len(inputs) <= 6, "Too many inputs." assert len(inputs) <= 6, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids) token_type_ids = inputs.get("token_type_ids", token_type_ids)
......
...@@ -25,6 +25,7 @@ from .configuration_transfo_xl import TransfoXLConfig ...@@ -25,6 +25,7 @@ from .configuration_transfo_xl import TransfoXLConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -519,7 +520,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): ...@@ -519,7 +520,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
head_mask = inputs[2] if len(inputs) > 2 else head_mask head_mask = inputs[2] if len(inputs) > 2 else head_mask
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
assert len(inputs) <= 4, "Too many inputs." assert len(inputs) <= 4, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
mems = inputs.get("mems", mems) mems = inputs.get("mems", mems)
head_mask = inputs.get("head_mask", head_mask) head_mask = inputs.get("head_mask", head_mask)
......
...@@ -26,6 +26,7 @@ import tensorflow as tf ...@@ -26,6 +26,7 @@ import tensorflow as tf
from .configuration_xlm import XLMConfig from .configuration_xlm import XLMConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, get_initializer, shape_list from .modeling_tf_utils import TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, get_initializer, shape_list
from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -324,7 +325,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -324,7 +325,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
head_mask = inputs[7] if len(inputs) > 7 else head_mask head_mask = inputs[7] if len(inputs) > 7 else head_mask
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
assert len(inputs) <= 9, "Too many inputs." assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
langs = inputs.get("langs", langs) langs = inputs.get("langs", langs)
......
...@@ -32,6 +32,7 @@ from .modeling_tf_utils import ( ...@@ -32,6 +32,7 @@ from .modeling_tf_utils import (
keras_serializable, keras_serializable,
shape_list, shape_list,
) )
from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -515,7 +516,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -515,7 +516,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
head_mask = inputs[7] if len(inputs) > 7 else head_mask head_mask = inputs[7] if len(inputs) > 7 else head_mask
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
assert len(inputs) <= 9, "Too many inputs." assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
mems = inputs.get("mems", mems) mems = inputs.get("mems", mems)
......
...@@ -80,6 +80,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -80,6 +80,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = []
def __init__( def __init__(
self, self,
...@@ -413,6 +414,7 @@ class TransfoXLTokenizerFast(PreTrainedTokenizerFast): ...@@ -413,6 +414,7 @@ class TransfoXLTokenizerFast(PreTrainedTokenizerFast):
vocab_files_names = VOCAB_FILES_NAMES_FAST vocab_files_names = VOCAB_FILES_NAMES_FAST
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP_FAST pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP_FAST
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = []
def __init__( def __init__(
self, self,
......
...@@ -18,10 +18,31 @@ import os ...@@ -18,10 +18,31 @@ import os
import pickle import pickle
import shutil import shutil
import tempfile import tempfile
from collections import OrderedDict
from typing import Dict, Tuple, Union
from tests.utils import require_tf, require_torch from tests.utils import require_tf, require_torch
def merge_model_tokenizer_mappings(
model_mapping: "Dict[PretrainedConfig, Union[PreTrainedModel, TFPreTrainedModel]]", # noqa: F821
tokenizer_mapping: "Dict[PretrainedConfig, Tuple[PreTrainedTokenizer, PreTrainedTokenizerFast]]", # noqa: F821
) -> "Dict[Union[PreTrainedTokenizer, PreTrainedTokenizerFast], Tuple[PretrainedConfig, Union[PreTrainedModel, TFPreTrainedModel]]]": # noqa: F821
configurations = list(model_mapping.keys())
model_tokenizer_mapping = OrderedDict([])
for configuration in configurations:
model = model_mapping[configuration]
tokenizer = tokenizer_mapping[configuration][0]
tokenizer_fast = tokenizer_mapping[configuration][1]
model_tokenizer_mapping.update({tokenizer: (configuration, model)})
if tokenizer_fast is not None:
model_tokenizer_mapping.update({tokenizer_fast: (configuration, model)})
return model_tokenizer_mapping
class TokenizerTesterMixin: class TokenizerTesterMixin:
tokenizer_class = None tokenizer_class = None
...@@ -712,3 +733,83 @@ class TokenizerTesterMixin: ...@@ -712,3 +733,83 @@ class TokenizerTesterMixin:
# add pad_token_id to pass subsequent tests # add pad_token_id to pass subsequent tests
tokenizer.add_special_tokens({"pad_token": "<PAD>"}) tokenizer.add_special_tokens({"pad_token": "<PAD>"})
@require_torch
def test_torch_encode_plus_sent_to_model(self):
from transformers import MODEL_MAPPING, TOKENIZER_MAPPING
MODEL_TOKENIZER_MAPPING = merge_model_tokenizer_mappings(MODEL_MAPPING, TOKENIZER_MAPPING)
tokenizer = self.get_tokenizer()
if tokenizer.__class__ not in MODEL_TOKENIZER_MAPPING:
return
config_class, model_class = MODEL_TOKENIZER_MAPPING[tokenizer.__class__]
config = config_class()
if config.is_encoder_decoder or config.pad_token_id is None:
return
model = model_class(config)
# Make sure the model contains at least the full vocabulary size in its embedding matrix
is_using_common_embeddings = hasattr(model.get_input_embeddings(), "weight")
assert (model.get_input_embeddings().weight.shape[0] >= len(tokenizer)) if is_using_common_embeddings else True
# Build sequence
first_ten_tokens = list(tokenizer.get_vocab().keys())[:10]
sequence = " ".join(first_ten_tokens)
encoded_sequence = tokenizer.encode_plus(sequence, return_tensors="pt")
batch_encoded_sequence = tokenizer.batch_encode_plus([sequence, sequence], return_tensors="pt")
# This should not fail
model(**encoded_sequence)
model(**batch_encoded_sequence)
if self.test_rust_tokenizer:
fast_tokenizer = self.get_rust_tokenizer()
encoded_sequence_fast = fast_tokenizer.encode_plus(sequence, return_tensors="pt")
batch_encoded_sequence_fast = fast_tokenizer.batch_encode_plus([sequence, sequence], return_tensors="pt")
# This should not fail
model(**encoded_sequence_fast)
model(**batch_encoded_sequence_fast)
@require_tf
def test_tf_encode_plus_sent_to_model(self):
from transformers import TF_MODEL_MAPPING, TOKENIZER_MAPPING
MODEL_TOKENIZER_MAPPING = merge_model_tokenizer_mappings(TF_MODEL_MAPPING, TOKENIZER_MAPPING)
tokenizer = self.get_tokenizer()
if tokenizer.__class__ not in MODEL_TOKENIZER_MAPPING:
return
config_class, model_class = MODEL_TOKENIZER_MAPPING[tokenizer.__class__]
config = config_class()
if config.is_encoder_decoder or config.pad_token_id is None:
return
model = model_class(config)
# Make sure the model contains at least the full vocabulary size in its embedding matrix
assert model.config.vocab_size >= len(tokenizer)
# Build sequence
first_ten_tokens = list(tokenizer.get_vocab().keys())[:10]
sequence = " ".join(first_ten_tokens)
encoded_sequence = tokenizer.encode_plus(sequence, return_tensors="tf")
batch_encoded_sequence = tokenizer.batch_encode_plus([sequence, sequence], return_tensors="tf")
# This should not fail
model(encoded_sequence)
model(batch_encoded_sequence)
if self.test_rust_tokenizer:
fast_tokenizer = self.get_rust_tokenizer()
encoded_sequence_fast = fast_tokenizer.encode_plus(sequence, return_tensors="tf")
batch_encoded_sequence_fast = fast_tokenizer.batch_encode_plus([sequence, sequence], return_tensors="tf")
# This should not fail
model(encoded_sequence_fast)
model(batch_encoded_sequence_fast)
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from transformers.tokenization_distilbert import DistilBertTokenizer from transformers.tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFast
from .test_tokenization_bert import BertTokenizationTest from .test_tokenization_bert import BertTokenizationTest
from .utils import slow from .utils import slow
...@@ -27,6 +27,9 @@ class DistilBertTokenizationTest(BertTokenizationTest): ...@@ -27,6 +27,9 @@ class DistilBertTokenizationTest(BertTokenizationTest):
def get_tokenizer(self, **kwargs): def get_tokenizer(self, **kwargs):
return DistilBertTokenizer.from_pretrained(self.tmpdirname, **kwargs) return DistilBertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_rust_tokenizer(self, **kwargs):
return DistilBertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
@slow @slow
def test_sequence_builders(self): def test_sequence_builders(self):
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment