Commit fa84ae26 authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Reformat source code with black.

This is the result of:

    $ black --line-length 119 examples templates transformers utils hubconf.py setup.py

There's a lot of fairly long lines in the project. As a consequence, I'm
picking the longest widely accepted line length, 119 characters.

This is also Thomas' preference, because it allows for explicit variable
names, to make the code easier to understand.
parent 63e3827c
...@@ -17,13 +17,13 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -17,13 +17,13 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
from transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE) from transformers.tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
from .utils import slow from .utils import slow
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
'fixtures/test_sentencepiece.model')
class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester): class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
...@@ -40,55 +40,135 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -40,55 +40,135 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
return XLNetTokenizer.from_pretrained(self.tmpdirname, **kwargs) return XLNetTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"This is a test" input_text = "This is a test"
output_text = u"This is a test" output_text = "This is a test"
return input_text, output_text return input_text, output_text
def test_full_tokenizer(self): def test_full_tokenizer(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
tokens = tokenizer.tokenize(u'This is a test') tokens = tokenizer.tokenize("This is a test")
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est']) self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"])
self.assertListEqual( self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.'])
ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual( self.assertListEqual(
ids, [8, 21, 84, 55, 24, 19, 7, 0, tokens,
602, 347, 347, 347, 3, 12, 66, [
46, 72, 80, 6, 0, 4]) SPIECE_UNDERLINE + "I",
SPIECE_UNDERLINE + "was",
SPIECE_UNDERLINE + "b",
"or",
"n",
SPIECE_UNDERLINE + "in",
SPIECE_UNDERLINE + "",
"9",
"2",
"0",
"0",
"0",
",",
SPIECE_UNDERLINE + "and",
SPIECE_UNDERLINE + "this",
SPIECE_UNDERLINE + "is",
SPIECE_UNDERLINE + "f",
"al",
"s",
"é",
".",
],
)
ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(ids, [8, 21, 84, 55, 24, 19, 7, 0, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 0, 4])
back_tokens = tokenizer.convert_ids_to_tokens(ids) back_tokens = tokenizer.convert_ids_to_tokens(ids)
self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', self.assertListEqual(
u'or', u'n', SPIECE_UNDERLINE + u'in', back_tokens,
SPIECE_UNDERLINE + u'', u'<unk>', u'2', u'0', u'0', u'0', u',', [
SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', SPIECE_UNDERLINE + "I",
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', SPIECE_UNDERLINE + "was",
u'<unk>', u'.']) SPIECE_UNDERLINE + "b",
"or",
"n",
SPIECE_UNDERLINE + "in",
SPIECE_UNDERLINE + "",
"<unk>",
"2",
"0",
"0",
"0",
",",
SPIECE_UNDERLINE + "and",
SPIECE_UNDERLINE + "this",
SPIECE_UNDERLINE + "is",
SPIECE_UNDERLINE + "f",
"al",
"s",
"<unk>",
".",
],
)
def test_tokenizer_lower(self): def test_tokenizer_lower(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True) tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True)
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'', u'i', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', self.assertListEqual(
u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', tokens,
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', [
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.']) SPIECE_UNDERLINE + "",
self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), [u"▁he", u"ll", u"o"]) "i",
SPIECE_UNDERLINE + "was",
SPIECE_UNDERLINE + "b",
"or",
"n",
SPIECE_UNDERLINE + "in",
SPIECE_UNDERLINE + "",
"9",
"2",
"0",
"0",
"0",
",",
SPIECE_UNDERLINE + "and",
SPIECE_UNDERLINE + "this",
SPIECE_UNDERLINE + "is",
SPIECE_UNDERLINE + "f",
"al",
"se",
".",
],
)
self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["▁he", "ll", "o"])
def test_tokenizer_no_lower(self): def test_tokenizer_no_lower(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=False) tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=False)
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', u'or', self.assertListEqual(
u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', tokens,
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', [
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.']) SPIECE_UNDERLINE + "I",
SPIECE_UNDERLINE + "was",
SPIECE_UNDERLINE + "b",
"or",
"n",
SPIECE_UNDERLINE + "in",
SPIECE_UNDERLINE + "",
"9",
"2",
"0",
"0",
"0",
",",
SPIECE_UNDERLINE + "and",
SPIECE_UNDERLINE + "this",
SPIECE_UNDERLINE + "is",
SPIECE_UNDERLINE + "f",
"al",
"se",
".",
],
)
@slow @slow
def test_sequence_builders(self): def test_sequence_builders(self):
...@@ -104,5 +184,5 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -104,5 +184,5 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
assert encoded_pair == text + [4] + text_2 + [4, 3] assert encoded_pair == text + [4] + text_2 + [4, 3]
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -27,6 +27,7 @@ def parse_flag_from_env(key, default=False): ...@@ -27,6 +27,7 @@ def parse_flag_from_env(key, default=False):
raise ValueError("If set, {} must be yes or no.".format(key)) raise ValueError("If set, {} must be yes or no.".format(key))
return _value return _value
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) _run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
_run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False) _run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False)
......
...@@ -13,8 +13,7 @@ ...@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Tokenization classes for ALBERT model.""" """ Tokenization classes for ALBERT model."""
from __future__ import (absolute_import, division, print_function, from __future__ import absolute_import, division, print_function, unicode_literals
unicode_literals)
from .tokenization_utils import PreTrainedTokenizer from .tokenization_utils import PreTrainedTokenizer
import logging import logging
...@@ -24,34 +23,34 @@ import os ...@@ -24,34 +23,34 @@ import os
from shutil import copyfile from shutil import copyfile
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {'vocab_file': 'spiece.model'} VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file': "vocab_file": {
{ "albert-base-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-spiece.model",
'albert-base-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-spiece.model", "albert-large-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-spiece.model",
'albert-large-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-spiece.model", "albert-xlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-spiece.model",
'albert-xlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-spiece.model", "albert-xxlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-spiece.model",
'albert-xxlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-spiece.model", "albert-base-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-spiece.model",
'albert-base-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-spiece.model", "albert-large-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-spiece.model",
'albert-large-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-spiece.model", "albert-xlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-spiece.model",
'albert-xlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-spiece.model", "albert-xxlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-spiece.model",
'albert-xxlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-spiece.model",
} }
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'albert-base-v1': 512, "albert-base-v1": 512,
'albert-large-v1': 512, "albert-large-v1": 512,
'albert-xlarge-v1': 512, "albert-xlarge-v1": 512,
'albert-xxlarge-v1': 512, "albert-xxlarge-v1": 512,
'albert-base-v2': 512, "albert-base-v2": 512,
'albert-large-v2': 512, "albert-large-v2": 512,
'albert-xlarge-v2': 512, "albert-xlarge-v2": 512,
'albert-xxlarge-v2': 512, "albert-xxlarge-v2": 512,
} }
SPIECE_UNDERLINE = u'▁' SPIECE_UNDERLINE = "▁"
class AlbertTokenizer(PreTrainedTokenizer): class AlbertTokenizer(PreTrainedTokenizer):
""" """
...@@ -59,18 +58,36 @@ class AlbertTokenizer(PreTrainedTokenizer): ...@@ -59,18 +58,36 @@ class AlbertTokenizer(PreTrainedTokenizer):
- requires `SentencePiece <https://github.com/google/sentencepiece>`_ - requires `SentencePiece <https://github.com/google/sentencepiece>`_
""" """
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, def __init__(
do_lower_case=True, remove_space=True, keep_accents=False, self,
bos_token="[CLS]", eos_token="[SEP]", unk_token="<unk>", sep_token="[SEP]", vocab_file,
pad_token="<pad>", cls_token="[CLS]", mask_token="[MASK]", **kwargs): do_lower_case=True,
super(AlbertTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, remove_space=True,
unk_token=unk_token, sep_token=sep_token, keep_accents=False,
pad_token=pad_token, cls_token=cls_token, bos_token="[CLS]",
mask_token=mask_token, **kwargs) eos_token="[SEP]",
unk_token="<unk>",
sep_token="[SEP]",
pad_token="<pad>",
cls_token="[CLS]",
mask_token="[MASK]",
**kwargs
):
super(AlbertTokenizer, self).__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
**kwargs
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
...@@ -78,8 +95,10 @@ class AlbertTokenizer(PreTrainedTokenizer): ...@@ -78,8 +95,10 @@ class AlbertTokenizer(PreTrainedTokenizer):
try: try:
import sentencepiece as spm import sentencepiece as spm
except ImportError: except ImportError:
logger.warning("You need to install SentencePiece to use AlbertTokenizer: https://github.com/google/sentencepiece" logger.warning(
"pip install sentencepiece") "You need to install SentencePiece to use AlbertTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
self.do_lower_case = do_lower_case self.do_lower_case = do_lower_case
self.remove_space = remove_space self.remove_space = remove_space
...@@ -103,24 +122,26 @@ class AlbertTokenizer(PreTrainedTokenizer): ...@@ -103,24 +122,26 @@ class AlbertTokenizer(PreTrainedTokenizer):
try: try:
import sentencepiece as spm import sentencepiece as spm
except ImportError: except ImportError:
logger.warning("You need to install SentencePiece to use AlbertTokenizer: https://github.com/google/sentencepiece" logger.warning(
"pip install sentencepiece") "You need to install SentencePiece to use AlbertTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
self.sp_model = spm.SentencePieceProcessor() self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(self.vocab_file) self.sp_model.Load(self.vocab_file)
def preprocess_text(self, inputs): def preprocess_text(self, inputs):
if self.remove_space: if self.remove_space:
outputs = ' '.join(inputs.strip().split()) outputs = " ".join(inputs.strip().split())
else: else:
outputs = inputs outputs = inputs
outputs = outputs.replace("``", '"').replace("''", '"') outputs = outputs.replace("``", '"').replace("''", '"')
if six.PY2 and isinstance(outputs, str): if six.PY2 and isinstance(outputs, str):
outputs = outputs.decode('utf-8') outputs = outputs.decode("utf-8")
if not self.keep_accents: if not self.keep_accents:
outputs = unicodedata.normalize('NFKD', outputs) outputs = unicodedata.normalize("NFKD", outputs)
outputs = ''.join([c for c in outputs if not unicodedata.combining(c)]) outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
if self.do_lower_case: if self.do_lower_case:
outputs = outputs.lower() outputs = outputs.lower()
...@@ -133,7 +154,7 @@ class AlbertTokenizer(PreTrainedTokenizer): ...@@ -133,7 +154,7 @@ class AlbertTokenizer(PreTrainedTokenizer):
text = self.preprocess_text(text) text = self.preprocess_text(text)
# note(zhiliny): in some systems, sentencepiece only accepts str for py2 # note(zhiliny): in some systems, sentencepiece only accepts str for py2
if six.PY2 and isinstance(text, unicode): if six.PY2 and isinstance(text, unicode):
text = text.encode('utf-8') text = text.encode("utf-8")
if not sample: if not sample:
pieces = self.sp_model.EncodeAsPieces(text) pieces = self.sp_model.EncodeAsPieces(text)
...@@ -141,9 +162,8 @@ class AlbertTokenizer(PreTrainedTokenizer): ...@@ -141,9 +162,8 @@ class AlbertTokenizer(PreTrainedTokenizer):
pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1) pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1)
new_pieces = [] new_pieces = []
for piece in pieces: for piece in pieces:
if len(piece) > 1 and piece[-1] == str(',') and piece[-2].isdigit(): if len(piece) > 1 and piece[-1] == str(",") and piece[-2].isdigit():
cur_pieces = self.sp_model.EncodeAsPieces( cur_pieces = self.sp_model.EncodeAsPieces(piece[:-1].replace(SPIECE_UNDERLINE, ""))
piece[:-1].replace(SPIECE_UNDERLINE, ''))
if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE: if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
if len(cur_pieces[0]) == 1: if len(cur_pieces[0]) == 1:
cur_pieces = cur_pieces[1:] cur_pieces = cur_pieces[1:]
...@@ -159,7 +179,7 @@ class AlbertTokenizer(PreTrainedTokenizer): ...@@ -159,7 +179,7 @@ class AlbertTokenizer(PreTrainedTokenizer):
ret_pieces = [] ret_pieces = []
for piece in new_pieces: for piece in new_pieces:
if isinstance(piece, str): if isinstance(piece, str):
piece = piece.decode('utf-8') piece = piece.decode("utf-8")
ret_pieces.append(piece) ret_pieces.append(piece)
new_pieces = ret_pieces new_pieces = ret_pieces
...@@ -173,12 +193,12 @@ class AlbertTokenizer(PreTrainedTokenizer): ...@@ -173,12 +193,12 @@ class AlbertTokenizer(PreTrainedTokenizer):
"""Converts an index (integer) in a token (string/unicode) using the vocab.""" """Converts an index (integer) in a token (string/unicode) using the vocab."""
token = self.sp_model.IdToPiece(index) token = self.sp_model.IdToPiece(index)
if six.PY2 and return_unicode and isinstance(token, str): if six.PY2 and return_unicode and isinstance(token, str):
token = token.decode('utf-8') token = token.decode("utf-8")
return token return token
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string.""" """Converts a sequence of tokens (strings for sub-words) in a single string."""
out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip() out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
return out_string return out_string
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
...@@ -213,8 +233,10 @@ class AlbertTokenizer(PreTrainedTokenizer): ...@@ -213,8 +233,10 @@ class AlbertTokenizer(PreTrainedTokenizer):
if already_has_special_tokens: if already_has_special_tokens:
if token_ids_1 is not None: if token_ids_1 is not None:
raise ValueError("You should not supply a second sequence if the provided sequence of " raise ValueError(
"ids is already formated with special tokens for the model.") "You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
if token_ids_1 is not None: if token_ids_1 is not None:
...@@ -244,7 +266,7 @@ class AlbertTokenizer(PreTrainedTokenizer): ...@@ -244,7 +266,7 @@ class AlbertTokenizer(PreTrainedTokenizer):
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return return
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
......
...@@ -35,6 +35,7 @@ from .tokenization_xlm_roberta import XLMRobertaTokenizer ...@@ -35,6 +35,7 @@ from .tokenization_xlm_roberta import XLMRobertaTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AutoTokenizer(object): class AutoTokenizer(object):
r""":class:`~transformers.AutoTokenizer` is a generic tokenizer class r""":class:`~transformers.AutoTokenizer` is a generic tokenizer class
that will be instantiated as one of the tokenizer classes of the library that will be instantiated as one of the tokenizer classes of the library
...@@ -62,9 +63,12 @@ class AutoTokenizer(object): ...@@ -62,9 +63,12 @@ class AutoTokenizer(object):
This class cannot be instantiated using `__init__()` (throw an error). This class cannot be instantiated using `__init__()` (throw an error).
""" """
def __init__(self): def __init__(self):
raise EnvironmentError("AutoTokenizer is designed to be instantiated " raise EnvironmentError(
"using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method.") "AutoTokenizer is designed to be instantiated "
"using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method."
)
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
...@@ -125,34 +129,38 @@ class AutoTokenizer(object): ...@@ -125,34 +129,38 @@ class AutoTokenizer(object):
tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/') tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/')
""" """
if 't5' in pretrained_model_name_or_path: if "t5" in pretrained_model_name_or_path:
return T5Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return T5Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'distilbert' in pretrained_model_name_or_path: elif "distilbert" in pretrained_model_name_or_path:
return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'albert' in pretrained_model_name_or_path: elif "albert" in pretrained_model_name_or_path:
return AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'camembert' in pretrained_model_name_or_path: elif "camembert" in pretrained_model_name_or_path:
return CamembertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return CamembertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'xlm-roberta' in pretrained_model_name_or_path: elif "xlm-roberta" in pretrained_model_name_or_path:
return XLMRobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return XLMRobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'roberta' in pretrained_model_name_or_path: elif "roberta" in pretrained_model_name_or_path:
return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'bert-base-japanese' in pretrained_model_name_or_path: elif "bert-base-japanese" in pretrained_model_name_or_path:
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'bert' in pretrained_model_name_or_path: elif "bert" in pretrained_model_name_or_path:
return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'openai-gpt' in pretrained_model_name_or_path: elif "openai-gpt" in pretrained_model_name_or_path:
return OpenAIGPTTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return OpenAIGPTTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'gpt2' in pretrained_model_name_or_path: elif "gpt2" in pretrained_model_name_or_path:
return GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'transfo-xl' in pretrained_model_name_or_path: elif "transfo-xl" in pretrained_model_name_or_path:
return TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'xlnet' in pretrained_model_name_or_path: elif "xlnet" in pretrained_model_name_or_path:
return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'xlm' in pretrained_model_name_or_path: elif "xlm" in pretrained_model_name_or_path:
return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'ctrl' in pretrained_model_name_or_path: elif "ctrl" in pretrained_model_name_or_path:
return CTRLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return CTRLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of " raise ValueError(
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " "Unrecognized model identifier in {}. Should contains one of "
"'xlm-roberta', 'xlm', 'roberta', 'distilbert,' 'camembert', 'ctrl', 'albert'".format(pretrained_model_name_or_path)) "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm-roberta', 'xlm', 'roberta', 'distilbert,' 'camembert', 'ctrl', 'albert'".format(
pretrained_model_name_or_path
)
)
...@@ -26,69 +26,68 @@ from .tokenization_utils import PreTrainedTokenizer ...@@ -26,69 +26,68 @@ from .tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'} VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file': "vocab_file": {
{ "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", "bert-base-german-cased": "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt", "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt",
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt", "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt",
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt", "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt",
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt", "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt", "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt", "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-vocab.txt",
'bert-base-german-dbmdz-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-vocab.txt", "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-vocab.txt",
'bert-base-german-dbmdz-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-vocab.txt", "bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/vocab.txt",
'bert-base-finnish-cased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/vocab.txt", "bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/vocab.txt",
'bert-base-finnish-uncased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/vocab.txt",
} }
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'bert-base-uncased': 512, "bert-base-uncased": 512,
'bert-large-uncased': 512, "bert-large-uncased": 512,
'bert-base-cased': 512, "bert-base-cased": 512,
'bert-large-cased': 512, "bert-large-cased": 512,
'bert-base-multilingual-uncased': 512, "bert-base-multilingual-uncased": 512,
'bert-base-multilingual-cased': 512, "bert-base-multilingual-cased": 512,
'bert-base-chinese': 512, "bert-base-chinese": 512,
'bert-base-german-cased': 512, "bert-base-german-cased": 512,
'bert-large-uncased-whole-word-masking': 512, "bert-large-uncased-whole-word-masking": 512,
'bert-large-cased-whole-word-masking': 512, "bert-large-cased-whole-word-masking": 512,
'bert-large-uncased-whole-word-masking-finetuned-squad': 512, "bert-large-uncased-whole-word-masking-finetuned-squad": 512,
'bert-large-cased-whole-word-masking-finetuned-squad': 512, "bert-large-cased-whole-word-masking-finetuned-squad": 512,
'bert-base-cased-finetuned-mrpc': 512, "bert-base-cased-finetuned-mrpc": 512,
'bert-base-german-dbmdz-cased': 512, "bert-base-german-dbmdz-cased": 512,
'bert-base-german-dbmdz-uncased': 512, "bert-base-german-dbmdz-uncased": 512,
'bert-base-finnish-cased-v1': 512, "bert-base-finnish-cased-v1": 512,
'bert-base-finnish-uncased-v1': 512, "bert-base-finnish-uncased-v1": 512,
} }
PRETRAINED_INIT_CONFIGURATION = { PRETRAINED_INIT_CONFIGURATION = {
'bert-base-uncased': {'do_lower_case': True}, "bert-base-uncased": {"do_lower_case": True},
'bert-large-uncased': {'do_lower_case': True}, "bert-large-uncased": {"do_lower_case": True},
'bert-base-cased': {'do_lower_case': False}, "bert-base-cased": {"do_lower_case": False},
'bert-large-cased': {'do_lower_case': False}, "bert-large-cased": {"do_lower_case": False},
'bert-base-multilingual-uncased': {'do_lower_case': True}, "bert-base-multilingual-uncased": {"do_lower_case": True},
'bert-base-multilingual-cased': {'do_lower_case': False}, "bert-base-multilingual-cased": {"do_lower_case": False},
'bert-base-chinese': {'do_lower_case': False}, "bert-base-chinese": {"do_lower_case": False},
'bert-base-german-cased': {'do_lower_case': False}, "bert-base-german-cased": {"do_lower_case": False},
'bert-large-uncased-whole-word-masking': {'do_lower_case': True}, "bert-large-uncased-whole-word-masking": {"do_lower_case": True},
'bert-large-cased-whole-word-masking': {'do_lower_case': False}, "bert-large-cased-whole-word-masking": {"do_lower_case": False},
'bert-large-uncased-whole-word-masking-finetuned-squad': {'do_lower_case': True}, "bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True},
'bert-large-cased-whole-word-masking-finetuned-squad': {'do_lower_case': False}, "bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False},
'bert-base-cased-finetuned-mrpc': {'do_lower_case': False}, "bert-base-cased-finetuned-mrpc": {"do_lower_case": False},
'bert-base-german-dbmdz-cased': {'do_lower_case': False}, "bert-base-german-dbmdz-cased": {"do_lower_case": False},
'bert-base-german-dbmdz-uncased': {'do_lower_case': True}, "bert-base-german-dbmdz-uncased": {"do_lower_case": True},
'bert-base-finnish-cased-v1': {'do_lower_case': False}, "bert-base-finnish-cased-v1": {"do_lower_case": False},
'bert-base-finnish-uncased-v1': {'do_lower_case': True}, "bert-base-finnish-uncased-v1": {"do_lower_case": True},
} }
...@@ -98,7 +97,7 @@ def load_vocab(vocab_file): ...@@ -98,7 +97,7 @@ def load_vocab(vocab_file):
with open(vocab_file, "r", encoding="utf-8") as reader: with open(vocab_file, "r", encoding="utf-8") as reader:
tokens = reader.readlines() tokens = reader.readlines()
for index, token in enumerate(tokens): for index, token in enumerate(tokens):
token = token.rstrip('\n') token = token.rstrip("\n")
vocab[token] = index vocab[token] = index
return vocab return vocab
...@@ -132,9 +131,20 @@ class BertTokenizer(PreTrainedTokenizer): ...@@ -132,9 +131,20 @@ class BertTokenizer(PreTrainedTokenizer):
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None, def __init__(
unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]", self,
mask_token="[MASK]", tokenize_chinese_chars=True, **kwargs): vocab_file,
do_lower_case=True,
do_basic_tokenize=True,
never_split=None,
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]",
tokenize_chinese_chars=True,
**kwargs
):
"""Constructs a BertTokenizer. """Constructs a BertTokenizer.
Args: Args:
...@@ -152,24 +162,29 @@ class BertTokenizer(PreTrainedTokenizer): ...@@ -152,24 +162,29 @@ class BertTokenizer(PreTrainedTokenizer):
This should likely be deactivated for Japanese: This should likely be deactivated for Japanese:
see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328 see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
""" """
super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token, super(BertTokenizer, self).__init__(
pad_token=pad_token, cls_token=cls_token, unk_token=unk_token,
mask_token=mask_token, **kwargs) sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
**kwargs
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
if not os.path.isfile(vocab_file): if not os.path.isfile(vocab_file):
raise ValueError( raise ValueError(
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)
)
self.vocab = load_vocab(vocab_file) self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict( self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
[(ids, tok) for tok, ids in self.vocab.items()])
self.do_basic_tokenize = do_basic_tokenize self.do_basic_tokenize = do_basic_tokenize
if do_basic_tokenize: if do_basic_tokenize:
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, self.basic_tokenizer = BasicTokenizer(
never_split=never_split, do_lower_case=do_lower_case, never_split=never_split, tokenize_chinese_chars=tokenize_chinese_chars
tokenize_chinese_chars=tokenize_chinese_chars) )
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token) self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
@property @property
...@@ -196,7 +211,7 @@ class BertTokenizer(PreTrainedTokenizer): ...@@ -196,7 +211,7 @@ class BertTokenizer(PreTrainedTokenizer):
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """ """ Converts a sequence of tokens (string) in a single string. """
out_string = ' '.join(tokens).replace(' ##', '').strip() out_string = " ".join(tokens).replace(" ##", "").strip()
return out_string return out_string
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
...@@ -231,8 +246,10 @@ class BertTokenizer(PreTrainedTokenizer): ...@@ -231,8 +246,10 @@ class BertTokenizer(PreTrainedTokenizer):
if already_has_special_tokens: if already_has_special_tokens:
if token_ids_1 is not None: if token_ids_1 is not None:
raise ValueError("You should not supply a second sequence if the provided sequence of " raise ValueError(
"ids is already formated with special tokens for the model.") "You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
if token_ids_1 is not None: if token_ids_1 is not None:
...@@ -258,16 +275,18 @@ class BertTokenizer(PreTrainedTokenizer): ...@@ -258,16 +275,18 @@ class BertTokenizer(PreTrainedTokenizer):
"""Save the tokenizer vocabulary to a directory or file.""" """Save the tokenizer vocabulary to a directory or file."""
index = 0 index = 0
if os.path.isdir(vocab_path): if os.path.isdir(vocab_path):
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file']) vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
else: else:
vocab_file = vocab_path vocab_file = vocab_path
with open(vocab_file, "w", encoding="utf-8") as writer: with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
if index != token_index: if index != token_index:
logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive." logger.warning(
" Please check that the vocabulary is not corrupted!".format(vocab_file)) "Saving vocabulary to {}: vocabulary indices are not consecutive."
" Please check that the vocabulary is not corrupted!".format(vocab_file)
)
index = token_index index = token_index
writer.write(token + u'\n') writer.write(token + "\n")
index += 1 index += 1
return (vocab_file,) return (vocab_file,)
...@@ -382,14 +401,16 @@ class BasicTokenizer(object): ...@@ -382,14 +401,16 @@ class BasicTokenizer(object):
# as is Japanese Hiragana and Katakana. Those alphabets are used to write # as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled # space-separated words, so they are not treated specially and handled
# like the all of the other languages. # like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or # if (
(cp >= 0x3400 and cp <= 0x4DBF) or # (cp >= 0x4E00 and cp <= 0x9FFF)
(cp >= 0x20000 and cp <= 0x2A6DF) or # or (cp >= 0x3400 and cp <= 0x4DBF) #
(cp >= 0x2A700 and cp <= 0x2B73F) or # or (cp >= 0x20000 and cp <= 0x2A6DF) #
(cp >= 0x2B740 and cp <= 0x2B81F) or # or (cp >= 0x2A700 and cp <= 0x2B73F) #
(cp >= 0x2B820 and cp <= 0x2CEAF) or or (cp >= 0x2B740 and cp <= 0x2B81F) #
(cp >= 0xF900 and cp <= 0xFAFF) or # or (cp >= 0x2B820 and cp <= 0x2CEAF) #
(cp >= 0x2F800 and cp <= 0x2FA1F)): # or (cp >= 0xF900 and cp <= 0xFAFF)
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
): #
return True return True
return False return False
...@@ -399,7 +420,7 @@ class BasicTokenizer(object): ...@@ -399,7 +420,7 @@ class BasicTokenizer(object):
output = [] output = []
for char in text: for char in text:
cp = ord(char) cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char): if cp == 0 or cp == 0xFFFD or _is_control(char):
continue continue
if _is_whitespace(char): if _is_whitespace(char):
output.append(" ") output.append(" ")
...@@ -499,8 +520,7 @@ def _is_punctuation(char): ...@@ -499,8 +520,7 @@ def _is_punctuation(char):
# Characters such as "^", "$", and "`" are not in the Unicode # Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for # Punctuation class but we treat them as punctuation anyways, for
# consistency. # consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True return True
cat = unicodedata.category(char) cat = unicodedata.category(char)
if cat.startswith("P"): if cat.startswith("P"):
......
...@@ -28,46 +28,45 @@ from .tokenization_utils import PreTrainedTokenizer ...@@ -28,46 +28,45 @@ from .tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'} VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file': "vocab_file": {
{ "bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-vocab.txt",
'bert-base-japanese': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-vocab.txt", "bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-vocab.txt",
'bert-base-japanese-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-vocab.txt", "bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-vocab.txt",
'bert-base-japanese-char': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-vocab.txt", "bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-vocab.txt",
'bert-base-japanese-char-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-vocab.txt"
} }
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'bert-base-japanese': 512, "bert-base-japanese": 512,
'bert-base-japanese-whole-word-masking': 512, "bert-base-japanese-whole-word-masking": 512,
'bert-base-japanese-char': 512, "bert-base-japanese-char": 512,
'bert-base-japanese-char-whole-word-masking': 512 "bert-base-japanese-char-whole-word-masking": 512,
} }
PRETRAINED_INIT_CONFIGURATION = { PRETRAINED_INIT_CONFIGURATION = {
'bert-base-japanese': { "bert-base-japanese": {
'do_lower_case': False, "do_lower_case": False,
'word_tokenizer_type': 'mecab', "word_tokenizer_type": "mecab",
'subword_tokenizer_type': 'wordpiece' "subword_tokenizer_type": "wordpiece",
}, },
'bert-base-japanese-whole-word-masking':{ "bert-base-japanese-whole-word-masking": {
'do_lower_case': False, "do_lower_case": False,
'word_tokenizer_type': 'mecab', "word_tokenizer_type": "mecab",
'subword_tokenizer_type': 'wordpiece' "subword_tokenizer_type": "wordpiece",
}, },
'bert-base-japanese-char': { "bert-base-japanese-char": {
'do_lower_case': False, "do_lower_case": False,
'word_tokenizer_type': 'mecab', "word_tokenizer_type": "mecab",
'subword_tokenizer_type': 'character' "subword_tokenizer_type": "character",
},
"bert-base-japanese-char-whole-word-masking": {
"do_lower_case": False,
"word_tokenizer_type": "mecab",
"subword_tokenizer_type": "character",
}, },
'bert-base-japanese-char-whole-word-masking': {
'do_lower_case': False,
'word_tokenizer_type': 'mecab',
'subword_tokenizer_type': 'character'
}
} }
...@@ -79,11 +78,22 @@ class BertJapaneseTokenizer(BertTokenizer): ...@@ -79,11 +78,22 @@ class BertJapaneseTokenizer(BertTokenizer):
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, do_lower_case=False, def __init__(
do_word_tokenize=True, do_subword_tokenize=True, self,
word_tokenizer_type='basic', subword_tokenizer_type='wordpiece', vocab_file,
never_split=None, unk_token='[UNK]', sep_token='[SEP]', do_lower_case=False,
pad_token='[PAD]', cls_token='[CLS]', mask_token='[MASK]', **kwargs): do_word_tokenize=True,
do_subword_tokenize=True,
word_tokenizer_type="basic",
subword_tokenizer_type="wordpiece",
never_split=None,
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]",
**kwargs
):
"""Constructs a MecabBertTokenizer. """Constructs a MecabBertTokenizer.
Args: Args:
...@@ -100,56 +110,53 @@ class BertJapaneseTokenizer(BertTokenizer): ...@@ -100,56 +110,53 @@ class BertJapaneseTokenizer(BertTokenizer):
**subword_tokenizer_type**: (`optional`) string (default "wordpiece") **subword_tokenizer_type**: (`optional`) string (default "wordpiece")
Type of subword tokenizer. Type of subword tokenizer.
""" """
super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token, super(BertTokenizer, self).__init__(
pad_token=pad_token, cls_token=cls_token, unk_token=unk_token,
mask_token=mask_token, **kwargs) sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
**kwargs
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
if not os.path.isfile(vocab_file): if not os.path.isfile(vocab_file):
raise ValueError( raise ValueError(
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)
)
self.vocab = load_vocab(vocab_file) self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict( self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
[(ids, tok) for tok, ids in self.vocab.items()])
self.do_word_tokenize = do_word_tokenize self.do_word_tokenize = do_word_tokenize
if do_word_tokenize: if do_word_tokenize:
if word_tokenizer_type == 'basic': if word_tokenizer_type == "basic":
self.word_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, self.word_tokenizer = BasicTokenizer(
never_split=never_split, do_lower_case=do_lower_case, never_split=never_split, tokenize_chinese_chars=False
tokenize_chinese_chars=False) )
elif word_tokenizer_type == 'mecab': elif word_tokenizer_type == "mecab":
self.word_tokenizer = MecabTokenizer(do_lower_case=do_lower_case, self.word_tokenizer = MecabTokenizer(do_lower_case=do_lower_case, never_split=never_split)
never_split=never_split)
else: else:
raise ValueError( raise ValueError("Invalid word_tokenizer_type '{}' is specified.".format(word_tokenizer_type))
"Invalid word_tokenizer_type '{}' is specified.".format(word_tokenizer_type))
self.do_subword_tokenize = do_subword_tokenize self.do_subword_tokenize = do_subword_tokenize
if do_subword_tokenize: if do_subword_tokenize:
if subword_tokenizer_type == 'wordpiece': if subword_tokenizer_type == "wordpiece":
self.subword_tokenizer = WordpieceTokenizer(vocab=self.vocab, self.subword_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
unk_token=self.unk_token) elif subword_tokenizer_type == "character":
elif subword_tokenizer_type == 'character': self.subword_tokenizer = CharacterTokenizer(vocab=self.vocab, unk_token=self.unk_token)
self.subword_tokenizer = CharacterTokenizer(vocab=self.vocab,
unk_token=self.unk_token)
else: else:
raise ValueError( raise ValueError("Invalid subword_tokenizer_type '{}' is specified.".format(subword_tokenizer_type))
"Invalid subword_tokenizer_type '{}' is specified.".format(subword_tokenizer_type))
def _tokenize(self, text): def _tokenize(self, text):
if self.do_word_tokenize: if self.do_word_tokenize:
tokens = self.word_tokenizer.tokenize(text, tokens = self.word_tokenizer.tokenize(text, never_split=self.all_special_tokens)
never_split=self.all_special_tokens)
else: else:
tokens = [text] tokens = [text]
if self.do_subword_tokenize: if self.do_subword_tokenize:
split_tokens = [sub_token for token in tokens split_tokens = [sub_token for token in tokens for sub_token in self.subword_tokenizer.tokenize(token)]
for sub_token in self.subword_tokenizer.tokenize(token)]
else: else:
split_tokens = tokens split_tokens = tokens
...@@ -177,27 +184,28 @@ class MecabTokenizer(object): ...@@ -177,27 +184,28 @@ class MecabTokenizer(object):
self.normalize_text = normalize_text self.normalize_text = normalize_text
import MeCab import MeCab
self.mecab = MeCab.Tagger() self.mecab = MeCab.Tagger()
def tokenize(self, text, never_split=None, **kwargs): def tokenize(self, text, never_split=None, **kwargs):
"""Tokenizes a piece of text.""" """Tokenizes a piece of text."""
if self.normalize_text: if self.normalize_text:
text = unicodedata.normalize('NFKC', text) text = unicodedata.normalize("NFKC", text)
never_split = self.never_split + (never_split if never_split is not None else []) never_split = self.never_split + (never_split if never_split is not None else [])
tokens = [] tokens = []
if six.PY2: if six.PY2:
mecab_output = self.mecab.parse(text.encode('utf-8')).decode('utf-8') mecab_output = self.mecab.parse(text.encode("utf-8")).decode("utf-8")
else: else:
mecab_output = self.mecab.parse(text) mecab_output = self.mecab.parse(text)
cursor = 0 cursor = 0
for line in mecab_output.split('\n'): for line in mecab_output.split("\n"):
if line == 'EOS': if line == "EOS":
break break
token, _ = line.split('\t') token, _ = line.split("\t")
token_start = text.index(token, cursor) token_start = text.index(token, cursor)
token_end = token_start + len(token) token_end = token_start + len(token)
if self.do_lower_case and token not in never_split: if self.do_lower_case and token not in never_split:
...@@ -240,7 +248,7 @@ class CharacterTokenizer(object): ...@@ -240,7 +248,7 @@ class CharacterTokenizer(object):
A list of characters. A list of characters.
""" """
if self.normalize_text: if self.normalize_text:
text = unicodedata.normalize('NFKC', text) text = unicodedata.normalize("NFKC", text)
output_tokens = [] output_tokens = []
for i, char in enumerate(text): for i, char in enumerate(text):
......
...@@ -13,8 +13,7 @@ ...@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
""" Tokenization classes for Camembert model.""" """ Tokenization classes for Camembert model."""
from __future__ import (absolute_import, division, print_function, from __future__ import absolute_import, division, print_function, unicode_literals
unicode_literals)
import logging import logging
import os import os
...@@ -26,19 +25,19 @@ from .tokenization_xlnet import SPIECE_UNDERLINE ...@@ -26,19 +25,19 @@ from .tokenization_xlnet import SPIECE_UNDERLINE
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {'vocab_file': 'sentencepiece.bpe.model'} VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"}
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file': "vocab_file": {
{ "camembert-base": "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-sentencepiece.bpe.model",
'camembert-base': "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-sentencepiece.bpe.model",
} }
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'camembert-base': None, "camembert-base": None,
} }
class CamembertTokenizer(PreTrainedTokenizer): class CamembertTokenizer(PreTrainedTokenizer):
""" """
Adapted from RobertaTokenizer and XLNetTokenizer Adapted from RobertaTokenizer and XLNetTokenizer
...@@ -46,17 +45,36 @@ class CamembertTokenizer(PreTrainedTokenizer): ...@@ -46,17 +45,36 @@ class CamembertTokenizer(PreTrainedTokenizer):
- requires `SentencePiece <https://github.com/google/sentencepiece>`_ - requires `SentencePiece <https://github.com/google/sentencepiece>`_
""" """
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, bos_token="<s>", eos_token="</s>", sep_token="</s>", def __init__(
cls_token="<s>", unk_token="<unk>", pad_token='<pad>', mask_token='<mask>', self,
additional_special_tokens=['<s>NOTUSED', '</s>NOTUSED'], **kwargs): vocab_file,
super(CamembertTokenizer, self).__init__(max_len=512, bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, bos_token="<s>",
sep_token=sep_token, cls_token=cls_token, pad_token=pad_token, eos_token="</s>",
mask_token=mask_token, additional_special_tokens=additional_special_tokens, sep_token="</s>",
**kwargs) cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
additional_special_tokens=["<s>NOTUSED", "</s>NOTUSED"],
**kwargs
):
super(CamembertTokenizer, self).__init__(
max_len=512,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
additional_special_tokens=additional_special_tokens,
**kwargs
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens
self.sp_model = spm.SentencePieceProcessor() self.sp_model = spm.SentencePieceProcessor()
...@@ -64,9 +82,9 @@ class CamembertTokenizer(PreTrainedTokenizer): ...@@ -64,9 +82,9 @@ class CamembertTokenizer(PreTrainedTokenizer):
self.vocab_file = vocab_file self.vocab_file = vocab_file
# HACK: These tokens were added by fairseq but don't seem to be actually used when duplicated in the actual # HACK: These tokens were added by fairseq but don't seem to be actually used when duplicated in the actual
# sentencepiece vocabulary (this is the case for <s> and </s> # sentencepiece vocabulary (this is the case for <s> and </s>
self.fairseq_tokens_to_ids = {'<s>NOTUSED': 0, '<pad>': 1, '</s>NOTUSED': 2, '<unk>': 3} self.fairseq_tokens_to_ids = {"<s>NOTUSED": 0, "<pad>": 1, "</s>NOTUSED": 2, "<unk>": 3}
self.fairseq_offset = len(self.fairseq_tokens_to_ids) self.fairseq_offset = len(self.fairseq_tokens_to_ids)
self.fairseq_tokens_to_ids['<mask>'] = len(self.sp_model) + len(self.fairseq_tokens_to_ids) self.fairseq_tokens_to_ids["<mask>"] = len(self.sp_model) + len(self.fairseq_tokens_to_ids)
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()} self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
...@@ -100,8 +118,10 @@ class CamembertTokenizer(PreTrainedTokenizer): ...@@ -100,8 +118,10 @@ class CamembertTokenizer(PreTrainedTokenizer):
""" """
if already_has_special_tokens: if already_has_special_tokens:
if token_ids_1 is not None: if token_ids_1 is not None:
raise ValueError("You should not supply a second sequence if the provided sequence of " raise ValueError(
"ids is already formated with special tokens for the model.") "You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
if token_ids_1 is None: if token_ids_1 is None:
...@@ -148,7 +168,7 @@ class CamembertTokenizer(PreTrainedTokenizer): ...@@ -148,7 +168,7 @@ class CamembertTokenizer(PreTrainedTokenizer):
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string.""" """Converts a sequence of tokens (strings for sub-words) in a single string."""
out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip() out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
return out_string return out_string
def save_vocabulary(self, save_directory): def save_vocabulary(self, save_directory):
...@@ -158,7 +178,7 @@ class CamembertTokenizer(PreTrainedTokenizer): ...@@ -158,7 +178,7 @@ class CamembertTokenizer(PreTrainedTokenizer):
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return return
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
......
...@@ -13,8 +13,7 @@ ...@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tokenization classes for Salesforce CTRL.""" """Tokenization classes for Salesforce CTRL."""
from __future__ import (absolute_import, division, print_function, from __future__ import absolute_import, division, print_function, unicode_literals
unicode_literals)
import json import json
import logging import logging
...@@ -27,23 +26,17 @@ from .tokenization_utils import PreTrainedTokenizer ...@@ -27,23 +26,17 @@ from .tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = { VOCAB_FILES_NAMES = {
'vocab_file': 'vocab.json', "vocab_file": "vocab.json",
'merges_file': 'merges.txt', "merges_file": "merges.txt",
} }
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file': "vocab_file": {"ctrl": "https://raw.githubusercontent.com/salesforce/ctrl/master/ctrl-vocab.json",},
{ "merges_file": {"ctrl": "https://raw.githubusercontent.com/salesforce/ctrl/master/ctrl-merges.txt",},
'ctrl': "https://raw.githubusercontent.com/salesforce/ctrl/master/ctrl-vocab.json",
},
'merges_file':
{
'ctrl': "https://raw.githubusercontent.com/salesforce/ctrl/master/ctrl-merges.txt",
},
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'ctrl': 256, "ctrl": 256,
} }
CONTROL_CODES = { CONTROL_CODES = {
...@@ -104,6 +97,7 @@ CONTROL_CODES = { ...@@ -104,6 +97,7 @@ CONTROL_CODES = {
"multilingual": 128406, "multilingual": 128406,
} }
def get_pairs(word): def get_pairs(word):
"""Return set of symbol pairs in a word. """Return set of symbol pairs in a word.
...@@ -118,11 +112,13 @@ def get_pairs(word): ...@@ -118,11 +112,13 @@ def get_pairs(word):
pairs = set(pairs) pairs = set(pairs)
return pairs return pairs
class CTRLTokenizer(PreTrainedTokenizer): class CTRLTokenizer(PreTrainedTokenizer):
""" """
CTRL BPE tokenizer. Peculiarities: CTRL BPE tokenizer. Peculiarities:
- Byte-Pair-Encoding - Byte-Pair-Encoding
""" """
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
...@@ -130,14 +126,18 @@ class CTRLTokenizer(PreTrainedTokenizer): ...@@ -130,14 +126,18 @@ class CTRLTokenizer(PreTrainedTokenizer):
def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs): def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs):
super(CTRLTokenizer, self).__init__(unk_token=unk_token, **kwargs) super(CTRLTokenizer, self).__init__(unk_token=unk_token, **kwargs)
self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens self.max_len_single_sentence = (
self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens self.max_len
) # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = (
self.max_len
) # no default special tokens - you can update this value if you add special tokens
with open(vocab_file, encoding="utf-8") as vocab_handle: with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle) self.encoder = json.load(vocab_handle)
self.decoder = {v:k for k,v in self.encoder.items()} self.decoder = {v: k for k, v in self.encoder.items()}
with open(merges_file, encoding='utf-8') as merges_handle: with open(merges_file, encoding="utf-8") as merges_handle:
merges = merges_handle.read().split('\n')[1:-1] merges = merges_handle.read().split("\n")[1:-1]
merges = [tuple(merge.split()) for merge in merges] merges = [tuple(merge.split()) for merge in merges]
self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {} self.cache = {}
...@@ -150,14 +150,14 @@ class CTRLTokenizer(PreTrainedTokenizer): ...@@ -150,14 +150,14 @@ class CTRLTokenizer(PreTrainedTokenizer):
if token in self.cache: if token in self.cache:
return self.cache[token] return self.cache[token]
word = tuple(token) word = tuple(token)
word = tuple(list(word[:-1]) + [word[-1]+'</w>']) word = tuple(list(word[:-1]) + [word[-1] + "</w>"])
pairs = get_pairs(word) pairs = get_pairs(word)
if not pairs: if not pairs:
return token return token
while True: while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks: if bigram not in self.bpe_ranks:
break break
first, second = bigram first, second = bigram
...@@ -172,8 +172,8 @@ class CTRLTokenizer(PreTrainedTokenizer): ...@@ -172,8 +172,8 @@ class CTRLTokenizer(PreTrainedTokenizer):
new_word.extend(word[i:]) new_word.extend(word[i:])
break break
if word[i] == first and i < len(word)-1 and word[i+1] == second: if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first+second) new_word.append(first + second)
i += 2 i += 2
else: else:
new_word.append(word[i]) new_word.append(word[i])
...@@ -184,7 +184,7 @@ class CTRLTokenizer(PreTrainedTokenizer): ...@@ -184,7 +184,7 @@ class CTRLTokenizer(PreTrainedTokenizer):
break break
else: else:
pairs = get_pairs(word) pairs = get_pairs(word)
word = '@@ '.join(word) word = "@@ ".join(word)
word = word[:-4] word = word[:-4]
self.cache[token] = word self.cache[token] = word
return word return word
...@@ -194,10 +194,10 @@ class CTRLTokenizer(PreTrainedTokenizer): ...@@ -194,10 +194,10 @@ class CTRLTokenizer(PreTrainedTokenizer):
""" """
split_tokens = [] split_tokens = []
words = re.findall(r'\S+\n?', text) words = re.findall(r"\S+\n?", text)
for token in words: for token in words:
split_tokens.extend([t for t in self.bpe(token).split(' ')]) split_tokens.extend([t for t in self.bpe(token).split(" ")])
return split_tokens return split_tokens
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
...@@ -210,7 +210,7 @@ class CTRLTokenizer(PreTrainedTokenizer): ...@@ -210,7 +210,7 @@ class CTRLTokenizer(PreTrainedTokenizer):
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """ """ Converts a sequence of tokens (string) in a single string. """
out_string = ' '.join(tokens).replace('@@ ', '').strip() out_string = " ".join(tokens).replace("@@ ", "").strip()
return out_string return out_string
def save_vocabulary(self, save_directory): def save_vocabulary(self, save_directory):
...@@ -218,21 +218,23 @@ class CTRLTokenizer(PreTrainedTokenizer): ...@@ -218,21 +218,23 @@ class CTRLTokenizer(PreTrainedTokenizer):
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return return
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file']) merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES["merges_file"])
with open(vocab_file, 'w', encoding='utf-8') as f: with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False)) f.write(json.dumps(self.encoder, ensure_ascii=False))
index = 0 index = 0
with open(merge_file, "w", encoding="utf-8") as writer: with open(merge_file, "w", encoding="utf-8") as writer:
writer.write(u'#version: 0.2\n') writer.write("#version: 0.2\n")
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index: if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." logger.warning(
" Please check that the tokenizer is not corrupted!".format(merge_file)) "Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file)
)
index = token_index index = token_index
writer.write(' '.join(bpe_tokens) + u'\n') writer.write(" ".join(bpe_tokens) + "\n")
index += 1 index += 1
return vocab_file, merge_file return vocab_file, merge_file
......
...@@ -26,23 +26,22 @@ from .tokenization_bert import BertTokenizer ...@@ -26,23 +26,22 @@ from .tokenization_bert import BertTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'} VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file': "vocab_file": {
{ "distilbert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", "distilbert-base-uncased-distilled-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", "distilbert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-vocab.txt",
'distilbert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-vocab.txt", "distilbert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
'distilbert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
} }
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'distilbert-base-uncased': 512, "distilbert-base-uncased": 512,
'distilbert-base-uncased-distilled-squad': 512, "distilbert-base-uncased-distilled-squad": 512,
'distilbert-base-german-cased': 512, "distilbert-base-german-cased": 512,
'distilbert-base-multilingual-cased': 512, "distilbert-base-multilingual-cased": 512,
} }
......
...@@ -13,8 +13,7 @@ ...@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tokenization classes for OpenAI GPT.""" """Tokenization classes for OpenAI GPT."""
from __future__ import (absolute_import, division, print_function, from __future__ import absolute_import, division, print_function, unicode_literals
unicode_literals)
import sys import sys
import json import json
...@@ -31,42 +30,42 @@ except ImportError: ...@@ -31,42 +30,42 @@ except ImportError:
def lru_cache(): def lru_cache():
return lambda func: func return lambda func: func
from .tokenization_utils import PreTrainedTokenizer from .tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = { VOCAB_FILES_NAMES = {
'vocab_file': 'vocab.json', "vocab_file": "vocab.json",
'merges_file': 'merges.txt', "merges_file": "merges.txt",
} }
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file': "vocab_file": {
{ "gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json", "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-vocab.json",
'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-vocab.json", "gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-vocab.json",
'gpt2-xl': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-vocab.json", "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-vocab.json",
'distilgpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-vocab.json",
}, },
'merges_file': "merges_file": {
{ "gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt", "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-merges.txt",
'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-merges.txt", "gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-merges.txt",
'gpt2-xl': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-merges.txt", "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-merges.txt",
'distilgpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-merges.txt",
}, },
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'gpt2': 1024, "gpt2": 1024,
'gpt2-medium': 1024, "gpt2-medium": 1024,
'gpt2-large': 1024, "gpt2-large": 1024,
'gpt2-xl': 1024, "gpt2-xl": 1024,
'distilgpt2': 1024, "distilgpt2": 1024,
} }
@lru_cache() @lru_cache()
def bytes_to_unicode(): def bytes_to_unicode():
""" """
...@@ -80,17 +79,20 @@ def bytes_to_unicode(): ...@@ -80,17 +79,20 @@ def bytes_to_unicode():
To avoid that, we want lookup tables between utf-8 bytes and unicode strings. To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
""" """
_chr = unichr if sys.version_info[0] == 2 else chr _chr = unichr if sys.version_info[0] == 2 else chr
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) bs = (
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:] cs = bs[:]
n = 0 n = 0
for b in range(2**8): for b in range(2 ** 8):
if b not in bs: if b not in bs:
bs.append(b) bs.append(b)
cs.append(2**8+n) cs.append(2 ** 8 + n)
n += 1 n += 1
cs = [_chr(n) for n in cs] cs = [_chr(n) for n in cs]
return dict(zip(bs, cs)) return dict(zip(bs, cs))
def get_pairs(word): def get_pairs(word):
"""Return set of symbol pairs in a word. """Return set of symbol pairs in a word.
...@@ -103,6 +105,7 @@ def get_pairs(word): ...@@ -103,6 +105,7 @@ def get_pairs(word):
prev_char = char prev_char = char
return pairs return pairs
class GPT2Tokenizer(PreTrainedTokenizer): class GPT2Tokenizer(PreTrainedTokenizer):
""" """
GPT-2 BPE tokenizer. Peculiarities: GPT-2 BPE tokenizer. Peculiarities:
...@@ -112,15 +115,28 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -112,15 +115,28 @@ class GPT2Tokenizer(PreTrainedTokenizer):
Otherwise, this tokenizer's ``encode``, ``decode``, and ``tokenize`` methods will not conserve Otherwise, this tokenizer's ``encode``, ``decode``, and ``tokenize`` methods will not conserve
the spaces at the beginning of a string: `tokenizer.decode(tokenizer.encode(" Hello")) = "Hello"` the spaces at the beginning of a string: `tokenizer.decode(tokenizer.encode(" Hello")) = "Hello"`
""" """
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, merges_file, errors='replace', unk_token="<|endoftext|>", def __init__(
bos_token="<|endoftext|>", eos_token="<|endoftext|>", **kwargs): self,
vocab_file,
merges_file,
errors="replace",
unk_token="<|endoftext|>",
bos_token="<|endoftext|>",
eos_token="<|endoftext|>",
**kwargs
):
super(GPT2Tokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs) super(GPT2Tokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens self.max_len_single_sentence = (
self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens self.max_len
) # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = (
self.max_len
) # no default special tokens - you can update this value if you add special tokens
with open(vocab_file, encoding="utf-8") as vocab_handle: with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle) self.encoder = json.load(vocab_handle)
...@@ -128,8 +144,8 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -128,8 +144,8 @@ class GPT2Tokenizer(PreTrainedTokenizer):
self.errors = errors # how to handle errors in decoding self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode() self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
with open(merges_file, encoding='utf-8') as merges_handle: with open(merges_file, encoding="utf-8") as merges_handle:
bpe_merges = merges_handle.read().split('\n')[1:-1] bpe_merges = merges_handle.read().split("\n")[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_merges] bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {} self.cache = {}
...@@ -151,7 +167,7 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -151,7 +167,7 @@ class GPT2Tokenizer(PreTrainedTokenizer):
return token return token
while True: while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks: if bigram not in self.bpe_ranks:
break break
first, second = bigram first, second = bigram
...@@ -166,8 +182,8 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -166,8 +182,8 @@ class GPT2Tokenizer(PreTrainedTokenizer):
new_word.extend(word[i:]) new_word.extend(word[i:])
break break
if word[i] == first and i < len(word)-1 and word[i+1] == second: if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first+second) new_word.append(first + second)
i += 2 i += 2
else: else:
new_word.append(word[i]) new_word.append(word[i])
...@@ -178,7 +194,7 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -178,7 +194,7 @@ class GPT2Tokenizer(PreTrainedTokenizer):
break break
else: else:
pairs = get_pairs(word) pairs = get_pairs(word)
word = ' '.join(word) word = " ".join(word)
self.cache[token] = word self.cache[token] = word
return word return word
...@@ -189,15 +205,19 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -189,15 +205,19 @@ class GPT2Tokenizer(PreTrainedTokenizer):
Begin the sentence with at least one space to get invariance to word order in GPT-2 (and RoBERTa) tokenizers. Begin the sentence with at least one space to get invariance to word order in GPT-2 (and RoBERTa) tokenizers.
""" """
if add_prefix_space: if add_prefix_space:
text = ' ' + text text = " " + text
bpe_tokens = [] bpe_tokens = []
for token in re.findall(self.pat, text): for token in re.findall(self.pat, text):
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
token = ''.join(self.byte_encoder[ord(b)] for b in token) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case) token = "".join(
self.byte_encoder[ord(b)] for b in token
) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
else: else:
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case) token = "".join(
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) self.byte_encoder[b] for b in token.encode("utf-8")
) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
return bpe_tokens return bpe_tokens
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
...@@ -210,8 +230,8 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -210,8 +230,8 @@ class GPT2Tokenizer(PreTrainedTokenizer):
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """ """ Converts a sequence of tokens (string) in a single string. """
text = ''.join(tokens) text = "".join(tokens)
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
return text return text
def save_vocabulary(self, save_directory): def save_vocabulary(self, save_directory):
...@@ -219,21 +239,23 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -219,21 +239,23 @@ class GPT2Tokenizer(PreTrainedTokenizer):
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return return
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file']) merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES["merges_file"])
with open(vocab_file, 'w', encoding='utf-8') as f: with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False)) f.write(json.dumps(self.encoder, ensure_ascii=False))
index = 0 index = 0
with open(merge_file, "w", encoding="utf-8") as writer: with open(merge_file, "w", encoding="utf-8") as writer:
writer.write(u'#version: 0.2\n') writer.write("#version: 0.2\n")
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index: if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." logger.warning(
" Please check that the tokenizer is not corrupted!".format(merge_file)) "Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file)
)
index = token_index index = token_index
writer.write(' '.join(bpe_tokens) + u'\n') writer.write(" ".join(bpe_tokens) + "\n")
index += 1 index += 1
return vocab_file, merge_file return vocab_file, merge_file
...@@ -13,8 +13,7 @@ ...@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tokenization classes for OpenAI GPT.""" """Tokenization classes for OpenAI GPT."""
from __future__ import (absolute_import, division, print_function, from __future__ import absolute_import, division, print_function, unicode_literals
unicode_literals)
import json import json
import logging import logging
...@@ -28,25 +27,20 @@ from .tokenization_bert import BasicTokenizer ...@@ -28,25 +27,20 @@ from .tokenization_bert import BasicTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = { VOCAB_FILES_NAMES = {
'vocab_file': 'vocab.json', "vocab_file": "vocab.json",
'merges_file': 'merges.txt', "merges_file": "merges.txt",
} }
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file': "vocab_file": {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json",},
{ "merges_file": {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt",},
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json",
},
'merges_file':
{
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt",
},
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'openai-gpt': 512, "openai-gpt": 512,
} }
def get_pairs(word): def get_pairs(word):
""" """
Return set of symbol pairs in a word. Return set of symbol pairs in a word.
...@@ -59,27 +53,30 @@ def get_pairs(word): ...@@ -59,27 +53,30 @@ def get_pairs(word):
prev_char = char prev_char = char
return pairs return pairs
def text_standardize(text): def text_standardize(text):
""" """
fixes some issues the spacy tokenizer had on books corpus fixes some issues the spacy tokenizer had on books corpus
also does some whitespace standardization also does some whitespace standardization
""" """
text = text.replace('—', '-') text = text.replace("—", "-")
text = text.replace('–', '-') text = text.replace("–", "-")
text = text.replace('―', '-') text = text.replace("―", "-")
text = text.replace('…', '...') text = text.replace("…", "...")
text = text.replace('´', "'") text = text.replace("´", "'")
text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text) text = re.sub(r"""(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)""", r" \1 ", text)
text = re.sub(r'\s*\n\s*', ' \n ', text) text = re.sub(r"\s*\n\s*", " \n ", text)
text = re.sub(r'[^\S\n]+', ' ', text) text = re.sub(r"[^\S\n]+", " ", text)
return text.strip() return text.strip()
class OpenAIGPTTokenizer(PreTrainedTokenizer): class OpenAIGPTTokenizer(PreTrainedTokenizer):
""" """
BPE tokenizer. Peculiarities: BPE tokenizer. Peculiarities:
- lower case all inputs - lower case all inputs
- uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not. - uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not.
""" """
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
...@@ -87,12 +84,17 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): ...@@ -87,12 +84,17 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs): def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs):
super(OpenAIGPTTokenizer, self).__init__(unk_token=unk_token, **kwargs) super(OpenAIGPTTokenizer, self).__init__(unk_token=unk_token, **kwargs)
self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens self.max_len_single_sentence = (
self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens self.max_len
) # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = (
self.max_len
) # no default special tokens - you can update this value if you add special tokens
try: try:
import ftfy import ftfy
from spacy.lang.en import English from spacy.lang.en import English
_nlp = English() _nlp = English()
self.nlp = _nlp.Defaults.create_tokenizer(_nlp) self.nlp = _nlp.Defaults.create_tokenizer(_nlp)
self.fix_text = ftfy.fix_text self.fix_text = ftfy.fix_text
...@@ -103,9 +105,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): ...@@ -103,9 +105,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
with open(vocab_file, encoding="utf-8") as vocab_handle: with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle) self.encoder = json.load(vocab_handle)
self.decoder = {v:k for k,v in self.encoder.items()} self.decoder = {v: k for k, v in self.encoder.items()}
with open(merges_file, encoding='utf-8') as merges_handle: with open(merges_file, encoding="utf-8") as merges_handle:
merges = merges_handle.read().split('\n')[1:-1] merges = merges_handle.read().split("\n")[1:-1]
merges = [tuple(merge.split()) for merge in merges] merges = [tuple(merge.split()) for merge in merges]
self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {} self.cache = {}
...@@ -115,16 +117,16 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): ...@@ -115,16 +117,16 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
return len(self.encoder) return len(self.encoder)
def bpe(self, token): def bpe(self, token):
word = tuple(token[:-1]) + (token[-1] + '</w>',) word = tuple(token[:-1]) + (token[-1] + "</w>",)
if token in self.cache: if token in self.cache:
return self.cache[token] return self.cache[token]
pairs = get_pairs(word) pairs = get_pairs(word)
if not pairs: if not pairs:
return token+'</w>' return token + "</w>"
while True: while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks: if bigram not in self.bpe_ranks:
break break
first, second = bigram first, second = bigram
...@@ -139,8 +141,8 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): ...@@ -139,8 +141,8 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
new_word.extend(word[i:]) new_word.extend(word[i:])
break break
if word[i] == first and i < len(word)-1 and word[i+1] == second: if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first+second) new_word.append(first + second)
i += 2 i += 2
else: else:
new_word.append(word[i]) new_word.append(word[i])
...@@ -151,9 +153,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): ...@@ -151,9 +153,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
break break
else: else:
pairs = get_pairs(word) pairs = get_pairs(word)
word = ' '.join(word) word = " ".join(word)
if word == '\n </w>': if word == "\n </w>":
word = '\n</w>' word = "\n</w>"
self.cache[token] = word self.cache[token] = word
return word return word
...@@ -164,12 +166,12 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): ...@@ -164,12 +166,12 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
# Using BERT's BasicTokenizer # Using BERT's BasicTokenizer
text = self.nlp.tokenize(text) text = self.nlp.tokenize(text)
for token in text: for token in text:
split_tokens.extend([t for t in self.bpe(token).split(' ')]) split_tokens.extend([t for t in self.bpe(token).split(" ")])
else: else:
# Using SpaCy & ftfy (original tokenization process of OpenAI GPT) # Using SpaCy & ftfy (original tokenization process of OpenAI GPT)
text = self.nlp(text_standardize(self.fix_text(text))) text = self.nlp(text_standardize(self.fix_text(text)))
for token in text: for token in text:
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')]) split_tokens.extend([t for t in self.bpe(token.text.lower()).split(" ")])
return split_tokens return split_tokens
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
...@@ -182,7 +184,7 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): ...@@ -182,7 +184,7 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """ """ Converts a sequence of tokens (string) in a single string. """
out_string = ''.join(tokens).replace('</w>', ' ').strip() out_string = "".join(tokens).replace("</w>", " ").strip()
return out_string return out_string
def save_vocabulary(self, save_directory): def save_vocabulary(self, save_directory):
...@@ -190,21 +192,23 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): ...@@ -190,21 +192,23 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return return
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file']) merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES["merges_file"])
with open(vocab_file, 'w', encoding='utf-8') as f: with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False)) f.write(json.dumps(self.encoder, ensure_ascii=False))
index = 0 index = 0
with open(merge_file, "w", encoding="utf-8") as writer: with open(merge_file, "w", encoding="utf-8") as writer:
writer.write(u'#version: 0.2\n') writer.write("#version: 0.2\n")
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index: if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." logger.warning(
" Please check that the tokenizer is not corrupted!".format(merge_file)) "Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file)
)
index = token_index index = token_index
writer.write(' '.join(bpe_tokens) + u'\n') writer.write(" ".join(bpe_tokens) + "\n")
index += 1 index += 1
return vocab_file, merge_file return vocab_file, merge_file
...@@ -13,8 +13,7 @@ ...@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tokenization classes for RoBERTa.""" """Tokenization classes for RoBERTa."""
from __future__ import (absolute_import, division, print_function, from __future__ import absolute_import, division, print_function, unicode_literals
unicode_literals)
import sys import sys
import json import json
...@@ -33,41 +32,40 @@ except ImportError: ...@@ -33,41 +32,40 @@ except ImportError:
def lru_cache(): def lru_cache():
return lambda func: func return lambda func: func
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = { VOCAB_FILES_NAMES = {
'vocab_file': 'vocab.json', "vocab_file": "vocab.json",
'merges_file': 'merges.txt', "merges_file": "merges.txt",
} }
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file': "vocab_file": {
{ "roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json",
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json", "roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json",
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json", "roberta-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-vocab.json",
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-vocab.json", "distilroberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-vocab.json",
'distilroberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-vocab.json", "roberta-base-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json",
'roberta-base-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json", "roberta-large-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json",
'roberta-large-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json",
}, },
'merges_file': "merges_file": {
{ "roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt",
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt", "roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt",
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt", "roberta-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-merges.txt",
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-merges.txt", "distilroberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-merges.txt",
'distilroberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-merges.txt", "roberta-base-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt",
'roberta-base-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt", "roberta-large-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt",
'roberta-large-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt",
}, },
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'roberta-base': 512, "roberta-base": 512,
'roberta-large': 512, "roberta-large": 512,
'roberta-large-mnli': 512, "roberta-large-mnli": 512,
'distilroberta-base': 512, "distilroberta-base": 512,
'roberta-base-openai-detector': 512, "roberta-base-openai-detector": 512,
'roberta-large-openai-detector': 512, "roberta-large-openai-detector": 512,
} }
...@@ -80,16 +78,38 @@ class RobertaTokenizer(GPT2Tokenizer): ...@@ -80,16 +78,38 @@ class RobertaTokenizer(GPT2Tokenizer):
Otherwise, this tokenizer ``encode`` and ``decode`` method will not conserve Otherwise, this tokenizer ``encode`` and ``decode`` method will not conserve
the absence of a space at the beginning of a string: `tokenizer.decode(tokenizer.encode("Hello")) = " Hello"` the absence of a space at the beginning of a string: `tokenizer.decode(tokenizer.encode("Hello")) = " Hello"`
""" """
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, merges_file, errors='replace', bos_token="<s>", eos_token="</s>", sep_token="</s>", def __init__(
cls_token="<s>", unk_token="<unk>", pad_token='<pad>', mask_token='<mask>', **kwargs): self,
super(RobertaTokenizer, self).__init__(vocab_file=vocab_file, merges_file=merges_file, errors=errors, vocab_file,
bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, merges_file,
sep_token=sep_token, cls_token=cls_token, pad_token=pad_token, errors="replace",
mask_token=mask_token, **kwargs) bos_token="<s>",
eos_token="</s>",
sep_token="</s>",
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
**kwargs
):
super(RobertaTokenizer, self).__init__(
vocab_file=vocab_file,
merges_file=merges_file,
errors=errors,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
**kwargs
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens
...@@ -124,8 +144,10 @@ class RobertaTokenizer(GPT2Tokenizer): ...@@ -124,8 +144,10 @@ class RobertaTokenizer(GPT2Tokenizer):
""" """
if already_has_special_tokens: if already_has_special_tokens:
if token_ids_1 is not None: if token_ids_1 is not None:
raise ValueError("You should not supply a second sequence if the provided sequence of " raise ValueError(
"ids is already formated with special tokens for the model.") "You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
if token_ids_1 is None: if token_ids_1 is None:
......
...@@ -26,26 +26,25 @@ from .tokenization_utils import PreTrainedTokenizer ...@@ -26,26 +26,25 @@ from .tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SPIECE_UNDERLINE = u'▁' SPIECE_UNDERLINE = "▁"
#################################################### ####################################################
# Mapping from the keyword arguments names of Tokenizer `__init__` # Mapping from the keyword arguments names of Tokenizer `__init__`
# to file names for serializing Tokenizer instances # to file names for serializing Tokenizer instances
#################################################### ####################################################
VOCAB_FILES_NAMES = {'vocab_file': 'spiece.model'} VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
#################################################### ####################################################
# Mapping from the keyword arguments names of Tokenizer `__init__` # Mapping from the keyword arguments names of Tokenizer `__init__`
# to pretrained vocabulary URL for all the model shortcut names. # to pretrained vocabulary URL for all the model shortcut names.
#################################################### ####################################################
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file': "vocab_file": {
{ "t5-small": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
't5-small': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model", "t5-base": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
't5-base': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model", "t5-large": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
't5-large': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model", "t5-3b": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
't5-3b': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model", "t5-11b": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
't5-11b': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
} }
} }
...@@ -53,13 +52,14 @@ PRETRAINED_VOCAB_FILES_MAP = { ...@@ -53,13 +52,14 @@ PRETRAINED_VOCAB_FILES_MAP = {
# Mapping from model shortcut names to max length of inputs # Mapping from model shortcut names to max length of inputs
#################################################### ####################################################
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
't5-small': 512, "t5-small": 512,
't5-base': 512, "t5-base": 512,
't5-large': 512, "t5-large": 512,
't5-3b': 512, "t5-3b": 512,
't5-11b': 512, "t5-11b": 512,
} }
class T5Tokenizer(PreTrainedTokenizer): class T5Tokenizer(PreTrainedTokenizer):
""" """
SentencePiece based tokenizer. Peculiarities: SentencePiece based tokenizer. Peculiarities:
...@@ -71,28 +71,43 @@ class T5Tokenizer(PreTrainedTokenizer): ...@@ -71,28 +71,43 @@ class T5Tokenizer(PreTrainedTokenizer):
(like in T5 preprocessing (like in T5 preprocessing
see: https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117) see: https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117)
""" """
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, eos_token="</s>", unk_token="<unk>", def __init__(
pad_token="<pad>", extra_ids=100, additional_special_tokens=None, **kwargs): self,
vocab_file,
eos_token="</s>",
unk_token="<unk>",
pad_token="<pad>",
extra_ids=100,
additional_special_tokens=None,
**kwargs
):
# Add extra_ids to the special token list # Add extra_ids to the special token list
if extra_ids > 0: if extra_ids > 0:
if additional_special_tokens is None: if additional_special_tokens is None:
additional_special_tokens = [] additional_special_tokens = []
additional_special_tokens.extend([u"<extra_id_{}>".format(i) for i in range(extra_ids)]) additional_special_tokens.extend(["<extra_id_{}>".format(i) for i in range(extra_ids)])
super(T5Tokenizer, self).__init__(eos_token=eos_token, unk_token=unk_token, super(T5Tokenizer, self).__init__(
pad_token=pad_token, additional_special_tokens=additional_special_tokens, eos_token=eos_token,
**kwargs) unk_token=unk_token,
pad_token=pad_token,
additional_special_tokens=additional_special_tokens,
**kwargs
)
try: try:
import sentencepiece as spm import sentencepiece as spm
except ImportError: except ImportError:
logger.warning("You need to install SentencePiece to use T5Tokenizer:" logger.warning(
"https://github.com/google/sentencepiece" "You need to install SentencePiece to use T5Tokenizer:"
"pip install sentencepiece") "https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
self.vocab_file = vocab_file self.vocab_file = vocab_file
self._extra_ids = extra_ids self._extra_ids = extra_ids
...@@ -114,8 +129,10 @@ class T5Tokenizer(PreTrainedTokenizer): ...@@ -114,8 +129,10 @@ class T5Tokenizer(PreTrainedTokenizer):
try: try:
import sentencepiece as spm import sentencepiece as spm
except ImportError: except ImportError:
logger.warning("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece" logger.warning(
"pip install sentencepiece") "You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
self.sp_model = spm.SentencePieceProcessor() self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(self.vocab_file) self.sp_model.Load(self.vocab_file)
...@@ -132,7 +149,7 @@ class T5Tokenizer(PreTrainedTokenizer): ...@@ -132,7 +149,7 @@ class T5Tokenizer(PreTrainedTokenizer):
ret_pieces = [] ret_pieces = []
for piece in pieces: for piece in pieces:
if isinstance(piece, str): if isinstance(piece, str):
piece = piece.decode('utf-8') piece = piece.decode("utf-8")
ret_pieces.append(piece) ret_pieces.append(piece)
pieces = ret_pieces pieces = ret_pieces
...@@ -140,8 +157,8 @@ class T5Tokenizer(PreTrainedTokenizer): ...@@ -140,8 +157,8 @@ class T5Tokenizer(PreTrainedTokenizer):
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
""" Converts a token (str/unicode) in an id using the vocab. """ """ Converts a token (str/unicode) in an id using the vocab. """
if token.startswith(u"<extra_id_"): if token.startswith("<extra_id_"):
l = re.match(r'<extra_id_(\d+)>', token) l = re.match(r"<extra_id_(\d+)>", token)
num = int(l.group(1)) num = int(l.group(1))
return self.vocab_size - num - 1 return self.vocab_size - num - 1
return self.sp_model.piece_to_id(token) return self.sp_model.piece_to_id(token)
...@@ -151,9 +168,9 @@ class T5Tokenizer(PreTrainedTokenizer): ...@@ -151,9 +168,9 @@ class T5Tokenizer(PreTrainedTokenizer):
if index < self.sp_model.get_piece_size(): if index < self.sp_model.get_piece_size():
token = self.sp_model.IdToPiece(index) token = self.sp_model.IdToPiece(index)
else: else:
token = u"<extra_id_{}>".format(self.vocab_size - 1 - index) token = "<extra_id_{}>".format(self.vocab_size - 1 - index)
if six.PY2 and return_unicode and isinstance(token, str): if six.PY2 and return_unicode and isinstance(token, str):
token = token.decode('utf-8') token = token.decode("utf-8")
return token return token
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
...@@ -168,7 +185,7 @@ class T5Tokenizer(PreTrainedTokenizer): ...@@ -168,7 +185,7 @@ class T5Tokenizer(PreTrainedTokenizer):
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return return
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
......
...@@ -16,8 +16,7 @@ ...@@ -16,8 +16,7 @@
""" Tokenization classes for Transformer XL model. """ Tokenization classes for Transformer XL model.
Adapted from https://github.com/kimiyoung/transformer-xl. Adapted from https://github.com/kimiyoung/transformer-xl.
""" """
from __future__ import (absolute_import, division, print_function, from __future__ import absolute_import, division, print_function, unicode_literals
unicode_literals)
import glob import glob
import logging import logging
...@@ -44,42 +43,58 @@ except ImportError: ...@@ -44,42 +43,58 @@ except ImportError:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {'pretrained_vocab_file': 'vocab.bin', 'vocab_file': 'vocab.txt'} VOCAB_FILES_NAMES = {"pretrained_vocab_file": "vocab.bin", "vocab_file": "vocab.txt"}
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
'pretrained_vocab_file': "pretrained_vocab_file": {
{ "transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin",
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin",
} }
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'transfo-xl-wt103': None, "transfo-xl-wt103": None,
} }
PRETRAINED_CORPUS_ARCHIVE_MAP = { PRETRAINED_CORPUS_ARCHIVE_MAP = {
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-corpus.bin", "transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-corpus.bin",
} }
CORPUS_NAME = 'corpus.bin' CORPUS_NAME = "corpus.bin"
class TransfoXLTokenizer(PreTrainedTokenizer): class TransfoXLTokenizer(PreTrainedTokenizer):
""" """
Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl
""" """
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, special=None, min_freq=0, max_size=None, lower_case=False, def __init__(
delimiter=None, vocab_file=None, pretrained_vocab_file=None, self,
never_split=None, unk_token="<unk>", eos_token="<eos>", special=None,
additional_special_tokens=["<formula>"], **kwargs): min_freq=0,
super(TransfoXLTokenizer, self).__init__(unk_token=unk_token, eos_token=eos_token, max_size=None,
additional_special_tokens=additional_special_tokens, lower_case=False,
**kwargs) delimiter=None,
vocab_file=None,
self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens pretrained_vocab_file=None,
self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens never_split=None,
unk_token="<unk>",
eos_token="<eos>",
additional_special_tokens=["<formula>"],
**kwargs
):
super(TransfoXLTokenizer, self).__init__(
unk_token=unk_token, eos_token=eos_token, additional_special_tokens=additional_special_tokens, **kwargs
)
self.max_len_single_sentence = (
self.max_len
) # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = (
self.max_len
) # no default special tokens - you can update this value if you add special tokens
if never_split is None: if never_split is None:
never_split = self.all_special_tokens never_split = self.all_special_tokens
...@@ -106,14 +121,15 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -106,14 +121,15 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
self.build_vocab() self.build_vocab()
def count_file(self, path, verbose=False, add_eos=False): def count_file(self, path, verbose=False, add_eos=False):
if verbose: logger.info('counting file {} ...'.format(path)) if verbose:
logger.info("counting file {} ...".format(path))
assert os.path.exists(path) assert os.path.exists(path)
sents = [] sents = []
with open(path, 'r', encoding='utf-8') as f: with open(path, "r", encoding="utf-8") as f:
for idx, line in enumerate(f): for idx, line in enumerate(f):
if verbose and idx > 0 and idx % 500000 == 0: if verbose and idx > 0 and idx % 500000 == 0:
logger.info(' line {}'.format(idx)) logger.info(" line {}".format(idx))
symbols = self.tokenize(line, add_eos=add_eos) symbols = self.tokenize(line, add_eos=add_eos)
self.counter.update(symbols) self.counter.update(symbols)
sents.append(symbols) sents.append(symbols)
...@@ -124,42 +140,42 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -124,42 +140,42 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
""" """
sents : a list of sentences, each a list of tokenized symbols sents : a list of sentences, each a list of tokenized symbols
""" """
if verbose: logger.info('counting {} sents ...'.format(len(sents))) if verbose:
logger.info("counting {} sents ...".format(len(sents)))
for idx, symbols in enumerate(sents): for idx, symbols in enumerate(sents):
if verbose and idx > 0 and idx % 500000 == 0: if verbose and idx > 0 and idx % 500000 == 0:
logger.info(' line {}'.format(idx)) logger.info(" line {}".format(idx))
self.counter.update(symbols) self.counter.update(symbols)
def _build_from_file(self, vocab_file): def _build_from_file(self, vocab_file):
self.idx2sym = [] self.idx2sym = []
self.sym2idx = OrderedDict() self.sym2idx = OrderedDict()
with open(vocab_file, 'r', encoding='utf-8') as f: with open(vocab_file, "r", encoding="utf-8") as f:
for line in f: for line in f:
symb = line.strip().split()[0] symb = line.strip().split()[0]
self.add_symbol(symb) self.add_symbol(symb)
if '<UNK>' in self.sym2idx: if "<UNK>" in self.sym2idx:
self.unk_idx = self.sym2idx['<UNK>'] self.unk_idx = self.sym2idx["<UNK>"]
elif '<unk>' in self.sym2idx: elif "<unk>" in self.sym2idx:
self.unk_idx = self.sym2idx['<unk>'] self.unk_idx = self.sym2idx["<unk>"]
else: else:
raise ValueError('No <unkown> token in vocabulary') raise ValueError("No <unkown> token in vocabulary")
def save_vocabulary(self, vocab_path): def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary to a directory or file.""" """Save the tokenizer vocabulary to a directory or file."""
if os.path.isdir(vocab_path): if os.path.isdir(vocab_path):
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['pretrained_vocab_file']) vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["pretrained_vocab_file"])
torch.save(self.__dict__, vocab_file) torch.save(self.__dict__, vocab_file)
return (vocab_file,) return (vocab_file,)
def build_vocab(self): def build_vocab(self):
if self.vocab_file: if self.vocab_file:
logger.info('building vocab from {}'.format(self.vocab_file)) logger.info("building vocab from {}".format(self.vocab_file))
self._build_from_file(self.vocab_file) self._build_from_file(self.vocab_file)
logger.info('final vocab size {}'.format(len(self))) logger.info("final vocab size {}".format(len(self)))
else: else:
logger.info('building vocab with min_freq={}, max_size={}'.format( logger.info("building vocab with min_freq={}, max_size={}".format(self.min_freq, self.max_size))
self.min_freq, self.max_size))
self.idx2sym = [] self.idx2sym = []
self.sym2idx = OrderedDict() self.sym2idx = OrderedDict()
...@@ -167,23 +183,22 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -167,23 +183,22 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
self.add_special(sym) self.add_special(sym)
for sym, cnt in self.counter.most_common(self.max_size): for sym, cnt in self.counter.most_common(self.max_size):
if cnt < self.min_freq: break if cnt < self.min_freq:
break
self.add_symbol(sym) self.add_symbol(sym)
logger.info('final vocab size {} from {} unique tokens'.format( logger.info("final vocab size {} from {} unique tokens".format(len(self), len(self.counter)))
len(self), len(self.counter)))
def encode_file(self, path, ordered=False, verbose=False, add_eos=True, def encode_file(self, path, ordered=False, verbose=False, add_eos=True, add_double_eos=False):
add_double_eos=False): if verbose:
if verbose: logger.info('encoding file {} ...'.format(path)) logger.info("encoding file {} ...".format(path))
assert os.path.exists(path) assert os.path.exists(path)
encoded = [] encoded = []
with open(path, 'r', encoding='utf-8') as f: with open(path, "r", encoding="utf-8") as f:
for idx, line in enumerate(f): for idx, line in enumerate(f):
if verbose and idx > 0 and idx % 500000 == 0: if verbose and idx > 0 and idx % 500000 == 0:
logger.info(' line {}'.format(idx)) logger.info(" line {}".format(idx))
symbols = self.tokenize(line, add_eos=add_eos, symbols = self.tokenize(line, add_eos=add_eos, add_double_eos=add_double_eos)
add_double_eos=add_double_eos)
encoded.append(self.convert_to_tensor(symbols)) encoded.append(self.convert_to_tensor(symbols))
if ordered: if ordered:
...@@ -192,11 +207,12 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -192,11 +207,12 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
return encoded return encoded
def encode_sents(self, sents, ordered=False, verbose=False): def encode_sents(self, sents, ordered=False, verbose=False):
if verbose: logger.info('encoding {} sents ...'.format(len(sents))) if verbose:
logger.info("encoding {} sents ...".format(len(sents)))
encoded = [] encoded = []
for idx, symbols in enumerate(sents): for idx, symbols in enumerate(sents):
if verbose and idx > 0 and idx % 500000 == 0: if verbose and idx > 0 and idx % 500000 == 0:
logger.info(' line {}'.format(idx)) logger.info(" line {}".format(idx))
encoded.append(self.convert_to_tensor(symbols)) encoded.append(self.convert_to_tensor(symbols))
if ordered: if ordered:
...@@ -208,7 +224,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -208,7 +224,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
if sym not in self.sym2idx: if sym not in self.sym2idx:
self.idx2sym.append(sym) self.idx2sym.append(sym)
self.sym2idx[sym] = len(self.idx2sym) - 1 self.sym2idx[sym] = len(self.idx2sym) - 1
setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym]) setattr(self, "{}_idx".format(sym.strip("<>")), self.sym2idx[sym])
def add_symbol(self, sym): def add_symbol(self, sym):
if sym not in self.sym2idx: if sym not in self.sym2idx:
...@@ -217,7 +233,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -217,7 +233,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
def _convert_id_to_token(self, idx): def _convert_id_to_token(self, idx):
"""Converts an id in a token (BPE) using the vocab.""" """Converts an id in a token (BPE) using the vocab."""
assert 0 <= idx < len(self), 'Index {} out of vocabulary range'.format(idx) assert 0 <= idx < len(self), "Index {} out of vocabulary range".format(idx)
return self.idx2sym[idx] return self.idx2sym[idx]
def _convert_token_to_id(self, sym): def _convert_token_to_id(self, sym):
...@@ -227,19 +243,19 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -227,19 +243,19 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
else: else:
# logger.info('encounter unk {}'.format(sym)) # logger.info('encounter unk {}'.format(sym))
# assert '<eos>' not in sym # assert '<eos>' not in sym
if hasattr(self, 'unk_idx'): if hasattr(self, "unk_idx"):
return self.sym2idx.get(sym, self.unk_idx) return self.sym2idx.get(sym, self.unk_idx)
# Backward compatibility with pre-trained models # Backward compatibility with pre-trained models
elif '<unk>' in self.sym2idx: elif "<unk>" in self.sym2idx:
return self.sym2idx['<unk>'] return self.sym2idx["<unk>"]
elif '<UNK>' in self.sym2idx: elif "<UNK>" in self.sym2idx:
return self.sym2idx['<UNK>'] return self.sym2idx["<UNK>"]
else: else:
raise ValueError('Token not in vocabulary and no <unk> token in vocabulary for replacement') raise ValueError("Token not in vocabulary and no <unk> token in vocabulary for replacement")
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """ """ Converts a sequence of tokens (string) in a single string. """
out_string = ' '.join(tokens).strip() out_string = " ".join(tokens).strip()
return out_string return out_string
def convert_to_tensor(self, symbols): def convert_to_tensor(self, symbols):
...@@ -256,21 +272,21 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -256,21 +272,21 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
line = line.lower() line = line.lower()
# empty delimiter '' will evaluate False # empty delimiter '' will evaluate False
if self.delimiter == '': if self.delimiter == "":
symbols = line symbols = line
else: else:
symbols = line.split(self.delimiter) symbols = line.split(self.delimiter)
if add_double_eos: # lm1b if add_double_eos: # lm1b
return ['<S>'] + symbols + ['<S>'] return ["<S>"] + symbols + ["<S>"]
elif add_eos: elif add_eos:
return symbols + ['<eos>'] return symbols + ["<eos>"]
else: else:
return symbols return symbols
class LMOrderedIterator(object): class LMOrderedIterator(object):
def __init__(self, data, bsz, bptt, device='cpu', ext_len=None): def __init__(self, data, bsz, bptt, device="cpu", ext_len=None):
""" """
data -- LongTensor -- the LongTensor is strictly ordered data -- LongTensor -- the LongTensor is strictly ordered
""" """
...@@ -293,14 +309,15 @@ class LMOrderedIterator(object): ...@@ -293,14 +309,15 @@ class LMOrderedIterator(object):
self.n_batch = (self.n_step + self.bptt - 1) // self.bptt self.n_batch = (self.n_step + self.bptt - 1) // self.bptt
def get_batch(self, i, bptt=None): def get_batch(self, i, bptt=None):
if bptt is None: bptt = self.bptt if bptt is None:
bptt = self.bptt
seq_len = min(bptt, self.data.size(0) - 1 - i) seq_len = min(bptt, self.data.size(0) - 1 - i)
end_idx = i + seq_len end_idx = i + seq_len
beg_idx = max(0, i - self.ext_len) beg_idx = max(0, i - self.ext_len)
data = self.data[beg_idx:end_idx] data = self.data[beg_idx:end_idx]
target = self.data[i+1:i+1+seq_len] target = self.data[i + 1 : i + 1 + seq_len]
data_out = data.transpose(0, 1).contiguous().to(self.device) data_out = data.transpose(0, 1).contiguous().to(self.device)
target_out = target.transpose(0, 1).contiguous().to(self.device) target_out = target.transpose(0, 1).contiguous().to(self.device)
...@@ -315,7 +332,7 @@ class LMOrderedIterator(object): ...@@ -315,7 +332,7 @@ class LMOrderedIterator(object):
max_len = self.bptt + max_deviation * std max_len = self.bptt + max_deviation * std
i = start i = start
while True: while True:
bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2. bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2.0
bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std)))) bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std))))
data, target, seq_len = self.get_batch(i, bptt) data, target, seq_len = self.get_batch(i, bptt)
i += seq_len i += seq_len
...@@ -328,7 +345,7 @@ class LMOrderedIterator(object): ...@@ -328,7 +345,7 @@ class LMOrderedIterator(object):
class LMShuffledIterator(object): class LMShuffledIterator(object):
def __init__(self, data, bsz, bptt, device='cpu', ext_len=None, shuffle=False): def __init__(self, data, bsz, bptt, device="cpu", ext_len=None, shuffle=False):
""" """
data -- list[LongTensor] -- there is no order among the LongTensors data -- list[LongTensor] -- there is no order among the LongTensors
""" """
...@@ -343,8 +360,7 @@ class LMShuffledIterator(object): ...@@ -343,8 +360,7 @@ class LMShuffledIterator(object):
def get_sent_stream(self): def get_sent_stream(self):
# index iterator # index iterator
epoch_indices = np.random.permutation(len(self.data)) if self.shuffle \ epoch_indices = np.random.permutation(len(self.data)) if self.shuffle else np.array(range(len(self.data)))
else np.array(range(len(self.data)))
# sentence iterator # sentence iterator
for idx in epoch_indices: for idx in epoch_indices:
...@@ -376,10 +392,8 @@ class LMShuffledIterator(object): ...@@ -376,10 +392,8 @@ class LMShuffledIterator(object):
# number of new tokens to fill in # number of new tokens to fill in
n_new = min(len(streams[i]) - 1, self.bptt - n_filled) n_new = min(len(streams[i]) - 1, self.bptt - n_filled)
# first n_retain tokens are retained from last batch # first n_retain tokens are retained from last batch
data[n_retain+n_filled:n_retain+n_filled+n_new, i] = \ data[n_retain + n_filled : n_retain + n_filled + n_new, i] = streams[i][:n_new]
streams[i][:n_new] target[n_filled : n_filled + n_new, i] = streams[i][1 : n_new + 1]
target[n_filled:n_filled+n_new, i] = \
streams[i][1:n_new+1]
streams[i] = streams[i][n_new:] streams[i] = streams[i][n_new:]
n_filled += n_new n_filled += n_new
except StopIteration: except StopIteration:
...@@ -408,8 +422,7 @@ class LMShuffledIterator(object): ...@@ -408,8 +422,7 @@ class LMShuffledIterator(object):
class LMMultiFileIterator(LMShuffledIterator): class LMMultiFileIterator(LMShuffledIterator):
def __init__(self, paths, vocab, bsz, bptt, device='cpu', ext_len=None, def __init__(self, paths, vocab, bsz, bptt, device="cpu", ext_len=None, shuffle=False):
shuffle=False):
self.paths = paths self.paths = paths
self.vocab = vocab self.vocab = vocab
...@@ -460,15 +473,16 @@ class TransfoXLCorpus(object): ...@@ -460,15 +473,16 @@ class TransfoXLCorpus(object):
"We assumed '{}' was a path or url but couldn't find files {} " "We assumed '{}' was a path or url but couldn't find files {} "
"at this path or url.".format( "at this path or url.".format(
pretrained_model_name_or_path, pretrained_model_name_or_path,
', '.join(PRETRAINED_CORPUS_ARCHIVE_MAP.keys()), ", ".join(PRETRAINED_CORPUS_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path, pretrained_model_name_or_path,
corpus_file)) corpus_file,
)
)
return None return None
if resolved_corpus_file == corpus_file: if resolved_corpus_file == corpus_file:
logger.info("loading corpus file {}".format(corpus_file)) logger.info("loading corpus file {}".format(corpus_file))
else: else:
logger.info("loading corpus file {} from cache at {}".format( logger.info("loading corpus file {} from cache at {}".format(corpus_file, resolved_corpus_file))
corpus_file, resolved_corpus_file))
# Instantiate tokenizer. # Instantiate tokenizer.
corpus = cls(*inputs, **kwargs) corpus = cls(*inputs, **kwargs)
...@@ -494,83 +508,78 @@ class TransfoXLCorpus(object): ...@@ -494,83 +508,78 @@ class TransfoXLCorpus(object):
def build_corpus(self, path, dataset): def build_corpus(self, path, dataset):
self.dataset = dataset self.dataset = dataset
if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']: if self.dataset in ["ptb", "wt2", "enwik8", "text8"]:
self.vocab.count_file(os.path.join(path, 'train.txt')) self.vocab.count_file(os.path.join(path, "train.txt"))
self.vocab.count_file(os.path.join(path, 'valid.txt')) self.vocab.count_file(os.path.join(path, "valid.txt"))
self.vocab.count_file(os.path.join(path, 'test.txt')) self.vocab.count_file(os.path.join(path, "test.txt"))
elif self.dataset == 'wt103': elif self.dataset == "wt103":
self.vocab.count_file(os.path.join(path, 'train.txt')) self.vocab.count_file(os.path.join(path, "train.txt"))
elif self.dataset == 'lm1b': elif self.dataset == "lm1b":
train_path_pattern = os.path.join( train_path_pattern = os.path.join(
path, '1-billion-word-language-modeling-benchmark-r13output', path,
'training-monolingual.tokenized.shuffled', 'news.en-*') "1-billion-word-language-modeling-benchmark-r13output",
"training-monolingual.tokenized.shuffled",
"news.en-*",
)
train_paths = glob.glob(train_path_pattern) train_paths = glob.glob(train_path_pattern)
# the vocab will load from file when build_vocab() is called # the vocab will load from file when build_vocab() is called
self.vocab.build_vocab() self.vocab.build_vocab()
if self.dataset in ['ptb', 'wt2', 'wt103']: if self.dataset in ["ptb", "wt2", "wt103"]:
self.train = self.vocab.encode_file( self.train = self.vocab.encode_file(os.path.join(path, "train.txt"), ordered=True)
os.path.join(path, 'train.txt'), ordered=True) self.valid = self.vocab.encode_file(os.path.join(path, "valid.txt"), ordered=True)
self.valid = self.vocab.encode_file( self.test = self.vocab.encode_file(os.path.join(path, "test.txt"), ordered=True)
os.path.join(path, 'valid.txt'), ordered=True) elif self.dataset in ["enwik8", "text8"]:
self.test = self.vocab.encode_file( self.train = self.vocab.encode_file(os.path.join(path, "train.txt"), ordered=True, add_eos=False)
os.path.join(path, 'test.txt'), ordered=True) self.valid = self.vocab.encode_file(os.path.join(path, "valid.txt"), ordered=True, add_eos=False)
elif self.dataset in ['enwik8', 'text8']: self.test = self.vocab.encode_file(os.path.join(path, "test.txt"), ordered=True, add_eos=False)
self.train = self.vocab.encode_file( elif self.dataset == "lm1b":
os.path.join(path, 'train.txt'), ordered=True, add_eos=False)
self.valid = self.vocab.encode_file(
os.path.join(path, 'valid.txt'), ordered=True, add_eos=False)
self.test = self.vocab.encode_file(
os.path.join(path, 'test.txt'), ordered=True, add_eos=False)
elif self.dataset == 'lm1b':
self.train = train_paths self.train = train_paths
self.valid = self.vocab.encode_file( self.valid = self.vocab.encode_file(os.path.join(path, "valid.txt"), ordered=False, add_double_eos=True)
os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True) self.test = self.vocab.encode_file(os.path.join(path, "test.txt"), ordered=False, add_double_eos=True)
self.test = self.vocab.encode_file(
os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True)
def get_iterator(self, split, *args, **kwargs): def get_iterator(self, split, *args, **kwargs):
if split == 'train': if split == "train":
if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']: if self.dataset in ["ptb", "wt2", "wt103", "enwik8", "text8"]:
data_iter = LMOrderedIterator(self.train, *args, **kwargs) data_iter = LMOrderedIterator(self.train, *args, **kwargs)
elif self.dataset == 'lm1b': elif self.dataset == "lm1b":
kwargs['shuffle'] = True kwargs["shuffle"] = True
data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs) data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs)
elif split in ['valid', 'test']: elif split in ["valid", "test"]:
data = self.valid if split == 'valid' else self.test data = self.valid if split == "valid" else self.test
if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']: if self.dataset in ["ptb", "wt2", "wt103", "enwik8", "text8"]:
data_iter = LMOrderedIterator(data, *args, **kwargs) data_iter = LMOrderedIterator(data, *args, **kwargs)
elif self.dataset == 'lm1b': elif self.dataset == "lm1b":
data_iter = LMShuffledIterator(data, *args, **kwargs) data_iter = LMShuffledIterator(data, *args, **kwargs)
return data_iter return data_iter
def get_lm_corpus(datadir, dataset): def get_lm_corpus(datadir, dataset):
fn = os.path.join(datadir, 'cache.pt') fn = os.path.join(datadir, "cache.pt")
fn_pickle = os.path.join(datadir, 'cache.pkl') fn_pickle = os.path.join(datadir, "cache.pkl")
if os.path.exists(fn): if os.path.exists(fn):
logger.info('Loading cached dataset...') logger.info("Loading cached dataset...")
corpus = torch.load(fn_pickle) corpus = torch.load(fn_pickle)
elif os.path.exists(fn): elif os.path.exists(fn):
logger.info('Loading cached dataset from pickle...') logger.info("Loading cached dataset from pickle...")
with open(fn, "rb") as fp: with open(fn, "rb") as fp:
corpus = pickle.load(fp) corpus = pickle.load(fp)
else: else:
logger.info('Producing dataset {}...'.format(dataset)) logger.info("Producing dataset {}...".format(dataset))
kwargs = {} kwargs = {}
if dataset in ['wt103', 'wt2']: if dataset in ["wt103", "wt2"]:
kwargs['special'] = ['<eos>'] kwargs["special"] = ["<eos>"]
kwargs['lower_case'] = False kwargs["lower_case"] = False
elif dataset == 'ptb': elif dataset == "ptb":
kwargs['special'] = ['<eos>'] kwargs["special"] = ["<eos>"]
kwargs['lower_case'] = True kwargs["lower_case"] = True
elif dataset == 'lm1b': elif dataset == "lm1b":
kwargs['special'] = [] kwargs["special"] = []
kwargs['lower_case'] = False kwargs["lower_case"] = False
kwargs['vocab_file'] = os.path.join(datadir, '1b_word_vocab.txt') kwargs["vocab_file"] = os.path.join(datadir, "1b_word_vocab.txt")
elif dataset in ['enwik8', 'text8']: elif dataset in ["enwik8", "text8"]:
pass pass
corpus = TransfoXLCorpus(datadir, dataset, **kwargs) corpus = TransfoXLCorpus(datadir, dataset, **kwargs)
......
...@@ -13,8 +13,7 @@ ...@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tokenization classes for OpenAI GPT.""" """Tokenization classes for OpenAI GPT."""
from __future__ import (absolute_import, division, print_function, from __future__ import absolute_import, division, print_function, unicode_literals
unicode_literals)
import logging import logging
import os import os
...@@ -34,9 +33,10 @@ if is_torch_available(): ...@@ -34,9 +33,10 @@ if is_torch_available():
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SPECIAL_TOKENS_MAP_FILE = 'special_tokens_map.json' SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
ADDED_TOKENS_FILE = 'added_tokens.json' ADDED_TOKENS_FILE = "added_tokens.json"
TOKENIZER_CONFIG_FILE = 'tokenizer_config.json' TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
class PreTrainedTokenizer(object): class PreTrainedTokenizer(object):
""" Base class for all tokenizers. """ Base class for all tokenizers.
...@@ -69,14 +69,22 @@ class PreTrainedTokenizer(object): ...@@ -69,14 +69,22 @@ class PreTrainedTokenizer(object):
- ``additional_special_tokens``: (`Optional`) list: a list of additional special tokens. Adding all special tokens here ensure they won't be split by the tokenization process. Will be associated to ``self.additional_special_tokens`` and ``self.additional_special_tokens_ids`` - ``additional_special_tokens``: (`Optional`) list: a list of additional special tokens. Adding all special tokens here ensure they won't be split by the tokenization process. Will be associated to ``self.additional_special_tokens`` and ``self.additional_special_tokens_ids``
""" """
vocab_files_names = {} vocab_files_names = {}
pretrained_vocab_files_map = {} pretrained_vocab_files_map = {}
pretrained_init_configuration = {} pretrained_init_configuration = {}
max_model_input_sizes = {} max_model_input_sizes = {}
SPECIAL_TOKENS_ATTRIBUTES = ["bos_token", "eos_token", "unk_token", "sep_token", SPECIAL_TOKENS_ATTRIBUTES = [
"pad_token", "cls_token", "mask_token", "bos_token",
"additional_special_tokens"] "eos_token",
"unk_token",
"sep_token",
"pad_token",
"cls_token",
"mask_token",
"additional_special_tokens",
]
padding_side = "right" padding_side = "right"
...@@ -227,8 +235,8 @@ class PreTrainedTokenizer(object): ...@@ -227,8 +235,8 @@ class PreTrainedTokenizer(object):
self.max_len = max_len if max_len is not None else int(1e12) self.max_len = max_len if max_len is not None else int(1e12)
# Padding side is right by default and over-riden in subclasses. If specified in the kwargs, it is changed. # Padding side is right by default and over-riden in subclasses. If specified in the kwargs, it is changed.
self.padding_side = kwargs.pop('padding_side', self.padding_side) self.padding_side = kwargs.pop("padding_side", self.padding_side)
# Added tokens # Added tokens
self.added_tokens_encoder = {} self.added_tokens_encoder = {}
self.unique_added_tokens_encoder = set() self.unique_added_tokens_encoder = set()
...@@ -240,13 +248,14 @@ class PreTrainedTokenizer(object): ...@@ -240,13 +248,14 @@ class PreTrainedTokenizer(object):
for key, value in kwargs.items(): for key, value in kwargs.items():
if key in self.SPECIAL_TOKENS_ATTRIBUTES: if key in self.SPECIAL_TOKENS_ATTRIBUTES:
if key == 'additional_special_tokens': if key == "additional_special_tokens":
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value) assert isinstance(value, (list, tuple)) and all(
isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value
)
else: else:
assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode)) assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode))
setattr(self, key, value) setattr(self, key, value)
@classmethod @classmethod
def from_pretrained(cls, *inputs, **kwargs): def from_pretrained(cls, *inputs, **kwargs):
r""" r"""
...@@ -302,13 +311,12 @@ class PreTrainedTokenizer(object): ...@@ -302,13 +311,12 @@ class PreTrainedTokenizer(object):
""" """
return cls._from_pretrained(*inputs, **kwargs) return cls._from_pretrained(*inputs, **kwargs)
@classmethod @classmethod
def _from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs): def _from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs):
cache_dir = kwargs.pop('cache_dir', None) cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop('force_download', False) force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop('resume_download', False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop('proxies', None) proxies = kwargs.pop("proxies", None)
s3_models = list(cls.max_model_input_sizes.keys()) s3_models = list(cls.max_model_input_sizes.keys())
vocab_files = {} vocab_files = {}
...@@ -317,15 +325,19 @@ class PreTrainedTokenizer(object): ...@@ -317,15 +325,19 @@ class PreTrainedTokenizer(object):
# Get the vocabulary from AWS S3 bucket # Get the vocabulary from AWS S3 bucket
for file_id, map_list in cls.pretrained_vocab_files_map.items(): for file_id, map_list in cls.pretrained_vocab_files_map.items():
vocab_files[file_id] = map_list[pretrained_model_name_or_path] vocab_files[file_id] = map_list[pretrained_model_name_or_path]
if cls.pretrained_init_configuration and pretrained_model_name_or_path in cls.pretrained_init_configuration: if (
cls.pretrained_init_configuration
and pretrained_model_name_or_path in cls.pretrained_init_configuration
):
init_configuration = cls.pretrained_init_configuration[pretrained_model_name_or_path] init_configuration = cls.pretrained_init_configuration[pretrained_model_name_or_path]
else: else:
# Get the vocabulary from local files # Get the vocabulary from local files
logger.info( logger.info(
"Model name '{}' not found in model shortcut name list ({}). " "Model name '{}' not found in model shortcut name list ({}). "
"Assuming '{}' is a path or url to a directory containing tokenizer files.".format( "Assuming '{}' is a path or url to a directory containing tokenizer files.".format(
pretrained_model_name_or_path, ', '.join(s3_models), pretrained_model_name_or_path, ", ".join(s3_models), pretrained_model_name_or_path
pretrained_model_name_or_path)) )
)
# Look for the tokenizer main vocabulary files # Look for the tokenizer main vocabulary files
for file_id, file_name in cls.vocab_files_names.items(): for file_id, file_name in cls.vocab_files_names.items():
...@@ -340,14 +352,15 @@ class PreTrainedTokenizer(object): ...@@ -340,14 +352,15 @@ class PreTrainedTokenizer(object):
full_file_name = pretrained_model_name_or_path full_file_name = pretrained_model_name_or_path
else: else:
full_file_name = hf_bucket_url(pretrained_model_name_or_path, postfix=file_name) full_file_name = hf_bucket_url(pretrained_model_name_or_path, postfix=file_name)
vocab_files[file_id] = full_file_name vocab_files[file_id] = full_file_name
# Look for the additional tokens files # Look for the additional tokens files
additional_files_names = {'added_tokens_file': ADDED_TOKENS_FILE, additional_files_names = {
'special_tokens_map_file': SPECIAL_TOKENS_MAP_FILE, "added_tokens_file": ADDED_TOKENS_FILE,
'tokenizer_config_file': TOKENIZER_CONFIG_FILE, "special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,
} "tokenizer_config_file": TOKENIZER_CONFIG_FILE,
}
# If a path to a file was provided, get the parent directory # If a path to a file was provided, get the parent directory
saved_directory = pretrained_model_name_or_path saved_directory = pretrained_model_name_or_path
...@@ -366,9 +379,12 @@ class PreTrainedTokenizer(object): ...@@ -366,9 +379,12 @@ class PreTrainedTokenizer(object):
"Model name '{}' was not found in tokenizers model name list ({}). " "Model name '{}' was not found in tokenizers model name list ({}). "
"We assumed '{}' was a path or url to a directory containing vocabulary files " "We assumed '{}' was a path or url to a directory containing vocabulary files "
"named {} but couldn't find such vocabulary files at this path or url.".format( "named {} but couldn't find such vocabulary files at this path or url.".format(
pretrained_model_name_or_path, ', '.join(s3_models),
pretrained_model_name_or_path, pretrained_model_name_or_path,
list(cls.vocab_files_names.values()))) ", ".join(s3_models),
pretrained_model_name_or_path,
list(cls.vocab_files_names.values()),
)
)
# Get files from url, cache, or disk depending on the case # Get files from url, cache, or disk depending on the case
try: try:
...@@ -377,17 +393,27 @@ class PreTrainedTokenizer(object): ...@@ -377,17 +393,27 @@ class PreTrainedTokenizer(object):
if file_path is None: if file_path is None:
resolved_vocab_files[file_id] = None resolved_vocab_files[file_id] = None
else: else:
resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download) resolved_vocab_files[file_id] = cached_path(
file_path,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
)
except EnvironmentError: except EnvironmentError:
if pretrained_model_name_or_path in s3_models: if pretrained_model_name_or_path in s3_models:
msg = "Couldn't reach server at '{}' to download vocabulary files." msg = "Couldn't reach server at '{}' to download vocabulary files."
else: else:
msg = "Model name '{}' was not found in tokenizers model name list ({}). " \ msg = (
"We assumed '{}' was a path or url to a directory containing vocabulary files " \ "Model name '{}' was not found in tokenizers model name list ({}). "
"We assumed '{}' was a path or url to a directory containing vocabulary files "
"named {}, but couldn't find such vocabulary files at this path or url.".format( "named {}, but couldn't find such vocabulary files at this path or url.".format(
pretrained_model_name_or_path, ', '.join(s3_models),
pretrained_model_name_or_path, pretrained_model_name_or_path,
list(cls.vocab_files_names.values())) ", ".join(s3_models),
pretrained_model_name_or_path,
list(cls.vocab_files_names.values()),
)
)
raise EnvironmentError(msg) raise EnvironmentError(msg)
...@@ -395,16 +421,15 @@ class PreTrainedTokenizer(object): ...@@ -395,16 +421,15 @@ class PreTrainedTokenizer(object):
if file_path == resolved_vocab_files[file_id]: if file_path == resolved_vocab_files[file_id]:
logger.info("loading file {}".format(file_path)) logger.info("loading file {}".format(file_path))
else: else:
logger.info("loading file {} from cache at {}".format( logger.info("loading file {} from cache at {}".format(file_path, resolved_vocab_files[file_id]))
file_path, resolved_vocab_files[file_id]))
# Prepare tokenizer initialization kwargs # Prepare tokenizer initialization kwargs
# Did we saved some inputs and kwargs to reload ? # Did we saved some inputs and kwargs to reload ?
tokenizer_config_file = resolved_vocab_files.pop('tokenizer_config_file', None) tokenizer_config_file = resolved_vocab_files.pop("tokenizer_config_file", None)
if tokenizer_config_file is not None: if tokenizer_config_file is not None:
with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle: with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
init_kwargs = json.load(tokenizer_config_handle) init_kwargs = json.load(tokenizer_config_handle)
saved_init_inputs = init_kwargs.pop('init_inputs', ()) saved_init_inputs = init_kwargs.pop("init_inputs", ())
if not init_inputs: if not init_inputs:
init_inputs = saved_init_inputs init_inputs = saved_init_inputs
else: else:
...@@ -419,11 +444,11 @@ class PreTrainedTokenizer(object): ...@@ -419,11 +444,11 @@ class PreTrainedTokenizer(object):
# wont index sequences longer than the number of positional embeddings # wont index sequences longer than the number of positional embeddings
max_len = cls.max_model_input_sizes[pretrained_model_name_or_path] max_len = cls.max_model_input_sizes[pretrained_model_name_or_path]
if max_len is not None and isinstance(max_len, (int, float)): if max_len is not None and isinstance(max_len, (int, float)):
init_kwargs['max_len'] = min(init_kwargs.get('max_len', int(1e12)), max_len) init_kwargs["max_len"] = min(init_kwargs.get("max_len", int(1e12)), max_len)
# Merge resolved_vocab_files arguments in init_kwargs. # Merge resolved_vocab_files arguments in init_kwargs.
added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None) added_tokens_file = resolved_vocab_files.pop("added_tokens_file", None)
special_tokens_map_file = resolved_vocab_files.pop('special_tokens_map_file', None) special_tokens_map_file = resolved_vocab_files.pop("special_tokens_map_file", None)
for args_name, file_path in resolved_vocab_files.items(): for args_name, file_path in resolved_vocab_files.items():
if args_name not in init_kwargs: if args_name not in init_kwargs:
init_kwargs[args_name] = file_path init_kwargs[args_name] = file_path
...@@ -438,8 +463,10 @@ class PreTrainedTokenizer(object): ...@@ -438,8 +463,10 @@ class PreTrainedTokenizer(object):
try: try:
tokenizer = cls(*init_inputs, **init_kwargs) tokenizer = cls(*init_inputs, **init_kwargs)
except OSError: except OSError:
OSError("Unable to load vocabulary from file. " OSError(
"Please check that the provided vocabulary is accessible and not corrupted.") "Unable to load vocabulary from file. "
"Please check that the provided vocabulary is accessible and not corrupted."
)
# Save inputs and kwargs for saving and re-loading with ``save_pretrained`` # Save inputs and kwargs for saving and re-loading with ``save_pretrained``
tokenizer.init_inputs = init_inputs tokenizer.init_inputs = init_inputs
...@@ -449,13 +476,12 @@ class PreTrainedTokenizer(object): ...@@ -449,13 +476,12 @@ class PreTrainedTokenizer(object):
if added_tokens_file is not None: if added_tokens_file is not None:
with open(added_tokens_file, encoding="utf-8") as added_tokens_handle: with open(added_tokens_file, encoding="utf-8") as added_tokens_handle:
added_tok_encoder = json.load(added_tokens_handle) added_tok_encoder = json.load(added_tokens_handle)
added_tok_decoder = {v:k for k, v in added_tok_encoder.items()} added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
tokenizer.added_tokens_encoder.update(added_tok_encoder) tokenizer.added_tokens_encoder.update(added_tok_encoder)
tokenizer.added_tokens_decoder.update(added_tok_decoder) tokenizer.added_tokens_decoder.update(added_tok_decoder)
return tokenizer return tokenizer
def save_pretrained(self, save_directory): def save_pretrained(self, save_directory):
""" Save the tokenizer vocabulary files together with: """ Save the tokenizer vocabulary files together with:
- added tokens, - added tokens,
...@@ -476,28 +502,27 @@ class PreTrainedTokenizer(object): ...@@ -476,28 +502,27 @@ class PreTrainedTokenizer(object):
tokenizer_config_file = os.path.join(save_directory, TOKENIZER_CONFIG_FILE) tokenizer_config_file = os.path.join(save_directory, TOKENIZER_CONFIG_FILE)
tokenizer_config = copy.deepcopy(self.init_kwargs) tokenizer_config = copy.deepcopy(self.init_kwargs)
tokenizer_config['init_inputs'] = copy.deepcopy(self.init_inputs) tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs)
for file_id in self.vocab_files_names.keys(): for file_id in self.vocab_files_names.keys():
tokenizer_config.pop(file_id, None) tokenizer_config.pop(file_id, None)
with open(tokenizer_config_file, 'w', encoding='utf-8') as f: with open(tokenizer_config_file, "w", encoding="utf-8") as f:
f.write(json.dumps(tokenizer_config, ensure_ascii=False)) f.write(json.dumps(tokenizer_config, ensure_ascii=False))
with open(special_tokens_map_file, 'w', encoding='utf-8') as f: with open(special_tokens_map_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.special_tokens_map, ensure_ascii=False)) f.write(json.dumps(self.special_tokens_map, ensure_ascii=False))
with open(added_tokens_file, 'w', encoding='utf-8') as f: with open(added_tokens_file, "w", encoding="utf-8") as f:
if self.added_tokens_encoder: if self.added_tokens_encoder:
out_str = json.dumps(self.added_tokens_encoder, ensure_ascii=False) out_str = json.dumps(self.added_tokens_encoder, ensure_ascii=False)
else: else:
out_str = u"{}" out_str = "{}"
f.write(out_str) f.write(out_str)
vocab_files = self.save_vocabulary(save_directory) vocab_files = self.save_vocabulary(save_directory)
return vocab_files + (special_tokens_map_file, added_tokens_file) return vocab_files + (special_tokens_map_file, added_tokens_file)
def save_vocabulary(self, save_directory): def save_vocabulary(self, save_directory):
""" Save the tokenizer vocabulary to a directory. This method does *NOT* save added tokens """ Save the tokenizer vocabulary to a directory. This method does *NOT* save added tokens
and special token mappings. and special token mappings.
...@@ -506,17 +531,14 @@ class PreTrainedTokenizer(object): ...@@ -506,17 +531,14 @@ class PreTrainedTokenizer(object):
""" """
raise NotImplementedError raise NotImplementedError
def vocab_size(self): def vocab_size(self):
""" Size of the base vocabulary (without the added tokens) """ """ Size of the base vocabulary (without the added tokens) """
raise NotImplementedError raise NotImplementedError
def __len__(self): def __len__(self):
""" Size of the full vocabulary with the added tokens """ """ Size of the full vocabulary with the added tokens """
return self.vocab_size + len(self.added_tokens_encoder) return self.vocab_size + len(self.added_tokens_encoder)
def add_tokens(self, new_tokens): def add_tokens(self, new_tokens):
""" """
Add a list of new tokens to the tokenizer class. If the new tokens are not in the Add a list of new tokens to the tokenizer class. If the new tokens are not in the
...@@ -544,16 +566,18 @@ class PreTrainedTokenizer(object): ...@@ -544,16 +566,18 @@ class PreTrainedTokenizer(object):
to_add_tokens = [] to_add_tokens = []
for token in new_tokens: for token in new_tokens:
assert isinstance(token, str) or (six.PY2 and isinstance(token, unicode)) assert isinstance(token, str) or (six.PY2 and isinstance(token, unicode))
if self.init_kwargs.get('do_lower_case', False) and token not in self.all_special_tokens: if self.init_kwargs.get("do_lower_case", False) and token not in self.all_special_tokens:
token = token.lower() token = token.lower()
if token != self.unk_token and \ if (
self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token) and \ token != self.unk_token
token not in to_add_tokens: and self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token)
and token not in to_add_tokens
):
to_add_tokens.append(token) to_add_tokens.append(token)
logger.info("Adding %s to the vocabulary", token) logger.info("Adding %s to the vocabulary", token)
added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(to_add_tokens)) added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(to_add_tokens))
added_tok_decoder = {v:k for k, v in added_tok_encoder.items()} added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
self.added_tokens_encoder.update(added_tok_encoder) self.added_tokens_encoder.update(added_tok_encoder)
self.unique_added_tokens_encoder = set(self.added_tokens_encoder.keys()).union(set(self.all_special_tokens)) self.unique_added_tokens_encoder = set(self.added_tokens_encoder.keys()).union(set(self.all_special_tokens))
self.added_tokens_decoder.update(added_tok_decoder) self.added_tokens_decoder.update(added_tok_decoder)
...@@ -622,8 +646,10 @@ class PreTrainedTokenizer(object): ...@@ -622,8 +646,10 @@ class PreTrainedTokenizer(object):
added_tokens = 0 added_tokens = 0
for key, value in special_tokens_dict.items(): for key, value in special_tokens_dict.items():
assert key in self.SPECIAL_TOKENS_ATTRIBUTES assert key in self.SPECIAL_TOKENS_ATTRIBUTES
if key == 'additional_special_tokens': if key == "additional_special_tokens":
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value) assert isinstance(value, (list, tuple)) and all(
isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value
)
added_tokens += self.add_tokens(value) added_tokens += self.add_tokens(value)
else: else:
assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode)) assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode))
...@@ -633,7 +659,6 @@ class PreTrainedTokenizer(object): ...@@ -633,7 +659,6 @@ class PreTrainedTokenizer(object):
return added_tokens return added_tokens
def tokenize(self, text, **kwargs): def tokenize(self, text, **kwargs):
""" Converts a string in a sequence of tokens (string), using the tokenizer. """ Converts a string in a sequence of tokens (string), using the tokenizer.
Split in words for word-based vocabulary or sub-words for sub-word-based Split in words for word-based vocabulary or sub-words for sub-word-based
...@@ -649,14 +674,10 @@ class PreTrainedTokenizer(object): ...@@ -649,14 +674,10 @@ class PreTrainedTokenizer(object):
def lowercase_text(t): def lowercase_text(t):
# convert non-special tokens to lowercase # convert non-special tokens to lowercase
escaped_special_toks = [re.escape(s_tok) for s_tok in all_special_tokens] escaped_special_toks = [re.escape(s_tok) for s_tok in all_special_tokens]
pattern = r'(' + r'|'.join(escaped_special_toks) + r')|' + \ pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
r'(.+?)' return re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), t)
return re.sub(
pattern, if self.init_kwargs.get("do_lower_case", False):
lambda m: m.groups()[0] or m.groups()[1].lower(),
t)
if self.init_kwargs.get('do_lower_case', False):
text = lowercase_text(text) text = lowercase_text(text)
def split_on_token(tok, text): def split_on_token(tok, text):
...@@ -694,9 +715,14 @@ class PreTrainedTokenizer(object): ...@@ -694,9 +715,14 @@ class PreTrainedTokenizer(object):
tokenized_text += [sub_text] tokenized_text += [sub_text]
text_list = tokenized_text text_list = tokenized_text
return list(itertools.chain.from_iterable((self._tokenize(token, **kwargs) \ return list(
if token not in self.unique_added_tokens_encoder itertools.chain.from_iterable(
else [token] for token in tokenized_text))) (
self._tokenize(token, **kwargs) if token not in self.unique_added_tokens_encoder else [token]
for token in tokenized_text
)
)
)
added_tokens = self.unique_added_tokens_encoder added_tokens = self.unique_added_tokens_encoder
tokenized_text = split_on_tokens(added_tokens, text) tokenized_text = split_on_tokens(added_tokens, text)
...@@ -737,16 +763,18 @@ class PreTrainedTokenizer(object): ...@@ -737,16 +763,18 @@ class PreTrainedTokenizer(object):
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
raise NotImplementedError raise NotImplementedError
def encode(self, def encode(
text, self,
text_pair=None, text,
add_special_tokens=True, text_pair=None,
max_length=None, add_special_tokens=True,
stride=0, max_length=None,
truncation_strategy='longest_first', stride=0,
pad_to_max_length=False, truncation_strategy="longest_first",
return_tensors=None, pad_to_max_length=False,
**kwargs): return_tensors=None,
**kwargs
):
""" """
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary. Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
...@@ -781,32 +809,36 @@ class PreTrainedTokenizer(object): ...@@ -781,32 +809,36 @@ class PreTrainedTokenizer(object):
or PyTorch torch.Tensor instead of a list of python integers. or PyTorch torch.Tensor instead of a list of python integers.
**kwargs: passed to the `self.tokenize()` method **kwargs: passed to the `self.tokenize()` method
""" """
encoded_inputs = self.encode_plus(text, encoded_inputs = self.encode_plus(
text_pair=text_pair, text,
max_length=max_length, text_pair=text_pair,
add_special_tokens=add_special_tokens, max_length=max_length,
stride=stride, add_special_tokens=add_special_tokens,
truncation_strategy=truncation_strategy, stride=stride,
pad_to_max_length=pad_to_max_length, truncation_strategy=truncation_strategy,
return_tensors=return_tensors, pad_to_max_length=pad_to_max_length,
**kwargs) return_tensors=return_tensors,
**kwargs
)
return encoded_inputs["input_ids"] return encoded_inputs["input_ids"]
def encode_plus(self, def encode_plus(
text, self,
text_pair=None, text,
add_special_tokens=True, text_pair=None,
max_length=None, add_special_tokens=True,
stride=0, max_length=None,
truncation_strategy='longest_first', stride=0,
pad_to_max_length=False, truncation_strategy="longest_first",
return_tensors=None, pad_to_max_length=False,
return_token_type_ids=True, return_tensors=None,
return_attention_mask=True, return_token_type_ids=True,
return_overflowing_tokens=False, return_attention_mask=True,
return_special_tokens_mask=False, return_overflowing_tokens=False,
**kwargs): return_special_tokens_mask=False,
**kwargs
):
""" """
Returns a dictionary containing the encoded sequence or sequence pair and additional informations: Returns a dictionary containing the encoded sequence or sequence pair and additional informations:
the mask for sequence classification and the overflowing elements if a ``max_length`` is specified. the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
...@@ -874,34 +906,40 @@ class PreTrainedTokenizer(object): ...@@ -874,34 +906,40 @@ class PreTrainedTokenizer(object):
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int): elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
return text return text
else: else:
raise ValueError("Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers.") raise ValueError(
"Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
)
first_ids = get_input_ids(text) first_ids = get_input_ids(text)
second_ids = get_input_ids(text_pair) if text_pair is not None else None second_ids = get_input_ids(text_pair) if text_pair is not None else None
return self.prepare_for_model(first_ids, return self.prepare_for_model(
pair_ids=second_ids, first_ids,
max_length=max_length, pair_ids=second_ids,
pad_to_max_length=pad_to_max_length, max_length=max_length,
add_special_tokens=add_special_tokens, pad_to_max_length=pad_to_max_length,
stride=stride, add_special_tokens=add_special_tokens,
truncation_strategy=truncation_strategy, stride=stride,
return_tensors=return_tensors, truncation_strategy=truncation_strategy,
return_attention_mask=return_attention_mask, return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids, return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens, return_token_type_ids=return_token_type_ids,
return_special_tokens_mask=return_special_tokens_mask) return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
def batch_encode_plus(self, )
batch_text_or_text_pairs=None,
add_special_tokens=False, def batch_encode_plus(
max_length=None, self,
stride=0, batch_text_or_text_pairs=None,
truncation_strategy='longest_first', add_special_tokens=False,
return_tensors=None, max_length=None,
return_input_lengths=False, stride=0,
return_attention_masks=False, truncation_strategy="longest_first",
**kwargs): return_tensors=None,
return_input_lengths=False,
return_attention_masks=False,
**kwargs
):
""" """
Returns a dictionary containing the encoded sequence or sequence pair and additional information: Returns a dictionary containing the encoded sequence or sequence pair and additional information:
the mask for sequence classification and the overflowing elements if a ``max_length`` is specified. the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
...@@ -933,12 +971,19 @@ class PreTrainedTokenizer(object): ...@@ -933,12 +971,19 @@ class PreTrainedTokenizer(object):
ids, pair_ids = ids_or_pair_ids ids, pair_ids = ids_or_pair_ids
else: else:
ids, pair_ids = ids_or_pair_ids, None ids, pair_ids = ids_or_pair_ids, None
outputs = self.encode_plus(ids, pair_ids, add_special_tokens=add_special_tokens, max_length=max_length, outputs = self.encode_plus(
stride=stride, truncation_strategy=truncation_strategy, return_tensors=None) ids,
pair_ids,
add_special_tokens=add_special_tokens,
max_length=max_length,
stride=stride,
truncation_strategy=truncation_strategy,
return_tensors=None,
)
# Append the non-padded length to the output # Append the non-padded length to the output
if return_input_lengths: if return_input_lengths:
outputs['input_len'] = len(outputs['input_ids']) outputs["input_len"] = len(outputs["input_ids"])
for key, value in outputs.items(): for key, value in outputs.items():
if key not in batch_outputs: if key not in batch_outputs:
...@@ -946,11 +991,11 @@ class PreTrainedTokenizer(object): ...@@ -946,11 +991,11 @@ class PreTrainedTokenizer(object):
batch_outputs[key].append(value) batch_outputs[key].append(value)
# Compute longest sequence size # Compute longest sequence size
max_seq_len = max(map(len, batch_outputs['input_ids'])) max_seq_len = max(map(len, batch_outputs["input_ids"]))
if return_attention_masks: if return_attention_masks:
# Allow the model to not give any special attention to padded input # Allow the model to not give any special attention to padded input
batch_outputs['attention_mask'] = [[0] * len(v) for v in batch_outputs['input_ids']] batch_outputs["attention_mask"] = [[0] * len(v) for v in batch_outputs["input_ids"]]
if return_tensors is not None: if return_tensors is not None:
...@@ -958,34 +1003,48 @@ class PreTrainedTokenizer(object): ...@@ -958,34 +1003,48 @@ class PreTrainedTokenizer(object):
for key, value in batch_outputs.items(): for key, value in batch_outputs.items():
padded_value = value padded_value = value
if key != 'input_len': if key != "input_len":
# Padding handle # Padding handle
padded_value = [v + [self.pad_token_id if key == 'input_ids' else 1] * (max_seq_len - len(v)) for v in padded_value] padded_value = [
v + [self.pad_token_id if key == "input_ids" else 1] * (max_seq_len - len(v))
for v in padded_value
]
if return_tensors == 'tf' and is_tf_available(): if return_tensors == "tf" and is_tf_available():
batch_outputs[key] = tf.constant(padded_value) batch_outputs[key] = tf.constant(padded_value)
elif return_tensors == 'pt' and is_torch_available(): elif return_tensors == "pt" and is_torch_available():
batch_outputs[key] = torch.tensor(padded_value) batch_outputs[key] = torch.tensor(padded_value)
elif return_tensors is not None: elif return_tensors is not None:
logger.warning("Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(return_tensors)) logger.warning(
"Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(
return_tensors
)
)
# encoder_attention_mask requires 1 for real token, 0 for padding, just invert value # encoder_attention_mask requires 1 for real token, 0 for padding, just invert value
if return_attention_masks: if return_attention_masks:
if is_tf_available(): if is_tf_available():
batch_outputs['attention_mask'] = tf.abs(batch_outputs['attention_mask'] - 1) batch_outputs["attention_mask"] = tf.abs(batch_outputs["attention_mask"] - 1)
else: else:
batch_outputs['attention_mask'] = torch.abs(batch_outputs['attention_mask'] - 1) batch_outputs["attention_mask"] = torch.abs(batch_outputs["attention_mask"] - 1)
return batch_outputs return batch_outputs
def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=True, stride=0, def prepare_for_model(
truncation_strategy='longest_first', self,
pad_to_max_length=False, ids,
return_tensors=None, pair_ids=None,
return_token_type_ids=True, max_length=None,
return_attention_mask=True, add_special_tokens=True,
return_overflowing_tokens=False, stride=0,
return_special_tokens_mask=False): truncation_strategy="longest_first",
pad_to_max_length=False,
return_tensors=None,
return_token_type_ids=True,
return_attention_mask=True,
return_overflowing_tokens=False,
return_special_tokens_mask=False,
):
""" """
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
It adds special tokens, truncates It adds special tokens, truncates
...@@ -1050,10 +1109,13 @@ class PreTrainedTokenizer(object): ...@@ -1050,10 +1109,13 @@ class PreTrainedTokenizer(object):
# Handle max sequence length # Handle max sequence length
total_len = len_ids + len_pair_ids + (self.num_added_tokens(pair=pair) if add_special_tokens else 0) total_len = len_ids + len_pair_ids + (self.num_added_tokens(pair=pair) if add_special_tokens else 0)
if max_length and total_len > max_length: if max_length and total_len > max_length:
ids, pair_ids, overflowing_tokens = self.truncate_sequences(ids, pair_ids=pair_ids, ids, pair_ids, overflowing_tokens = self.truncate_sequences(
num_tokens_to_remove=total_len-max_length, ids,
truncation_strategy=truncation_strategy, pair_ids=pair_ids,
stride=stride) num_tokens_to_remove=total_len - max_length,
truncation_strategy=truncation_strategy,
stride=stride,
)
if return_overflowing_tokens: if return_overflowing_tokens:
encoded_inputs["overflowing_tokens"] = overflowing_tokens encoded_inputs["overflowing_tokens"] = overflowing_tokens
encoded_inputs["num_truncated_tokens"] = total_len - max_length encoded_inputs["num_truncated_tokens"] = total_len - max_length
...@@ -1081,54 +1143,64 @@ class PreTrainedTokenizer(object): ...@@ -1081,54 +1143,64 @@ class PreTrainedTokenizer(object):
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"][:max_length] encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"][:max_length]
if max_length is None and len(encoded_inputs["input_ids"]) > self.max_len: if max_length is None and len(encoded_inputs["input_ids"]) > self.max_len:
logger.warning("Token indices sequence length is longer than the specified maximum sequence length " logger.warning(
"for this model ({} > {}). Running this sequence through the model will result in " "Token indices sequence length is longer than the specified maximum sequence length "
"indexing errors".format(len(ids), self.max_len)) "for this model ({} > {}). Running this sequence through the model will result in "
"indexing errors".format(len(ids), self.max_len)
)
needs_to_be_padded = pad_to_max_length and ( needs_to_be_padded = pad_to_max_length and (
max_length and len(encoded_inputs["input_ids"]) < max_length max_length
or and len(encoded_inputs["input_ids"]) < max_length
max_length is None and len(encoded_inputs["input_ids"]) < self.max_len and self.max_len <= 10000 or max_length is None
and len(encoded_inputs["input_ids"]) < self.max_len
and self.max_len <= 10000
) )
if pad_to_max_length and max_length is None and self.max_len > 10000: if pad_to_max_length and max_length is None and self.max_len > 10000:
logger.warning("Sequence can't be padded as no maximum length is specified and the model maximum length is too high.") logger.warning(
"Sequence can't be padded as no maximum length is specified and the model maximum length is too high."
)
if needs_to_be_padded: if needs_to_be_padded:
difference = (max_length if max_length is not None else self.max_len) - len(encoded_inputs["input_ids"]) difference = (max_length if max_length is not None else self.max_len) - len(encoded_inputs["input_ids"])
if self.padding_side == 'right': if self.padding_side == "right":
if return_attention_mask: if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) + [0] * difference encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) + [0] * difference
if return_token_type_ids: if return_token_type_ids:
encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference encoded_inputs["token_type_ids"] = (
encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
)
if return_special_tokens_mask: if return_special_tokens_mask:
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
encoded_inputs["input_ids"] = encoded_inputs["input_ids"] + [self.pad_token_id] * difference encoded_inputs["input_ids"] = encoded_inputs["input_ids"] + [self.pad_token_id] * difference
elif self.padding_side == 'left': elif self.padding_side == "left":
if return_attention_mask: if return_attention_mask:
encoded_inputs["attention_mask"] = [0] * difference + [1] * len(encoded_inputs["input_ids"]) encoded_inputs["attention_mask"] = [0] * difference + [1] * len(encoded_inputs["input_ids"])
if return_token_type_ids: if return_token_type_ids:
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs["token_type_ids"] encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
"token_type_ids"
]
if return_special_tokens_mask: if return_special_tokens_mask:
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
encoded_inputs["input_ids"] = [self.pad_token_id] * difference + encoded_inputs["input_ids"] encoded_inputs["input_ids"] = [self.pad_token_id] * difference + encoded_inputs["input_ids"]
else: else:
raise ValueError("Invalid padding strategy:" + str(self.padding_side)) raise ValueError("Invalid padding strategy:" + str(self.padding_side))
elif return_attention_mask: elif return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"])
# Prepare inputs as tensors if asked # Prepare inputs as tensors if asked
if return_tensors == 'tf' and is_tf_available(): if return_tensors == "tf" and is_tf_available():
encoded_inputs["input_ids"] = tf.constant([encoded_inputs["input_ids"]]) encoded_inputs["input_ids"] = tf.constant([encoded_inputs["input_ids"]])
encoded_inputs["token_type_ids"] = tf.constant([encoded_inputs["token_type_ids"]]) encoded_inputs["token_type_ids"] = tf.constant([encoded_inputs["token_type_ids"]])
if "attention_mask" in encoded_inputs: if "attention_mask" in encoded_inputs:
encoded_inputs["attention_mask"] = tf.constant([encoded_inputs["attention_mask"]]) encoded_inputs["attention_mask"] = tf.constant([encoded_inputs["attention_mask"]])
elif return_tensors == 'pt' and is_torch_available(): elif return_tensors == "pt" and is_torch_available():
encoded_inputs["input_ids"] = torch.tensor([encoded_inputs["input_ids"]]) encoded_inputs["input_ids"] = torch.tensor([encoded_inputs["input_ids"]])
encoded_inputs["token_type_ids"] = torch.tensor([encoded_inputs["token_type_ids"]]) encoded_inputs["token_type_ids"] = torch.tensor([encoded_inputs["token_type_ids"]])
...@@ -1137,11 +1209,15 @@ class PreTrainedTokenizer(object): ...@@ -1137,11 +1209,15 @@ class PreTrainedTokenizer(object):
elif return_tensors is not None: elif return_tensors is not None:
logger.warning( logger.warning(
"Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format( "Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(
return_tensors)) return_tensors
)
)
return encoded_inputs return encoded_inputs
def truncate_sequences(self, ids, pair_ids=None, num_tokens_to_remove=0, truncation_strategy='longest_first', stride=0): def truncate_sequences(
self, ids, pair_ids=None, num_tokens_to_remove=0, truncation_strategy="longest_first", stride=0
):
"""Truncates a sequence pair in place to the maximum length. """Truncates a sequence pair in place to the maximum length.
truncation_strategy: string selected in the following options: truncation_strategy: string selected in the following options:
- 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length
...@@ -1154,7 +1230,7 @@ class PreTrainedTokenizer(object): ...@@ -1154,7 +1230,7 @@ class PreTrainedTokenizer(object):
if num_tokens_to_remove <= 0: if num_tokens_to_remove <= 0:
return ids, pair_ids, [] return ids, pair_ids, []
if truncation_strategy == 'longest_first': if truncation_strategy == "longest_first":
overflowing_tokens = [] overflowing_tokens = []
for _ in range(num_tokens_to_remove): for _ in range(num_tokens_to_remove):
if pair_ids is None or len(ids) > len(pair_ids): if pair_ids is None or len(ids) > len(pair_ids):
...@@ -1165,20 +1241,22 @@ class PreTrainedTokenizer(object): ...@@ -1165,20 +1241,22 @@ class PreTrainedTokenizer(object):
window_len = min(len(ids), stride) window_len = min(len(ids), stride)
if window_len > 0: if window_len > 0:
overflowing_tokens = ids[-window_len:] + overflowing_tokens overflowing_tokens = ids[-window_len:] + overflowing_tokens
elif truncation_strategy == 'only_first': elif truncation_strategy == "only_first":
assert len(ids) > num_tokens_to_remove assert len(ids) > num_tokens_to_remove
window_len = min(len(ids), stride + num_tokens_to_remove) window_len = min(len(ids), stride + num_tokens_to_remove)
overflowing_tokens = ids[-window_len:] overflowing_tokens = ids[-window_len:]
ids = ids[:-num_tokens_to_remove] ids = ids[:-num_tokens_to_remove]
elif truncation_strategy == 'only_second': elif truncation_strategy == "only_second":
assert pair_ids is not None and len(pair_ids) > num_tokens_to_remove assert pair_ids is not None and len(pair_ids) > num_tokens_to_remove
window_len = min(len(pair_ids), stride + num_tokens_to_remove) window_len = min(len(pair_ids), stride + num_tokens_to_remove)
overflowing_tokens = pair_ids[-window_len:] overflowing_tokens = pair_ids[-window_len:]
pair_ids = pair_ids[:-num_tokens_to_remove] pair_ids = pair_ids[:-num_tokens_to_remove]
elif truncation_strategy == 'do_not_truncate': elif truncation_strategy == "do_not_truncate":
raise ValueError("Input sequence are too long for max_length. Please select a truncation strategy.") raise ValueError("Input sequence are too long for max_length. Please select a truncation strategy.")
else: else:
raise ValueError("Truncation_strategy should be selected in ['longest_first', 'only_first', 'only_second', 'do_not_truncate']") raise ValueError(
"Truncation_strategy should be selected in ['longest_first', 'only_first', 'only_second', 'do_not_truncate']"
)
return (ids, pair_ids, overflowing_tokens) return (ids, pair_ids, overflowing_tokens)
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
...@@ -1246,7 +1324,7 @@ class PreTrainedTokenizer(object): ...@@ -1246,7 +1324,7 @@ class PreTrainedTokenizer(object):
The most simple way to do it is ' '.join(self.convert_ids_to_tokens(token_ids)) The most simple way to do it is ' '.join(self.convert_ids_to_tokens(token_ids))
but we often want to remove sub-word tokenization artifacts at the same time. but we often want to remove sub-word tokenization artifacts at the same time.
""" """
return ' '.join(self.convert_ids_to_tokens(tokens)) return " ".join(self.convert_ids_to_tokens(tokens))
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True): def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
""" """
...@@ -1278,7 +1356,7 @@ class PreTrainedTokenizer(object): ...@@ -1278,7 +1356,7 @@ class PreTrainedTokenizer(object):
current_sub_text.append(token) current_sub_text.append(token)
if current_sub_text: if current_sub_text:
sub_texts.append(self.convert_tokens_to_string(current_sub_text)) sub_texts.append(self.convert_tokens_to_string(current_sub_text))
text = ' '.join(sub_texts) text = " ".join(sub_texts)
if clean_up_tokenization_spaces: if clean_up_tokenization_spaces:
clean_text = self.clean_up_tokenization(text) clean_text = self.clean_up_tokenization(text)
...@@ -1323,7 +1401,17 @@ class PreTrainedTokenizer(object): ...@@ -1323,7 +1401,17 @@ class PreTrainedTokenizer(object):
def clean_up_tokenization(out_string): def clean_up_tokenization(out_string):
""" Clean up a list of simple English tokenization artifacts like spaces before punctuations and abreviated forms. """ Clean up a list of simple English tokenization artifacts like spaces before punctuations and abreviated forms.
""" """
out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',' out_string = (
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't" out_string.replace(" .", ".")
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re") .replace(" ?", "?")
.replace(" !", "!")
.replace(" ,", ",")
.replace(" ' ", "'")
.replace(" n't", "n't")
.replace(" 'm", "'m")
.replace(" do not", " don't")
.replace(" 's", "'s")
.replace(" 've", "'ve")
.replace(" 're", "'re")
)
return out_string return out_string
...@@ -13,8 +13,7 @@ ...@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tokenization classes for XLM.""" """Tokenization classes for XLM."""
from __future__ import (absolute_import, division, print_function, from __future__ import absolute_import, division, print_function, unicode_literals
unicode_literals)
import json import json
import logging import logging
...@@ -32,386 +31,402 @@ from .tokenization_bert import BasicTokenizer ...@@ -32,386 +31,402 @@ from .tokenization_bert import BasicTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = { VOCAB_FILES_NAMES = {
'vocab_file': 'vocab.json', "vocab_file": "vocab.json",
'merges_file': 'merges.txt', "merges_file": "merges.txt",
} }
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file': "vocab_file": {
{ "xlm-mlm-en-2048": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-vocab.json",
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-vocab.json", "xlm-mlm-ende-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-vocab.json",
'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-vocab.json", "xlm-mlm-enfr-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-vocab.json",
'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-vocab.json", "xlm-mlm-enro-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-vocab.json",
'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-vocab.json", "xlm-mlm-tlm-xnli15-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-vocab.json",
'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-vocab.json", "xlm-mlm-xnli15-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-vocab.json",
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-vocab.json", "xlm-clm-enfr-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-vocab.json",
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-vocab.json", "xlm-clm-ende-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-vocab.json",
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-vocab.json", "xlm-mlm-17-1280": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-vocab.json",
'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-vocab.json", "xlm-mlm-100-1280": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-vocab.json",
'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-vocab.json",
}, },
'merges_file': "merges_file": {
{ "xlm-mlm-en-2048": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt",
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt", "xlm-mlm-ende-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-merges.txt",
'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-merges.txt", "xlm-mlm-enfr-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-merges.txt",
'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-merges.txt", "xlm-mlm-enro-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-merges.txt",
'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-merges.txt", "xlm-mlm-tlm-xnli15-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-merges.txt",
'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-merges.txt", "xlm-mlm-xnli15-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-merges.txt",
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-merges.txt", "xlm-clm-enfr-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-merges.txt",
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-merges.txt", "xlm-clm-ende-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-merges.txt",
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-merges.txt", "xlm-mlm-17-1280": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-merges.txt",
'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-merges.txt", "xlm-mlm-100-1280": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-merges.txt",
'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-merges.txt",
}, },
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'xlm-mlm-en-2048': 512, "xlm-mlm-en-2048": 512,
'xlm-mlm-ende-1024': 512, "xlm-mlm-ende-1024": 512,
'xlm-mlm-enfr-1024': 512, "xlm-mlm-enfr-1024": 512,
'xlm-mlm-enro-1024': 512, "xlm-mlm-enro-1024": 512,
'xlm-mlm-tlm-xnli15-1024': 512, "xlm-mlm-tlm-xnli15-1024": 512,
'xlm-mlm-xnli15-1024': 512, "xlm-mlm-xnli15-1024": 512,
'xlm-clm-enfr-1024': 512, "xlm-clm-enfr-1024": 512,
'xlm-clm-ende-1024': 512, "xlm-clm-ende-1024": 512,
'xlm-mlm-17-1280': 512, "xlm-mlm-17-1280": 512,
'xlm-mlm-100-1280': 512, "xlm-mlm-100-1280": 512,
} }
PRETRAINED_INIT_CONFIGURATION = { PRETRAINED_INIT_CONFIGURATION = {
'xlm-mlm-en-2048': {"do_lowercase_and_remove_accent": True}, "xlm-mlm-en-2048": {"do_lowercase_and_remove_accent": True},
'xlm-mlm-ende-1024': { "do_lowercase_and_remove_accent": True, "xlm-mlm-ende-1024": {
"id2lang": { "0": "de", "do_lowercase_and_remove_accent": True,
"1": "en"}, "id2lang": {"0": "de", "1": "en"},
"lang2id": { "de": 0, "lang2id": {"de": 0, "en": 1},
"en": 1 }}, },
'xlm-mlm-enfr-1024': { "do_lowercase_and_remove_accent": True, "xlm-mlm-enfr-1024": {
"id2lang": { "0": "en", "do_lowercase_and_remove_accent": True,
"1": "fr"}, "id2lang": {"0": "en", "1": "fr"},
"lang2id": { "en": 0, "lang2id": {"en": 0, "fr": 1},
"fr": 1 }}, },
'xlm-mlm-enro-1024': { "do_lowercase_and_remove_accent": True, "xlm-mlm-enro-1024": {
"id2lang": { "0": "en", "do_lowercase_and_remove_accent": True,
"1": "ro"}, "id2lang": {"0": "en", "1": "ro"},
"lang2id": { "en": 0, "lang2id": {"en": 0, "ro": 1},
"ro": 1 }}, },
'xlm-mlm-tlm-xnli15-1024': { "do_lowercase_and_remove_accent": True, "xlm-mlm-tlm-xnli15-1024": {
"id2lang": { "0": "ar", "do_lowercase_and_remove_accent": True,
"1": "bg", "id2lang": {
"2": "de", "0": "ar",
"3": "el", "1": "bg",
"4": "en", "2": "de",
"5": "es", "3": "el",
"6": "fr", "4": "en",
"7": "hi", "5": "es",
"8": "ru", "6": "fr",
"9": "sw", "7": "hi",
"10": "th", "8": "ru",
"11": "tr", "9": "sw",
"12": "ur", "10": "th",
"13": "vi", "11": "tr",
"14": "zh"}, "12": "ur",
"lang2id": { "ar": 0, "13": "vi",
"bg": 1, "14": "zh",
"de": 2, },
"el": 3, "lang2id": {
"en": 4, "ar": 0,
"es": 5, "bg": 1,
"fr": 6, "de": 2,
"hi": 7, "el": 3,
"ru": 8, "en": 4,
"sw": 9, "es": 5,
"th": 10, "fr": 6,
"tr": 11, "hi": 7,
"ur": 12, "ru": 8,
"vi": 13, "sw": 9,
"zh": 14 }}, "th": 10,
'xlm-mlm-xnli15-1024': { "do_lowercase_and_remove_accent": True, "tr": 11,
"id2lang": { "0": "ar", "ur": 12,
"1": "bg", "vi": 13,
"2": "de", "zh": 14,
"3": "el", },
"4": "en", },
"5": "es", "xlm-mlm-xnli15-1024": {
"6": "fr", "do_lowercase_and_remove_accent": True,
"7": "hi", "id2lang": {
"8": "ru", "0": "ar",
"9": "sw", "1": "bg",
"10": "th", "2": "de",
"11": "tr", "3": "el",
"12": "ur", "4": "en",
"13": "vi", "5": "es",
"14": "zh"}, "6": "fr",
"lang2id": { "ar": 0, "7": "hi",
"bg": 1, "8": "ru",
"de": 2, "9": "sw",
"el": 3, "10": "th",
"en": 4, "11": "tr",
"es": 5, "12": "ur",
"fr": 6, "13": "vi",
"hi": 7, "14": "zh",
"ru": 8, },
"sw": 9, "lang2id": {
"th": 10, "ar": 0,
"tr": 11, "bg": 1,
"ur": 12, "de": 2,
"vi": 13, "el": 3,
"zh": 14 }}, "en": 4,
'xlm-clm-enfr-1024': { "do_lowercase_and_remove_accent": True, "es": 5,
"id2lang": { "0": "en", "fr": 6,
"1": "fr"}, "hi": 7,
"lang2id": { "en": 0, "ru": 8,
"fr": 1 }}, "sw": 9,
'xlm-clm-ende-1024': { "do_lowercase_and_remove_accent": True, "th": 10,
"id2lang": { "0": "de", "tr": 11,
"1": "en"}, "ur": 12,
"lang2id": { "de": 0, "vi": 13,
"en": 1 }}, "zh": 14,
'xlm-mlm-17-1280': {"do_lowercase_and_remove_accent": False, },
"id2lang": { },
"0": "ar", "xlm-clm-enfr-1024": {
"1": "de", "do_lowercase_and_remove_accent": True,
"2": "en", "id2lang": {"0": "en", "1": "fr"},
"3": "es", "lang2id": {"en": 0, "fr": 1},
"4": "fr", },
"5": "hi", "xlm-clm-ende-1024": {
"6": "it", "do_lowercase_and_remove_accent": True,
"7": "ja", "id2lang": {"0": "de", "1": "en"},
"8": "ko", "lang2id": {"de": 0, "en": 1},
"9": "nl", },
"10": "pl", "xlm-mlm-17-1280": {
"11": "pt", "do_lowercase_and_remove_accent": False,
"12": "ru", "id2lang": {
"13": "sv", "0": "ar",
"14": "tr", "1": "de",
"15": "vi", "2": "en",
"16": "zh" "3": "es",
}, "4": "fr",
"lang2id": { "5": "hi",
"ar": 0, "6": "it",
"de": 1, "7": "ja",
"en": 2, "8": "ko",
"es": 3, "9": "nl",
"fr": 4, "10": "pl",
"hi": 5, "11": "pt",
"it": 6, "12": "ru",
"ja": 7, "13": "sv",
"ko": 8, "14": "tr",
"nl": 9, "15": "vi",
"pl": 10, "16": "zh",
"pt": 11, },
"ru": 12, "lang2id": {
"sv": 13, "ar": 0,
"tr": 14, "de": 1,
"vi": 15, "en": 2,
"zh": 16}}, "es": 3,
'xlm-mlm-100-1280': {"do_lowercase_and_remove_accent": False, "fr": 4,
"id2lang": { "hi": 5,
"0": "af", "it": 6,
"1": "als", "ja": 7,
"2": "am", "ko": 8,
"3": "an", "nl": 9,
"4": "ang", "pl": 10,
"5": "ar", "pt": 11,
"6": "arz", "ru": 12,
"7": "ast", "sv": 13,
"8": "az", "tr": 14,
"9": "bar", "vi": 15,
"10": "be", "zh": 16,
"11": "bg", },
"12": "bn", },
"13": "br", "xlm-mlm-100-1280": {
"14": "bs", "do_lowercase_and_remove_accent": False,
"15": "ca", "id2lang": {
"16": "ceb", "0": "af",
"17": "ckb", "1": "als",
"18": "cs", "2": "am",
"19": "cy", "3": "an",
"20": "da", "4": "ang",
"21": "de", "5": "ar",
"22": "el", "6": "arz",
"23": "en", "7": "ast",
"24": "eo", "8": "az",
"25": "es", "9": "bar",
"26": "et", "10": "be",
"27": "eu", "11": "bg",
"28": "fa", "12": "bn",
"29": "fi", "13": "br",
"30": "fr", "14": "bs",
"31": "fy", "15": "ca",
"32": "ga", "16": "ceb",
"33": "gan", "17": "ckb",
"34": "gl", "18": "cs",
"35": "gu", "19": "cy",
"36": "he", "20": "da",
"37": "hi", "21": "de",
"38": "hr", "22": "el",
"39": "hu", "23": "en",
"40": "hy", "24": "eo",
"41": "ia", "25": "es",
"42": "id", "26": "et",
"43": "is", "27": "eu",
"44": "it", "28": "fa",
"45": "ja", "29": "fi",
"46": "jv", "30": "fr",
"47": "ka", "31": "fy",
"48": "kk", "32": "ga",
"49": "kn", "33": "gan",
"50": "ko", "34": "gl",
"51": "ku", "35": "gu",
"52": "la", "36": "he",
"53": "lb", "37": "hi",
"54": "lt", "38": "hr",
"55": "lv", "39": "hu",
"56": "mk", "40": "hy",
"57": "ml", "41": "ia",
"58": "mn", "42": "id",
"59": "mr", "43": "is",
"60": "ms", "44": "it",
"61": "my", "45": "ja",
"62": "nds", "46": "jv",
"63": "ne", "47": "ka",
"64": "nl", "48": "kk",
"65": "nn", "49": "kn",
"66": "no", "50": "ko",
"67": "oc", "51": "ku",
"68": "pl", "52": "la",
"69": "pt", "53": "lb",
"70": "ro", "54": "lt",
"71": "ru", "55": "lv",
"72": "scn", "56": "mk",
"73": "sco", "57": "ml",
"74": "sh", "58": "mn",
"75": "si", "59": "mr",
"76": "simple", "60": "ms",
"77": "sk", "61": "my",
"78": "sl", "62": "nds",
"79": "sq", "63": "ne",
"80": "sr", "64": "nl",
"81": "sv", "65": "nn",
"82": "sw", "66": "no",
"83": "ta", "67": "oc",
"84": "te", "68": "pl",
"85": "th", "69": "pt",
"86": "tl", "70": "ro",
"87": "tr", "71": "ru",
"88": "tt", "72": "scn",
"89": "uk", "73": "sco",
"90": "ur", "74": "sh",
"91": "uz", "75": "si",
"92": "vi", "76": "simple",
"93": "war", "77": "sk",
"94": "wuu", "78": "sl",
"95": "yi", "79": "sq",
"96": "zh", "80": "sr",
"97": "zh_classical", "81": "sv",
"98": "zh_min_nan", "82": "sw",
"99": "zh_yue" "83": "ta",
}, "84": "te",
"lang2id": { "85": "th",
"af": 0, "86": "tl",
"als": 1, "87": "tr",
"am": 2, "88": "tt",
"an": 3, "89": "uk",
"ang": 4, "90": "ur",
"ar": 5, "91": "uz",
"arz": 6, "92": "vi",
"ast": 7, "93": "war",
"az": 8, "94": "wuu",
"bar": 9, "95": "yi",
"be": 10, "96": "zh",
"bg": 11, "97": "zh_classical",
"bn": 12, "98": "zh_min_nan",
"br": 13, "99": "zh_yue",
"bs": 14, },
"ca": 15, "lang2id": {
"ceb": 16, "af": 0,
"ckb": 17, "als": 1,
"cs": 18, "am": 2,
"cy": 19, "an": 3,
"da": 20, "ang": 4,
"de": 21, "ar": 5,
"el": 22, "arz": 6,
"en": 23, "ast": 7,
"eo": 24, "az": 8,
"es": 25, "bar": 9,
"et": 26, "be": 10,
"eu": 27, "bg": 11,
"fa": 28, "bn": 12,
"fi": 29, "br": 13,
"fr": 30, "bs": 14,
"fy": 31, "ca": 15,
"ga": 32, "ceb": 16,
"gan": 33, "ckb": 17,
"gl": 34, "cs": 18,
"gu": 35, "cy": 19,
"he": 36, "da": 20,
"hi": 37, "de": 21,
"hr": 38, "el": 22,
"hu": 39, "en": 23,
"hy": 40, "eo": 24,
"ia": 41, "es": 25,
"id": 42, "et": 26,
"is": 43, "eu": 27,
"it": 44, "fa": 28,
"ja": 45, "fi": 29,
"jv": 46, "fr": 30,
"ka": 47, "fy": 31,
"kk": 48, "ga": 32,
"kn": 49, "gan": 33,
"ko": 50, "gl": 34,
"ku": 51, "gu": 35,
"la": 52, "he": 36,
"lb": 53, "hi": 37,
"lt": 54, "hr": 38,
"lv": 55, "hu": 39,
"mk": 56, "hy": 40,
"ml": 57, "ia": 41,
"mn": 58, "id": 42,
"mr": 59, "is": 43,
"ms": 60, "it": 44,
"my": 61, "ja": 45,
"nds": 62, "jv": 46,
"ne": 63, "ka": 47,
"nl": 64, "kk": 48,
"nn": 65, "kn": 49,
"no": 66, "ko": 50,
"oc": 67, "ku": 51,
"pl": 68, "la": 52,
"pt": 69, "lb": 53,
"ro": 70, "lt": 54,
"ru": 71, "lv": 55,
"scn": 72, "mk": 56,
"sco": 73, "ml": 57,
"sh": 74, "mn": 58,
"si": 75, "mr": 59,
"simple": 76, "ms": 60,
"sk": 77, "my": 61,
"sl": 78, "nds": 62,
"sq": 79, "ne": 63,
"sr": 80, "nl": 64,
"sv": 81, "nn": 65,
"sw": 82, "no": 66,
"ta": 83, "oc": 67,
"te": 84, "pl": 68,
"th": 85, "pt": 69,
"tl": 86, "ro": 70,
"tr": 87, "ru": 71,
"tt": 88, "scn": 72,
"uk": 89, "sco": 73,
"ur": 90, "sh": 74,
"uz": 91, "si": 75,
"vi": 92, "simple": 76,
"war": 93, "sk": 77,
"wuu": 94, "sl": 78,
"yi": 95, "sq": 79,
"zh": 96, "sr": 80,
"zh_classical": 97, "sv": 81,
"zh_min_nan": 98, "sw": 82,
"zh_yue": 99 "ta": 83,
}}, "te": 84,
"th": 85,
"tl": 86,
"tr": 87,
"tt": 88,
"uk": 89,
"ur": 90,
"uz": 91,
"vi": 92,
"war": 93,
"wuu": 94,
"yi": 95,
"zh": 96,
"zh_classical": 97,
"zh_min_nan": 98,
"zh_yue": 99,
},
},
} }
def get_pairs(word): def get_pairs(word):
""" """
Return set of symbol pairs in a word. Return set of symbol pairs in a word.
...@@ -430,7 +445,7 @@ def lowercase_and_remove_accent(text): ...@@ -430,7 +445,7 @@ def lowercase_and_remove_accent(text):
Lowercase and strips accents from a piece of text based on Lowercase and strips accents from a piece of text based on
https://github.com/facebookresearch/XLM/blob/master/tools/lowercase_and_remove_accent.py https://github.com/facebookresearch/XLM/blob/master/tools/lowercase_and_remove_accent.py
""" """
text = ' '.join(text) text = " ".join(text)
text = text.lower() text = text.lower()
text = unicodedata.normalize("NFD", text) text = unicodedata.normalize("NFD", text)
output = [] output = []
...@@ -439,73 +454,73 @@ def lowercase_and_remove_accent(text): ...@@ -439,73 +454,73 @@ def lowercase_and_remove_accent(text):
if cat == "Mn": if cat == "Mn":
continue continue
output.append(char) output.append(char)
return "".join(output).lower().split(' ') return "".join(output).lower().split(" ")
def replace_unicode_punct(text): def replace_unicode_punct(text):
''' """
Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/replace-unicode-punctuation.perl Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/replace-unicode-punctuation.perl
''' """
text = text.replace(',', ',') text = text.replace(",", ",")
text = re.sub(r'。\s*', '. ', text) text = re.sub(r"。\s*", ". ", text)
text = text.replace('、', ',') text = text.replace("、", ",")
text = text.replace('”', '"') text = text.replace("”", '"')
text = text.replace('“', '"') text = text.replace("“", '"')
text = text.replace('∶', ':') text = text.replace("∶", ":")
text = text.replace(':', ':') text = text.replace(":", ":")
text = text.replace('?', '?') text = text.replace("?", "?")
text = text.replace('《', '"') text = text.replace("《", '"')
text = text.replace('》', '"') text = text.replace("》", '"')
text = text.replace(')', ')') text = text.replace(")", ")")
text = text.replace('!', '!') text = text.replace("!", "!")
text = text.replace('(', '(') text = text.replace("(", "(")
text = text.replace(';', ';') text = text.replace(";", ";")
text = text.replace('1', '"') text = text.replace("1", '"')
text = text.replace('」', '"') text = text.replace("」", '"')
text = text.replace('「', '"') text = text.replace("「", '"')
text = text.replace('0', '0') text = text.replace("0", "0")
text = text.replace('3', '3') text = text.replace("3", "3")
text = text.replace('2', '2') text = text.replace("2", "2")
text = text.replace('5', '5') text = text.replace("5", "5")
text = text.replace('6', '6') text = text.replace("6", "6")
text = text.replace('9', '9') text = text.replace("9", "9")
text = text.replace('7', '7') text = text.replace("7", "7")
text = text.replace('8', '8') text = text.replace("8", "8")
text = text.replace('4', '4') text = text.replace("4", "4")
text = re.sub(r'.\s*', '. ', text) text = re.sub(r".\s*", ". ", text)
text = text.replace('~', '~') text = text.replace("~", "~")
text = text.replace('’', '\'') text = text.replace("’", "'")
text = text.replace('…', '...') text = text.replace("…", "...")
text = text.replace('━', '-') text = text.replace("━", "-")
text = text.replace('〈', '<') text = text.replace("〈", "<")
text = text.replace('〉', '>') text = text.replace("〉", ">")
text = text.replace('【', '[') text = text.replace("【", "[")
text = text.replace('】', ']') text = text.replace("】", "]")
text = text.replace('%', '%') text = text.replace("%", "%")
return text return text
def remove_non_printing_char(text): def remove_non_printing_char(text):
''' """
Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/remove-non-printing-char.perl Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/remove-non-printing-char.perl
''' """
output = [] output = []
for char in text: for char in text:
cat = unicodedata.category(char) cat = unicodedata.category(char)
if cat.startswith('C'): if cat.startswith("C"):
continue continue
output.append(char) output.append(char)
return "".join(output) return "".join(output)
def romanian_preprocessing(text): def romanian_preprocessing(text):
'''Sennrich's WMT16 scripts for Romanian preprocessing, used by model `xlm-mlm-enro-1024`''' """Sennrich's WMT16 scripts for Romanian preprocessing, used by model `xlm-mlm-enro-1024`"""
# https://github.com/rsennrich/wmt16-scripts/blob/master/preprocess/normalise-romanian.py # https://github.com/rsennrich/wmt16-scripts/blob/master/preprocess/normalise-romanian.py
text = text.replace("\u015e", "\u0218").replace("\u015f", "\u0219") text = text.replace("\u015e", "\u0218").replace("\u015f", "\u0219")
text = text.replace("\u0162", "\u021a").replace("\u0163", "\u021b") text = text.replace("\u0162", "\u021a").replace("\u0163", "\u021b")
# https://github.com/rsennrich/wmt16-scripts/blob/master/preprocess/remove-diacritics.py # https://github.com/rsennrich/wmt16-scripts/blob/master/preprocess/remove-diacritics.py
text = text.replace("\u0218", "S").replace("\u0219", "s") #s-comma text = text.replace("\u0218", "S").replace("\u0219", "s") # s-comma
text = text.replace("\u021a", "T").replace("\u021b", "t") #t-comma text = text.replace("\u021a", "T").replace("\u021b", "t") # t-comma
text = text.replace("\u0102", "A").replace("\u0103", "a") text = text.replace("\u0102", "A").replace("\u0103", "a")
text = text.replace("\u00C2", "A").replace("\u00E2", "a") text = text.replace("\u00C2", "A").replace("\u00E2", "a")
text = text.replace("\u00CE", "I").replace("\u00EE", "i") text = text.replace("\u00CE", "I").replace("\u00EE", "i")
...@@ -531,33 +546,58 @@ class XLMTokenizer(PreTrainedTokenizer): ...@@ -531,33 +546,58 @@ class XLMTokenizer(PreTrainedTokenizer):
- `do_lowercase_and_remove_accent` controle lower casing and accent (automatically set for pretrained vocabularies) - `do_lowercase_and_remove_accent` controle lower casing and accent (automatically set for pretrained vocabularies)
""" """
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, merges_file, unk_token="<unk>", bos_token="<s>", def __init__(
sep_token="</s>", pad_token="<pad>", cls_token="</s>", self,
mask_token="<special1>", additional_special_tokens=["<special0>", vocab_file,
"<special1>", "<special2>", "<special3>", "<special4>", "<special5>", merges_file,
"<special6>", "<special7>", "<special8>", "<special9>"], unk_token="<unk>",
lang2id=None, id2lang=None, do_lowercase_and_remove_accent=True, bos_token="<s>",
**kwargs): sep_token="</s>",
super(XLMTokenizer, self).__init__(unk_token=unk_token, bos_token=bos_token, pad_token="<pad>",
sep_token=sep_token, pad_token=pad_token, cls_token="</s>",
cls_token=cls_token, mask_token=mask_token, mask_token="<special1>",
additional_special_tokens=additional_special_tokens, additional_special_tokens=[
**kwargs) "<special0>",
"<special1>",
"<special2>",
"<special3>",
"<special4>",
"<special5>",
"<special6>",
"<special7>",
"<special8>",
"<special9>",
],
lang2id=None,
id2lang=None,
do_lowercase_and_remove_accent=True,
**kwargs
):
super(XLMTokenizer, self).__init__(
unk_token=unk_token,
bos_token=bos_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
additional_special_tokens=additional_special_tokens,
**kwargs
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
# cache of sm.MosesPunctNormalizer instance # cache of sm.MosesPunctNormalizer instance
self.cache_moses_punct_normalizer = dict() self.cache_moses_punct_normalizer = dict()
# cache of sm.MosesTokenizer instance # cache of sm.MosesTokenizer instance
self.cache_moses_tokenizer = dict() self.cache_moses_tokenizer = dict()
self.lang_with_custom_tokenizer = set(['zh', 'th', 'ja']) self.lang_with_custom_tokenizer = set(["zh", "th", "ja"])
# True for current supported model (v1.2.0), False for XLM-17 & 100 # True for current supported model (v1.2.0), False for XLM-17 & 100
self.do_lowercase_and_remove_accent = do_lowercase_and_remove_accent self.do_lowercase_and_remove_accent = do_lowercase_and_remove_accent
self.lang2id = lang2id self.lang2id = lang2id
...@@ -570,9 +610,9 @@ class XLMTokenizer(PreTrainedTokenizer): ...@@ -570,9 +610,9 @@ class XLMTokenizer(PreTrainedTokenizer):
with open(vocab_file, encoding="utf-8") as vocab_handle: with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle) self.encoder = json.load(vocab_handle)
self.decoder = {v:k for k,v in self.encoder.items()} self.decoder = {v: k for k, v in self.encoder.items()}
with open(merges_file, encoding='utf-8') as merges_handle: with open(merges_file, encoding="utf-8") as merges_handle:
merges = merges_handle.read().split('\n')[:-1] merges = merges_handle.read().split("\n")[:-1]
merges = [tuple(merge.split()[:2]) for merge in merges] merges = [tuple(merge.split()[:2]) for merge in merges]
self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {} self.cache = {}
...@@ -603,9 +643,14 @@ class XLMTokenizer(PreTrainedTokenizer): ...@@ -603,9 +643,14 @@ class XLMTokenizer(PreTrainedTokenizer):
if self.ja_word_tokenizer is None: if self.ja_word_tokenizer is None:
try: try:
import Mykytea import Mykytea
self.ja_word_tokenizer = Mykytea.Mykytea('-model %s/local/share/kytea/model.bin' % os.path.expanduser('~'))
self.ja_word_tokenizer = Mykytea.Mykytea(
"-model %s/local/share/kytea/model.bin" % os.path.expanduser("~")
)
except (AttributeError, ImportError) as e: except (AttributeError, ImportError) as e:
logger.error("Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper (https://github.com/chezou/Mykytea-python) with the following steps") logger.error(
"Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper (https://github.com/chezou/Mykytea-python) with the following steps"
)
logger.error("1. git clone git@github.com:neubig/kytea.git && cd kytea") logger.error("1. git clone git@github.com:neubig/kytea.git && cd kytea")
logger.error("2. autoreconf -i") logger.error("2. autoreconf -i")
logger.error("3. ./configure --prefix=$HOME/local") logger.error("3. ./configure --prefix=$HOME/local")
...@@ -619,16 +664,16 @@ class XLMTokenizer(PreTrainedTokenizer): ...@@ -619,16 +664,16 @@ class XLMTokenizer(PreTrainedTokenizer):
return len(self.encoder) return len(self.encoder)
def bpe(self, token): def bpe(self, token):
word = tuple(token[:-1]) + (token[-1] + '</w>',) word = tuple(token[:-1]) + (token[-1] + "</w>",)
if token in self.cache: if token in self.cache:
return self.cache[token] return self.cache[token]
pairs = get_pairs(word) pairs = get_pairs(word)
if not pairs: if not pairs:
return token+'</w>' return token + "</w>"
while True: while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks: if bigram not in self.bpe_ranks:
break break
first, second = bigram first, second = bigram
...@@ -643,8 +688,8 @@ class XLMTokenizer(PreTrainedTokenizer): ...@@ -643,8 +688,8 @@ class XLMTokenizer(PreTrainedTokenizer):
new_word.extend(word[i:]) new_word.extend(word[i:])
break break
if word[i] == first and i < len(word)-1 and word[i+1] == second: if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first+second) new_word.append(first + second)
i += 2 i += 2
else: else:
new_word.append(word[i]) new_word.append(word[i])
...@@ -655,13 +700,13 @@ class XLMTokenizer(PreTrainedTokenizer): ...@@ -655,13 +700,13 @@ class XLMTokenizer(PreTrainedTokenizer):
break break
else: else:
pairs = get_pairs(word) pairs = get_pairs(word)
word = ' '.join(word) word = " ".join(word)
if word == '\n </w>': if word == "\n </w>":
word = '\n</w>' word = "\n</w>"
self.cache[token] = word self.cache[token] = word
return word return word
def _tokenize(self, text, lang='en', bypass_tokenizer=False): def _tokenize(self, text, lang="en", bypass_tokenizer=False):
""" """
Tokenize a string given language code. For Chinese, Japanese and Thai, we use a language specific tokenizerself. Otherwise, we use Moses. Tokenize a string given language code. For Chinese, Japanese and Thai, we use a language specific tokenizerself. Otherwise, we use Moses.
...@@ -697,45 +742,49 @@ class XLMTokenizer(PreTrainedTokenizer): ...@@ -697,45 +742,49 @@ class XLMTokenizer(PreTrainedTokenizer):
List of tokens. List of tokens.
""" """
if lang and self.lang2id and lang not in self.lang2id: if lang and self.lang2id and lang not in self.lang2id:
logger.error("Supplied language code not found in lang2id mapping. Please check that your language is supported by the loaded pretrained model.") logger.error(
"Supplied language code not found in lang2id mapping. Please check that your language is supported by the loaded pretrained model."
)
if bypass_tokenizer: if bypass_tokenizer:
text = text.split() text = text.split()
elif lang not in self.lang_with_custom_tokenizer: elif lang not in self.lang_with_custom_tokenizer:
text = self.moses_pipeline(text, lang=lang) text = self.moses_pipeline(text, lang=lang)
# TODO: make sure we are using `xlm-mlm-enro-1024`, since XLM-100 doesn't have this step # TODO: make sure we are using `xlm-mlm-enro-1024`, since XLM-100 doesn't have this step
if lang == 'ro': if lang == "ro":
text = romanian_preprocessing(text) text = romanian_preprocessing(text)
text = self.moses_tokenize(text, lang=lang) text = self.moses_tokenize(text, lang=lang)
elif lang == 'th': elif lang == "th":
text = self.moses_pipeline(text, lang=lang) text = self.moses_pipeline(text, lang=lang)
try: try:
if 'pythainlp' not in sys.modules: if "pythainlp" not in sys.modules:
from pythainlp.tokenize import word_tokenize as th_word_tokenize from pythainlp.tokenize import word_tokenize as th_word_tokenize
else: else:
th_word_tokenize = sys.modules['pythainlp'].word_tokenize th_word_tokenize = sys.modules["pythainlp"].word_tokenize
except (AttributeError, ImportError) as e: except (AttributeError, ImportError) as e:
logger.error("Make sure you install PyThaiNLP (https://github.com/PyThaiNLP/pythainlp) with the following steps") logger.error(
"Make sure you install PyThaiNLP (https://github.com/PyThaiNLP/pythainlp) with the following steps"
)
logger.error("1. pip install pythainlp") logger.error("1. pip install pythainlp")
raise e raise e
text = th_word_tokenize(text) text = th_word_tokenize(text)
elif lang == 'zh': elif lang == "zh":
try: try:
if 'jieba' not in sys.modules: if "jieba" not in sys.modules:
import jieba import jieba
else: else:
jieba = sys.modules['jieba'] jieba = sys.modules["jieba"]
except (AttributeError, ImportError) as e: except (AttributeError, ImportError) as e:
logger.error("Make sure you install Jieba (https://github.com/fxsjy/jieba) with the following steps") logger.error("Make sure you install Jieba (https://github.com/fxsjy/jieba) with the following steps")
logger.error("1. pip install jieba") logger.error("1. pip install jieba")
raise e raise e
text = ' '.join(jieba.cut(text)) text = " ".join(jieba.cut(text))
text = self.moses_pipeline(text, lang=lang) text = self.moses_pipeline(text, lang=lang)
text = text.split() text = text.split()
elif lang == 'ja': elif lang == "ja":
text = self.moses_pipeline(text, lang=lang) text = self.moses_pipeline(text, lang=lang)
text = self.ja_tokenize(text) text = self.ja_tokenize(text)
else: else:
raise ValueError('It should not reach here') raise ValueError("It should not reach here")
if self.do_lowercase_and_remove_accent and not bypass_tokenizer: if self.do_lowercase_and_remove_accent and not bypass_tokenizer:
text = lowercase_and_remove_accent(text) text = lowercase_and_remove_accent(text)
...@@ -743,7 +792,7 @@ class XLMTokenizer(PreTrainedTokenizer): ...@@ -743,7 +792,7 @@ class XLMTokenizer(PreTrainedTokenizer):
split_tokens = [] split_tokens = []
for token in text: for token in text:
if token: if token:
split_tokens.extend([t for t in self.bpe(token).split(' ')]) split_tokens.extend([t for t in self.bpe(token).split(" ")])
return split_tokens return split_tokens
...@@ -757,7 +806,7 @@ class XLMTokenizer(PreTrainedTokenizer): ...@@ -757,7 +806,7 @@ class XLMTokenizer(PreTrainedTokenizer):
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """ """ Converts a sequence of tokens (string) in a single string. """
out_string = ''.join(tokens).replace('</w>', ' ').strip() out_string = "".join(tokens).replace("</w>", " ").strip()
return out_string return out_string
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
...@@ -792,8 +841,10 @@ class XLMTokenizer(PreTrainedTokenizer): ...@@ -792,8 +841,10 @@ class XLMTokenizer(PreTrainedTokenizer):
if already_has_special_tokens: if already_has_special_tokens:
if token_ids_1 is not None: if token_ids_1 is not None:
raise ValueError("You should not supply a second sequence if the provided sequence of " raise ValueError(
"ids is already formated with special tokens for the model.") "You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
if token_ids_1 is not None: if token_ids_1 is not None:
...@@ -820,20 +871,22 @@ class XLMTokenizer(PreTrainedTokenizer): ...@@ -820,20 +871,22 @@ class XLMTokenizer(PreTrainedTokenizer):
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return return
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file']) merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES["merges_file"])
with open(vocab_file, 'w', encoding='utf-8') as f: with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False)) f.write(json.dumps(self.encoder, ensure_ascii=False))
index = 0 index = 0
with open(merge_file, "w", encoding="utf-8") as writer: with open(merge_file, "w", encoding="utf-8") as writer:
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index: if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." logger.warning(
" Please check that the tokenizer is not corrupted!".format(merge_file)) "Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file)
)
index = token_index index = token_index
writer.write(' '.join(bpe_tokens) + u'\n') writer.write(" ".join(bpe_tokens) + "\n")
index += 1 index += 1
return vocab_file, merge_file return vocab_file, merge_file
...@@ -13,8 +13,7 @@ ...@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
""" Tokenization classes for XLM-RoBERTa model.""" """ Tokenization classes for XLM-RoBERTa model."""
from __future__ import (absolute_import, division, print_function, from __future__ import absolute_import, division, print_function, unicode_literals
unicode_literals)
import logging import logging
import os import os
...@@ -26,29 +25,29 @@ from .tokenization_xlnet import SPIECE_UNDERLINE ...@@ -26,29 +25,29 @@ from .tokenization_xlnet import SPIECE_UNDERLINE
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {'vocab_file': 'sentencepiece.bpe.model'} VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"}
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file': "vocab_file": {
{ "xlm-roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-base-sentencepiece.bpe.model",
'xlm-roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-base-sentencepiece.bpe.model", "xlm-roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-sentencepiece.bpe.model",
'xlm-roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-sentencepiece.bpe.model", "xlm-roberta-large-finetuned-conll02-dutch": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-dutch-sentencepiece.bpe.model",
'xlm-roberta-large-finetuned-conll02-dutch': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-dutch-sentencepiece.bpe.model", "xlm-roberta-large-finetuned-conll02-spanish": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-spanish-sentencepiece.bpe.model",
'xlm-roberta-large-finetuned-conll02-spanish': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-spanish-sentencepiece.bpe.model", "xlm-roberta-large-finetuned-conll03-english": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-english-sentencepiece.bpe.model",
'xlm-roberta-large-finetuned-conll03-english': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-english-sentencepiece.bpe.model", "xlm-roberta-large-finetuned-conll03-german": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-german-sentencepiece.bpe.model",
'xlm-roberta-large-finetuned-conll03-german': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-german-sentencepiece.bpe.model",
} }
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'xlm-roberta-base': 512, "xlm-roberta-base": 512,
'xlm-roberta-large': 512, "xlm-roberta-large": 512,
'xlm-roberta-large-finetuned-conll02-dutch': 512, "xlm-roberta-large-finetuned-conll02-dutch": 512,
'xlm-roberta-large-finetuned-conll02-spanish': 512, "xlm-roberta-large-finetuned-conll02-spanish": 512,
'xlm-roberta-large-finetuned-conll03-english': 512, "xlm-roberta-large-finetuned-conll03-english": 512,
'xlm-roberta-large-finetuned-conll03-german': 512, "xlm-roberta-large-finetuned-conll03-german": 512,
} }
class XLMRobertaTokenizer(PreTrainedTokenizer): class XLMRobertaTokenizer(PreTrainedTokenizer):
""" """
Adapted from RobertaTokenizer and XLNetTokenizer Adapted from RobertaTokenizer and XLNetTokenizer
...@@ -56,17 +55,33 @@ class XLMRobertaTokenizer(PreTrainedTokenizer): ...@@ -56,17 +55,33 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
- requires `SentencePiece <https://github.com/google/sentencepiece>`_ - requires `SentencePiece <https://github.com/google/sentencepiece>`_
""" """
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, bos_token="<s>", eos_token="</s>", sep_token="</s>", def __init__(
cls_token="<s>", unk_token="<unk>", pad_token='<pad>', mask_token='<mask>', self,
**kwargs): vocab_file,
super(XLMRobertaTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, bos_token="<s>",
sep_token=sep_token, cls_token=cls_token, pad_token=pad_token, eos_token="</s>",
mask_token=mask_token, sep_token="</s>",
**kwargs) cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
**kwargs
):
super(XLMRobertaTokenizer, self).__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
**kwargs
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens
self.sp_model = spm.SentencePieceProcessor() self.sp_model = spm.SentencePieceProcessor()
...@@ -85,7 +100,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer): ...@@ -85,7 +100,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
# The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab # The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab
self.fairseq_offset = 1 self.fairseq_offset = 1
self.fairseq_tokens_to_ids['<mask>'] = len(self.sp_model) + len(self.fairseq_tokens_to_ids) self.fairseq_tokens_to_ids["<mask>"] = len(self.sp_model) + len(self.fairseq_tokens_to_ids)
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()} self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
...@@ -119,8 +134,10 @@ class XLMRobertaTokenizer(PreTrainedTokenizer): ...@@ -119,8 +134,10 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
""" """
if already_has_special_tokens: if already_has_special_tokens:
if token_ids_1 is not None: if token_ids_1 is not None:
raise ValueError("You should not supply a second sequence if the provided sequence of " raise ValueError(
"ids is already formated with special tokens for the model.") "You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
if token_ids_1 is None: if token_ids_1 is None:
...@@ -164,7 +181,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer): ...@@ -164,7 +181,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string.""" """Converts a sequence of tokens (strings for sub-words) in a single string."""
out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip() out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
return out_string return out_string
def save_vocabulary(self, save_directory): def save_vocabulary(self, save_directory):
...@@ -174,7 +191,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer): ...@@ -174,7 +191,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return return
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
......
...@@ -13,8 +13,7 @@ ...@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Tokenization classes for XLNet model.""" """ Tokenization classes for XLNet model."""
from __future__ import (absolute_import, division, print_function, from __future__ import absolute_import, division, print_function, unicode_literals
unicode_literals)
import logging import logging
import os import os
...@@ -27,51 +26,69 @@ from .tokenization_utils import PreTrainedTokenizer ...@@ -27,51 +26,69 @@ from .tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {'vocab_file': 'spiece.model'} VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file': "vocab_file": {
{ "xlnet-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-spiece.model",
'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-spiece.model", "xlnet-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-spiece.model",
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-spiece.model",
} }
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'xlnet-base-cased': None, "xlnet-base-cased": None,
'xlnet-large-cased': None, "xlnet-large-cased": None,
} }
SPIECE_UNDERLINE = u'▁' SPIECE_UNDERLINE = "▁"
# Segments (not really needed) # Segments (not really needed)
SEG_ID_A = 0 SEG_ID_A = 0
SEG_ID_B = 1 SEG_ID_B = 1
SEG_ID_CLS = 2 SEG_ID_CLS = 2
SEG_ID_SEP = 3 SEG_ID_SEP = 3
SEG_ID_PAD = 4 SEG_ID_PAD = 4
class XLNetTokenizer(PreTrainedTokenizer): class XLNetTokenizer(PreTrainedTokenizer):
""" """
SentencePiece based tokenizer. Peculiarities: SentencePiece based tokenizer. Peculiarities:
- requires `SentencePiece <https://github.com/google/sentencepiece>`_ - requires `SentencePiece <https://github.com/google/sentencepiece>`_
""" """
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
padding_side = "left" padding_side = "left"
def __init__(self, vocab_file, def __init__(
do_lower_case=False, remove_space=True, keep_accents=False, self,
bos_token="<s>", eos_token="</s>", unk_token="<unk>", sep_token="<sep>", vocab_file,
pad_token="<pad>", cls_token="<cls>", mask_token="<mask>", do_lower_case=False,
additional_special_tokens=["<eop>", "<eod>"], **kwargs): remove_space=True,
super(XLNetTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, keep_accents=False,
unk_token=unk_token, sep_token=sep_token, bos_token="<s>",
pad_token=pad_token, cls_token=cls_token, eos_token="</s>",
mask_token=mask_token, additional_special_tokens= unk_token="<unk>",
additional_special_tokens, **kwargs) sep_token="<sep>",
pad_token="<pad>",
cls_token="<cls>",
mask_token="<mask>",
additional_special_tokens=["<eop>", "<eod>"],
**kwargs
):
super(XLNetTokenizer, self).__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
additional_special_tokens=additional_special_tokens,
**kwargs
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
...@@ -80,8 +97,10 @@ class XLNetTokenizer(PreTrainedTokenizer): ...@@ -80,8 +97,10 @@ class XLNetTokenizer(PreTrainedTokenizer):
try: try:
import sentencepiece as spm import sentencepiece as spm
except ImportError: except ImportError:
logger.warning("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece" logger.warning(
"pip install sentencepiece") "You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
self.do_lower_case = do_lower_case self.do_lower_case = do_lower_case
self.remove_space = remove_space self.remove_space = remove_space
...@@ -105,24 +124,26 @@ class XLNetTokenizer(PreTrainedTokenizer): ...@@ -105,24 +124,26 @@ class XLNetTokenizer(PreTrainedTokenizer):
try: try:
import sentencepiece as spm import sentencepiece as spm
except ImportError: except ImportError:
logger.warning("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece" logger.warning(
"pip install sentencepiece") "You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
self.sp_model = spm.SentencePieceProcessor() self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(self.vocab_file) self.sp_model.Load(self.vocab_file)
def preprocess_text(self, inputs): def preprocess_text(self, inputs):
if self.remove_space: if self.remove_space:
outputs = ' '.join(inputs.strip().split()) outputs = " ".join(inputs.strip().split())
else: else:
outputs = inputs outputs = inputs
outputs = outputs.replace("``", '"').replace("''", '"') outputs = outputs.replace("``", '"').replace("''", '"')
if six.PY2 and isinstance(outputs, str): if six.PY2 and isinstance(outputs, str):
outputs = outputs.decode('utf-8') outputs = outputs.decode("utf-8")
if not self.keep_accents: if not self.keep_accents:
outputs = unicodedata.normalize('NFKD', outputs) outputs = unicodedata.normalize("NFKD", outputs)
outputs = ''.join([c for c in outputs if not unicodedata.combining(c)]) outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
if self.do_lower_case: if self.do_lower_case:
outputs = outputs.lower() outputs = outputs.lower()
...@@ -135,7 +156,7 @@ class XLNetTokenizer(PreTrainedTokenizer): ...@@ -135,7 +156,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
text = self.preprocess_text(text) text = self.preprocess_text(text)
# note(zhiliny): in some systems, sentencepiece only accepts str for py2 # note(zhiliny): in some systems, sentencepiece only accepts str for py2
if six.PY2 and isinstance(text, unicode): if six.PY2 and isinstance(text, unicode):
text = text.encode('utf-8') text = text.encode("utf-8")
if not sample: if not sample:
pieces = self.sp_model.EncodeAsPieces(text) pieces = self.sp_model.EncodeAsPieces(text)
...@@ -143,9 +164,8 @@ class XLNetTokenizer(PreTrainedTokenizer): ...@@ -143,9 +164,8 @@ class XLNetTokenizer(PreTrainedTokenizer):
pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1) pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1)
new_pieces = [] new_pieces = []
for piece in pieces: for piece in pieces:
if len(piece) > 1 and piece[-1] == str(',') and piece[-2].isdigit(): if len(piece) > 1 and piece[-1] == str(",") and piece[-2].isdigit():
cur_pieces = self.sp_model.EncodeAsPieces( cur_pieces = self.sp_model.EncodeAsPieces(piece[:-1].replace(SPIECE_UNDERLINE, ""))
piece[:-1].replace(SPIECE_UNDERLINE, ''))
if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE: if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
if len(cur_pieces[0]) == 1: if len(cur_pieces[0]) == 1:
cur_pieces = cur_pieces[1:] cur_pieces = cur_pieces[1:]
...@@ -161,7 +181,7 @@ class XLNetTokenizer(PreTrainedTokenizer): ...@@ -161,7 +181,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
ret_pieces = [] ret_pieces = []
for piece in new_pieces: for piece in new_pieces:
if isinstance(piece, str): if isinstance(piece, str):
piece = piece.decode('utf-8') piece = piece.decode("utf-8")
ret_pieces.append(piece) ret_pieces.append(piece)
new_pieces = ret_pieces new_pieces = ret_pieces
...@@ -175,12 +195,12 @@ class XLNetTokenizer(PreTrainedTokenizer): ...@@ -175,12 +195,12 @@ class XLNetTokenizer(PreTrainedTokenizer):
"""Converts an index (integer) in a token (string/unicode) using the vocab.""" """Converts an index (integer) in a token (string/unicode) using the vocab."""
token = self.sp_model.IdToPiece(index) token = self.sp_model.IdToPiece(index)
if six.PY2 and return_unicode and isinstance(token, str): if six.PY2 and return_unicode and isinstance(token, str):
token = token.decode('utf-8') token = token.decode("utf-8")
return token return token
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string.""" """Converts a sequence of tokens (strings for sub-words) in a single string."""
out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip() out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
return out_string return out_string
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
...@@ -215,8 +235,10 @@ class XLNetTokenizer(PreTrainedTokenizer): ...@@ -215,8 +235,10 @@ class XLNetTokenizer(PreTrainedTokenizer):
if already_has_special_tokens: if already_has_special_tokens:
if token_ids_1 is not None: if token_ids_1 is not None:
raise ValueError("You should not supply a second sequence if the provided sequence of " raise ValueError(
"ids is already formated with special tokens for the model.") "You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
if token_ids_1 is not None: if token_ids_1 is not None:
...@@ -247,7 +269,7 @@ class XLNetTokenizer(PreTrainedTokenizer): ...@@ -247,7 +269,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return return
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
......
''' Script for downloading all GLUE data. """ Script for downloading all GLUE data.
Original source: https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e Original source: https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e
Note: for legal reasons, we are unable to host MRPC. Note: for legal reasons, we are unable to host MRPC.
...@@ -16,7 +16,7 @@ rm MSRParaphraseCorpus.msi ...@@ -16,7 +16,7 @@ rm MSRParaphraseCorpus.msi
1/30/19: It looks like SentEval is no longer hosting their extracted and tokenized MRPC data, so you'll need to download the data from the original source for now. 1/30/19: It looks like SentEval is no longer hosting their extracted and tokenized MRPC data, so you'll need to download the data from the original source for now.
2/11/19: It looks like SentEval actually *is* hosting the extracted data. Hooray! 2/11/19: It looks like SentEval actually *is* hosting the extracted data. Hooray!
''' """
import os import os
import sys import sys
...@@ -27,20 +27,23 @@ import urllib.request ...@@ -27,20 +27,23 @@ import urllib.request
import zipfile import zipfile
TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"] TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"]
TASK2PATH = {"CoLA":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4', TASK2PATH = {
"SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8', "CoLA": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4",
"MRPC":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc', "SST": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8",
"QQP":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5', "MRPC": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc",
"STS":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5', "QQP": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5",
"MNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce', "STS": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5",
"SNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df', "MNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce",
"QNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601', "SNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df",
"RTE":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb', "QNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601",
"WNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf', "RTE": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb",
"diagnostic":'https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D'} "WNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf",
"diagnostic": "https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D",
MRPC_TRAIN = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt' }
MRPC_TEST = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt'
MRPC_TRAIN = "https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt"
MRPC_TEST = "https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt"
def download_and_extract(task, data_dir): def download_and_extract(task, data_dir):
print("Downloading and extracting %s..." % task) print("Downloading and extracting %s..." % task)
...@@ -51,6 +54,7 @@ def download_and_extract(task, data_dir): ...@@ -51,6 +54,7 @@ def download_and_extract(task, data_dir):
os.remove(data_file) os.remove(data_file)
print("\tCompleted!") print("\tCompleted!")
def format_mrpc(data_dir, path_to_data): def format_mrpc(data_dir, path_to_data):
print("Processing MRPC...") print("Processing MRPC...")
mrpc_dir = os.path.join(data_dir, "MRPC") mrpc_dir = os.path.join(data_dir, "MRPC")
...@@ -72,30 +76,32 @@ def format_mrpc(data_dir, path_to_data): ...@@ -72,30 +76,32 @@ def format_mrpc(data_dir, path_to_data):
dev_ids = [] dev_ids = []
with open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf8") as ids_fh: with open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf8") as ids_fh:
for row in ids_fh: for row in ids_fh:
dev_ids.append(row.strip().split('\t')) dev_ids.append(row.strip().split("\t"))
with open(mrpc_train_file, encoding="utf8") as data_fh, \ with open(mrpc_train_file, encoding="utf8") as data_fh, open(
open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding="utf8") as train_fh, \ os.path.join(mrpc_dir, "train.tsv"), "w", encoding="utf8"
open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding="utf8") as dev_fh: ) as train_fh, open(os.path.join(mrpc_dir, "dev.tsv"), "w", encoding="utf8") as dev_fh:
header = data_fh.readline() header = data_fh.readline()
train_fh.write(header) train_fh.write(header)
dev_fh.write(header) dev_fh.write(header)
for row in data_fh: for row in data_fh:
label, id1, id2, s1, s2 = row.strip().split('\t') label, id1, id2, s1, s2 = row.strip().split("\t")
if [id1, id2] in dev_ids: if [id1, id2] in dev_ids:
dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2)) dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
else: else:
train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2)) train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
with open(mrpc_test_file, encoding="utf8") as data_fh, \ with open(mrpc_test_file, encoding="utf8") as data_fh, open(
open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding="utf8") as test_fh: os.path.join(mrpc_dir, "test.tsv"), "w", encoding="utf8"
) as test_fh:
header = data_fh.readline() header = data_fh.readline()
test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n") test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n")
for idx, row in enumerate(data_fh): for idx, row in enumerate(data_fh):
label, id1, id2, s1, s2 = row.strip().split('\t') label, id1, id2, s1, s2 = row.strip().split("\t")
test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2)) test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2))
print("\tCompleted!") print("\tCompleted!")
def download_diagnostic(data_dir): def download_diagnostic(data_dir):
print("Downloading and extracting diagnostic...") print("Downloading and extracting diagnostic...")
if not os.path.isdir(os.path.join(data_dir, "diagnostic")): if not os.path.isdir(os.path.join(data_dir, "diagnostic")):
...@@ -105,8 +111,9 @@ def download_diagnostic(data_dir): ...@@ -105,8 +111,9 @@ def download_diagnostic(data_dir):
print("\tCompleted!") print("\tCompleted!")
return return
def get_tasks(task_names): def get_tasks(task_names):
task_names = task_names.split(',') task_names = task_names.split(",")
if "all" in task_names: if "all" in task_names:
tasks = TASKS tasks = TASKS
else: else:
...@@ -116,13 +123,19 @@ def get_tasks(task_names): ...@@ -116,13 +123,19 @@ def get_tasks(task_names):
tasks.append(task_name) tasks.append(task_name)
return tasks return tasks
def main(arguments): def main(arguments):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', help='directory to save data to', type=str, default='glue_data') parser.add_argument("--data_dir", help="directory to save data to", type=str, default="glue_data")
parser.add_argument('--tasks', help='tasks to download data for as a comma separated string', parser.add_argument(
type=str, default='all') "--tasks", help="tasks to download data for as a comma separated string", type=str, default="all"
parser.add_argument('--path_to_mrpc', help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt', )
type=str, default='') parser.add_argument(
"--path_to_mrpc",
help="path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt",
type=str,
default="",
)
args = parser.parse_args(arguments) args = parser.parse_args(arguments)
if not os.path.isdir(args.data_dir): if not os.path.isdir(args.data_dir):
...@@ -130,13 +143,13 @@ def main(arguments): ...@@ -130,13 +143,13 @@ def main(arguments):
tasks = get_tasks(args.tasks) tasks = get_tasks(args.tasks)
for task in tasks: for task in tasks:
if task == 'MRPC': if task == "MRPC":
format_mrpc(args.data_dir, args.path_to_mrpc) format_mrpc(args.data_dir, args.path_to_mrpc)
elif task == 'diagnostic': elif task == "diagnostic":
download_diagnostic(args.data_dir) download_diagnostic(args.data_dir)
else: else:
download_and_extract(task, args.data_dir) download_and_extract(task, args.data_dir)
if __name__ == '__main__': if __name__ == "__main__":
sys.exit(main(sys.argv[1:])) sys.exit(main(sys.argv[1:]))
...@@ -43,7 +43,7 @@ def scan_code_for_links(source): ...@@ -43,7 +43,7 @@ def scan_code_for_links(source):
""" Scans the file to find links using a regular expression. """ Scans the file to find links using a regular expression.
Returns a list of links. Returns a list of links.
""" """
with open(source, 'r') as content: with open(source, "r") as content:
content = content.read() content = content.read()
raw_links = re.findall(REGEXP_FIND_S3_LINKS, content) raw_links = re.findall(REGEXP_FIND_S3_LINKS, content)
links = [prefix + suffix for _, prefix, suffix in raw_links] links = [prefix + suffix for _, prefix, suffix in raw_links]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment