Unverified Commit d2f21f08 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #1092 from shijie-wu/xlm-tokenization

Added cleaned configuration properties for tokenizer with serialization - improve tokenization of XLM
parents 12b9cc9e 7044ed6b
...@@ -44,6 +44,8 @@ XLM_PRETRAINED_MODEL_ARCHIVE_MAP = { ...@@ -44,6 +44,8 @@ XLM_PRETRAINED_MODEL_ARCHIVE_MAP = {
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-pytorch_model.bin", 'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-pytorch_model.bin",
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-pytorch_model.bin", 'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-pytorch_model.bin",
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-pytorch_model.bin", 'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-pytorch_model.bin",
'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-pytorch_model.json",
'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-pytorch_model.json",
} }
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = { XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json", 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json",
...@@ -54,6 +56,8 @@ XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = { ...@@ -54,6 +56,8 @@ XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-config.json", 'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-config.json",
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-config.json", 'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-config.json",
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-config.json", 'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-config.json",
'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-config.json",
'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-config.json",
} }
...@@ -114,6 +118,7 @@ class XLMConfig(PretrainedConfig): ...@@ -114,6 +118,7 @@ class XLMConfig(PretrainedConfig):
causal=False, causal=False,
asm=False, asm=False,
n_langs=1, n_langs=1,
use_lang_emb=True,
max_position_embeddings=512, max_position_embeddings=512,
embed_init_std=2048 ** -0.5, embed_init_std=2048 ** -0.5,
layer_norm_eps=1e-12, layer_norm_eps=1e-12,
...@@ -157,6 +162,7 @@ class XLMConfig(PretrainedConfig): ...@@ -157,6 +162,7 @@ class XLMConfig(PretrainedConfig):
self.causal = causal self.causal = causal
self.asm = asm self.asm = asm
self.n_langs = n_langs self.n_langs = n_langs
self.use_lang_emb = use_lang_emb
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.bos_index = bos_index self.bos_index = bos_index
self.eos_index = eos_index self.eos_index = eos_index
...@@ -488,7 +494,7 @@ class XLMModel(XLMPreTrainedModel): ...@@ -488,7 +494,7 @@ class XLMModel(XLMPreTrainedModel):
""" """
ATTRIBUTES = ['encoder', 'eos_index', 'pad_index', # 'with_output', ATTRIBUTES = ['encoder', 'eos_index', 'pad_index', # 'with_output',
'n_langs', 'n_words', 'dim', 'n_layers', 'n_heads', 'n_langs', 'use_lang_emb', 'n_words', 'dim', 'n_layers', 'n_heads',
'hidden_dim', 'dropout', 'attention_dropout', 'asm', 'hidden_dim', 'dropout', 'attention_dropout', 'asm',
'asm_cutoffs', 'asm_div_value'] 'asm_cutoffs', 'asm_div_value']
...@@ -507,6 +513,7 @@ class XLMModel(XLMPreTrainedModel): ...@@ -507,6 +513,7 @@ class XLMModel(XLMPreTrainedModel):
# dictionary / languages # dictionary / languages
self.n_langs = config.n_langs self.n_langs = config.n_langs
self.use_lang_emb = config.use_lang_emb
self.n_words = config.n_words self.n_words = config.n_words
self.eos_index = config.eos_index self.eos_index = config.eos_index
self.pad_index = config.pad_index self.pad_index = config.pad_index
...@@ -529,7 +536,7 @@ class XLMModel(XLMPreTrainedModel): ...@@ -529,7 +536,7 @@ class XLMModel(XLMPreTrainedModel):
self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.dim) self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.dim)
if config.sinusoidal_embeddings: if config.sinusoidal_embeddings:
create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight) create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight)
if config.n_langs > 1: if config.n_langs > 1 and config.use_lang_emb:
self.lang_embeddings = nn.Embedding(self.n_langs, self.dim) self.lang_embeddings = nn.Embedding(self.n_langs, self.dim)
self.embeddings = nn.Embedding(self.n_words, self.dim, padding_idx=self.pad_index) self.embeddings = nn.Embedding(self.n_words, self.dim, padding_idx=self.pad_index)
self.layer_norm_emb = nn.LayerNorm(self.dim, eps=config.layer_norm_eps) self.layer_norm_emb = nn.LayerNorm(self.dim, eps=config.layer_norm_eps)
...@@ -628,7 +635,7 @@ class XLMModel(XLMPreTrainedModel): ...@@ -628,7 +635,7 @@ class XLMModel(XLMPreTrainedModel):
# embeddings # embeddings
tensor = self.embeddings(input_ids) tensor = self.embeddings(input_ids)
tensor = tensor + self.position_embeddings(position_ids).expand_as(tensor) tensor = tensor + self.position_embeddings(position_ids).expand_as(tensor)
if langs is not None: if langs is not None and self.use_lang_emb:
tensor = tensor + self.lang_embeddings(langs) tensor = tensor + self.lang_embeddings(langs)
if token_type_ids is not None: if token_type_ids is not None:
tensor = tensor + self.embeddings(token_type_ids) tensor = tensor + self.embeddings(token_type_ids)
......
...@@ -41,8 +41,8 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -41,8 +41,8 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
def get_tokenizer(self): def get_tokenizer(self, **kwargs):
return self.tokenizer_class.from_pretrained(self.tmpdirname) return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"UNwant\u00E9d,running" input_text = u"UNwant\u00E9d,running"
......
...@@ -27,8 +27,8 @@ class DistilBertTokenizationTest(BertTokenizationTest): ...@@ -27,8 +27,8 @@ class DistilBertTokenizationTest(BertTokenizationTest):
tokenizer_class = DistilBertTokenizer tokenizer_class = DistilBertTokenizer
def get_tokenizer(self): def get_tokenizer(self, **kwargs):
return DistilBertTokenizer.from_pretrained(self.tmpdirname) return DistilBertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def test_sequence_builders(self): def test_sequence_builders(self):
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
......
...@@ -44,8 +44,9 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -44,8 +44,9 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
with open(self.merges_file, "w") as fp: with open(self.merges_file, "w") as fp:
fp.write("\n".join(merges)) fp.write("\n".join(merges))
def get_tokenizer(self): def get_tokenizer(self, **kwargs):
return GPT2Tokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map) kwargs.update(self.special_tokens_map)
return GPT2Tokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"lower newer" input_text = u"lower newer"
......
...@@ -45,8 +45,8 @@ class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -45,8 +45,8 @@ class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester):
with open(self.merges_file, "w") as fp: with open(self.merges_file, "w") as fp:
fp.write("\n".join(merges)) fp.write("\n".join(merges))
def get_tokenizer(self): def get_tokenizer(self, **kwargs):
return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname) return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"lower newer" input_text = u"lower newer"
......
...@@ -43,8 +43,9 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -43,8 +43,9 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
with open(self.merges_file, "w") as fp: with open(self.merges_file, "w") as fp:
fp.write("\n".join(merges)) fp.write("\n".join(merges))
def get_tokenizer(self): def get_tokenizer(self, **kwargs):
return RobertaTokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map) kwargs.update(self.special_tokens_map)
return RobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"lower newer" input_text = u"lower newer"
......
...@@ -49,24 +49,33 @@ class CommonTestCases: ...@@ -49,24 +49,33 @@ class CommonTestCases:
def tearDown(self): def tearDown(self):
shutil.rmtree(self.tmpdirname) shutil.rmtree(self.tmpdirname)
def get_tokenizer(self): def get_tokenizer(self, **kwargs):
raise NotImplementedError raise NotImplementedError
def get_input_output_texts(self): def get_input_output_texts(self):
raise NotImplementedError raise NotImplementedError
def test_save_and_load_tokenizer(self): def test_save_and_load_tokenizer(self):
# safety check on max_len default value so we are sure the test works
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
self.assertNotEqual(tokenizer.max_len, 42)
# Now let's start the test
tokenizer = self.get_tokenizer(max_len=42)
before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
with TemporaryDirectory() as tmpdirname: with TemporaryDirectory() as tmpdirname:
tokenizer.save_pretrained(tmpdirname) tokenizer.save_pretrained(tmpdirname)
tokenizer = tokenizer.from_pretrained(tmpdirname) tokenizer = self.tokenizer_class.from_pretrained(tmpdirname)
after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
self.assertListEqual(before_tokens, after_tokens) self.assertListEqual(before_tokens, after_tokens)
self.assertEqual(tokenizer.max_len, 42)
tokenizer = self.tokenizer_class.from_pretrained(tmpdirname, max_len=43)
self.assertEqual(tokenizer.max_len, 43)
def test_pickle_tokenizer(self): def test_pickle_tokenizer(self):
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
self.assertIsNotNone(tokenizer) self.assertIsNotNone(tokenizer)
......
...@@ -37,8 +37,9 @@ class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -37,8 +37,9 @@ class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester):
with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
def get_tokenizer(self): def get_tokenizer(self, **kwargs):
return TransfoXLTokenizer.from_pretrained(self.tmpdirname, lower_case=True) kwargs['lower_case'] = True
return TransfoXLTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"<unk> UNwanted , running" input_text = u"<unk> UNwanted , running"
......
...@@ -44,8 +44,8 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -44,8 +44,8 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
with open(self.merges_file, "w") as fp: with open(self.merges_file, "w") as fp:
fp.write("\n".join(merges)) fp.write("\n".join(merges))
def get_tokenizer(self): def get_tokenizer(self, **kwargs):
return XLMTokenizer.from_pretrained(self.tmpdirname) return XLMTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"lower newer" input_text = u"lower newer"
......
...@@ -35,8 +35,8 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -35,8 +35,8 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
tokenizer.save_pretrained(self.tmpdirname) tokenizer.save_pretrained(self.tmpdirname)
def get_tokenizer(self): def get_tokenizer(self, **kwargs):
return XLNetTokenizer.from_pretrained(self.tmpdirname) return XLNetTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"This is a test" input_text = u"This is a test"
......
...@@ -63,6 +63,23 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { ...@@ -63,6 +63,23 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'bert-base-cased-finetuned-mrpc': 512, 'bert-base-cased-finetuned-mrpc': 512,
} }
PRETRAINED_INIT_CONFIGURATION = {
'bert-base-uncased': {'do_lower_case': True},
'bert-large-uncased': {'do_lower_case': True},
'bert-base-cased': {'do_lower_case': False},
'bert-large-cased': {'do_lower_case': False},
'bert-base-multilingual-uncased': {'do_lower_case': True},
'bert-base-multilingual-cased': {'do_lower_case': False},
'bert-base-chinese': {'do_lower_case': False},
'bert-base-german-cased': {'do_lower_case': False},
'bert-large-uncased-whole-word-masking': {'do_lower_case': True},
'bert-large-cased-whole-word-masking': {'do_lower_case': False},
'bert-large-uncased-whole-word-masking-finetuned-squad': {'do_lower_case': True},
'bert-large-cased-whole-word-masking-finetuned-squad': {'do_lower_case': False},
'bert-base-cased-finetuned-mrpc': {'do_lower_case': False},
}
def load_vocab(vocab_file): def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary.""" """Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict() vocab = collections.OrderedDict()
...@@ -100,6 +117,7 @@ class BertTokenizer(PreTrainedTokenizer): ...@@ -100,6 +117,7 @@ class BertTokenizer(PreTrainedTokenizer):
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None, def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None,
...@@ -202,24 +220,6 @@ class BertTokenizer(PreTrainedTokenizer): ...@@ -202,24 +220,6 @@ class BertTokenizer(PreTrainedTokenizer):
index += 1 index += 1
return (vocab_file,) return (vocab_file,)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
""" Instantiate a BertTokenizer from pre-trained vocabulary files.
"""
if pretrained_model_name_or_path in PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES:
if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True):
logger.warning("The pre-trained model you are loading is a cased model but you have not set "
"`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
"you may want to check this behavior.")
kwargs['do_lower_case'] = False
elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True):
logger.warning("The pre-trained model you are loading is an uncased model but you have set "
"`do_lower_case` to False. We are setting `do_lower_case=True` for you "
"but you may want to check this behavior.")
kwargs['do_lower_case'] = True
return super(BertTokenizer, cls)._from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
class BasicTokenizer(object): class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.).""" """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
......
...@@ -95,6 +95,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -95,6 +95,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
# in a library like ours, at all. # in a library like ours, at all.
vocab_dict = torch.load(pretrained_vocab_file) vocab_dict = torch.load(pretrained_vocab_file)
for key, value in vocab_dict.items(): for key, value in vocab_dict.items():
if key not in self.__dict__:
self.__dict__[key] = value self.__dict__[key] = value
if vocab_file is not None: if vocab_file is not None:
......
...@@ -20,6 +20,7 @@ import logging ...@@ -20,6 +20,7 @@ import logging
import os import os
import json import json
import six import six
import copy
from io import open from io import open
from .file_utils import cached_path from .file_utils import cached_path
...@@ -28,6 +29,7 @@ logger = logging.getLogger(__name__) ...@@ -28,6 +29,7 @@ logger = logging.getLogger(__name__)
SPECIAL_TOKENS_MAP_FILE = 'special_tokens_map.json' SPECIAL_TOKENS_MAP_FILE = 'special_tokens_map.json'
ADDED_TOKENS_FILE = 'added_tokens.json' ADDED_TOKENS_FILE = 'added_tokens.json'
TOKENIZER_CONFIG_FILE = 'tokenizer_config.json'
class PreTrainedTokenizer(object): class PreTrainedTokenizer(object):
""" Base class for all tokenizers. """ Base class for all tokenizers.
...@@ -40,6 +42,7 @@ class PreTrainedTokenizer(object): ...@@ -40,6 +42,7 @@ class PreTrainedTokenizer(object):
- ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file required by the model, and as associated values, the filename for saving the associated file (string). - ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file required by the model, and as associated values, the filename for saving the associated file (string).
- ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the associated pretrained vocabulary file. - ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the associated pretrained vocabulary file.
- ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, the maximum length of the sequence inputs of this model, or None if the model has no maximum input size. - ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, the maximum length of the sequence inputs of this model, or None if the model has no maximum input size.
- ``pretrained_init_configuration``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, a dictionnary of specific arguments to pass to the ``__init__``method of the tokenizer class for this pretrained model when loading the tokenizer with the ``from_pretrained()`` method.
Parameters: Parameters:
...@@ -61,6 +64,7 @@ class PreTrainedTokenizer(object): ...@@ -61,6 +64,7 @@ class PreTrainedTokenizer(object):
""" """
vocab_files_names = {} vocab_files_names = {}
pretrained_vocab_files_map = {} pretrained_vocab_files_map = {}
pretrained_init_configuration = {}
max_model_input_sizes = {} max_model_input_sizes = {}
SPECIAL_TOKENS_ATTRIBUTES = ["bos_token", "eos_token", "unk_token", "sep_token", SPECIAL_TOKENS_ATTRIBUTES = ["bos_token", "eos_token", "unk_token", "sep_token",
...@@ -166,12 +170,15 @@ class PreTrainedTokenizer(object): ...@@ -166,12 +170,15 @@ class PreTrainedTokenizer(object):
self._additional_special_tokens = [] self._additional_special_tokens = []
self.max_len = max_len if max_len is not None else int(1e12) self.max_len = max_len if max_len is not None else int(1e12)
self.max_len_single_sentence = self.max_len
self.max_len_sentences_pair = self.max_len
# Added tokens
self.added_tokens_encoder = {} self.added_tokens_encoder = {}
self.added_tokens_decoder = {} self.added_tokens_decoder = {}
# inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``)
self.init_inputs = ()
self.init_kwargs = {}
for key, value in kwargs.items(): for key, value in kwargs.items():
if key in self.SPECIAL_TOKENS_ATTRIBUTES: if key in self.SPECIAL_TOKENS_ATTRIBUTES:
if key == 'additional_special_tokens': if key == 'additional_special_tokens':
...@@ -231,17 +238,20 @@ class PreTrainedTokenizer(object): ...@@ -231,17 +238,20 @@ class PreTrainedTokenizer(object):
@classmethod @classmethod
def _from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): def _from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs):
cache_dir = kwargs.pop('cache_dir', None) cache_dir = kwargs.pop('cache_dir', None)
force_download = kwargs.pop('force_download', False) force_download = kwargs.pop('force_download', False)
proxies = kwargs.pop('proxies', None) proxies = kwargs.pop('proxies', None)
s3_models = list(cls.max_model_input_sizes.keys()) s3_models = list(cls.max_model_input_sizes.keys())
vocab_files = {} vocab_files = {}
init_configuration = {}
if pretrained_model_name_or_path in s3_models: if pretrained_model_name_or_path in s3_models:
# Get the vocabulary from AWS S3 bucket # Get the vocabulary from AWS S3 bucket
for file_id, map_list in cls.pretrained_vocab_files_map.items(): for file_id, map_list in cls.pretrained_vocab_files_map.items():
vocab_files[file_id] = map_list[pretrained_model_name_or_path] vocab_files[file_id] = map_list[pretrained_model_name_or_path]
if cls.pretrained_init_configuration and pretrained_model_name_or_path in cls.pretrained_init_configuration:
init_configuration = cls.pretrained_init_configuration[pretrained_model_name_or_path]
else: else:
# Get the vocabulary from local files # Get the vocabulary from local files
logger.info( logger.info(
...@@ -264,15 +274,17 @@ class PreTrainedTokenizer(object): ...@@ -264,15 +274,17 @@ class PreTrainedTokenizer(object):
vocab_files[file_id] = full_file_name vocab_files[file_id] = full_file_name
# Look for the additional tokens files # Look for the additional tokens files
all_vocab_files_names = {'added_tokens_file': ADDED_TOKENS_FILE, additional_files_names = {'added_tokens_file': ADDED_TOKENS_FILE,
'special_tokens_map_file': SPECIAL_TOKENS_MAP_FILE} 'special_tokens_map_file': SPECIAL_TOKENS_MAP_FILE,
'tokenizer_config_file': TOKENIZER_CONFIG_FILE,
}
# If a path to a file was provided, get the parent directory # If a path to a file was provided, get the parent directory
saved_directory = pretrained_model_name_or_path saved_directory = pretrained_model_name_or_path
if os.path.exists(saved_directory) and not os.path.isdir(saved_directory): if os.path.exists(saved_directory) and not os.path.isdir(saved_directory):
saved_directory = os.path.dirname(saved_directory) saved_directory = os.path.dirname(saved_directory)
for file_id, file_name in all_vocab_files_names.items(): for file_id, file_name in additional_files_names.items():
full_file_name = os.path.join(saved_directory, file_name) full_file_name = os.path.join(saved_directory, file_name)
if not os.path.exists(full_file_name): if not os.path.exists(full_file_name):
logger.info("Didn't find file {}. We won't load it.".format(full_file_name)) logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
...@@ -315,28 +327,46 @@ class PreTrainedTokenizer(object): ...@@ -315,28 +327,46 @@ class PreTrainedTokenizer(object):
logger.info("loading file {} from cache at {}".format( logger.info("loading file {} from cache at {}".format(
file_path, resolved_vocab_files[file_id])) file_path, resolved_vocab_files[file_id]))
# Prepare tokenizer initialization kwargs
# Did we saved some inputs and kwargs to reload ?
tokenizer_config_file = resolved_vocab_files.pop('tokenizer_config_file', None)
if tokenizer_config_file is not None:
init_kwargs = json.load(open(tokenizer_config_file, encoding="utf-8"))
saved_init_inputs = init_kwargs.pop('init_inputs', ())
if not init_inputs:
init_inputs = saved_init_inputs
else:
init_kwargs = init_configuration
# Update with newly provided kwargs
init_kwargs.update(kwargs)
# Set max length if needed # Set max length if needed
if pretrained_model_name_or_path in cls.max_model_input_sizes: if pretrained_model_name_or_path in cls.max_model_input_sizes:
# if we're using a pretrained model, ensure the tokenizer # if we're using a pretrained model, ensure the tokenizer
# wont index sequences longer than the number of positional embeddings # wont index sequences longer than the number of positional embeddings
max_len = cls.max_model_input_sizes[pretrained_model_name_or_path] max_len = cls.max_model_input_sizes[pretrained_model_name_or_path]
if max_len is not None and isinstance(max_len, (int, float)): if max_len is not None and isinstance(max_len, (int, float)):
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) init_kwargs['max_len'] = min(init_kwargs.get('max_len', int(1e12)), max_len)
# Merge resolved_vocab_files arguments in kwargs. # Merge resolved_vocab_files arguments in init_kwargs.
added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None) added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None)
special_tokens_map_file = resolved_vocab_files.pop('special_tokens_map_file', None) special_tokens_map_file = resolved_vocab_files.pop('special_tokens_map_file', None)
for args_name, file_path in resolved_vocab_files.items(): for args_name, file_path in resolved_vocab_files.items():
if args_name not in kwargs: if args_name not in init_kwargs:
kwargs[args_name] = file_path init_kwargs[args_name] = file_path
if special_tokens_map_file is not None: if special_tokens_map_file is not None:
special_tokens_map = json.load(open(special_tokens_map_file, encoding="utf-8")) special_tokens_map = json.load(open(special_tokens_map_file, encoding="utf-8"))
for key, value in special_tokens_map.items(): for key, value in special_tokens_map.items():
if key not in kwargs: if key not in init_kwargs:
kwargs[key] = value init_kwargs[key] = value
# Instantiate tokenizer. # Instantiate tokenizer.
tokenizer = cls(*inputs, **kwargs) tokenizer = cls(*init_inputs, **init_kwargs)
# Save inputs and kwargs for saving and re-loading with ``save_pretrained``
tokenizer.init_inputs = init_inputs
tokenizer.init_kwargs = init_kwargs
# Add supplementary tokens. # Add supplementary tokens.
if added_tokens_file is not None: if added_tokens_file is not None:
...@@ -349,8 +379,13 @@ class PreTrainedTokenizer(object): ...@@ -349,8 +379,13 @@ class PreTrainedTokenizer(object):
def save_pretrained(self, save_directory): def save_pretrained(self, save_directory):
""" Save the tokenizer vocabulary files (with added tokens) and the """ Save the tokenizer vocabulary files together with:
special-tokens-to-class-attributes-mapping to a directory. - added tokens,
- special-tokens-to-class-attributes-mapping,
- tokenizer instantiation positional and keywords inputs (e.g. do_lower_case for Bert).
This won't save modifications other than (added tokens and special token mapping) you may have
applied to the tokenizer after the instantion (e.g. modifying tokenizer.do_lower_case after creation).
This method make sure the full tokenizer can then be re-loaded using the :func:`~pytorch_transformers.PreTrainedTokenizer.from_pretrained` class method. This method make sure the full tokenizer can then be re-loaded using the :func:`~pytorch_transformers.PreTrainedTokenizer.from_pretrained` class method.
""" """
...@@ -360,6 +395,15 @@ class PreTrainedTokenizer(object): ...@@ -360,6 +395,15 @@ class PreTrainedTokenizer(object):
special_tokens_map_file = os.path.join(save_directory, SPECIAL_TOKENS_MAP_FILE) special_tokens_map_file = os.path.join(save_directory, SPECIAL_TOKENS_MAP_FILE)
added_tokens_file = os.path.join(save_directory, ADDED_TOKENS_FILE) added_tokens_file = os.path.join(save_directory, ADDED_TOKENS_FILE)
tokenizer_config_file = os.path.join(save_directory, TOKENIZER_CONFIG_FILE)
tokenizer_config = copy.deepcopy(self.init_kwargs)
tokenizer_config['init_inputs'] = copy.deepcopy(self.init_inputs)
for file_id in self.vocab_files_names.keys():
tokenizer_config.pop(file_id, None)
with open(tokenizer_config_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(tokenizer_config, ensure_ascii=False))
with open(special_tokens_map_file, 'w', encoding='utf-8') as f: with open(special_tokens_map_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.special_tokens_map, ensure_ascii=False)) f.write(json.dumps(self.special_tokens_map, ensure_ascii=False))
...@@ -566,7 +610,7 @@ class PreTrainedTokenizer(object): ...@@ -566,7 +610,7 @@ class PreTrainedTokenizer(object):
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
raise NotImplementedError raise NotImplementedError
def encode(self, text, text_pair=None, add_special_tokens=False): def encode(self, text, text_pair=None, add_special_tokens=False, **kwargs):
""" """
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary. Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
...@@ -577,15 +621,16 @@ class PreTrainedTokenizer(object): ...@@ -577,15 +621,16 @@ class PreTrainedTokenizer(object):
text_pair: Optional second sequence to be encoded. text_pair: Optional second sequence to be encoded.
add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
to their model. to their model.
**kwargs: passed to the `self.tokenize()` method
""" """
if text_pair is None: if text_pair is None:
if add_special_tokens: if add_special_tokens:
return self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text))) return self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text, **kwargs)))
else: else:
return self.convert_tokens_to_ids(self.tokenize(text)) return self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text)] first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)]
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair)] second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)]
if add_special_tokens: if add_special_tokens:
return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens) return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens)
......
This diff is collapsed.
...@@ -61,7 +61,7 @@ class XLNetTokenizer(PreTrainedTokenizer): ...@@ -61,7 +61,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, max_len=None, def __init__(self, vocab_file,
do_lower_case=False, remove_space=True, keep_accents=False, do_lower_case=False, remove_space=True, keep_accents=False,
bos_token="<s>", eos_token="</s>", unk_token="<unk>", sep_token="<sep>", bos_token="<s>", eos_token="</s>", unk_token="<unk>", sep_token="<sep>",
pad_token="<pad>", cls_token="<cls>", mask_token="<mask>", pad_token="<pad>", cls_token="<cls>", mask_token="<mask>",
......
...@@ -10,3 +10,5 @@ requests ...@@ -10,3 +10,5 @@ requests
regex regex
# For XLNet # For XLNet
sentencepiece sentencepiece
# For XLM
sacremoses
\ No newline at end of file
...@@ -55,7 +55,8 @@ setup( ...@@ -55,7 +55,8 @@ setup(
'requests', 'requests',
'tqdm', 'tqdm',
'regex', 'regex',
'sentencepiece'], 'sentencepiece',
'sacremoses'],
entry_points={ entry_points={
'console_scripts': [ 'console_scripts': [
"pytorch_transformers=pytorch_transformers.__main__:main", "pytorch_transformers=pytorch_transformers.__main__:main",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment