Commit 1e82cd84 authored by Lysandre's avatar Lysandre
Browse files

Flaubert auto tokenizer + tests

cc @julien-c
parent d18d47be
......@@ -50,8 +50,8 @@ class FlaubertConfig(XLMConfig):
Probability to drop layers during training (Fan et al., Reducing Transformer Depth on Demand
with Structured Dropout. ICLR 2020)
vocab_size (:obj:`int`, optional, defaults to 30145):
Vocabulary size of the XLM model. Defines the different tokens that
can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.XLMModel`.
Vocabulary size of the Flaubert model. Defines the different tokens that
can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.FlaubertModel`.
emb_dim (:obj:`int`, optional, defaults to 2048):
Dimensionality of the encoder layers and the pooler layer.
n_layer (:obj:`int`, optional, defaults to 12):
......
......@@ -25,6 +25,7 @@ from .configuration_auto import (
CamembertConfig,
CTRLConfig,
DistilBertConfig,
FlaubertConfig,
GPT2Config,
OpenAIGPTConfig,
RobertaConfig,
......@@ -41,6 +42,7 @@ from .tokenization_bert_japanese import BertJapaneseTokenizer
from .tokenization_camembert import CamembertTokenizer
from .tokenization_ctrl import CTRLTokenizer
from .tokenization_distilbert import DistilBertTokenizer
from .tokenization_flaubert import FlaubertTokenizer
from .tokenization_gpt2 import GPT2Tokenizer
from .tokenization_openai import OpenAIGPTTokenizer
from .tokenization_roberta import RobertaTokenizer
......@@ -67,6 +69,7 @@ TOKENIZER_MAPPING = OrderedDict(
(GPT2Config, GPT2Tokenizer),
(TransfoXLConfig, TransfoXLTokenizer),
(XLNetConfig, XLNetTokenizer),
(FlaubertConfig, FlaubertTokenizer),
(XLMConfig, XLMTokenizer),
(CTRLConfig, CTRLTokenizer),
]
......
......@@ -39,6 +39,14 @@ if is_torch_available():
BertForQuestionAnswering,
)
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from transformers.modeling_auto import (
MODEL_MAPPING,
MODEL_FOR_PRETRAINING_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING,
)
@require_torch
......@@ -127,3 +135,26 @@ class AutoModelTest(unittest.TestCase):
self.assertIsInstance(model, RobertaForMaskedLM)
self.assertEqual(model.num_parameters(), 14830)
self.assertEqual(model.num_parameters(only_trainable=True), 14830)
def test_parents_and_children_in_mappings(self):
# Test that the children are placed before the parents in the mappings, as the `instanceof` will be triggered
# by the parents and will return the wrong configuration type when using auto models
mappings = (
MODEL_MAPPING,
MODEL_FOR_PRETRAINING_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING,
)
for mapping in mappings:
mapping = tuple(mapping.items())
for index, (child_config, child_model) in enumerate(mapping[1:]):
for parent_config, parent_model in mapping[: index + 1]:
with self.subTest(
msg="Testing if {} is child of {}".format(child_config.__name__, parent_config.__name__)
):
self.assertFalse(issubclass(child_config, parent_config))
self.assertFalse(issubclass(child_model, parent_model))
......@@ -25,6 +25,7 @@ from transformers import (
GPT2Tokenizer,
RobertaTokenizer,
)
from transformers.tokenization_auto import TOKENIZER_MAPPING
from .utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, slow # noqa: F401
......@@ -70,3 +71,19 @@ class AutoTokenizerTest(unittest.TestCase):
for tokenizer_class in [BertTokenizer, AutoTokenizer]:
with self.assertRaises(EnvironmentError):
_ = tokenizer_class.from_pretrained("julien-c/herlolip-not-exists")
def test_parents_and_children_in_mappings(self):
# Test that the children are placed before the parents in the mappings, as the `instanceof` will be triggered
# by the parents and will return the wrong configuration type when using auto models
mappings = (TOKENIZER_MAPPING,)
for mapping in mappings:
mapping = tuple(mapping.items())
for index, (child_config, child_model) in enumerate(mapping[1:]):
for parent_config, parent_model in mapping[: index + 1]:
with self.subTest(
msg="Testing if {} is child of {}".format(child_config.__name__, parent_config.__name__)
):
self.assertFalse(issubclass(child_config, parent_config))
self.assertFalse(issubclass(child_model, parent_model))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment