Unverified Commit 4fc9f9ef authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #910 from huggingface/auto_models

Adding AutoTokenizer and AutoModel classes that automatically detect architecture - Clean up tokenizers
parents 3a126e73 d43dc48b
...@@ -20,12 +20,16 @@ import json ...@@ -20,12 +20,16 @@ import json
from pytorch_transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES from pytorch_transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory from .tokenization_tests_commons import CommonTestCases
class XLMTokenizationTest(unittest.TestCase): class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
def test_full_tokenizer(self): tokenizer_class = XLMTokenizer
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
def setUp(self):
super(XLMTokenizationTest, self).setUp()
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
"w</w>", "r</w>", "t</w>", "w</w>", "r</w>", "t</w>",
"lo", "low", "er</w>", "lo", "low", "er</w>",
...@@ -33,30 +37,34 @@ class XLMTokenizationTest(unittest.TestCase): ...@@ -33,30 +37,34 @@ class XLMTokenizationTest(unittest.TestCase):
vocab_tokens = dict(zip(vocab, range(len(vocab)))) vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["l o 123", "lo w 1456", "e r</w> 1789", ""] merges = ["l o 123", "lo w 1456", "e r</w> 1789", ""]
with TemporaryDirectory() as tmpdirname: self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file']) with open(self.vocab_file, "w") as fp:
with open(vocab_file, "w") as fp: fp.write(json.dumps(vocab_tokens))
fp.write(json.dumps(vocab_tokens)) with open(self.merges_file, "w") as fp:
with open(merges_file, "w") as fp: fp.write("\n".join(merges))
fp.write("\n".join(merges))
input_text = u"lower newer" def get_tokenizer(self):
output_text = u"lower newer" return XLMTokenizer.from_pretrained(self.tmpdirname)
create_and_check_tokenizer_commons(self, input_text, output_text, XLMTokenizer, tmpdirname) def get_input_output_texts(self):
input_text = u"lower newer"
output_text = u"lower newer"
return input_text, output_text
tokenizer = XLMTokenizer(vocab_file, merges_file) def test_full_tokenizer(self):
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
tokenizer = XLMTokenizer(self.vocab_file, self.merges_file)
text = "lower" text = "lower"
bpe_tokens = ["low", "er</w>"] bpe_tokens = ["low", "er</w>"]
tokens = tokenizer.tokenize(text) tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens) self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + ["<unk>"] input_tokens = tokens + ["<unk>"]
input_bpe_tokens = [14, 15, 20] input_bpe_tokens = [14, 15, 20]
self.assertListEqual( self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -19,48 +19,58 @@ import unittest ...@@ -19,48 +19,58 @@ import unittest
from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE) from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE)
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory from .tokenization_tests_commons import CommonTestCases
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'fixtures/test_sentencepiece.model') 'fixtures/test_sentencepiece.model')
class XLNetTokenizationTest(unittest.TestCase): class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
tokenizer_class = XLNetTokenizer
def setUp(self):
super(XLNetTokenizationTest, self).setUp()
# We have a SentencePiece fixture for testing
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
tokenizer.save_pretrained(self.tmpdirname)
def get_tokenizer(self):
return XLNetTokenizer.from_pretrained(self.tmpdirname)
def get_input_output_texts(self):
input_text = u"This is a test"
output_text = u"This is a test"
return input_text, output_text
def test_full_tokenizer(self): def test_full_tokenizer(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
with TemporaryDirectory() as tmpdirname: tokens = tokenizer.tokenize(u'This is a test')
tokenizer.save_pretrained(tmpdirname) self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
input_text = u"This is a test" self.assertListEqual(
output_text = u"This is a test" tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
create_and_check_tokenizer_commons(self, input_text, output_text, XLNetTokenizer, tmpdirname) tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
tokens = tokenizer.tokenize(u'This is a test') u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est']) u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.'])
self.assertListEqual( ids = tokenizer.convert_tokens_to_ids(tokens)
tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382]) self.assertListEqual(
ids, [8, 21, 84, 55, 24, 19, 7, 0,
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") 602, 347, 347, 347, 3, 12, 66,
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', 46, 72, 80, 6, 0, 4])
u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', back_tokens = tokenizer.convert_ids_to_tokens(ids)
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.']) self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
ids = tokenizer.convert_tokens_to_ids(tokens) u'or', u'n', SPIECE_UNDERLINE + u'in',
self.assertListEqual( SPIECE_UNDERLINE + u'', u'<unk>', u'2', u'0', u'0', u'0', u',',
ids, [8, 21, 84, 55, 24, 19, 7, 0, SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
602, 347, 347, 347, 3, 12, 66, SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's',
46, 72, 80, 6, 0, 4]) u'<unk>', u'.'])
back_tokens = tokenizer.convert_ids_to_tokens(ids)
self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
u'or', u'n', SPIECE_UNDERLINE + u'in',
SPIECE_UNDERLINE + u'', u'<unk>', u'2', u'0', u'0', u'0', u',',
SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's',
u'<unk>', u'.'])
def test_tokenizer_lower(self): def test_tokenizer_lower(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True) tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True)
......
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Auto Model class. """
from __future__ import absolute_import, division, print_function, unicode_literals
import logging
from .tokenization_bert import BertTokenizer
from .tokenization_openai import OpenAIGPTTokenizer
from .tokenization_gpt2 import GPT2Tokenizer
from .tokenization_transfo_xl import TransfoXLTokenizer
from .tokenization_xlnet import XLNetTokenizer
from .tokenization_xlm import XLMTokenizer
logger = logging.getLogger(__name__)
class AutoTokenizer(object):
r""":class:`~pytorch_transformers.AutoTokenizer` is a generic tokenizer class
that will be instantiated as one of the tokenizer classes of the library
when created with the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)`
class method.
The `from_pretrained()` method take care of returning the correct tokenizer class instance
using pattern matching on the `pretrained_model_name_or_path` string.
The tokenizer class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
- contains `bert`: BertTokenizer (Bert model)
- contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
- contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
- contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
- contains `xlnet`: XLNetTokenizer (XLNet model)
- contains `xlm`: XLMTokenizer (XLM model)
This class cannot be instantiated using `__init__()` (throw an error).
"""
def __init__(self):
raise EnvironmentError("AutoTokenizer is designed to be instantiated "
"using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method.")
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
r""" Instantiate a one of the tokenizer classes of the library
from a pre-trained model vocabulary.
The tokenizer class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
- contains `bert`: BertTokenizer (Bert model)
- contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
- contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
- contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
- contains `xlnet`: XLNetTokenizer (XLNet model)
- contains `xlm`: XLMTokenizer (XLM model)
Params:
**pretrained_model_name_or_path**: either:
- a string with the `shortcut name` of a pre-trained model configuration to load from cache
or download and cache if not already stored in cache (e.g. 'bert-base-uncased').
- a path to a `directory` containing a configuration file saved
using the `save_pretrained(save_directory)` method.
- a path or url to a saved configuration `file`.
**cache_dir**: (`optional`) string:
Path to a directory in which a downloaded pre-trained model
configuration should be cached if the standard cache should not be used.
Examples::
config = AutoTokenizer.from_pretrained('bert-base-uncased') # Download vocabulary from S3 and cache.
config = AutoTokenizer.from_pretrained('./test/bert_saved_model/') # E.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`
"""
if 'bert' in pretrained_model_name_or_path:
return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'openai-gpt' in pretrained_model_name_or_path:
return OpenAIGPTTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'gpt2' in pretrained_model_name_or_path:
return GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'transfo-xl' in pretrained_model_name_or_path:
return TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'xlnet' in pretrained_model_name_or_path:
return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'xlm' in pretrained_model_name_or_path:
return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm'".format(pretrained_model_name_or_path))
...@@ -22,7 +22,7 @@ import os ...@@ -22,7 +22,7 @@ import os
import unicodedata import unicodedata
from io import open from io import open
from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization from .tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -86,7 +86,7 @@ def whitespace_tokenize(text): ...@@ -86,7 +86,7 @@ def whitespace_tokenize(text):
class BertTokenizer(PreTrainedTokenizer): class BertTokenizer(PreTrainedTokenizer):
r""" r"""
Constructs a BertTokenizer. Constructs a BertTokenizer.
:class:`~pytorch_pretrained_bert.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece :class:`~pytorch_transformers.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece
Args: Args:
vocab_file: Path to a one-wordpiece-per-line vocabulary file vocab_file: Path to a one-wordpiece-per-line vocabulary file
...@@ -119,7 +119,7 @@ class BertTokenizer(PreTrainedTokenizer): ...@@ -119,7 +119,7 @@ class BertTokenizer(PreTrainedTokenizer):
Only has an effect when do_basic_tokenize=True Only has an effect when do_basic_tokenize=True
**tokenize_chinese_chars**: (`optional`) boolean (default True) **tokenize_chinese_chars**: (`optional`) boolean (default True)
Whether to tokenize Chinese characters. Whether to tokenize Chinese characters.
This should likely be desactivated for Japanese: This should likely be deactivated for Japanese:
see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328 see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
""" """
super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token, super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token,
...@@ -214,7 +214,7 @@ class BasicTokenizer(object): ...@@ -214,7 +214,7 @@ class BasicTokenizer(object):
List of token not to split. List of token not to split.
**tokenize_chinese_chars**: (`optional`) boolean (default True) **tokenize_chinese_chars**: (`optional`) boolean (default True)
Whether to tokenize Chinese characters. Whether to tokenize Chinese characters.
This should likely be desactivated for Japanese: This should likely be deactivated for Japanese:
see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328 see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
""" """
if never_split is None: if never_split is None:
......
...@@ -31,7 +31,7 @@ except ImportError: ...@@ -31,7 +31,7 @@ except ImportError:
def lru_cache(): def lru_cache():
return lambda func: func return lambda func: func
from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization from .tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -102,9 +102,9 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -102,9 +102,9 @@ class GPT2Tokenizer(PreTrainedTokenizer):
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, merges_file, errors='replace', def __init__(self, vocab_file, merges_file, errors='replace', unk_token="<|endoftext|>",
bos_token="<|endoftext|>", eos_token="<|endoftext|>", **kwargs): bos_token="<|endoftext|>", eos_token="<|endoftext|>", **kwargs):
super(GPT2Tokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, **kwargs) super(GPT2Tokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
self.encoder = json.load(open(vocab_file)) self.encoder = json.load(open(vocab_file))
self.decoder = {v:k for k,v in self.encoder.items()} self.decoder = {v:k for k,v in self.encoder.items()}
...@@ -177,9 +177,7 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -177,9 +177,7 @@ class GPT2Tokenizer(PreTrainedTokenizer):
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
""" Converts a token (str/unicode) in an id using the vocab. """ """ Converts a token (str/unicode) in an id using the vocab. """
if token in self.encoder: return self.encoder.get(token, self.encoder.get(self.unk_token))
return self.encoder.get(token)
return self.encoder.get(self.unk_token)
def _convert_id_to_token(self, index): def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (string/unicode) using the vocab.""" """Converts an index (integer) in a token (string/unicode) using the vocab."""
......
...@@ -30,7 +30,7 @@ import torch ...@@ -30,7 +30,7 @@ import torch
import numpy as np import numpy as np
from .file_utils import cached_path from .file_utils import cached_path
from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization from .tokenization_utils import PreTrainedTokenizer
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
import cPickle as pickle import cPickle as pickle
......
...@@ -30,14 +30,34 @@ SPECIAL_TOKENS_MAP_FILE = 'special_tokens_map.json' ...@@ -30,14 +30,34 @@ SPECIAL_TOKENS_MAP_FILE = 'special_tokens_map.json'
ADDED_TOKENS_FILE = 'added_tokens.json' ADDED_TOKENS_FILE = 'added_tokens.json'
class PreTrainedTokenizer(object): class PreTrainedTokenizer(object):
""" An abstract class to handle dowloading and loading pretrained tokenizers and adding tokens to the vocabulary. """ Base class for all tokenizers.
Handle all the shared methods for tokenization and special tokens as well as methods dowloading/caching/loading pretrained tokenizers as well as adding tokens to the vocabulary.
Derived class can set up a few special tokens to be used in common scripts and internals: This class also contain the added tokens in a unified way on top of all tokenizers so we don't have to handle the specific vocabulary augmentation methods of the various underlying dictionary structures (BPE, sentencepiece...).
bos_token, eos_token, EOP_TOKEN, EOD_TOKEN, unk_token, sep_token, pad_token, cls_token, mask_token
additional_special_tokens = []
We defined an added_tokens_encoder to add new tokens to the vocabulary without having to handle the Class attributes (overridden by derived classes):
specific vocabulary augmentation methods of the various underlying dictionnary structures (BPE, sentencepiece...).
- ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file required by the model, and as associated values, the filename for saving the associated file (string).
- ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the associated pretrained vocabulary file.
- ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, the maximum length of the sequence inputs of this model, or None if the model has no maximum input size.
Parameters:
- ``bos_token``: (`Optional`) string: a beginning of sentence token. Will be associated to ``self.bos_token``
- ``eos_token``: (`Optional`) string: an end of sentence token. Will be associated to ``self.eos_token``
- ``unk_token``: (`Optional`) string: an unknown token. Will be associated to ``self.unk_token``
- ``sep_token``: (`Optional`) string: a separation token (e.g. to separate context and query in an input sequence). Will be associated to ``self.sep_token``
- ``pad_token``: (`Optional`) string: a padding token. Will be associated to ``self.pad_token``
- ``cls_token``: (`Optional`) string: a classification token (e.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model). Will be associated to ``self.cls_token``
- ``mask_token``: (`Optional`) string: a masking token (e.g. when training a model with masked-language modeling). Will be associated to ``self.mask_token``
- ``additional_special_tokens``: (`Optional`) list: a list of additional special tokens. Adding all special tokens here ensure they won't be split by the tokenization process. Will be associated to ``self.additional_special_tokens``
""" """
vocab_files_names = {} vocab_files_names = {}
pretrained_vocab_files_map = {} pretrained_vocab_files_map = {}
...@@ -49,48 +69,56 @@ class PreTrainedTokenizer(object): ...@@ -49,48 +69,56 @@ class PreTrainedTokenizer(object):
@property @property
def bos_token(self): def bos_token(self):
""" Beginning of sentence token (string). Log an error if used while not having been set. """
if self._bos_token is None: if self._bos_token is None:
logger.error("Using bos_token, but it is not set yet.") logger.error("Using bos_token, but it is not set yet.")
return self._bos_token return self._bos_token
@property @property
def eos_token(self): def eos_token(self):
""" End of sentence token (string). Log an error if used while not having been set. """
if self._eos_token is None: if self._eos_token is None:
logger.error("Using eos_token, but it is not set yet.") logger.error("Using eos_token, but it is not set yet.")
return self._eos_token return self._eos_token
@property @property
def unk_token(self): def unk_token(self):
""" Unknown token (string). Log an error if used while not having been set. """
if self._unk_token is None: if self._unk_token is None:
logger.error("Using unk_token, but it is not set yet.") logger.error("Using unk_token, but it is not set yet.")
return self._unk_token return self._unk_token
@property @property
def sep_token(self): def sep_token(self):
""" Separation token (string). E.g. separate context and query in an input sequence. Log an error if used while not having been set. """
if self._sep_token is None: if self._sep_token is None:
logger.error("Using sep_token, but it is not set yet.") logger.error("Using sep_token, but it is not set yet.")
return self._sep_token return self._sep_token
@property @property
def pad_token(self): def pad_token(self):
""" Padding token (string). Log an error if used while not having been set. """
if self._pad_token is None: if self._pad_token is None:
logger.error("Using pad_token, but it is not set yet.") logger.error("Using pad_token, but it is not set yet.")
return self._pad_token return self._pad_token
@property @property
def cls_token(self): def cls_token(self):
""" Classification token (string). E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """
if self._cls_token is None: if self._cls_token is None:
logger.error("Using cls_token, but it is not set yet.") logger.error("Using cls_token, but it is not set yet.")
return self._cls_token return self._cls_token
@property @property
def mask_token(self): def mask_token(self):
""" Mask token (string). E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """
if self._mask_token is None: if self._mask_token is None:
logger.error("Using mask_token, but it is not set yet.") logger.error("Using mask_token, but it is not set yet.")
return self._mask_token return self._mask_token
@property @property
def additional_special_tokens(self): def additional_special_tokens(self):
""" All the additional special tokens you may want to use (list of strings). Log an error if used while not having been set. """
if self._additional_special_tokens is None: if self._additional_special_tokens is None:
logger.error("Using additional_special_tokens, but it is not set yet.") logger.error("Using additional_special_tokens, but it is not set yet.")
return self._additional_special_tokens return self._additional_special_tokens
...@@ -143,20 +171,58 @@ class PreTrainedTokenizer(object): ...@@ -143,20 +171,58 @@ class PreTrainedTokenizer(object):
for key, value in kwargs.items(): for key, value in kwargs.items():
if key in self.SPECIAL_TOKENS_ATTRIBUTES: if key in self.SPECIAL_TOKENS_ATTRIBUTES:
if key == 'additional_special_tokens':
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value)
else:
assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode))
setattr(self, key, value) setattr(self, key, value)
@classmethod @classmethod
def from_pretrained(cls, *inputs, **kwargs): def from_pretrained(cls, *inputs, **kwargs):
r""" Instantiate a :class:`~pytorch_transformers.PreTrainedTokenizer` (or a derived class) from a predefined tokenizer.
Parameters:
pretrained_model_name_or_path: either:
- a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``.
- a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~pytorch_transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``.
- (not applicable to all derived classes) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``.
cache_dir: (`optional`) string:
Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used.
inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method.
kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~pytorch_transformers.PreTrainedTokenizer` for details.
Examples::
# We can't instantiate directly the base class `PreTrainedTokenizer` so let's show our examples on a derived class: BertTokenizer
# Download vocabulary from S3 and cache.
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# If vocabulary files are in a directory (e.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`)
tokenizer = BertTokenizer.from_pretrained('./test/saved_model/')
# If the tokenizer uses a single vocabulary file, you can point directly to this file
tokenizer = BertTokenizer.from_pretrained('./test/saved_model/my_vocab.txt')
# You can link tokens to special vocabulary when instantiating
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', unk_token='<unk>')
# You should be sure '<unk>' is in the vocabulary when doing that.
# Otherwise use tokenizer.add_special_tokens({'unk_token': '<unk>'}) instead)
assert tokenizer.unk_token == '<unk>'
"""
return cls._from_pretrained(*inputs, **kwargs) return cls._from_pretrained(*inputs, **kwargs)
@classmethod @classmethod
def _from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): def _from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
""" cache_dir = kwargs.pop('cache_dir', None)
Instantiate a PreTrainedTokenizer from pre-trained vocabulary files.
Download and cache the vocabulary files if needed.
"""
s3_models = list(cls.max_model_input_sizes.keys()) s3_models = list(cls.max_model_input_sizes.keys())
vocab_files = {} vocab_files = {}
if pretrained_model_name_or_path in s3_models: if pretrained_model_name_or_path in s3_models:
...@@ -271,8 +337,9 @@ class PreTrainedTokenizer(object): ...@@ -271,8 +337,9 @@ class PreTrainedTokenizer(object):
def save_pretrained(self, save_directory): def save_pretrained(self, save_directory):
""" Save the tokenizer vocabulary files (with added tokens) and the """ Save the tokenizer vocabulary files (with added tokens) and the
special-tokens-to-class-attributes-mapping to a directory, so that it special-tokens-to-class-attributes-mapping to a directory.
can be re-loaded using the `from_pretrained(save_directory)` class method.
This method make sure the full tokenizer can then be re-loaded using the :func:`~pytorch_transformers.PreTrainedTokenizer.from_pretrained` class method.
""" """
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
logger.error("Saving directory ({}) should be a directory".format(save_directory)) logger.error("Saving directory ({}) should be a directory".format(save_directory))
...@@ -297,38 +364,52 @@ class PreTrainedTokenizer(object): ...@@ -297,38 +364,52 @@ class PreTrainedTokenizer(object):
def save_vocabulary(self, save_directory): def save_vocabulary(self, save_directory):
""" Save the tokenizer vocabulary to a directory. This method doesn't save added tokens """ Save the tokenizer vocabulary to a directory. This method does *NOT* save added tokens
and special token mappings. and special token mappings.
Please use `save_pretrained()` to save the full Tokenizer state so that it can be Please use :func:`~pytorch_transformers.PreTrainedTokenizer.save_pretrained` `()` to save the full Tokenizer state if you want to reload it using the :func:`~pytorch_transformers.PreTrainedTokenizer.from_pretrained` class method.
reloaded using the `from_pretrained(save_directory)` class method.
""" """
raise NotImplementedError raise NotImplementedError
def vocab_size(self): def vocab_size(self):
""" Size of the base vocabulary (without the added tokens) """
raise NotImplementedError raise NotImplementedError
def __len__(self): def __len__(self):
""" Size of the full vocabulary with the added tokens """
return self.vocab_size + len(self.added_tokens_encoder) return self.vocab_size + len(self.added_tokens_encoder)
def add_tokens(self, new_tokens): def add_tokens(self, new_tokens):
""" Add a list of new tokens to the tokenizer class. If the new tokens are not in the """ Add a list of new tokens to the tokenizer class. If the new tokens are not in the
vocabulary, they are added to the added_tokens_encoder with indices starting from vocabulary, they are added to it with indices starting from length of the current vocabulary.
the last index of the current vocabulary.
Parameters:
new_tokens: list of string. Each string is a token to add. Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
Returns: Returns:
Number of tokens added to the vocabulary which can be used to correspondingly Number of tokens added to the vocabulary.
increase the size of the associated model embedding matrices.
Examples::
# Let's see how to increase the vocabulary of Bert model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
print('We have added', num_added_toks, 'tokens')
model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
""" """
if not new_tokens: if not new_tokens:
return 0 return 0
to_add_tokens = [] to_add_tokens = []
for token in new_tokens: for token in new_tokens:
if self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token): assert isinstance(token, str) or (six.PY2 and isinstance(token, unicode))
if token != self.unk_token and \
self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token):
to_add_tokens.append(token) to_add_tokens.append(token)
logger.info("Adding %s to the vocabulary", token) logger.info("Adding %s to the vocabulary", token)
...@@ -341,24 +422,48 @@ class PreTrainedTokenizer(object): ...@@ -341,24 +422,48 @@ class PreTrainedTokenizer(object):
def add_special_tokens(self, special_tokens_dict): def add_special_tokens(self, special_tokens_dict):
""" Add a dictionnary of special tokens (eos, pad, cls...) to the encoder and link them """ Add a dictionary of special tokens (eos, pad, cls...) to the encoder and link them
to class attributes. If the special tokens are not in the vocabulary, they are added to class attributes. If special tokens are NOT in the vocabulary, they are added
to it and indexed starting from the last index of the current vocabulary. to it (indexed starting from the last index of the current vocabulary).
Parameters:
special_tokens_dict: dict of string. Keys should be in the list of predefined special attributes: [``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``].
Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
Returns: Returns:
Number of tokens added to the vocabulary which can be used to correspondingly Number of tokens added to the vocabulary.
increase the size of the associated model embedding matrices.
Examples::
# Let's see how to add a new classification token to GPT-2
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')
special_tokens_dict = {'cls_token': '<CLS>'}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
print('We have added', num_added_toks, 'tokens')
model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
assert tokenizer.cls_token == '<CLS>'
""" """
if not special_tokens_dict: if not special_tokens_dict:
return 0 return 0
added_special_tokens = self.add_tokens(special_tokens_dict.values()) added_tokens = 0
for key, value in special_tokens_dict.items(): for key, value in special_tokens_dict.items():
assert key in self.SPECIAL_TOKENS_ATTRIBUTES
if key == 'additional_special_tokens':
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value)
added_tokens += self.add_tokens(value)
else:
assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode))
added_tokens += self.add_tokens([value])
logger.info("Assigning %s to the %s key of the tokenizer", value, key) logger.info("Assigning %s to the %s key of the tokenizer", value, key)
setattr(self, key, value) setattr(self, key, value)
return added_special_tokens return added_tokens
def tokenize(self, text, **kwargs): def tokenize(self, text, **kwargs):
""" Converts a string in a sequence of tokens (string), using the tokenizer. """ Converts a string in a sequence of tokens (string), using the tokenizer.
...@@ -386,13 +491,13 @@ class PreTrainedTokenizer(object): ...@@ -386,13 +491,13 @@ class PreTrainedTokenizer(object):
Split in words for word-based vocabulary or sub-words for sub-word-based Split in words for word-based vocabulary or sub-words for sub-word-based
vocabularies (BPE/SentencePieces/WordPieces). vocabularies (BPE/SentencePieces/WordPieces).
Don't take care of added tokens. Do NOT take care of added tokens.
""" """
raise NotImplementedError raise NotImplementedError
def convert_tokens_to_ids(self, tokens): def convert_tokens_to_ids(self, tokens):
""" Converts a single token or a sequence of tokens (str/unicode) in a integer id """ Converts a single token, or a sequence of tokens, (str/unicode) in a single integer id
(resp.) a sequence of ids, using the vocabulary. (resp. a sequence of ids), using the vocabulary.
""" """
if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)): if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)):
return self._convert_token_to_id_with_added_voc(tokens) return self._convert_token_to_id_with_added_voc(tokens)
...@@ -417,7 +522,8 @@ class PreTrainedTokenizer(object): ...@@ -417,7 +522,8 @@ class PreTrainedTokenizer(object):
def encode(self, text): def encode(self, text):
""" Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary. """ Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
same as self.convert_tokens_to_ids(self.tokenize(text)).
Same doing ``self.convert_tokens_to_ids(self.tokenize(text))``.
""" """
return self.convert_tokens_to_ids(self.tokenize(text)) return self.convert_tokens_to_ids(self.tokenize(text))
...@@ -457,11 +563,13 @@ class PreTrainedTokenizer(object): ...@@ -457,11 +563,13 @@ class PreTrainedTokenizer(object):
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True): def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary """ Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
with options to remove special tokens and clean up tokenization spaces. with options to remove special tokens and clean up tokenization spaces.
Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``.
""" """
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
text = self.convert_tokens_to_string(filtered_tokens) text = self.convert_tokens_to_string(filtered_tokens)
if clean_up_tokenization_spaces: if clean_up_tokenization_spaces:
text = clean_up_tokenization(text) text = self.clean_up_tokenization(text)
return text return text
@property @property
...@@ -497,10 +605,11 @@ class PreTrainedTokenizer(object): ...@@ -497,10 +605,11 @@ class PreTrainedTokenizer(object):
all_ids = list(self.convert_tokens_to_ids(t) for t in all_toks) all_ids = list(self.convert_tokens_to_ids(t) for t in all_toks)
return all_ids return all_ids
@staticmethod
def clean_up_tokenization(out_string):
def clean_up_tokenization(out_string): """ Clean up a list of simple English tokenization artifacts like spaces before punctuations and abreviated forms.
out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',' """
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't" out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re") ).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
return out_string ).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
return out_string
...@@ -23,7 +23,7 @@ from shutil import copyfile ...@@ -23,7 +23,7 @@ from shutil import copyfile
import unicodedata import unicodedata
import six import six
from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization from .tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
# PyTorch # PyTorch
torch>=0.4.1 torch>=1.0.0
# progress bars in model download and training scripts # progress bars in model download and training scripts
tqdm tqdm
# Accessing files from S3 directly. # Accessing files from S3 directly.
......
...@@ -49,7 +49,7 @@ setup( ...@@ -49,7 +49,7 @@ setup(
url="https://github.com/huggingface/pytorch-transformers", url="https://github.com/huggingface/pytorch-transformers",
packages=find_packages(exclude=["*.tests", "*.tests.*", packages=find_packages(exclude=["*.tests", "*.tests.*",
"tests.*", "tests"]), "tests.*", "tests"]),
install_requires=['torch>=0.4.1', install_requires=['torch>=1.0.0',
'numpy', 'numpy',
'boto3', 'boto3',
'requests', 'requests',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment