Commit fa84ae26 authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Reformat source code with black.

This is the result of:

    $ black --line-length 119 examples templates transformers utils hubconf.py setup.py

There's a lot of fairly long lines in the project. As a consequence, I'm
picking the longest widely accepted line length, 119 characters.

This is also Thomas' preference, because it allows for explicit variable
names, to make the code easier to understand.
parent 63e3827c
......@@ -17,13 +17,13 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os
import unittest
from transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE)
from transformers.tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE
from .tokenization_tests_commons import CommonTestCases
from .utils import slow
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'fixtures/test_sentencepiece.model')
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
......@@ -40,55 +40,135 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
return XLNetTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self):
input_text = u"This is a test"
output_text = u"This is a test"
input_text = "This is a test"
output_text = "This is a test"
return input_text, output_text
def test_full_tokenizer(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
tokens = tokenizer.tokenize(u'This is a test')
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
tokens = tokenizer.tokenize("This is a test")
self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"])
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.'])
ids = tokenizer.convert_tokens_to_ids(tokens)
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
self.assertListEqual(
ids, [8, 21, 84, 55, 24, 19, 7, 0,
602, 347, 347, 347, 3, 12, 66,
46, 72, 80, 6, 0, 4])
tokens,
[
SPIECE_UNDERLINE + "I",
SPIECE_UNDERLINE + "was",
SPIECE_UNDERLINE + "b",
"or",
"n",
SPIECE_UNDERLINE + "in",
SPIECE_UNDERLINE + "",
"9",
"2",
"0",
"0",
"0",
",",
SPIECE_UNDERLINE + "and",
SPIECE_UNDERLINE + "this",
SPIECE_UNDERLINE + "is",
SPIECE_UNDERLINE + "f",
"al",
"s",
"é",
".",
],
)
ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(ids, [8, 21, 84, 55, 24, 19, 7, 0, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 0, 4])
back_tokens = tokenizer.convert_ids_to_tokens(ids)
self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
u'or', u'n', SPIECE_UNDERLINE + u'in',
SPIECE_UNDERLINE + u'', u'<unk>', u'2', u'0', u'0', u'0', u',',
SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's',
u'<unk>', u'.'])
self.assertListEqual(
back_tokens,
[
SPIECE_UNDERLINE + "I",
SPIECE_UNDERLINE + "was",
SPIECE_UNDERLINE + "b",
"or",
"n",
SPIECE_UNDERLINE + "in",
SPIECE_UNDERLINE + "",
"<unk>",
"2",
"0",
"0",
"0",
",",
SPIECE_UNDERLINE + "and",
SPIECE_UNDERLINE + "this",
SPIECE_UNDERLINE + "is",
SPIECE_UNDERLINE + "f",
"al",
"s",
"<unk>",
".",
],
)
def test_tokenizer_lower(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True)
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'', u'i', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.'])
self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), [u"▁he", u"ll", u"o"])
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
self.assertListEqual(
tokens,
[
SPIECE_UNDERLINE + "",
"i",
SPIECE_UNDERLINE + "was",
SPIECE_UNDERLINE + "b",
"or",
"n",
SPIECE_UNDERLINE + "in",
SPIECE_UNDERLINE + "",
"9",
"2",
"0",
"0",
"0",
",",
SPIECE_UNDERLINE + "and",
SPIECE_UNDERLINE + "this",
SPIECE_UNDERLINE + "is",
SPIECE_UNDERLINE + "f",
"al",
"se",
".",
],
)
self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["▁he", "ll", "o"])
def test_tokenizer_no_lower(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=False)
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', u'or',
u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.'])
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
self.assertListEqual(
tokens,
[
SPIECE_UNDERLINE + "I",
SPIECE_UNDERLINE + "was",
SPIECE_UNDERLINE + "b",
"or",
"n",
SPIECE_UNDERLINE + "in",
SPIECE_UNDERLINE + "",
"9",
"2",
"0",
"0",
"0",
",",
SPIECE_UNDERLINE + "and",
SPIECE_UNDERLINE + "this",
SPIECE_UNDERLINE + "is",
SPIECE_UNDERLINE + "f",
"al",
"se",
".",
],
)
@slow
def test_sequence_builders(self):
......@@ -104,5 +184,5 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
assert encoded_pair == text + [4] + text_2 + [4, 3]
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
......@@ -27,6 +27,7 @@ def parse_flag_from_env(key, default=False):
raise ValueError("If set, {} must be yes or no.".format(key))
return _value
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
_run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False)
......
......@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Tokenization classes for ALBERT model."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from __future__ import absolute_import, division, print_function, unicode_literals
from .tokenization_utils import PreTrainedTokenizer
import logging
......@@ -24,34 +23,34 @@ import os
from shutil import copyfile
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {'vocab_file': 'spiece.model'}
VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
'albert-base-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-spiece.model",
'albert-large-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-spiece.model",
'albert-xlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-spiece.model",
'albert-xxlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-spiece.model",
'albert-base-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-spiece.model",
'albert-large-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-spiece.model",
'albert-xlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-spiece.model",
'albert-xxlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-spiece.model",
"vocab_file": {
"albert-base-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-spiece.model",
"albert-large-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-spiece.model",
"albert-xlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-spiece.model",
"albert-xxlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-spiece.model",
"albert-base-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-spiece.model",
"albert-large-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-spiece.model",
"albert-xlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-spiece.model",
"albert-xxlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-spiece.model",
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'albert-base-v1': 512,
'albert-large-v1': 512,
'albert-xlarge-v1': 512,
'albert-xxlarge-v1': 512,
'albert-base-v2': 512,
'albert-large-v2': 512,
'albert-xlarge-v2': 512,
'albert-xxlarge-v2': 512,
"albert-base-v1": 512,
"albert-large-v1": 512,
"albert-xlarge-v1": 512,
"albert-xxlarge-v1": 512,
"albert-base-v2": 512,
"albert-large-v2": 512,
"albert-xlarge-v2": 512,
"albert-xxlarge-v2": 512,
}
SPIECE_UNDERLINE = u'▁'
SPIECE_UNDERLINE = "▁"
class AlbertTokenizer(PreTrainedTokenizer):
"""
......@@ -59,18 +58,36 @@ class AlbertTokenizer(PreTrainedTokenizer):
- requires `SentencePiece <https://github.com/google/sentencepiece>`_
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file,
do_lower_case=True, remove_space=True, keep_accents=False,
bos_token="[CLS]", eos_token="[SEP]", unk_token="<unk>", sep_token="[SEP]",
pad_token="<pad>", cls_token="[CLS]", mask_token="[MASK]", **kwargs):
super(AlbertTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token,
unk_token=unk_token, sep_token=sep_token,
pad_token=pad_token, cls_token=cls_token,
mask_token=mask_token, **kwargs)
def __init__(
self,
vocab_file,
do_lower_case=True,
remove_space=True,
keep_accents=False,
bos_token="[CLS]",
eos_token="[SEP]",
unk_token="<unk>",
sep_token="[SEP]",
pad_token="<pad>",
cls_token="[CLS]",
mask_token="[MASK]",
**kwargs
):
super(AlbertTokenizer, self).__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
**kwargs
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
......@@ -78,8 +95,10 @@ class AlbertTokenizer(PreTrainedTokenizer):
try:
import sentencepiece as spm
except ImportError:
logger.warning("You need to install SentencePiece to use AlbertTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece")
logger.warning(
"You need to install SentencePiece to use AlbertTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
self.do_lower_case = do_lower_case
self.remove_space = remove_space
......@@ -103,24 +122,26 @@ class AlbertTokenizer(PreTrainedTokenizer):
try:
import sentencepiece as spm
except ImportError:
logger.warning("You need to install SentencePiece to use AlbertTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece")
logger.warning(
"You need to install SentencePiece to use AlbertTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(self.vocab_file)
def preprocess_text(self, inputs):
if self.remove_space:
outputs = ' '.join(inputs.strip().split())
outputs = " ".join(inputs.strip().split())
else:
outputs = inputs
outputs = outputs.replace("``", '"').replace("''", '"')
if six.PY2 and isinstance(outputs, str):
outputs = outputs.decode('utf-8')
outputs = outputs.decode("utf-8")
if not self.keep_accents:
outputs = unicodedata.normalize('NFKD', outputs)
outputs = ''.join([c for c in outputs if not unicodedata.combining(c)])
outputs = unicodedata.normalize("NFKD", outputs)
outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
if self.do_lower_case:
outputs = outputs.lower()
......@@ -133,7 +154,7 @@ class AlbertTokenizer(PreTrainedTokenizer):
text = self.preprocess_text(text)
# note(zhiliny): in some systems, sentencepiece only accepts str for py2
if six.PY2 and isinstance(text, unicode):
text = text.encode('utf-8')
text = text.encode("utf-8")
if not sample:
pieces = self.sp_model.EncodeAsPieces(text)
......@@ -141,9 +162,8 @@ class AlbertTokenizer(PreTrainedTokenizer):
pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1)
new_pieces = []
for piece in pieces:
if len(piece) > 1 and piece[-1] == str(',') and piece[-2].isdigit():
cur_pieces = self.sp_model.EncodeAsPieces(
piece[:-1].replace(SPIECE_UNDERLINE, ''))
if len(piece) > 1 and piece[-1] == str(",") and piece[-2].isdigit():
cur_pieces = self.sp_model.EncodeAsPieces(piece[:-1].replace(SPIECE_UNDERLINE, ""))
if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
if len(cur_pieces[0]) == 1:
cur_pieces = cur_pieces[1:]
......@@ -159,7 +179,7 @@ class AlbertTokenizer(PreTrainedTokenizer):
ret_pieces = []
for piece in new_pieces:
if isinstance(piece, str):
piece = piece.decode('utf-8')
piece = piece.decode("utf-8")
ret_pieces.append(piece)
new_pieces = ret_pieces
......@@ -173,12 +193,12 @@ class AlbertTokenizer(PreTrainedTokenizer):
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
token = self.sp_model.IdToPiece(index)
if six.PY2 and return_unicode and isinstance(token, str):
token = token.decode('utf-8')
token = token.decode("utf-8")
return token
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip()
out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
return out_string
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
......@@ -213,8 +233,10 @@ class AlbertTokenizer(PreTrainedTokenizer):
if already_has_special_tokens:
if token_ids_1 is not None:
raise ValueError("You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model.")
raise ValueError(
"You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
if token_ids_1 is not None:
......@@ -244,7 +266,7 @@ class AlbertTokenizer(PreTrainedTokenizer):
if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)
......
......@@ -35,6 +35,7 @@ from .tokenization_xlm_roberta import XLMRobertaTokenizer
logger = logging.getLogger(__name__)
class AutoTokenizer(object):
r""":class:`~transformers.AutoTokenizer` is a generic tokenizer class
that will be instantiated as one of the tokenizer classes of the library
......@@ -62,9 +63,12 @@ class AutoTokenizer(object):
This class cannot be instantiated using `__init__()` (throw an error).
"""
def __init__(self):
raise EnvironmentError("AutoTokenizer is designed to be instantiated "
"using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method.")
raise EnvironmentError(
"AutoTokenizer is designed to be instantiated "
"using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method."
)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
......@@ -125,34 +129,38 @@ class AutoTokenizer(object):
tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/')
"""
if 't5' in pretrained_model_name_or_path:
if "t5" in pretrained_model_name_or_path:
return T5Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'distilbert' in pretrained_model_name_or_path:
elif "distilbert" in pretrained_model_name_or_path:
return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'albert' in pretrained_model_name_or_path:
elif "albert" in pretrained_model_name_or_path:
return AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'camembert' in pretrained_model_name_or_path:
elif "camembert" in pretrained_model_name_or_path:
return CamembertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'xlm-roberta' in pretrained_model_name_or_path:
elif "xlm-roberta" in pretrained_model_name_or_path:
return XLMRobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'roberta' in pretrained_model_name_or_path:
elif "roberta" in pretrained_model_name_or_path:
return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'bert-base-japanese' in pretrained_model_name_or_path:
elif "bert-base-japanese" in pretrained_model_name_or_path:
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'bert' in pretrained_model_name_or_path:
elif "bert" in pretrained_model_name_or_path:
return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'openai-gpt' in pretrained_model_name_or_path:
elif "openai-gpt" in pretrained_model_name_or_path:
return OpenAIGPTTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'gpt2' in pretrained_model_name_or_path:
elif "gpt2" in pretrained_model_name_or_path:
return GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'transfo-xl' in pretrained_model_name_or_path:
elif "transfo-xl" in pretrained_model_name_or_path:
return TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'xlnet' in pretrained_model_name_or_path:
elif "xlnet" in pretrained_model_name_or_path:
return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'xlm' in pretrained_model_name_or_path:
elif "xlm" in pretrained_model_name_or_path:
return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'ctrl' in pretrained_model_name_or_path:
elif "ctrl" in pretrained_model_name_or_path:
return CTRLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
raise ValueError(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm-roberta', 'xlm', 'roberta', 'distilbert,' 'camembert', 'ctrl', 'albert'".format(pretrained_model_name_or_path))
"'xlm-roberta', 'xlm', 'roberta', 'distilbert,' 'camembert', 'ctrl', 'albert'".format(
pretrained_model_name_or_path
)
)
......@@ -26,69 +26,68 @@ from .tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'}
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt",
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt",
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt",
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
'bert-base-german-dbmdz-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-vocab.txt",
'bert-base-german-dbmdz-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-vocab.txt",
'bert-base-finnish-cased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/vocab.txt",
'bert-base-finnish-uncased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/vocab.txt",
"vocab_file": {
"bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
"bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
"bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
"bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
"bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
"bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
"bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
"bert-base-german-cased": "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
"bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt",
"bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt",
"bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt",
"bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
"bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
"bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-vocab.txt",
"bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-vocab.txt",
"bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/vocab.txt",
"bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/vocab.txt",
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'bert-base-uncased': 512,
'bert-large-uncased': 512,
'bert-base-cased': 512,
'bert-large-cased': 512,
'bert-base-multilingual-uncased': 512,
'bert-base-multilingual-cased': 512,
'bert-base-chinese': 512,
'bert-base-german-cased': 512,
'bert-large-uncased-whole-word-masking': 512,
'bert-large-cased-whole-word-masking': 512,
'bert-large-uncased-whole-word-masking-finetuned-squad': 512,
'bert-large-cased-whole-word-masking-finetuned-squad': 512,
'bert-base-cased-finetuned-mrpc': 512,
'bert-base-german-dbmdz-cased': 512,
'bert-base-german-dbmdz-uncased': 512,
'bert-base-finnish-cased-v1': 512,
'bert-base-finnish-uncased-v1': 512,
"bert-base-uncased": 512,
"bert-large-uncased": 512,
"bert-base-cased": 512,
"bert-large-cased": 512,
"bert-base-multilingual-uncased": 512,
"bert-base-multilingual-cased": 512,
"bert-base-chinese": 512,
"bert-base-german-cased": 512,
"bert-large-uncased-whole-word-masking": 512,
"bert-large-cased-whole-word-masking": 512,
"bert-large-uncased-whole-word-masking-finetuned-squad": 512,
"bert-large-cased-whole-word-masking-finetuned-squad": 512,
"bert-base-cased-finetuned-mrpc": 512,
"bert-base-german-dbmdz-cased": 512,
"bert-base-german-dbmdz-uncased": 512,
"bert-base-finnish-cased-v1": 512,
"bert-base-finnish-uncased-v1": 512,
}
PRETRAINED_INIT_CONFIGURATION = {
'bert-base-uncased': {'do_lower_case': True},
'bert-large-uncased': {'do_lower_case': True},
'bert-base-cased': {'do_lower_case': False},
'bert-large-cased': {'do_lower_case': False},
'bert-base-multilingual-uncased': {'do_lower_case': True},
'bert-base-multilingual-cased': {'do_lower_case': False},
'bert-base-chinese': {'do_lower_case': False},
'bert-base-german-cased': {'do_lower_case': False},
'bert-large-uncased-whole-word-masking': {'do_lower_case': True},
'bert-large-cased-whole-word-masking': {'do_lower_case': False},
'bert-large-uncased-whole-word-masking-finetuned-squad': {'do_lower_case': True},
'bert-large-cased-whole-word-masking-finetuned-squad': {'do_lower_case': False},
'bert-base-cased-finetuned-mrpc': {'do_lower_case': False},
'bert-base-german-dbmdz-cased': {'do_lower_case': False},
'bert-base-german-dbmdz-uncased': {'do_lower_case': True},
'bert-base-finnish-cased-v1': {'do_lower_case': False},
'bert-base-finnish-uncased-v1': {'do_lower_case': True},
"bert-base-uncased": {"do_lower_case": True},
"bert-large-uncased": {"do_lower_case": True},
"bert-base-cased": {"do_lower_case": False},
"bert-large-cased": {"do_lower_case": False},
"bert-base-multilingual-uncased": {"do_lower_case": True},
"bert-base-multilingual-cased": {"do_lower_case": False},
"bert-base-chinese": {"do_lower_case": False},
"bert-base-german-cased": {"do_lower_case": False},
"bert-large-uncased-whole-word-masking": {"do_lower_case": True},
"bert-large-cased-whole-word-masking": {"do_lower_case": False},
"bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True},
"bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False},
"bert-base-cased-finetuned-mrpc": {"do_lower_case": False},
"bert-base-german-dbmdz-cased": {"do_lower_case": False},
"bert-base-german-dbmdz-uncased": {"do_lower_case": True},
"bert-base-finnish-cased-v1": {"do_lower_case": False},
"bert-base-finnish-uncased-v1": {"do_lower_case": True},
}
......@@ -98,7 +97,7 @@ def load_vocab(vocab_file):
with open(vocab_file, "r", encoding="utf-8") as reader:
tokens = reader.readlines()
for index, token in enumerate(tokens):
token = token.rstrip('\n')
token = token.rstrip("\n")
vocab[token] = index
return vocab
......@@ -132,9 +131,20 @@ class BertTokenizer(PreTrainedTokenizer):
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None,
unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]",
mask_token="[MASK]", tokenize_chinese_chars=True, **kwargs):
def __init__(
self,
vocab_file,
do_lower_case=True,
do_basic_tokenize=True,
never_split=None,
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]",
tokenize_chinese_chars=True,
**kwargs
):
"""Constructs a BertTokenizer.
Args:
......@@ -152,24 +162,29 @@ class BertTokenizer(PreTrainedTokenizer):
This should likely be deactivated for Japanese:
see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
"""
super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token,
pad_token=pad_token, cls_token=cls_token,
mask_token=mask_token, **kwargs)
super(BertTokenizer, self).__init__(
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
**kwargs
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
if not os.path.isfile(vocab_file):
raise ValueError(
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)
)
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict(
[(ids, tok) for tok, ids in self.vocab.items()])
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
self.do_basic_tokenize = do_basic_tokenize
if do_basic_tokenize:
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
never_split=never_split,
tokenize_chinese_chars=tokenize_chinese_chars)
self.basic_tokenizer = BasicTokenizer(
do_lower_case=do_lower_case, never_split=never_split, tokenize_chinese_chars=tokenize_chinese_chars
)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
@property
......@@ -196,7 +211,7 @@ class BertTokenizer(PreTrainedTokenizer):
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
out_string = ' '.join(tokens).replace(' ##', '').strip()
out_string = " ".join(tokens).replace(" ##", "").strip()
return out_string
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
......@@ -231,8 +246,10 @@ class BertTokenizer(PreTrainedTokenizer):
if already_has_special_tokens:
if token_ids_1 is not None:
raise ValueError("You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model.")
raise ValueError(
"You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
if token_ids_1 is not None:
......@@ -258,16 +275,18 @@ class BertTokenizer(PreTrainedTokenizer):
"""Save the tokenizer vocabulary to a directory or file."""
index = 0
if os.path.isdir(vocab_path):
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file'])
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
else:
vocab_file = vocab_path
with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive."
" Please check that the vocabulary is not corrupted!".format(vocab_file))
logger.warning(
"Saving vocabulary to {}: vocabulary indices are not consecutive."
" Please check that the vocabulary is not corrupted!".format(vocab_file)
)
index = token_index
writer.write(token + u'\n')
writer.write(token + "\n")
index += 1
return (vocab_file,)
......@@ -382,14 +401,16 @@ class BasicTokenizer(object):
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
if (
(cp >= 0x4E00 and cp <= 0x9FFF)
or (cp >= 0x3400 and cp <= 0x4DBF) #
or (cp >= 0x20000 and cp <= 0x2A6DF) #
or (cp >= 0x2A700 and cp <= 0x2B73F) #
or (cp >= 0x2B740 and cp <= 0x2B81F) #
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
or (cp >= 0xF900 and cp <= 0xFAFF)
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
): #
return True
return False
......@@ -399,7 +420,7 @@ class BasicTokenizer(object):
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
if cp == 0 or cp == 0xFFFD or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
......@@ -499,8 +520,7 @@ def _is_punctuation(char):
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
......
......@@ -28,46 +28,45 @@ from .tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'}
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
'bert-base-japanese': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-vocab.txt",
'bert-base-japanese-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-vocab.txt",
'bert-base-japanese-char': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-vocab.txt",
'bert-base-japanese-char-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-vocab.txt"
"vocab_file": {
"bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-vocab.txt",
"bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-vocab.txt",
"bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-vocab.txt",
"bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-vocab.txt",
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'bert-base-japanese': 512,
'bert-base-japanese-whole-word-masking': 512,
'bert-base-japanese-char': 512,
'bert-base-japanese-char-whole-word-masking': 512
"bert-base-japanese": 512,
"bert-base-japanese-whole-word-masking": 512,
"bert-base-japanese-char": 512,
"bert-base-japanese-char-whole-word-masking": 512,
}
PRETRAINED_INIT_CONFIGURATION = {
'bert-base-japanese': {
'do_lower_case': False,
'word_tokenizer_type': 'mecab',
'subword_tokenizer_type': 'wordpiece'
"bert-base-japanese": {
"do_lower_case": False,
"word_tokenizer_type": "mecab",
"subword_tokenizer_type": "wordpiece",
},
'bert-base-japanese-whole-word-masking':{
'do_lower_case': False,
'word_tokenizer_type': 'mecab',
'subword_tokenizer_type': 'wordpiece'
"bert-base-japanese-whole-word-masking": {
"do_lower_case": False,
"word_tokenizer_type": "mecab",
"subword_tokenizer_type": "wordpiece",
},
'bert-base-japanese-char': {
'do_lower_case': False,
'word_tokenizer_type': 'mecab',
'subword_tokenizer_type': 'character'
"bert-base-japanese-char": {
"do_lower_case": False,
"word_tokenizer_type": "mecab",
"subword_tokenizer_type": "character",
},
"bert-base-japanese-char-whole-word-masking": {
"do_lower_case": False,
"word_tokenizer_type": "mecab",
"subword_tokenizer_type": "character",
},
'bert-base-japanese-char-whole-word-masking': {
'do_lower_case': False,
'word_tokenizer_type': 'mecab',
'subword_tokenizer_type': 'character'
}
}
......@@ -79,11 +78,22 @@ class BertJapaneseTokenizer(BertTokenizer):
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, do_lower_case=False,
do_word_tokenize=True, do_subword_tokenize=True,
word_tokenizer_type='basic', subword_tokenizer_type='wordpiece',
never_split=None, unk_token='[UNK]', sep_token='[SEP]',
pad_token='[PAD]', cls_token='[CLS]', mask_token='[MASK]', **kwargs):
def __init__(
self,
vocab_file,
do_lower_case=False,
do_word_tokenize=True,
do_subword_tokenize=True,
word_tokenizer_type="basic",
subword_tokenizer_type="wordpiece",
never_split=None,
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]",
**kwargs
):
"""Constructs a MecabBertTokenizer.
Args:
......@@ -100,56 +110,53 @@ class BertJapaneseTokenizer(BertTokenizer):
**subword_tokenizer_type**: (`optional`) string (default "wordpiece")
Type of subword tokenizer.
"""
super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token,
pad_token=pad_token, cls_token=cls_token,
mask_token=mask_token, **kwargs)
super(BertTokenizer, self).__init__(
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
**kwargs
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
if not os.path.isfile(vocab_file):
raise ValueError(
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)
)
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict(
[(ids, tok) for tok, ids in self.vocab.items()])
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
self.do_word_tokenize = do_word_tokenize
if do_word_tokenize:
if word_tokenizer_type == 'basic':
self.word_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
never_split=never_split,
tokenize_chinese_chars=False)
elif word_tokenizer_type == 'mecab':
self.word_tokenizer = MecabTokenizer(do_lower_case=do_lower_case,
never_split=never_split)
if word_tokenizer_type == "basic":
self.word_tokenizer = BasicTokenizer(
do_lower_case=do_lower_case, never_split=never_split, tokenize_chinese_chars=False
)
elif word_tokenizer_type == "mecab":
self.word_tokenizer = MecabTokenizer(do_lower_case=do_lower_case, never_split=never_split)
else:
raise ValueError(
"Invalid word_tokenizer_type '{}' is specified.".format(word_tokenizer_type))
raise ValueError("Invalid word_tokenizer_type '{}' is specified.".format(word_tokenizer_type))
self.do_subword_tokenize = do_subword_tokenize
if do_subword_tokenize:
if subword_tokenizer_type == 'wordpiece':
self.subword_tokenizer = WordpieceTokenizer(vocab=self.vocab,
unk_token=self.unk_token)
elif subword_tokenizer_type == 'character':
self.subword_tokenizer = CharacterTokenizer(vocab=self.vocab,
unk_token=self.unk_token)
if subword_tokenizer_type == "wordpiece":
self.subword_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
elif subword_tokenizer_type == "character":
self.subword_tokenizer = CharacterTokenizer(vocab=self.vocab, unk_token=self.unk_token)
else:
raise ValueError(
"Invalid subword_tokenizer_type '{}' is specified.".format(subword_tokenizer_type))
raise ValueError("Invalid subword_tokenizer_type '{}' is specified.".format(subword_tokenizer_type))
def _tokenize(self, text):
if self.do_word_tokenize:
tokens = self.word_tokenizer.tokenize(text,
never_split=self.all_special_tokens)
tokens = self.word_tokenizer.tokenize(text, never_split=self.all_special_tokens)
else:
tokens = [text]
if self.do_subword_tokenize:
split_tokens = [sub_token for token in tokens
for sub_token in self.subword_tokenizer.tokenize(token)]
split_tokens = [sub_token for token in tokens for sub_token in self.subword_tokenizer.tokenize(token)]
else:
split_tokens = tokens
......@@ -177,27 +184,28 @@ class MecabTokenizer(object):
self.normalize_text = normalize_text
import MeCab
self.mecab = MeCab.Tagger()
def tokenize(self, text, never_split=None, **kwargs):
"""Tokenizes a piece of text."""
if self.normalize_text:
text = unicodedata.normalize('NFKC', text)
text = unicodedata.normalize("NFKC", text)
never_split = self.never_split + (never_split if never_split is not None else [])
tokens = []
if six.PY2:
mecab_output = self.mecab.parse(text.encode('utf-8')).decode('utf-8')
mecab_output = self.mecab.parse(text.encode("utf-8")).decode("utf-8")
else:
mecab_output = self.mecab.parse(text)
cursor = 0
for line in mecab_output.split('\n'):
if line == 'EOS':
for line in mecab_output.split("\n"):
if line == "EOS":
break
token, _ = line.split('\t')
token, _ = line.split("\t")
token_start = text.index(token, cursor)
token_end = token_start + len(token)
if self.do_lower_case and token not in never_split:
......@@ -240,7 +248,7 @@ class CharacterTokenizer(object):
A list of characters.
"""
if self.normalize_text:
text = unicodedata.normalize('NFKC', text)
text = unicodedata.normalize("NFKC", text)
output_tokens = []
for i, char in enumerate(text):
......
......@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License
""" Tokenization classes for Camembert model."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from __future__ import absolute_import, division, print_function, unicode_literals
import logging
import os
......@@ -26,19 +25,19 @@ from .tokenization_xlnet import SPIECE_UNDERLINE
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {'vocab_file': 'sentencepiece.bpe.model'}
VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"}
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
'camembert-base': "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-sentencepiece.bpe.model",
"vocab_file": {
"camembert-base": "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-sentencepiece.bpe.model",
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'camembert-base': None,
"camembert-base": None,
}
class CamembertTokenizer(PreTrainedTokenizer):
"""
Adapted from RobertaTokenizer and XLNetTokenizer
......@@ -46,17 +45,36 @@ class CamembertTokenizer(PreTrainedTokenizer):
- requires `SentencePiece <https://github.com/google/sentencepiece>`_
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, bos_token="<s>", eos_token="</s>", sep_token="</s>",
cls_token="<s>", unk_token="<unk>", pad_token='<pad>', mask_token='<mask>',
additional_special_tokens=['<s>NOTUSED', '</s>NOTUSED'], **kwargs):
super(CamembertTokenizer, self).__init__(max_len=512, bos_token=bos_token, eos_token=eos_token, unk_token=unk_token,
sep_token=sep_token, cls_token=cls_token, pad_token=pad_token,
mask_token=mask_token, additional_special_tokens=additional_special_tokens,
**kwargs)
def __init__(
self,
vocab_file,
bos_token="<s>",
eos_token="</s>",
sep_token="</s>",
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
additional_special_tokens=["<s>NOTUSED", "</s>NOTUSED"],
**kwargs
):
super(CamembertTokenizer, self).__init__(
max_len=512,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
additional_special_tokens=additional_special_tokens,
**kwargs
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens
self.sp_model = spm.SentencePieceProcessor()
......@@ -64,9 +82,9 @@ class CamembertTokenizer(PreTrainedTokenizer):
self.vocab_file = vocab_file
# HACK: These tokens were added by fairseq but don't seem to be actually used when duplicated in the actual
# sentencepiece vocabulary (this is the case for <s> and </s>
self.fairseq_tokens_to_ids = {'<s>NOTUSED': 0, '<pad>': 1, '</s>NOTUSED': 2, '<unk>': 3}
self.fairseq_tokens_to_ids = {"<s>NOTUSED": 0, "<pad>": 1, "</s>NOTUSED": 2, "<unk>": 3}
self.fairseq_offset = len(self.fairseq_tokens_to_ids)
self.fairseq_tokens_to_ids['<mask>'] = len(self.sp_model) + len(self.fairseq_tokens_to_ids)
self.fairseq_tokens_to_ids["<mask>"] = len(self.sp_model) + len(self.fairseq_tokens_to_ids)
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
......@@ -100,8 +118,10 @@ class CamembertTokenizer(PreTrainedTokenizer):
"""
if already_has_special_tokens:
if token_ids_1 is not None:
raise ValueError("You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model.")
raise ValueError(
"You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
if token_ids_1 is None:
......@@ -148,7 +168,7 @@ class CamembertTokenizer(PreTrainedTokenizer):
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip()
out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
return out_string
def save_vocabulary(self, save_directory):
......@@ -158,7 +178,7 @@ class CamembertTokenizer(PreTrainedTokenizer):
if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)
......
......@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for Salesforce CTRL."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from __future__ import absolute_import, division, print_function, unicode_literals
import json
import logging
......@@ -27,23 +26,17 @@ from .tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {
'vocab_file': 'vocab.json',
'merges_file': 'merges.txt',
"vocab_file": "vocab.json",
"merges_file": "merges.txt",
}
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
'ctrl': "https://raw.githubusercontent.com/salesforce/ctrl/master/ctrl-vocab.json",
},
'merges_file':
{
'ctrl': "https://raw.githubusercontent.com/salesforce/ctrl/master/ctrl-merges.txt",
},
"vocab_file": {"ctrl": "https://raw.githubusercontent.com/salesforce/ctrl/master/ctrl-vocab.json",},
"merges_file": {"ctrl": "https://raw.githubusercontent.com/salesforce/ctrl/master/ctrl-merges.txt",},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'ctrl': 256,
"ctrl": 256,
}
CONTROL_CODES = {
......@@ -104,6 +97,7 @@ CONTROL_CODES = {
"multilingual": 128406,
}
def get_pairs(word):
"""Return set of symbol pairs in a word.
......@@ -118,11 +112,13 @@ def get_pairs(word):
pairs = set(pairs)
return pairs
class CTRLTokenizer(PreTrainedTokenizer):
"""
CTRL BPE tokenizer. Peculiarities:
- Byte-Pair-Encoding
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
......@@ -130,14 +126,18 @@ class CTRLTokenizer(PreTrainedTokenizer):
def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs):
super(CTRLTokenizer, self).__init__(unk_token=unk_token, **kwargs)
self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens
self.max_len_single_sentence = (
self.max_len
) # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = (
self.max_len
) # no default special tokens - you can update this value if you add special tokens
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
self.decoder = {v:k for k,v in self.encoder.items()}
with open(merges_file, encoding='utf-8') as merges_handle:
merges = merges_handle.read().split('\n')[1:-1]
self.decoder = {v: k for k, v in self.encoder.items()}
with open(merges_file, encoding="utf-8") as merges_handle:
merges = merges_handle.read().split("\n")[1:-1]
merges = [tuple(merge.split()) for merge in merges]
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {}
......@@ -150,14 +150,14 @@ class CTRLTokenizer(PreTrainedTokenizer):
if token in self.cache:
return self.cache[token]
word = tuple(token)
word = tuple(list(word[:-1]) + [word[-1]+'</w>'])
word = tuple(list(word[:-1]) + [word[-1] + "</w>"])
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
......@@ -172,8 +172,8 @@ class CTRLTokenizer(PreTrainedTokenizer):
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
......@@ -184,7 +184,7 @@ class CTRLTokenizer(PreTrainedTokenizer):
break
else:
pairs = get_pairs(word)
word = '@@ '.join(word)
word = "@@ ".join(word)
word = word[:-4]
self.cache[token] = word
return word
......@@ -194,10 +194,10 @@ class CTRLTokenizer(PreTrainedTokenizer):
"""
split_tokens = []
words = re.findall(r'\S+\n?', text)
words = re.findall(r"\S+\n?", text)
for token in words:
split_tokens.extend([t for t in self.bpe(token).split(' ')])
split_tokens.extend([t for t in self.bpe(token).split(" ")])
return split_tokens
def _convert_token_to_id(self, token):
......@@ -210,7 +210,7 @@ class CTRLTokenizer(PreTrainedTokenizer):
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
out_string = ' '.join(tokens).replace('@@ ', '').strip()
out_string = " ".join(tokens).replace("@@ ", "").strip()
return out_string
def save_vocabulary(self, save_directory):
......@@ -218,21 +218,23 @@ class CTRLTokenizer(PreTrainedTokenizer):
if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file'])
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES["merges_file"])
with open(vocab_file, 'w', encoding='utf-8') as f:
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write(u'#version: 0.2\n')
writer.write("#version: 0.2\n")
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file))
logger.warning(
"Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file)
)
index = token_index
writer.write(' '.join(bpe_tokens) + u'\n')
writer.write(" ".join(bpe_tokens) + "\n")
index += 1
return vocab_file, merge_file
......
......@@ -26,23 +26,22 @@ from .tokenization_bert import BertTokenizer
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'}
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
'distilbert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-vocab.txt",
'distilbert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
"vocab_file": {
"distilbert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
"distilbert-base-uncased-distilled-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
"distilbert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-vocab.txt",
"distilbert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'distilbert-base-uncased': 512,
'distilbert-base-uncased-distilled-squad': 512,
'distilbert-base-german-cased': 512,
'distilbert-base-multilingual-cased': 512,
"distilbert-base-uncased": 512,
"distilbert-base-uncased-distilled-squad": 512,
"distilbert-base-german-cased": 512,
"distilbert-base-multilingual-cased": 512,
}
......
......@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from __future__ import absolute_import, division, print_function, unicode_literals
import sys
import json
......@@ -31,42 +30,42 @@ except ImportError:
def lru_cache():
return lambda func: func
from .tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {
'vocab_file': 'vocab.json',
'merges_file': 'merges.txt',
"vocab_file": "vocab.json",
"merges_file": "merges.txt",
}
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json",
'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-vocab.json",
'gpt2-xl': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-vocab.json",
'distilgpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-vocab.json",
"vocab_file": {
"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json",
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-vocab.json",
"gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-vocab.json",
"distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-vocab.json",
},
'merges_file':
{
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-merges.txt",
'gpt2-xl': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-merges.txt",
'distilgpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-merges.txt",
"merges_file": {
"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-merges.txt",
"gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-merges.txt",
"distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-merges.txt",
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'gpt2': 1024,
'gpt2-medium': 1024,
'gpt2-large': 1024,
'gpt2-xl': 1024,
'distilgpt2': 1024,
"gpt2": 1024,
"gpt2-medium": 1024,
"gpt2-large": 1024,
"gpt2-xl": 1024,
"distilgpt2": 1024,
}
@lru_cache()
def bytes_to_unicode():
"""
......@@ -80,17 +79,20 @@ def bytes_to_unicode():
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
"""
_chr = unichr if sys.version_info[0] == 2 else chr
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
bs = (
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2**8):
for b in range(2 ** 8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
cs.append(2 ** 8 + n)
n += 1
cs = [_chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
......@@ -103,6 +105,7 @@ def get_pairs(word):
prev_char = char
return pairs
class GPT2Tokenizer(PreTrainedTokenizer):
"""
GPT-2 BPE tokenizer. Peculiarities:
......@@ -112,15 +115,28 @@ class GPT2Tokenizer(PreTrainedTokenizer):
Otherwise, this tokenizer's ``encode``, ``decode``, and ``tokenize`` methods will not conserve
the spaces at the beginning of a string: `tokenizer.decode(tokenizer.encode(" Hello")) = "Hello"`
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, merges_file, errors='replace', unk_token="<|endoftext|>",
bos_token="<|endoftext|>", eos_token="<|endoftext|>", **kwargs):
def __init__(
self,
vocab_file,
merges_file,
errors="replace",
unk_token="<|endoftext|>",
bos_token="<|endoftext|>",
eos_token="<|endoftext|>",
**kwargs
):
super(GPT2Tokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens
self.max_len_single_sentence = (
self.max_len
) # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = (
self.max_len
) # no default special tokens - you can update this value if you add special tokens
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
......@@ -128,8 +144,8 @@ class GPT2Tokenizer(PreTrainedTokenizer):
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
with open(merges_file, encoding='utf-8') as merges_handle:
bpe_merges = merges_handle.read().split('\n')[1:-1]
with open(merges_file, encoding="utf-8") as merges_handle:
bpe_merges = merges_handle.read().split("\n")[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
......@@ -151,7 +167,7 @@ class GPT2Tokenizer(PreTrainedTokenizer):
return token
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
......@@ -166,8 +182,8 @@ class GPT2Tokenizer(PreTrainedTokenizer):
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
......@@ -178,7 +194,7 @@ class GPT2Tokenizer(PreTrainedTokenizer):
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
word = " ".join(word)
self.cache[token] = word
return word
......@@ -189,15 +205,19 @@ class GPT2Tokenizer(PreTrainedTokenizer):
Begin the sentence with at least one space to get invariance to word order in GPT-2 (and RoBERTa) tokenizers.
"""
if add_prefix_space:
text = ' ' + text
text = " " + text
bpe_tokens = []
for token in re.findall(self.pat, text):
if sys.version_info[0] == 2:
token = ''.join(self.byte_encoder[ord(b)] for b in token) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
token = "".join(
self.byte_encoder[ord(b)] for b in token
) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
else:
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
token = "".join(
self.byte_encoder[b] for b in token.encode("utf-8")
) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
return bpe_tokens
def _convert_token_to_id(self, token):
......@@ -210,8 +230,8 @@ class GPT2Tokenizer(PreTrainedTokenizer):
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
text = ''.join(tokens)
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
text = "".join(tokens)
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
return text
def save_vocabulary(self, save_directory):
......@@ -219,21 +239,23 @@ class GPT2Tokenizer(PreTrainedTokenizer):
if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file'])
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES["merges_file"])
with open(vocab_file, 'w', encoding='utf-8') as f:
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write(u'#version: 0.2\n')
writer.write("#version: 0.2\n")
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file))
logger.warning(
"Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file)
)
index = token_index
writer.write(' '.join(bpe_tokens) + u'\n')
writer.write(" ".join(bpe_tokens) + "\n")
index += 1
return vocab_file, merge_file
......@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from __future__ import absolute_import, division, print_function, unicode_literals
import json
import logging
......@@ -28,25 +27,20 @@ from .tokenization_bert import BasicTokenizer
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {
'vocab_file': 'vocab.json',
'merges_file': 'merges.txt',
"vocab_file": "vocab.json",
"merges_file": "merges.txt",
}
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json",
},
'merges_file':
{
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt",
},
"vocab_file": {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json",},
"merges_file": {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt",},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'openai-gpt': 512,
"openai-gpt": 512,
}
def get_pairs(word):
"""
Return set of symbol pairs in a word.
......@@ -59,27 +53,30 @@ def get_pairs(word):
prev_char = char
return pairs
def text_standardize(text):
"""
fixes some issues the spacy tokenizer had on books corpus
also does some whitespace standardization
"""
text = text.replace('—', '-')
text = text.replace('–', '-')
text = text.replace('―', '-')
text = text.replace('…', '...')
text = text.replace('´', "'")
text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text)
text = re.sub(r'\s*\n\s*', ' \n ', text)
text = re.sub(r'[^\S\n]+', ' ', text)
text = text.replace("—", "-")
text = text.replace("–", "-")
text = text.replace("―", "-")
text = text.replace("…", "...")
text = text.replace("´", "'")
text = re.sub(r"""(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)""", r" \1 ", text)
text = re.sub(r"\s*\n\s*", " \n ", text)
text = re.sub(r"[^\S\n]+", " ", text)
return text.strip()
class OpenAIGPTTokenizer(PreTrainedTokenizer):
"""
BPE tokenizer. Peculiarities:
- lower case all inputs
- uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
......@@ -87,12 +84,17 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs):
super(OpenAIGPTTokenizer, self).__init__(unk_token=unk_token, **kwargs)
self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens
self.max_len_single_sentence = (
self.max_len
) # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = (
self.max_len
) # no default special tokens - you can update this value if you add special tokens
try:
import ftfy
from spacy.lang.en import English
_nlp = English()
self.nlp = _nlp.Defaults.create_tokenizer(_nlp)
self.fix_text = ftfy.fix_text
......@@ -103,9 +105,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
self.decoder = {v:k for k,v in self.encoder.items()}
with open(merges_file, encoding='utf-8') as merges_handle:
merges = merges_handle.read().split('\n')[1:-1]
self.decoder = {v: k for k, v in self.encoder.items()}
with open(merges_file, encoding="utf-8") as merges_handle:
merges = merges_handle.read().split("\n")[1:-1]
merges = [tuple(merge.split()) for merge in merges]
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {}
......@@ -115,16 +117,16 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
return len(self.encoder)
def bpe(self, token):
word = tuple(token[:-1]) + (token[-1] + '</w>',)
word = tuple(token[:-1]) + (token[-1] + "</w>",)
if token in self.cache:
return self.cache[token]
pairs = get_pairs(word)
if not pairs:
return token+'</w>'
return token + "</w>"
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
......@@ -139,8 +141,8 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
......@@ -151,9 +153,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
if word == '\n </w>':
word = '\n</w>'
word = " ".join(word)
if word == "\n </w>":
word = "\n</w>"
self.cache[token] = word
return word
......@@ -164,12 +166,12 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
# Using BERT's BasicTokenizer
text = self.nlp.tokenize(text)
for token in text:
split_tokens.extend([t for t in self.bpe(token).split(' ')])
split_tokens.extend([t for t in self.bpe(token).split(" ")])
else:
# Using SpaCy & ftfy (original tokenization process of OpenAI GPT)
text = self.nlp(text_standardize(self.fix_text(text)))
for token in text:
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(" ")])
return split_tokens
def _convert_token_to_id(self, token):
......@@ -182,7 +184,7 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
out_string = ''.join(tokens).replace('</w>', ' ').strip()
out_string = "".join(tokens).replace("</w>", " ").strip()
return out_string
def save_vocabulary(self, save_directory):
......@@ -190,21 +192,23 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file'])
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES["merges_file"])
with open(vocab_file, 'w', encoding='utf-8') as f:
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write(u'#version: 0.2\n')
writer.write("#version: 0.2\n")
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file))
logger.warning(
"Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file)
)
index = token_index
writer.write(' '.join(bpe_tokens) + u'\n')
writer.write(" ".join(bpe_tokens) + "\n")
index += 1
return vocab_file, merge_file
......@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for RoBERTa."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from __future__ import absolute_import, division, print_function, unicode_literals
import sys
import json
......@@ -33,41 +32,40 @@ except ImportError:
def lru_cache():
return lambda func: func
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {
'vocab_file': 'vocab.json',
'merges_file': 'merges.txt',
"vocab_file": "vocab.json",
"merges_file": "merges.txt",
}
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json",
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json",
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-vocab.json",
'distilroberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-vocab.json",
'roberta-base-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json",
'roberta-large-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json",
"vocab_file": {
"roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json",
"roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json",
"roberta-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-vocab.json",
"distilroberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-vocab.json",
"roberta-base-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json",
"roberta-large-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json",
},
'merges_file':
{
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt",
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt",
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-merges.txt",
'distilroberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-merges.txt",
'roberta-base-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt",
'roberta-large-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt",
"merges_file": {
"roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt",
"roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt",
"roberta-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-merges.txt",
"distilroberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-merges.txt",
"roberta-base-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt",
"roberta-large-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt",
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'roberta-base': 512,
'roberta-large': 512,
'roberta-large-mnli': 512,
'distilroberta-base': 512,
'roberta-base-openai-detector': 512,
'roberta-large-openai-detector': 512,
"roberta-base": 512,
"roberta-large": 512,
"roberta-large-mnli": 512,
"distilroberta-base": 512,
"roberta-base-openai-detector": 512,
"roberta-large-openai-detector": 512,
}
......@@ -80,16 +78,38 @@ class RobertaTokenizer(GPT2Tokenizer):
Otherwise, this tokenizer ``encode`` and ``decode`` method will not conserve
the absence of a space at the beginning of a string: `tokenizer.decode(tokenizer.encode("Hello")) = " Hello"`
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, merges_file, errors='replace', bos_token="<s>", eos_token="</s>", sep_token="</s>",
cls_token="<s>", unk_token="<unk>", pad_token='<pad>', mask_token='<mask>', **kwargs):
super(RobertaTokenizer, self).__init__(vocab_file=vocab_file, merges_file=merges_file, errors=errors,
bos_token=bos_token, eos_token=eos_token, unk_token=unk_token,
sep_token=sep_token, cls_token=cls_token, pad_token=pad_token,
mask_token=mask_token, **kwargs)
def __init__(
self,
vocab_file,
merges_file,
errors="replace",
bos_token="<s>",
eos_token="</s>",
sep_token="</s>",
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
**kwargs
):
super(RobertaTokenizer, self).__init__(
vocab_file=vocab_file,
merges_file=merges_file,
errors=errors,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
**kwargs
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens
......@@ -124,8 +144,10 @@ class RobertaTokenizer(GPT2Tokenizer):
"""
if already_has_special_tokens:
if token_ids_1 is not None:
raise ValueError("You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model.")
raise ValueError(
"You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
if token_ids_1 is None:
......
......@@ -26,26 +26,25 @@ from .tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__)
SPIECE_UNDERLINE = u'▁'
SPIECE_UNDERLINE = "▁"
####################################################
# Mapping from the keyword arguments names of Tokenizer `__init__`
# to file names for serializing Tokenizer instances
####################################################
VOCAB_FILES_NAMES = {'vocab_file': 'spiece.model'}
VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
####################################################
# Mapping from the keyword arguments names of Tokenizer `__init__`
# to pretrained vocabulary URL for all the model shortcut names.
####################################################
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
't5-small': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
't5-base': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
't5-large': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
't5-3b': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
't5-11b': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
"vocab_file": {
"t5-small": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
"t5-base": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
"t5-large": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
"t5-3b": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
"t5-11b": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
}
}
......@@ -53,13 +52,14 @@ PRETRAINED_VOCAB_FILES_MAP = {
# Mapping from model shortcut names to max length of inputs
####################################################
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
't5-small': 512,
't5-base': 512,
't5-large': 512,
't5-3b': 512,
't5-11b': 512,
"t5-small": 512,
"t5-base": 512,
"t5-large": 512,
"t5-3b": 512,
"t5-11b": 512,
}
class T5Tokenizer(PreTrainedTokenizer):
"""
SentencePiece based tokenizer. Peculiarities:
......@@ -71,28 +71,43 @@ class T5Tokenizer(PreTrainedTokenizer):
(like in T5 preprocessing
see: https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117)
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, eos_token="</s>", unk_token="<unk>",
pad_token="<pad>", extra_ids=100, additional_special_tokens=None, **kwargs):
def __init__(
self,
vocab_file,
eos_token="</s>",
unk_token="<unk>",
pad_token="<pad>",
extra_ids=100,
additional_special_tokens=None,
**kwargs
):
# Add extra_ids to the special token list
if extra_ids > 0:
if additional_special_tokens is None:
additional_special_tokens = []
additional_special_tokens.extend([u"<extra_id_{}>".format(i) for i in range(extra_ids)])
additional_special_tokens.extend(["<extra_id_{}>".format(i) for i in range(extra_ids)])
super(T5Tokenizer, self).__init__(eos_token=eos_token, unk_token=unk_token,
pad_token=pad_token, additional_special_tokens=additional_special_tokens,
**kwargs)
super(T5Tokenizer, self).__init__(
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
additional_special_tokens=additional_special_tokens,
**kwargs
)
try:
import sentencepiece as spm
except ImportError:
logger.warning("You need to install SentencePiece to use T5Tokenizer:"
logger.warning(
"You need to install SentencePiece to use T5Tokenizer:"
"https://github.com/google/sentencepiece"
"pip install sentencepiece")
"pip install sentencepiece"
)
self.vocab_file = vocab_file
self._extra_ids = extra_ids
......@@ -114,8 +129,10 @@ class T5Tokenizer(PreTrainedTokenizer):
try:
import sentencepiece as spm
except ImportError:
logger.warning("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece")
logger.warning(
"You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(self.vocab_file)
......@@ -132,7 +149,7 @@ class T5Tokenizer(PreTrainedTokenizer):
ret_pieces = []
for piece in pieces:
if isinstance(piece, str):
piece = piece.decode('utf-8')
piece = piece.decode("utf-8")
ret_pieces.append(piece)
pieces = ret_pieces
......@@ -140,8 +157,8 @@ class T5Tokenizer(PreTrainedTokenizer):
def _convert_token_to_id(self, token):
""" Converts a token (str/unicode) in an id using the vocab. """
if token.startswith(u"<extra_id_"):
l = re.match(r'<extra_id_(\d+)>', token)
if token.startswith("<extra_id_"):
l = re.match(r"<extra_id_(\d+)>", token)
num = int(l.group(1))
return self.vocab_size - num - 1
return self.sp_model.piece_to_id(token)
......@@ -151,9 +168,9 @@ class T5Tokenizer(PreTrainedTokenizer):
if index < self.sp_model.get_piece_size():
token = self.sp_model.IdToPiece(index)
else:
token = u"<extra_id_{}>".format(self.vocab_size - 1 - index)
token = "<extra_id_{}>".format(self.vocab_size - 1 - index)
if six.PY2 and return_unicode and isinstance(token, str):
token = token.decode('utf-8')
token = token.decode("utf-8")
return token
def convert_tokens_to_string(self, tokens):
......@@ -168,7 +185,7 @@ class T5Tokenizer(PreTrainedTokenizer):
if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)
......
......@@ -16,8 +16,7 @@
""" Tokenization classes for Transformer XL model.
Adapted from https://github.com/kimiyoung/transformer-xl.
"""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from __future__ import absolute_import, division, print_function, unicode_literals
import glob
import logging
......@@ -44,42 +43,58 @@ except ImportError:
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {'pretrained_vocab_file': 'vocab.bin', 'vocab_file': 'vocab.txt'}
VOCAB_FILES_NAMES = {"pretrained_vocab_file": "vocab.bin", "vocab_file": "vocab.txt"}
PRETRAINED_VOCAB_FILES_MAP = {
'pretrained_vocab_file':
{
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin",
"pretrained_vocab_file": {
"transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin",
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'transfo-xl-wt103': None,
"transfo-xl-wt103": None,
}
PRETRAINED_CORPUS_ARCHIVE_MAP = {
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-corpus.bin",
"transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-corpus.bin",
}
CORPUS_NAME = 'corpus.bin'
CORPUS_NAME = "corpus.bin"
class TransfoXLTokenizer(PreTrainedTokenizer):
"""
Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, special=None, min_freq=0, max_size=None, lower_case=False,
delimiter=None, vocab_file=None, pretrained_vocab_file=None,
never_split=None, unk_token="<unk>", eos_token="<eos>",
additional_special_tokens=["<formula>"], **kwargs):
super(TransfoXLTokenizer, self).__init__(unk_token=unk_token, eos_token=eos_token,
additional_special_tokens=additional_special_tokens,
**kwargs)
self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens
def __init__(
self,
special=None,
min_freq=0,
max_size=None,
lower_case=False,
delimiter=None,
vocab_file=None,
pretrained_vocab_file=None,
never_split=None,
unk_token="<unk>",
eos_token="<eos>",
additional_special_tokens=["<formula>"],
**kwargs
):
super(TransfoXLTokenizer, self).__init__(
unk_token=unk_token, eos_token=eos_token, additional_special_tokens=additional_special_tokens, **kwargs
)
self.max_len_single_sentence = (
self.max_len
) # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = (
self.max_len
) # no default special tokens - you can update this value if you add special tokens
if never_split is None:
never_split = self.all_special_tokens
......@@ -106,14 +121,15 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
self.build_vocab()
def count_file(self, path, verbose=False, add_eos=False):
if verbose: logger.info('counting file {} ...'.format(path))
if verbose:
logger.info("counting file {} ...".format(path))
assert os.path.exists(path)
sents = []
with open(path, 'r', encoding='utf-8') as f:
with open(path, "r", encoding="utf-8") as f:
for idx, line in enumerate(f):
if verbose and idx > 0 and idx % 500000 == 0:
logger.info(' line {}'.format(idx))
logger.info(" line {}".format(idx))
symbols = self.tokenize(line, add_eos=add_eos)
self.counter.update(symbols)
sents.append(symbols)
......@@ -124,42 +140,42 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
"""
sents : a list of sentences, each a list of tokenized symbols
"""
if verbose: logger.info('counting {} sents ...'.format(len(sents)))
if verbose:
logger.info("counting {} sents ...".format(len(sents)))
for idx, symbols in enumerate(sents):
if verbose and idx > 0 and idx % 500000 == 0:
logger.info(' line {}'.format(idx))
logger.info(" line {}".format(idx))
self.counter.update(symbols)
def _build_from_file(self, vocab_file):
self.idx2sym = []
self.sym2idx = OrderedDict()
with open(vocab_file, 'r', encoding='utf-8') as f:
with open(vocab_file, "r", encoding="utf-8") as f:
for line in f:
symb = line.strip().split()[0]
self.add_symbol(symb)
if '<UNK>' in self.sym2idx:
self.unk_idx = self.sym2idx['<UNK>']
elif '<unk>' in self.sym2idx:
self.unk_idx = self.sym2idx['<unk>']
if "<UNK>" in self.sym2idx:
self.unk_idx = self.sym2idx["<UNK>"]
elif "<unk>" in self.sym2idx:
self.unk_idx = self.sym2idx["<unk>"]
else:
raise ValueError('No <unkown> token in vocabulary')
raise ValueError("No <unkown> token in vocabulary")
def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary to a directory or file."""
if os.path.isdir(vocab_path):
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['pretrained_vocab_file'])
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["pretrained_vocab_file"])
torch.save(self.__dict__, vocab_file)
return (vocab_file,)
def build_vocab(self):
if self.vocab_file:
logger.info('building vocab from {}'.format(self.vocab_file))
logger.info("building vocab from {}".format(self.vocab_file))
self._build_from_file(self.vocab_file)
logger.info('final vocab size {}'.format(len(self)))
logger.info("final vocab size {}".format(len(self)))
else:
logger.info('building vocab with min_freq={}, max_size={}'.format(
self.min_freq, self.max_size))
logger.info("building vocab with min_freq={}, max_size={}".format(self.min_freq, self.max_size))
self.idx2sym = []
self.sym2idx = OrderedDict()
......@@ -167,23 +183,22 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
self.add_special(sym)
for sym, cnt in self.counter.most_common(self.max_size):
if cnt < self.min_freq: break
if cnt < self.min_freq:
break
self.add_symbol(sym)
logger.info('final vocab size {} from {} unique tokens'.format(
len(self), len(self.counter)))
logger.info("final vocab size {} from {} unique tokens".format(len(self), len(self.counter)))
def encode_file(self, path, ordered=False, verbose=False, add_eos=True,
add_double_eos=False):
if verbose: logger.info('encoding file {} ...'.format(path))
def encode_file(self, path, ordered=False, verbose=False, add_eos=True, add_double_eos=False):
if verbose:
logger.info("encoding file {} ...".format(path))
assert os.path.exists(path)
encoded = []
with open(path, 'r', encoding='utf-8') as f:
with open(path, "r", encoding="utf-8") as f:
for idx, line in enumerate(f):
if verbose and idx > 0 and idx % 500000 == 0:
logger.info(' line {}'.format(idx))
symbols = self.tokenize(line, add_eos=add_eos,
add_double_eos=add_double_eos)
logger.info(" line {}".format(idx))
symbols = self.tokenize(line, add_eos=add_eos, add_double_eos=add_double_eos)
encoded.append(self.convert_to_tensor(symbols))
if ordered:
......@@ -192,11 +207,12 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
return encoded
def encode_sents(self, sents, ordered=False, verbose=False):
if verbose: logger.info('encoding {} sents ...'.format(len(sents)))
if verbose:
logger.info("encoding {} sents ...".format(len(sents)))
encoded = []
for idx, symbols in enumerate(sents):
if verbose and idx > 0 and idx % 500000 == 0:
logger.info(' line {}'.format(idx))
logger.info(" line {}".format(idx))
encoded.append(self.convert_to_tensor(symbols))
if ordered:
......@@ -208,7 +224,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
if sym not in self.sym2idx:
self.idx2sym.append(sym)
self.sym2idx[sym] = len(self.idx2sym) - 1
setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym])
setattr(self, "{}_idx".format(sym.strip("<>")), self.sym2idx[sym])
def add_symbol(self, sym):
if sym not in self.sym2idx:
......@@ -217,7 +233,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
def _convert_id_to_token(self, idx):
"""Converts an id in a token (BPE) using the vocab."""
assert 0 <= idx < len(self), 'Index {} out of vocabulary range'.format(idx)
assert 0 <= idx < len(self), "Index {} out of vocabulary range".format(idx)
return self.idx2sym[idx]
def _convert_token_to_id(self, sym):
......@@ -227,19 +243,19 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
else:
# logger.info('encounter unk {}'.format(sym))
# assert '<eos>' not in sym
if hasattr(self, 'unk_idx'):
if hasattr(self, "unk_idx"):
return self.sym2idx.get(sym, self.unk_idx)
# Backward compatibility with pre-trained models
elif '<unk>' in self.sym2idx:
return self.sym2idx['<unk>']
elif '<UNK>' in self.sym2idx:
return self.sym2idx['<UNK>']
elif "<unk>" in self.sym2idx:
return self.sym2idx["<unk>"]
elif "<UNK>" in self.sym2idx:
return self.sym2idx["<UNK>"]
else:
raise ValueError('Token not in vocabulary and no <unk> token in vocabulary for replacement')
raise ValueError("Token not in vocabulary and no <unk> token in vocabulary for replacement")
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
out_string = ' '.join(tokens).strip()
out_string = " ".join(tokens).strip()
return out_string
def convert_to_tensor(self, symbols):
......@@ -256,21 +272,21 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
line = line.lower()
# empty delimiter '' will evaluate False
if self.delimiter == '':
if self.delimiter == "":
symbols = line
else:
symbols = line.split(self.delimiter)
if add_double_eos: # lm1b
return ['<S>'] + symbols + ['<S>']
return ["<S>"] + symbols + ["<S>"]
elif add_eos:
return symbols + ['<eos>']
return symbols + ["<eos>"]
else:
return symbols
class LMOrderedIterator(object):
def __init__(self, data, bsz, bptt, device='cpu', ext_len=None):
def __init__(self, data, bsz, bptt, device="cpu", ext_len=None):
"""
data -- LongTensor -- the LongTensor is strictly ordered
"""
......@@ -293,14 +309,15 @@ class LMOrderedIterator(object):
self.n_batch = (self.n_step + self.bptt - 1) // self.bptt
def get_batch(self, i, bptt=None):
if bptt is None: bptt = self.bptt
if bptt is None:
bptt = self.bptt
seq_len = min(bptt, self.data.size(0) - 1 - i)
end_idx = i + seq_len
beg_idx = max(0, i - self.ext_len)
data = self.data[beg_idx:end_idx]
target = self.data[i+1:i+1+seq_len]
target = self.data[i + 1 : i + 1 + seq_len]
data_out = data.transpose(0, 1).contiguous().to(self.device)
target_out = target.transpose(0, 1).contiguous().to(self.device)
......@@ -315,7 +332,7 @@ class LMOrderedIterator(object):
max_len = self.bptt + max_deviation * std
i = start
while True:
bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2.
bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2.0
bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std))))
data, target, seq_len = self.get_batch(i, bptt)
i += seq_len
......@@ -328,7 +345,7 @@ class LMOrderedIterator(object):
class LMShuffledIterator(object):
def __init__(self, data, bsz, bptt, device='cpu', ext_len=None, shuffle=False):
def __init__(self, data, bsz, bptt, device="cpu", ext_len=None, shuffle=False):
"""
data -- list[LongTensor] -- there is no order among the LongTensors
"""
......@@ -343,8 +360,7 @@ class LMShuffledIterator(object):
def get_sent_stream(self):
# index iterator
epoch_indices = np.random.permutation(len(self.data)) if self.shuffle \
else np.array(range(len(self.data)))
epoch_indices = np.random.permutation(len(self.data)) if self.shuffle else np.array(range(len(self.data)))
# sentence iterator
for idx in epoch_indices:
......@@ -376,10 +392,8 @@ class LMShuffledIterator(object):
# number of new tokens to fill in
n_new = min(len(streams[i]) - 1, self.bptt - n_filled)
# first n_retain tokens are retained from last batch
data[n_retain+n_filled:n_retain+n_filled+n_new, i] = \
streams[i][:n_new]
target[n_filled:n_filled+n_new, i] = \
streams[i][1:n_new+1]
data[n_retain + n_filled : n_retain + n_filled + n_new, i] = streams[i][:n_new]
target[n_filled : n_filled + n_new, i] = streams[i][1 : n_new + 1]
streams[i] = streams[i][n_new:]
n_filled += n_new
except StopIteration:
......@@ -408,8 +422,7 @@ class LMShuffledIterator(object):
class LMMultiFileIterator(LMShuffledIterator):
def __init__(self, paths, vocab, bsz, bptt, device='cpu', ext_len=None,
shuffle=False):
def __init__(self, paths, vocab, bsz, bptt, device="cpu", ext_len=None, shuffle=False):
self.paths = paths
self.vocab = vocab
......@@ -460,15 +473,16 @@ class TransfoXLCorpus(object):
"We assumed '{}' was a path or url but couldn't find files {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_CORPUS_ARCHIVE_MAP.keys()),
", ".join(PRETRAINED_CORPUS_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
corpus_file))
corpus_file,
)
)
return None
if resolved_corpus_file == corpus_file:
logger.info("loading corpus file {}".format(corpus_file))
else:
logger.info("loading corpus file {} from cache at {}".format(
corpus_file, resolved_corpus_file))
logger.info("loading corpus file {} from cache at {}".format(corpus_file, resolved_corpus_file))
# Instantiate tokenizer.
corpus = cls(*inputs, **kwargs)
......@@ -494,83 +508,78 @@ class TransfoXLCorpus(object):
def build_corpus(self, path, dataset):
self.dataset = dataset
if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']:
self.vocab.count_file(os.path.join(path, 'train.txt'))
self.vocab.count_file(os.path.join(path, 'valid.txt'))
self.vocab.count_file(os.path.join(path, 'test.txt'))
elif self.dataset == 'wt103':
self.vocab.count_file(os.path.join(path, 'train.txt'))
elif self.dataset == 'lm1b':
if self.dataset in ["ptb", "wt2", "enwik8", "text8"]:
self.vocab.count_file(os.path.join(path, "train.txt"))
self.vocab.count_file(os.path.join(path, "valid.txt"))
self.vocab.count_file(os.path.join(path, "test.txt"))
elif self.dataset == "wt103":
self.vocab.count_file(os.path.join(path, "train.txt"))
elif self.dataset == "lm1b":
train_path_pattern = os.path.join(
path, '1-billion-word-language-modeling-benchmark-r13output',
'training-monolingual.tokenized.shuffled', 'news.en-*')
path,
"1-billion-word-language-modeling-benchmark-r13output",
"training-monolingual.tokenized.shuffled",
"news.en-*",
)
train_paths = glob.glob(train_path_pattern)
# the vocab will load from file when build_vocab() is called
self.vocab.build_vocab()
if self.dataset in ['ptb', 'wt2', 'wt103']:
self.train = self.vocab.encode_file(
os.path.join(path, 'train.txt'), ordered=True)
self.valid = self.vocab.encode_file(
os.path.join(path, 'valid.txt'), ordered=True)
self.test = self.vocab.encode_file(
os.path.join(path, 'test.txt'), ordered=True)
elif self.dataset in ['enwik8', 'text8']:
self.train = self.vocab.encode_file(
os.path.join(path, 'train.txt'), ordered=True, add_eos=False)
self.valid = self.vocab.encode_file(
os.path.join(path, 'valid.txt'), ordered=True, add_eos=False)
self.test = self.vocab.encode_file(
os.path.join(path, 'test.txt'), ordered=True, add_eos=False)
elif self.dataset == 'lm1b':
if self.dataset in ["ptb", "wt2", "wt103"]:
self.train = self.vocab.encode_file(os.path.join(path, "train.txt"), ordered=True)
self.valid = self.vocab.encode_file(os.path.join(path, "valid.txt"), ordered=True)
self.test = self.vocab.encode_file(os.path.join(path, "test.txt"), ordered=True)
elif self.dataset in ["enwik8", "text8"]:
self.train = self.vocab.encode_file(os.path.join(path, "train.txt"), ordered=True, add_eos=False)
self.valid = self.vocab.encode_file(os.path.join(path, "valid.txt"), ordered=True, add_eos=False)
self.test = self.vocab.encode_file(os.path.join(path, "test.txt"), ordered=True, add_eos=False)
elif self.dataset == "lm1b":
self.train = train_paths
self.valid = self.vocab.encode_file(
os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True)
self.test = self.vocab.encode_file(
os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True)
self.valid = self.vocab.encode_file(os.path.join(path, "valid.txt"), ordered=False, add_double_eos=True)
self.test = self.vocab.encode_file(os.path.join(path, "test.txt"), ordered=False, add_double_eos=True)
def get_iterator(self, split, *args, **kwargs):
if split == 'train':
if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']:
if split == "train":
if self.dataset in ["ptb", "wt2", "wt103", "enwik8", "text8"]:
data_iter = LMOrderedIterator(self.train, *args, **kwargs)
elif self.dataset == 'lm1b':
kwargs['shuffle'] = True
elif self.dataset == "lm1b":
kwargs["shuffle"] = True
data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs)
elif split in ['valid', 'test']:
data = self.valid if split == 'valid' else self.test
if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']:
elif split in ["valid", "test"]:
data = self.valid if split == "valid" else self.test
if self.dataset in ["ptb", "wt2", "wt103", "enwik8", "text8"]:
data_iter = LMOrderedIterator(data, *args, **kwargs)
elif self.dataset == 'lm1b':
elif self.dataset == "lm1b":
data_iter = LMShuffledIterator(data, *args, **kwargs)
return data_iter
def get_lm_corpus(datadir, dataset):
fn = os.path.join(datadir, 'cache.pt')
fn_pickle = os.path.join(datadir, 'cache.pkl')
fn = os.path.join(datadir, "cache.pt")
fn_pickle = os.path.join(datadir, "cache.pkl")
if os.path.exists(fn):
logger.info('Loading cached dataset...')
logger.info("Loading cached dataset...")
corpus = torch.load(fn_pickle)
elif os.path.exists(fn):
logger.info('Loading cached dataset from pickle...')
logger.info("Loading cached dataset from pickle...")
with open(fn, "rb") as fp:
corpus = pickle.load(fp)
else:
logger.info('Producing dataset {}...'.format(dataset))
logger.info("Producing dataset {}...".format(dataset))
kwargs = {}
if dataset in ['wt103', 'wt2']:
kwargs['special'] = ['<eos>']
kwargs['lower_case'] = False
elif dataset == 'ptb':
kwargs['special'] = ['<eos>']
kwargs['lower_case'] = True
elif dataset == 'lm1b':
kwargs['special'] = []
kwargs['lower_case'] = False
kwargs['vocab_file'] = os.path.join(datadir, '1b_word_vocab.txt')
elif dataset in ['enwik8', 'text8']:
if dataset in ["wt103", "wt2"]:
kwargs["special"] = ["<eos>"]
kwargs["lower_case"] = False
elif dataset == "ptb":
kwargs["special"] = ["<eos>"]
kwargs["lower_case"] = True
elif dataset == "lm1b":
kwargs["special"] = []
kwargs["lower_case"] = False
kwargs["vocab_file"] = os.path.join(datadir, "1b_word_vocab.txt")
elif dataset in ["enwik8", "text8"]:
pass
corpus = TransfoXLCorpus(datadir, dataset, **kwargs)
......
......@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from __future__ import absolute_import, division, print_function, unicode_literals
import logging
import os
......@@ -34,9 +33,10 @@ if is_torch_available():
logger = logging.getLogger(__name__)
SPECIAL_TOKENS_MAP_FILE = 'special_tokens_map.json'
ADDED_TOKENS_FILE = 'added_tokens.json'
TOKENIZER_CONFIG_FILE = 'tokenizer_config.json'
SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
ADDED_TOKENS_FILE = "added_tokens.json"
TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
class PreTrainedTokenizer(object):
""" Base class for all tokenizers.
......@@ -69,14 +69,22 @@ class PreTrainedTokenizer(object):
- ``additional_special_tokens``: (`Optional`) list: a list of additional special tokens. Adding all special tokens here ensure they won't be split by the tokenization process. Will be associated to ``self.additional_special_tokens`` and ``self.additional_special_tokens_ids``
"""
vocab_files_names = {}
pretrained_vocab_files_map = {}
pretrained_init_configuration = {}
max_model_input_sizes = {}
SPECIAL_TOKENS_ATTRIBUTES = ["bos_token", "eos_token", "unk_token", "sep_token",
"pad_token", "cls_token", "mask_token",
"additional_special_tokens"]
SPECIAL_TOKENS_ATTRIBUTES = [
"bos_token",
"eos_token",
"unk_token",
"sep_token",
"pad_token",
"cls_token",
"mask_token",
"additional_special_tokens",
]
padding_side = "right"
......@@ -227,7 +235,7 @@ class PreTrainedTokenizer(object):
self.max_len = max_len if max_len is not None else int(1e12)
# Padding side is right by default and over-riden in subclasses. If specified in the kwargs, it is changed.
self.padding_side = kwargs.pop('padding_side', self.padding_side)
self.padding_side = kwargs.pop("padding_side", self.padding_side)
# Added tokens
self.added_tokens_encoder = {}
......@@ -240,13 +248,14 @@ class PreTrainedTokenizer(object):
for key, value in kwargs.items():
if key in self.SPECIAL_TOKENS_ATTRIBUTES:
if key == 'additional_special_tokens':
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value)
if key == "additional_special_tokens":
assert isinstance(value, (list, tuple)) and all(
isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value
)
else:
assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode))
setattr(self, key, value)
@classmethod
def from_pretrained(cls, *inputs, **kwargs):
r"""
......@@ -302,13 +311,12 @@ class PreTrainedTokenizer(object):
"""
return cls._from_pretrained(*inputs, **kwargs)
@classmethod
def _from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs):
cache_dir = kwargs.pop('cache_dir', None)
force_download = kwargs.pop('force_download', False)
resume_download = kwargs.pop('resume_download', False)
proxies = kwargs.pop('proxies', None)
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
s3_models = list(cls.max_model_input_sizes.keys())
vocab_files = {}
......@@ -317,15 +325,19 @@ class PreTrainedTokenizer(object):
# Get the vocabulary from AWS S3 bucket
for file_id, map_list in cls.pretrained_vocab_files_map.items():
vocab_files[file_id] = map_list[pretrained_model_name_or_path]
if cls.pretrained_init_configuration and pretrained_model_name_or_path in cls.pretrained_init_configuration:
if (
cls.pretrained_init_configuration
and pretrained_model_name_or_path in cls.pretrained_init_configuration
):
init_configuration = cls.pretrained_init_configuration[pretrained_model_name_or_path]
else:
# Get the vocabulary from local files
logger.info(
"Model name '{}' not found in model shortcut name list ({}). "
"Assuming '{}' is a path or url to a directory containing tokenizer files.".format(
pretrained_model_name_or_path, ', '.join(s3_models),
pretrained_model_name_or_path))
pretrained_model_name_or_path, ", ".join(s3_models), pretrained_model_name_or_path
)
)
# Look for the tokenizer main vocabulary files
for file_id, file_name in cls.vocab_files_names.items():
......@@ -344,9 +356,10 @@ class PreTrainedTokenizer(object):
vocab_files[file_id] = full_file_name
# Look for the additional tokens files
additional_files_names = {'added_tokens_file': ADDED_TOKENS_FILE,
'special_tokens_map_file': SPECIAL_TOKENS_MAP_FILE,
'tokenizer_config_file': TOKENIZER_CONFIG_FILE,
additional_files_names = {
"added_tokens_file": ADDED_TOKENS_FILE,
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
}
# If a path to a file was provided, get the parent directory
......@@ -366,9 +379,12 @@ class PreTrainedTokenizer(object):
"Model name '{}' was not found in tokenizers model name list ({}). "
"We assumed '{}' was a path or url to a directory containing vocabulary files "
"named {} but couldn't find such vocabulary files at this path or url.".format(
pretrained_model_name_or_path, ', '.join(s3_models),
pretrained_model_name_or_path,
list(cls.vocab_files_names.values())))
", ".join(s3_models),
pretrained_model_name_or_path,
list(cls.vocab_files_names.values()),
)
)
# Get files from url, cache, or disk depending on the case
try:
......@@ -377,17 +393,27 @@ class PreTrainedTokenizer(object):
if file_path is None:
resolved_vocab_files[file_id] = None
else:
resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download)
resolved_vocab_files[file_id] = cached_path(
file_path,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
)
except EnvironmentError:
if pretrained_model_name_or_path in s3_models:
msg = "Couldn't reach server at '{}' to download vocabulary files."
else:
msg = "Model name '{}' was not found in tokenizers model name list ({}). " \
"We assumed '{}' was a path or url to a directory containing vocabulary files " \
msg = (
"Model name '{}' was not found in tokenizers model name list ({}). "
"We assumed '{}' was a path or url to a directory containing vocabulary files "
"named {}, but couldn't find such vocabulary files at this path or url.".format(
pretrained_model_name_or_path, ', '.join(s3_models),
pretrained_model_name_or_path,
list(cls.vocab_files_names.values()))
", ".join(s3_models),
pretrained_model_name_or_path,
list(cls.vocab_files_names.values()),
)
)
raise EnvironmentError(msg)
......@@ -395,16 +421,15 @@ class PreTrainedTokenizer(object):
if file_path == resolved_vocab_files[file_id]:
logger.info("loading file {}".format(file_path))
else:
logger.info("loading file {} from cache at {}".format(
file_path, resolved_vocab_files[file_id]))
logger.info("loading file {} from cache at {}".format(file_path, resolved_vocab_files[file_id]))
# Prepare tokenizer initialization kwargs
# Did we saved some inputs and kwargs to reload ?
tokenizer_config_file = resolved_vocab_files.pop('tokenizer_config_file', None)
tokenizer_config_file = resolved_vocab_files.pop("tokenizer_config_file", None)
if tokenizer_config_file is not None:
with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
init_kwargs = json.load(tokenizer_config_handle)
saved_init_inputs = init_kwargs.pop('init_inputs', ())
saved_init_inputs = init_kwargs.pop("init_inputs", ())
if not init_inputs:
init_inputs = saved_init_inputs
else:
......@@ -419,11 +444,11 @@ class PreTrainedTokenizer(object):
# wont index sequences longer than the number of positional embeddings
max_len = cls.max_model_input_sizes[pretrained_model_name_or_path]
if max_len is not None and isinstance(max_len, (int, float)):
init_kwargs['max_len'] = min(init_kwargs.get('max_len', int(1e12)), max_len)
init_kwargs["max_len"] = min(init_kwargs.get("max_len", int(1e12)), max_len)
# Merge resolved_vocab_files arguments in init_kwargs.
added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None)
special_tokens_map_file = resolved_vocab_files.pop('special_tokens_map_file', None)
added_tokens_file = resolved_vocab_files.pop("added_tokens_file", None)
special_tokens_map_file = resolved_vocab_files.pop("special_tokens_map_file", None)
for args_name, file_path in resolved_vocab_files.items():
if args_name not in init_kwargs:
init_kwargs[args_name] = file_path
......@@ -438,8 +463,10 @@ class PreTrainedTokenizer(object):
try:
tokenizer = cls(*init_inputs, **init_kwargs)
except OSError:
OSError("Unable to load vocabulary from file. "
"Please check that the provided vocabulary is accessible and not corrupted.")
OSError(
"Unable to load vocabulary from file. "
"Please check that the provided vocabulary is accessible and not corrupted."
)
# Save inputs and kwargs for saving and re-loading with ``save_pretrained``
tokenizer.init_inputs = init_inputs
......@@ -449,13 +476,12 @@ class PreTrainedTokenizer(object):
if added_tokens_file is not None:
with open(added_tokens_file, encoding="utf-8") as added_tokens_handle:
added_tok_encoder = json.load(added_tokens_handle)
added_tok_decoder = {v:k for k, v in added_tok_encoder.items()}
added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
tokenizer.added_tokens_encoder.update(added_tok_encoder)
tokenizer.added_tokens_decoder.update(added_tok_decoder)
return tokenizer
def save_pretrained(self, save_directory):
""" Save the tokenizer vocabulary files together with:
- added tokens,
......@@ -476,28 +502,27 @@ class PreTrainedTokenizer(object):
tokenizer_config_file = os.path.join(save_directory, TOKENIZER_CONFIG_FILE)
tokenizer_config = copy.deepcopy(self.init_kwargs)
tokenizer_config['init_inputs'] = copy.deepcopy(self.init_inputs)
tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs)
for file_id in self.vocab_files_names.keys():
tokenizer_config.pop(file_id, None)
with open(tokenizer_config_file, 'w', encoding='utf-8') as f:
with open(tokenizer_config_file, "w", encoding="utf-8") as f:
f.write(json.dumps(tokenizer_config, ensure_ascii=False))
with open(special_tokens_map_file, 'w', encoding='utf-8') as f:
with open(special_tokens_map_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.special_tokens_map, ensure_ascii=False))
with open(added_tokens_file, 'w', encoding='utf-8') as f:
with open(added_tokens_file, "w", encoding="utf-8") as f:
if self.added_tokens_encoder:
out_str = json.dumps(self.added_tokens_encoder, ensure_ascii=False)
else:
out_str = u"{}"
out_str = "{}"
f.write(out_str)
vocab_files = self.save_vocabulary(save_directory)
return vocab_files + (special_tokens_map_file, added_tokens_file)
def save_vocabulary(self, save_directory):
""" Save the tokenizer vocabulary to a directory. This method does *NOT* save added tokens
and special token mappings.
......@@ -506,17 +531,14 @@ class PreTrainedTokenizer(object):
"""
raise NotImplementedError
def vocab_size(self):
""" Size of the base vocabulary (without the added tokens) """
raise NotImplementedError
def __len__(self):
""" Size of the full vocabulary with the added tokens """
return self.vocab_size + len(self.added_tokens_encoder)
def add_tokens(self, new_tokens):
"""
Add a list of new tokens to the tokenizer class. If the new tokens are not in the
......@@ -544,16 +566,18 @@ class PreTrainedTokenizer(object):
to_add_tokens = []
for token in new_tokens:
assert isinstance(token, str) or (six.PY2 and isinstance(token, unicode))
if self.init_kwargs.get('do_lower_case', False) and token not in self.all_special_tokens:
if self.init_kwargs.get("do_lower_case", False) and token not in self.all_special_tokens:
token = token.lower()
if token != self.unk_token and \
self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token) and \
token not in to_add_tokens:
if (
token != self.unk_token
and self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token)
and token not in to_add_tokens
):
to_add_tokens.append(token)
logger.info("Adding %s to the vocabulary", token)
added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(to_add_tokens))
added_tok_decoder = {v:k for k, v in added_tok_encoder.items()}
added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
self.added_tokens_encoder.update(added_tok_encoder)
self.unique_added_tokens_encoder = set(self.added_tokens_encoder.keys()).union(set(self.all_special_tokens))
self.added_tokens_decoder.update(added_tok_decoder)
......@@ -622,8 +646,10 @@ class PreTrainedTokenizer(object):
added_tokens = 0
for key, value in special_tokens_dict.items():
assert key in self.SPECIAL_TOKENS_ATTRIBUTES
if key == 'additional_special_tokens':
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value)
if key == "additional_special_tokens":
assert isinstance(value, (list, tuple)) and all(
isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value
)
added_tokens += self.add_tokens(value)
else:
assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode))
......@@ -633,7 +659,6 @@ class PreTrainedTokenizer(object):
return added_tokens
def tokenize(self, text, **kwargs):
""" Converts a string in a sequence of tokens (string), using the tokenizer.
Split in words for word-based vocabulary or sub-words for sub-word-based
......@@ -649,14 +674,10 @@ class PreTrainedTokenizer(object):
def lowercase_text(t):
# convert non-special tokens to lowercase
escaped_special_toks = [re.escape(s_tok) for s_tok in all_special_tokens]
pattern = r'(' + r'|'.join(escaped_special_toks) + r')|' + \
r'(.+?)'
return re.sub(
pattern,
lambda m: m.groups()[0] or m.groups()[1].lower(),
t)
if self.init_kwargs.get('do_lower_case', False):
pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
return re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), t)
if self.init_kwargs.get("do_lower_case", False):
text = lowercase_text(text)
def split_on_token(tok, text):
......@@ -694,9 +715,14 @@ class PreTrainedTokenizer(object):
tokenized_text += [sub_text]
text_list = tokenized_text
return list(itertools.chain.from_iterable((self._tokenize(token, **kwargs) \
if token not in self.unique_added_tokens_encoder
else [token] for token in tokenized_text)))
return list(
itertools.chain.from_iterable(
(
self._tokenize(token, **kwargs) if token not in self.unique_added_tokens_encoder else [token]
for token in tokenized_text
)
)
)
added_tokens = self.unique_added_tokens_encoder
tokenized_text = split_on_tokens(added_tokens, text)
......@@ -737,16 +763,18 @@ class PreTrainedTokenizer(object):
def _convert_token_to_id(self, token):
raise NotImplementedError
def encode(self,
def encode(
self,
text,
text_pair=None,
add_special_tokens=True,
max_length=None,
stride=0,
truncation_strategy='longest_first',
truncation_strategy="longest_first",
pad_to_max_length=False,
return_tensors=None,
**kwargs):
**kwargs
):
"""
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
......@@ -781,7 +809,8 @@ class PreTrainedTokenizer(object):
or PyTorch torch.Tensor instead of a list of python integers.
**kwargs: passed to the `self.tokenize()` method
"""
encoded_inputs = self.encode_plus(text,
encoded_inputs = self.encode_plus(
text,
text_pair=text_pair,
max_length=max_length,
add_special_tokens=add_special_tokens,
......@@ -789,24 +818,27 @@ class PreTrainedTokenizer(object):
truncation_strategy=truncation_strategy,
pad_to_max_length=pad_to_max_length,
return_tensors=return_tensors,
**kwargs)
**kwargs
)
return encoded_inputs["input_ids"]
def encode_plus(self,
def encode_plus(
self,
text,
text_pair=None,
add_special_tokens=True,
max_length=None,
stride=0,
truncation_strategy='longest_first',
truncation_strategy="longest_first",
pad_to_max_length=False,
return_tensors=None,
return_token_type_ids=True,
return_attention_mask=True,
return_overflowing_tokens=False,
return_special_tokens_mask=False,
**kwargs):
**kwargs
):
"""
Returns a dictionary containing the encoded sequence or sequence pair and additional informations:
the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
......@@ -874,12 +906,15 @@ class PreTrainedTokenizer(object):
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
return text
else:
raise ValueError("Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers.")
raise ValueError(
"Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
)
first_ids = get_input_ids(text)
second_ids = get_input_ids(text_pair) if text_pair is not None else None
return self.prepare_for_model(first_ids,
return self.prepare_for_model(
first_ids,
pair_ids=second_ids,
max_length=max_length,
pad_to_max_length=pad_to_max_length,
......@@ -890,18 +925,21 @@ class PreTrainedTokenizer(object):
return_attention_mask=return_attention_mask,
return_token_type_ids=return_token_type_ids,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask)
return_special_tokens_mask=return_special_tokens_mask,
)
def batch_encode_plus(self,
def batch_encode_plus(
self,
batch_text_or_text_pairs=None,
add_special_tokens=False,
max_length=None,
stride=0,
truncation_strategy='longest_first',
truncation_strategy="longest_first",
return_tensors=None,
return_input_lengths=False,
return_attention_masks=False,
**kwargs):
**kwargs
):
"""
Returns a dictionary containing the encoded sequence or sequence pair and additional information:
the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
......@@ -933,12 +971,19 @@ class PreTrainedTokenizer(object):
ids, pair_ids = ids_or_pair_ids
else:
ids, pair_ids = ids_or_pair_ids, None
outputs = self.encode_plus(ids, pair_ids, add_special_tokens=add_special_tokens, max_length=max_length,
stride=stride, truncation_strategy=truncation_strategy, return_tensors=None)
outputs = self.encode_plus(
ids,
pair_ids,
add_special_tokens=add_special_tokens,
max_length=max_length,
stride=stride,
truncation_strategy=truncation_strategy,
return_tensors=None,
)
# Append the non-padded length to the output
if return_input_lengths:
outputs['input_len'] = len(outputs['input_ids'])
outputs["input_len"] = len(outputs["input_ids"])
for key, value in outputs.items():
if key not in batch_outputs:
......@@ -946,11 +991,11 @@ class PreTrainedTokenizer(object):
batch_outputs[key].append(value)
# Compute longest sequence size
max_seq_len = max(map(len, batch_outputs['input_ids']))
max_seq_len = max(map(len, batch_outputs["input_ids"]))
if return_attention_masks:
# Allow the model to not give any special attention to padded input
batch_outputs['attention_mask'] = [[0] * len(v) for v in batch_outputs['input_ids']]
batch_outputs["attention_mask"] = [[0] * len(v) for v in batch_outputs["input_ids"]]
if return_tensors is not None:
......@@ -958,34 +1003,48 @@ class PreTrainedTokenizer(object):
for key, value in batch_outputs.items():
padded_value = value
if key != 'input_len':
if key != "input_len":
# Padding handle
padded_value = [v + [self.pad_token_id if key == 'input_ids' else 1] * (max_seq_len - len(v)) for v in padded_value]
padded_value = [
v + [self.pad_token_id if key == "input_ids" else 1] * (max_seq_len - len(v))
for v in padded_value
]
if return_tensors == 'tf' and is_tf_available():
if return_tensors == "tf" and is_tf_available():
batch_outputs[key] = tf.constant(padded_value)
elif return_tensors == 'pt' and is_torch_available():
elif return_tensors == "pt" and is_torch_available():
batch_outputs[key] = torch.tensor(padded_value)
elif return_tensors is not None:
logger.warning("Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(return_tensors))
logger.warning(
"Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(
return_tensors
)
)
# encoder_attention_mask requires 1 for real token, 0 for padding, just invert value
if return_attention_masks:
if is_tf_available():
batch_outputs['attention_mask'] = tf.abs(batch_outputs['attention_mask'] - 1)
batch_outputs["attention_mask"] = tf.abs(batch_outputs["attention_mask"] - 1)
else:
batch_outputs['attention_mask'] = torch.abs(batch_outputs['attention_mask'] - 1)
batch_outputs["attention_mask"] = torch.abs(batch_outputs["attention_mask"] - 1)
return batch_outputs
def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=True, stride=0,
truncation_strategy='longest_first',
def prepare_for_model(
self,
ids,
pair_ids=None,
max_length=None,
add_special_tokens=True,
stride=0,
truncation_strategy="longest_first",
pad_to_max_length=False,
return_tensors=None,
return_token_type_ids=True,
return_attention_mask=True,
return_overflowing_tokens=False,
return_special_tokens_mask=False):
return_special_tokens_mask=False,
):
"""
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
It adds special tokens, truncates
......@@ -1050,10 +1109,13 @@ class PreTrainedTokenizer(object):
# Handle max sequence length
total_len = len_ids + len_pair_ids + (self.num_added_tokens(pair=pair) if add_special_tokens else 0)
if max_length and total_len > max_length:
ids, pair_ids, overflowing_tokens = self.truncate_sequences(ids, pair_ids=pair_ids,
num_tokens_to_remove=total_len-max_length,
ids, pair_ids, overflowing_tokens = self.truncate_sequences(
ids,
pair_ids=pair_ids,
num_tokens_to_remove=total_len - max_length,
truncation_strategy=truncation_strategy,
stride=stride)
stride=stride,
)
if return_overflowing_tokens:
encoded_inputs["overflowing_tokens"] = overflowing_tokens
encoded_inputs["num_truncated_tokens"] = total_len - max_length
......@@ -1081,35 +1143,45 @@ class PreTrainedTokenizer(object):
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"][:max_length]
if max_length is None and len(encoded_inputs["input_ids"]) > self.max_len:
logger.warning("Token indices sequence length is longer than the specified maximum sequence length "
logger.warning(
"Token indices sequence length is longer than the specified maximum sequence length "
"for this model ({} > {}). Running this sequence through the model will result in "
"indexing errors".format(len(ids), self.max_len))
"indexing errors".format(len(ids), self.max_len)
)
needs_to_be_padded = pad_to_max_length and (
max_length and len(encoded_inputs["input_ids"]) < max_length
or
max_length is None and len(encoded_inputs["input_ids"]) < self.max_len and self.max_len <= 10000
max_length
and len(encoded_inputs["input_ids"]) < max_length
or max_length is None
and len(encoded_inputs["input_ids"]) < self.max_len
and self.max_len <= 10000
)
if pad_to_max_length and max_length is None and self.max_len > 10000:
logger.warning("Sequence can't be padded as no maximum length is specified and the model maximum length is too high.")
logger.warning(
"Sequence can't be padded as no maximum length is specified and the model maximum length is too high."
)
if needs_to_be_padded:
difference = (max_length if max_length is not None else self.max_len) - len(encoded_inputs["input_ids"])
if self.padding_side == 'right':
if self.padding_side == "right":
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) + [0] * difference
if return_token_type_ids:
encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
encoded_inputs["token_type_ids"] = (
encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
)
if return_special_tokens_mask:
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
encoded_inputs["input_ids"] = encoded_inputs["input_ids"] + [self.pad_token_id] * difference
elif self.padding_side == 'left':
elif self.padding_side == "left":
if return_attention_mask:
encoded_inputs["attention_mask"] = [0] * difference + [1] * len(encoded_inputs["input_ids"])
if return_token_type_ids:
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs["token_type_ids"]
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
"token_type_ids"
]
if return_special_tokens_mask:
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
encoded_inputs["input_ids"] = [self.pad_token_id] * difference + encoded_inputs["input_ids"]
......@@ -1121,14 +1193,14 @@ class PreTrainedTokenizer(object):
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"])
# Prepare inputs as tensors if asked
if return_tensors == 'tf' and is_tf_available():
if return_tensors == "tf" and is_tf_available():
encoded_inputs["input_ids"] = tf.constant([encoded_inputs["input_ids"]])
encoded_inputs["token_type_ids"] = tf.constant([encoded_inputs["token_type_ids"]])
if "attention_mask" in encoded_inputs:
encoded_inputs["attention_mask"] = tf.constant([encoded_inputs["attention_mask"]])
elif return_tensors == 'pt' and is_torch_available():
elif return_tensors == "pt" and is_torch_available():
encoded_inputs["input_ids"] = torch.tensor([encoded_inputs["input_ids"]])
encoded_inputs["token_type_ids"] = torch.tensor([encoded_inputs["token_type_ids"]])
......@@ -1137,11 +1209,15 @@ class PreTrainedTokenizer(object):
elif return_tensors is not None:
logger.warning(
"Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(
return_tensors))
return_tensors
)
)
return encoded_inputs
def truncate_sequences(self, ids, pair_ids=None, num_tokens_to_remove=0, truncation_strategy='longest_first', stride=0):
def truncate_sequences(
self, ids, pair_ids=None, num_tokens_to_remove=0, truncation_strategy="longest_first", stride=0
):
"""Truncates a sequence pair in place to the maximum length.
truncation_strategy: string selected in the following options:
- 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length
......@@ -1154,7 +1230,7 @@ class PreTrainedTokenizer(object):
if num_tokens_to_remove <= 0:
return ids, pair_ids, []
if truncation_strategy == 'longest_first':
if truncation_strategy == "longest_first":
overflowing_tokens = []
for _ in range(num_tokens_to_remove):
if pair_ids is None or len(ids) > len(pair_ids):
......@@ -1165,20 +1241,22 @@ class PreTrainedTokenizer(object):
window_len = min(len(ids), stride)
if window_len > 0:
overflowing_tokens = ids[-window_len:] + overflowing_tokens
elif truncation_strategy == 'only_first':
elif truncation_strategy == "only_first":
assert len(ids) > num_tokens_to_remove
window_len = min(len(ids), stride + num_tokens_to_remove)
overflowing_tokens = ids[-window_len:]
ids = ids[:-num_tokens_to_remove]
elif truncation_strategy == 'only_second':
elif truncation_strategy == "only_second":
assert pair_ids is not None and len(pair_ids) > num_tokens_to_remove
window_len = min(len(pair_ids), stride + num_tokens_to_remove)
overflowing_tokens = pair_ids[-window_len:]
pair_ids = pair_ids[:-num_tokens_to_remove]
elif truncation_strategy == 'do_not_truncate':
elif truncation_strategy == "do_not_truncate":
raise ValueError("Input sequence are too long for max_length. Please select a truncation strategy.")
else:
raise ValueError("Truncation_strategy should be selected in ['longest_first', 'only_first', 'only_second', 'do_not_truncate']")
raise ValueError(
"Truncation_strategy should be selected in ['longest_first', 'only_first', 'only_second', 'do_not_truncate']"
)
return (ids, pair_ids, overflowing_tokens)
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
......@@ -1246,7 +1324,7 @@ class PreTrainedTokenizer(object):
The most simple way to do it is ' '.join(self.convert_ids_to_tokens(token_ids))
but we often want to remove sub-word tokenization artifacts at the same time.
"""
return ' '.join(self.convert_ids_to_tokens(tokens))
return " ".join(self.convert_ids_to_tokens(tokens))
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
"""
......@@ -1278,7 +1356,7 @@ class PreTrainedTokenizer(object):
current_sub_text.append(token)
if current_sub_text:
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
text = ' '.join(sub_texts)
text = " ".join(sub_texts)
if clean_up_tokenization_spaces:
clean_text = self.clean_up_tokenization(text)
......@@ -1323,7 +1401,17 @@ class PreTrainedTokenizer(object):
def clean_up_tokenization(out_string):
""" Clean up a list of simple English tokenization artifacts like spaces before punctuations and abreviated forms.
"""
out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
out_string = (
out_string.replace(" .", ".")
.replace(" ?", "?")
.replace(" !", "!")
.replace(" ,", ",")
.replace(" ' ", "'")
.replace(" n't", "n't")
.replace(" 'm", "'m")
.replace(" do not", " don't")
.replace(" 's", "'s")
.replace(" 've", "'ve")
.replace(" 're", "'re")
)
return out_string
......@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for XLM."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from __future__ import absolute_import, division, print_function, unicode_literals
import json
import logging
......@@ -32,71 +31,71 @@ from .tokenization_bert import BasicTokenizer
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {
'vocab_file': 'vocab.json',
'merges_file': 'merges.txt',
"vocab_file": "vocab.json",
"merges_file": "merges.txt",
}
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-vocab.json",
'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-vocab.json",
'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-vocab.json",
'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-vocab.json",
'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-vocab.json",
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-vocab.json",
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-vocab.json",
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-vocab.json",
'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-vocab.json",
'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-vocab.json",
"vocab_file": {
"xlm-mlm-en-2048": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-vocab.json",
"xlm-mlm-ende-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-vocab.json",
"xlm-mlm-enfr-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-vocab.json",
"xlm-mlm-enro-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-vocab.json",
"xlm-mlm-tlm-xnli15-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-vocab.json",
"xlm-mlm-xnli15-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-vocab.json",
"xlm-clm-enfr-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-vocab.json",
"xlm-clm-ende-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-vocab.json",
"xlm-mlm-17-1280": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-vocab.json",
"xlm-mlm-100-1280": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-vocab.json",
},
'merges_file':
{
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt",
'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-merges.txt",
'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-merges.txt",
'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-merges.txt",
'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-merges.txt",
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-merges.txt",
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-merges.txt",
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-merges.txt",
'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-merges.txt",
'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-merges.txt",
"merges_file": {
"xlm-mlm-en-2048": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt",
"xlm-mlm-ende-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-merges.txt",
"xlm-mlm-enfr-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-merges.txt",
"xlm-mlm-enro-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-merges.txt",
"xlm-mlm-tlm-xnli15-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-merges.txt",
"xlm-mlm-xnli15-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-merges.txt",
"xlm-clm-enfr-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-merges.txt",
"xlm-clm-ende-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-merges.txt",
"xlm-mlm-17-1280": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-merges.txt",
"xlm-mlm-100-1280": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-merges.txt",
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'xlm-mlm-en-2048': 512,
'xlm-mlm-ende-1024': 512,
'xlm-mlm-enfr-1024': 512,
'xlm-mlm-enro-1024': 512,
'xlm-mlm-tlm-xnli15-1024': 512,
'xlm-mlm-xnli15-1024': 512,
'xlm-clm-enfr-1024': 512,
'xlm-clm-ende-1024': 512,
'xlm-mlm-17-1280': 512,
'xlm-mlm-100-1280': 512,
"xlm-mlm-en-2048": 512,
"xlm-mlm-ende-1024": 512,
"xlm-mlm-enfr-1024": 512,
"xlm-mlm-enro-1024": 512,
"xlm-mlm-tlm-xnli15-1024": 512,
"xlm-mlm-xnli15-1024": 512,
"xlm-clm-enfr-1024": 512,
"xlm-clm-ende-1024": 512,
"xlm-mlm-17-1280": 512,
"xlm-mlm-100-1280": 512,
}
PRETRAINED_INIT_CONFIGURATION = {
'xlm-mlm-en-2048': {"do_lowercase_and_remove_accent": True},
'xlm-mlm-ende-1024': { "do_lowercase_and_remove_accent": True,
"id2lang": { "0": "de",
"1": "en"},
"lang2id": { "de": 0,
"en": 1 }},
'xlm-mlm-enfr-1024': { "do_lowercase_and_remove_accent": True,
"id2lang": { "0": "en",
"1": "fr"},
"lang2id": { "en": 0,
"fr": 1 }},
'xlm-mlm-enro-1024': { "do_lowercase_and_remove_accent": True,
"id2lang": { "0": "en",
"1": "ro"},
"lang2id": { "en": 0,
"ro": 1 }},
'xlm-mlm-tlm-xnli15-1024': { "do_lowercase_and_remove_accent": True,
"id2lang": { "0": "ar",
"xlm-mlm-en-2048": {"do_lowercase_and_remove_accent": True},
"xlm-mlm-ende-1024": {
"do_lowercase_and_remove_accent": True,
"id2lang": {"0": "de", "1": "en"},
"lang2id": {"de": 0, "en": 1},
},
"xlm-mlm-enfr-1024": {
"do_lowercase_and_remove_accent": True,
"id2lang": {"0": "en", "1": "fr"},
"lang2id": {"en": 0, "fr": 1},
},
"xlm-mlm-enro-1024": {
"do_lowercase_and_remove_accent": True,
"id2lang": {"0": "en", "1": "ro"},
"lang2id": {"en": 0, "ro": 1},
},
"xlm-mlm-tlm-xnli15-1024": {
"do_lowercase_and_remove_accent": True,
"id2lang": {
"0": "ar",
"1": "bg",
"2": "de",
"3": "el",
......@@ -110,8 +109,10 @@ PRETRAINED_INIT_CONFIGURATION = {
"11": "tr",
"12": "ur",
"13": "vi",
"14": "zh"},
"lang2id": { "ar": 0,
"14": "zh",
},
"lang2id": {
"ar": 0,
"bg": 1,
"de": 2,
"el": 3,
......@@ -125,9 +126,13 @@ PRETRAINED_INIT_CONFIGURATION = {
"tr": 11,
"ur": 12,
"vi": 13,
"zh": 14 }},
'xlm-mlm-xnli15-1024': { "do_lowercase_and_remove_accent": True,
"id2lang": { "0": "ar",
"zh": 14,
},
},
"xlm-mlm-xnli15-1024": {
"do_lowercase_and_remove_accent": True,
"id2lang": {
"0": "ar",
"1": "bg",
"2": "de",
"3": "el",
......@@ -141,8 +146,10 @@ PRETRAINED_INIT_CONFIGURATION = {
"11": "tr",
"12": "ur",
"13": "vi",
"14": "zh"},
"lang2id": { "ar": 0,
"14": "zh",
},
"lang2id": {
"ar": 0,
"bg": 1,
"de": 2,
"el": 3,
......@@ -156,18 +163,21 @@ PRETRAINED_INIT_CONFIGURATION = {
"tr": 11,
"ur": 12,
"vi": 13,
"zh": 14 }},
'xlm-clm-enfr-1024': { "do_lowercase_and_remove_accent": True,
"id2lang": { "0": "en",
"1": "fr"},
"lang2id": { "en": 0,
"fr": 1 }},
'xlm-clm-ende-1024': { "do_lowercase_and_remove_accent": True,
"id2lang": { "0": "de",
"1": "en"},
"lang2id": { "de": 0,
"en": 1 }},
'xlm-mlm-17-1280': {"do_lowercase_and_remove_accent": False,
"zh": 14,
},
},
"xlm-clm-enfr-1024": {
"do_lowercase_and_remove_accent": True,
"id2lang": {"0": "en", "1": "fr"},
"lang2id": {"en": 0, "fr": 1},
},
"xlm-clm-ende-1024": {
"do_lowercase_and_remove_accent": True,
"id2lang": {"0": "de", "1": "en"},
"lang2id": {"de": 0, "en": 1},
},
"xlm-mlm-17-1280": {
"do_lowercase_and_remove_accent": False,
"id2lang": {
"0": "ar",
"1": "de",
......@@ -185,7 +195,7 @@ PRETRAINED_INIT_CONFIGURATION = {
"13": "sv",
"14": "tr",
"15": "vi",
"16": "zh"
"16": "zh",
},
"lang2id": {
"ar": 0,
......@@ -204,8 +214,11 @@ PRETRAINED_INIT_CONFIGURATION = {
"sv": 13,
"tr": 14,
"vi": 15,
"zh": 16}},
'xlm-mlm-100-1280': {"do_lowercase_and_remove_accent": False,
"zh": 16,
},
},
"xlm-mlm-100-1280": {
"do_lowercase_and_remove_accent": False,
"id2lang": {
"0": "af",
"1": "als",
......@@ -306,7 +319,7 @@ PRETRAINED_INIT_CONFIGURATION = {
"96": "zh",
"97": "zh_classical",
"98": "zh_min_nan",
"99": "zh_yue"
"99": "zh_yue",
},
"lang2id": {
"af": 0,
......@@ -408,10 +421,12 @@ PRETRAINED_INIT_CONFIGURATION = {
"zh": 96,
"zh_classical": 97,
"zh_min_nan": 98,
"zh_yue": 99
}},
"zh_yue": 99,
},
},
}
def get_pairs(word):
"""
Return set of symbol pairs in a word.
......@@ -430,7 +445,7 @@ def lowercase_and_remove_accent(text):
Lowercase and strips accents from a piece of text based on
https://github.com/facebookresearch/XLM/blob/master/tools/lowercase_and_remove_accent.py
"""
text = ' '.join(text)
text = " ".join(text)
text = text.lower()
text = unicodedata.normalize("NFD", text)
output = []
......@@ -439,73 +454,73 @@ def lowercase_and_remove_accent(text):
if cat == "Mn":
continue
output.append(char)
return "".join(output).lower().split(' ')
return "".join(output).lower().split(" ")
def replace_unicode_punct(text):
'''
"""
Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/replace-unicode-punctuation.perl
'''
text = text.replace(',', ',')
text = re.sub(r'。\s*', '. ', text)
text = text.replace('、', ',')
text = text.replace('”', '"')
text = text.replace('“', '"')
text = text.replace('∶', ':')
text = text.replace(':', ':')
text = text.replace('?', '?')
text = text.replace('《', '"')
text = text.replace('》', '"')
text = text.replace(')', ')')
text = text.replace('!', '!')
text = text.replace('(', '(')
text = text.replace(';', ';')
text = text.replace('1', '"')
text = text.replace('」', '"')
text = text.replace('「', '"')
text = text.replace('0', '0')
text = text.replace('3', '3')
text = text.replace('2', '2')
text = text.replace('5', '5')
text = text.replace('6', '6')
text = text.replace('9', '9')
text = text.replace('7', '7')
text = text.replace('8', '8')
text = text.replace('4', '4')
text = re.sub(r'.\s*', '. ', text)
text = text.replace('~', '~')
text = text.replace('’', '\'')
text = text.replace('…', '...')
text = text.replace('━', '-')
text = text.replace('〈', '<')
text = text.replace('〉', '>')
text = text.replace('【', '[')
text = text.replace('】', ']')
text = text.replace('%', '%')
"""
text = text.replace(",", ",")
text = re.sub(r"。\s*", ". ", text)
text = text.replace("、", ",")
text = text.replace("”", '"')
text = text.replace("“", '"')
text = text.replace("∶", ":")
text = text.replace(":", ":")
text = text.replace("?", "?")
text = text.replace("《", '"')
text = text.replace("》", '"')
text = text.replace(")", ")")
text = text.replace("!", "!")
text = text.replace("(", "(")
text = text.replace(";", ";")
text = text.replace("1", '"')
text = text.replace("」", '"')
text = text.replace("「", '"')
text = text.replace("0", "0")
text = text.replace("3", "3")
text = text.replace("2", "2")
text = text.replace("5", "5")
text = text.replace("6", "6")
text = text.replace("9", "9")
text = text.replace("7", "7")
text = text.replace("8", "8")
text = text.replace("4", "4")
text = re.sub(r".\s*", ". ", text)
text = text.replace("~", "~")
text = text.replace("’", "'")
text = text.replace("…", "...")
text = text.replace("━", "-")
text = text.replace("〈", "<")
text = text.replace("〉", ">")
text = text.replace("【", "[")
text = text.replace("】", "]")
text = text.replace("%", "%")
return text
def remove_non_printing_char(text):
'''
"""
Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/remove-non-printing-char.perl
'''
"""
output = []
for char in text:
cat = unicodedata.category(char)
if cat.startswith('C'):
if cat.startswith("C"):
continue
output.append(char)
return "".join(output)
def romanian_preprocessing(text):
'''Sennrich's WMT16 scripts for Romanian preprocessing, used by model `xlm-mlm-enro-1024`'''
"""Sennrich's WMT16 scripts for Romanian preprocessing, used by model `xlm-mlm-enro-1024`"""
# https://github.com/rsennrich/wmt16-scripts/blob/master/preprocess/normalise-romanian.py
text = text.replace("\u015e", "\u0218").replace("\u015f", "\u0219")
text = text.replace("\u0162", "\u021a").replace("\u0163", "\u021b")
# https://github.com/rsennrich/wmt16-scripts/blob/master/preprocess/remove-diacritics.py
text = text.replace("\u0218", "S").replace("\u0219", "s") #s-comma
text = text.replace("\u021a", "T").replace("\u021b", "t") #t-comma
text = text.replace("\u0218", "S").replace("\u0219", "s") # s-comma
text = text.replace("\u021a", "T").replace("\u021b", "t") # t-comma
text = text.replace("\u0102", "A").replace("\u0103", "a")
text = text.replace("\u00C2", "A").replace("\u00E2", "a")
text = text.replace("\u00CE", "I").replace("\u00EE", "i")
......@@ -531,24 +546,49 @@ class XLMTokenizer(PreTrainedTokenizer):
- `do_lowercase_and_remove_accent` controle lower casing and accent (automatically set for pretrained vocabularies)
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, merges_file, unk_token="<unk>", bos_token="<s>",
sep_token="</s>", pad_token="<pad>", cls_token="</s>",
mask_token="<special1>", additional_special_tokens=["<special0>",
"<special1>", "<special2>", "<special3>", "<special4>", "<special5>",
"<special6>", "<special7>", "<special8>", "<special9>"],
lang2id=None, id2lang=None, do_lowercase_and_remove_accent=True,
**kwargs):
super(XLMTokenizer, self).__init__(unk_token=unk_token, bos_token=bos_token,
sep_token=sep_token, pad_token=pad_token,
cls_token=cls_token, mask_token=mask_token,
def __init__(
self,
vocab_file,
merges_file,
unk_token="<unk>",
bos_token="<s>",
sep_token="</s>",
pad_token="<pad>",
cls_token="</s>",
mask_token="<special1>",
additional_special_tokens=[
"<special0>",
"<special1>",
"<special2>",
"<special3>",
"<special4>",
"<special5>",
"<special6>",
"<special7>",
"<special8>",
"<special9>",
],
lang2id=None,
id2lang=None,
do_lowercase_and_remove_accent=True,
**kwargs
):
super(XLMTokenizer, self).__init__(
unk_token=unk_token,
bos_token=bos_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
additional_special_tokens=additional_special_tokens,
**kwargs)
**kwargs
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
......@@ -557,7 +597,7 @@ class XLMTokenizer(PreTrainedTokenizer):
self.cache_moses_punct_normalizer = dict()
# cache of sm.MosesTokenizer instance
self.cache_moses_tokenizer = dict()
self.lang_with_custom_tokenizer = set(['zh', 'th', 'ja'])
self.lang_with_custom_tokenizer = set(["zh", "th", "ja"])
# True for current supported model (v1.2.0), False for XLM-17 & 100
self.do_lowercase_and_remove_accent = do_lowercase_and_remove_accent
self.lang2id = lang2id
......@@ -570,9 +610,9 @@ class XLMTokenizer(PreTrainedTokenizer):
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
self.decoder = {v:k for k,v in self.encoder.items()}
with open(merges_file, encoding='utf-8') as merges_handle:
merges = merges_handle.read().split('\n')[:-1]
self.decoder = {v: k for k, v in self.encoder.items()}
with open(merges_file, encoding="utf-8") as merges_handle:
merges = merges_handle.read().split("\n")[:-1]
merges = [tuple(merge.split()[:2]) for merge in merges]
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {}
......@@ -603,9 +643,14 @@ class XLMTokenizer(PreTrainedTokenizer):
if self.ja_word_tokenizer is None:
try:
import Mykytea
self.ja_word_tokenizer = Mykytea.Mykytea('-model %s/local/share/kytea/model.bin' % os.path.expanduser('~'))
self.ja_word_tokenizer = Mykytea.Mykytea(
"-model %s/local/share/kytea/model.bin" % os.path.expanduser("~")
)
except (AttributeError, ImportError) as e:
logger.error("Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper (https://github.com/chezou/Mykytea-python) with the following steps")
logger.error(
"Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper (https://github.com/chezou/Mykytea-python) with the following steps"
)
logger.error("1. git clone git@github.com:neubig/kytea.git && cd kytea")
logger.error("2. autoreconf -i")
logger.error("3. ./configure --prefix=$HOME/local")
......@@ -619,16 +664,16 @@ class XLMTokenizer(PreTrainedTokenizer):
return len(self.encoder)
def bpe(self, token):
word = tuple(token[:-1]) + (token[-1] + '</w>',)
word = tuple(token[:-1]) + (token[-1] + "</w>",)
if token in self.cache:
return self.cache[token]
pairs = get_pairs(word)
if not pairs:
return token+'</w>'
return token + "</w>"
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
......@@ -643,8 +688,8 @@ class XLMTokenizer(PreTrainedTokenizer):
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
......@@ -655,13 +700,13 @@ class XLMTokenizer(PreTrainedTokenizer):
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
if word == '\n </w>':
word = '\n</w>'
word = " ".join(word)
if word == "\n </w>":
word = "\n</w>"
self.cache[token] = word
return word
def _tokenize(self, text, lang='en', bypass_tokenizer=False):
def _tokenize(self, text, lang="en", bypass_tokenizer=False):
"""
Tokenize a string given language code. For Chinese, Japanese and Thai, we use a language specific tokenizerself. Otherwise, we use Moses.
......@@ -697,45 +742,49 @@ class XLMTokenizer(PreTrainedTokenizer):
List of tokens.
"""
if lang and self.lang2id and lang not in self.lang2id:
logger.error("Supplied language code not found in lang2id mapping. Please check that your language is supported by the loaded pretrained model.")
logger.error(
"Supplied language code not found in lang2id mapping. Please check that your language is supported by the loaded pretrained model."
)
if bypass_tokenizer:
text = text.split()
elif lang not in self.lang_with_custom_tokenizer:
text = self.moses_pipeline(text, lang=lang)
# TODO: make sure we are using `xlm-mlm-enro-1024`, since XLM-100 doesn't have this step
if lang == 'ro':
if lang == "ro":
text = romanian_preprocessing(text)
text = self.moses_tokenize(text, lang=lang)
elif lang == 'th':
elif lang == "th":
text = self.moses_pipeline(text, lang=lang)
try:
if 'pythainlp' not in sys.modules:
if "pythainlp" not in sys.modules:
from pythainlp.tokenize import word_tokenize as th_word_tokenize
else:
th_word_tokenize = sys.modules['pythainlp'].word_tokenize
th_word_tokenize = sys.modules["pythainlp"].word_tokenize
except (AttributeError, ImportError) as e:
logger.error("Make sure you install PyThaiNLP (https://github.com/PyThaiNLP/pythainlp) with the following steps")
logger.error(
"Make sure you install PyThaiNLP (https://github.com/PyThaiNLP/pythainlp) with the following steps"
)
logger.error("1. pip install pythainlp")
raise e
text = th_word_tokenize(text)
elif lang == 'zh':
elif lang == "zh":
try:
if 'jieba' not in sys.modules:
if "jieba" not in sys.modules:
import jieba
else:
jieba = sys.modules['jieba']
jieba = sys.modules["jieba"]
except (AttributeError, ImportError) as e:
logger.error("Make sure you install Jieba (https://github.com/fxsjy/jieba) with the following steps")
logger.error("1. pip install jieba")
raise e
text = ' '.join(jieba.cut(text))
text = " ".join(jieba.cut(text))
text = self.moses_pipeline(text, lang=lang)
text = text.split()
elif lang == 'ja':
elif lang == "ja":
text = self.moses_pipeline(text, lang=lang)
text = self.ja_tokenize(text)
else:
raise ValueError('It should not reach here')
raise ValueError("It should not reach here")
if self.do_lowercase_and_remove_accent and not bypass_tokenizer:
text = lowercase_and_remove_accent(text)
......@@ -743,7 +792,7 @@ class XLMTokenizer(PreTrainedTokenizer):
split_tokens = []
for token in text:
if token:
split_tokens.extend([t for t in self.bpe(token).split(' ')])
split_tokens.extend([t for t in self.bpe(token).split(" ")])
return split_tokens
......@@ -757,7 +806,7 @@ class XLMTokenizer(PreTrainedTokenizer):
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
out_string = ''.join(tokens).replace('</w>', ' ').strip()
out_string = "".join(tokens).replace("</w>", " ").strip()
return out_string
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
......@@ -792,8 +841,10 @@ class XLMTokenizer(PreTrainedTokenizer):
if already_has_special_tokens:
if token_ids_1 is not None:
raise ValueError("You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model.")
raise ValueError(
"You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
if token_ids_1 is not None:
......@@ -820,20 +871,22 @@ class XLMTokenizer(PreTrainedTokenizer):
if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file'])
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES["merges_file"])
with open(vocab_file, 'w', encoding='utf-8') as f:
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file))
logger.warning(
"Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file)
)
index = token_index
writer.write(' '.join(bpe_tokens) + u'\n')
writer.write(" ".join(bpe_tokens) + "\n")
index += 1
return vocab_file, merge_file
......@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License
""" Tokenization classes for XLM-RoBERTa model."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from __future__ import absolute_import, division, print_function, unicode_literals
import logging
import os
......@@ -26,29 +25,29 @@ from .tokenization_xlnet import SPIECE_UNDERLINE
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {'vocab_file': 'sentencepiece.bpe.model'}
VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"}
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
'xlm-roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-base-sentencepiece.bpe.model",
'xlm-roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-sentencepiece.bpe.model",
'xlm-roberta-large-finetuned-conll02-dutch': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-dutch-sentencepiece.bpe.model",
'xlm-roberta-large-finetuned-conll02-spanish': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-spanish-sentencepiece.bpe.model",
'xlm-roberta-large-finetuned-conll03-english': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-english-sentencepiece.bpe.model",
'xlm-roberta-large-finetuned-conll03-german': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-german-sentencepiece.bpe.model",
"vocab_file": {
"xlm-roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-base-sentencepiece.bpe.model",
"xlm-roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-sentencepiece.bpe.model",
"xlm-roberta-large-finetuned-conll02-dutch": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-dutch-sentencepiece.bpe.model",
"xlm-roberta-large-finetuned-conll02-spanish": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-spanish-sentencepiece.bpe.model",
"xlm-roberta-large-finetuned-conll03-english": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-english-sentencepiece.bpe.model",
"xlm-roberta-large-finetuned-conll03-german": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-german-sentencepiece.bpe.model",
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'xlm-roberta-base': 512,
'xlm-roberta-large': 512,
'xlm-roberta-large-finetuned-conll02-dutch': 512,
'xlm-roberta-large-finetuned-conll02-spanish': 512,
'xlm-roberta-large-finetuned-conll03-english': 512,
'xlm-roberta-large-finetuned-conll03-german': 512,
"xlm-roberta-base": 512,
"xlm-roberta-large": 512,
"xlm-roberta-large-finetuned-conll02-dutch": 512,
"xlm-roberta-large-finetuned-conll02-spanish": 512,
"xlm-roberta-large-finetuned-conll03-english": 512,
"xlm-roberta-large-finetuned-conll03-german": 512,
}
class XLMRobertaTokenizer(PreTrainedTokenizer):
"""
Adapted from RobertaTokenizer and XLNetTokenizer
......@@ -56,17 +55,33 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
- requires `SentencePiece <https://github.com/google/sentencepiece>`_
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, bos_token="<s>", eos_token="</s>", sep_token="</s>",
cls_token="<s>", unk_token="<unk>", pad_token='<pad>', mask_token='<mask>',
**kwargs):
super(XLMRobertaTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token,
sep_token=sep_token, cls_token=cls_token, pad_token=pad_token,
def __init__(
self,
vocab_file,
bos_token="<s>",
eos_token="</s>",
sep_token="</s>",
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
**kwargs
):
super(XLMRobertaTokenizer, self).__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
**kwargs)
**kwargs
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens
self.sp_model = spm.SentencePieceProcessor()
......@@ -85,7 +100,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
# The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab
self.fairseq_offset = 1
self.fairseq_tokens_to_ids['<mask>'] = len(self.sp_model) + len(self.fairseq_tokens_to_ids)
self.fairseq_tokens_to_ids["<mask>"] = len(self.sp_model) + len(self.fairseq_tokens_to_ids)
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
......@@ -119,8 +134,10 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
"""
if already_has_special_tokens:
if token_ids_1 is not None:
raise ValueError("You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model.")
raise ValueError(
"You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
if token_ids_1 is None:
......@@ -164,7 +181,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip()
out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
return out_string
def save_vocabulary(self, save_directory):
......@@ -174,7 +191,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)
......
......@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Tokenization classes for XLNet model."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from __future__ import absolute_import, division, print_function, unicode_literals
import logging
import os
......@@ -27,22 +26,21 @@ from .tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {'vocab_file': 'spiece.model'}
VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-spiece.model",
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-spiece.model",
"vocab_file": {
"xlnet-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-spiece.model",
"xlnet-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-spiece.model",
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'xlnet-base-cased': None,
'xlnet-large-cased': None,
"xlnet-base-cased": None,
"xlnet-large-cased": None,
}
SPIECE_UNDERLINE = u'▁'
SPIECE_UNDERLINE = "▁"
# Segments (not really needed)
SEG_ID_A = 0
......@@ -51,27 +49,46 @@ SEG_ID_CLS = 2
SEG_ID_SEP = 3
SEG_ID_PAD = 4
class XLNetTokenizer(PreTrainedTokenizer):
"""
SentencePiece based tokenizer. Peculiarities:
- requires `SentencePiece <https://github.com/google/sentencepiece>`_
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
padding_side = "left"
def __init__(self, vocab_file,
do_lower_case=False, remove_space=True, keep_accents=False,
bos_token="<s>", eos_token="</s>", unk_token="<unk>", sep_token="<sep>",
pad_token="<pad>", cls_token="<cls>", mask_token="<mask>",
additional_special_tokens=["<eop>", "<eod>"], **kwargs):
super(XLNetTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token,
unk_token=unk_token, sep_token=sep_token,
pad_token=pad_token, cls_token=cls_token,
mask_token=mask_token, additional_special_tokens=
additional_special_tokens, **kwargs)
def __init__(
self,
vocab_file,
do_lower_case=False,
remove_space=True,
keep_accents=False,
bos_token="<s>",
eos_token="</s>",
unk_token="<unk>",
sep_token="<sep>",
pad_token="<pad>",
cls_token="<cls>",
mask_token="<mask>",
additional_special_tokens=["<eop>", "<eod>"],
**kwargs
):
super(XLNetTokenizer, self).__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
additional_special_tokens=additional_special_tokens,
**kwargs
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
......@@ -80,8 +97,10 @@ class XLNetTokenizer(PreTrainedTokenizer):
try:
import sentencepiece as spm
except ImportError:
logger.warning("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece")
logger.warning(
"You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
self.do_lower_case = do_lower_case
self.remove_space = remove_space
......@@ -105,24 +124,26 @@ class XLNetTokenizer(PreTrainedTokenizer):
try:
import sentencepiece as spm
except ImportError:
logger.warning("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece")
logger.warning(
"You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(self.vocab_file)
def preprocess_text(self, inputs):
if self.remove_space:
outputs = ' '.join(inputs.strip().split())
outputs = " ".join(inputs.strip().split())
else:
outputs = inputs
outputs = outputs.replace("``", '"').replace("''", '"')
if six.PY2 and isinstance(outputs, str):
outputs = outputs.decode('utf-8')
outputs = outputs.decode("utf-8")
if not self.keep_accents:
outputs = unicodedata.normalize('NFKD', outputs)
outputs = ''.join([c for c in outputs if not unicodedata.combining(c)])
outputs = unicodedata.normalize("NFKD", outputs)
outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
if self.do_lower_case:
outputs = outputs.lower()
......@@ -135,7 +156,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
text = self.preprocess_text(text)
# note(zhiliny): in some systems, sentencepiece only accepts str for py2
if six.PY2 and isinstance(text, unicode):
text = text.encode('utf-8')
text = text.encode("utf-8")
if not sample:
pieces = self.sp_model.EncodeAsPieces(text)
......@@ -143,9 +164,8 @@ class XLNetTokenizer(PreTrainedTokenizer):
pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1)
new_pieces = []
for piece in pieces:
if len(piece) > 1 and piece[-1] == str(',') and piece[-2].isdigit():
cur_pieces = self.sp_model.EncodeAsPieces(
piece[:-1].replace(SPIECE_UNDERLINE, ''))
if len(piece) > 1 and piece[-1] == str(",") and piece[-2].isdigit():
cur_pieces = self.sp_model.EncodeAsPieces(piece[:-1].replace(SPIECE_UNDERLINE, ""))
if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
if len(cur_pieces[0]) == 1:
cur_pieces = cur_pieces[1:]
......@@ -161,7 +181,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
ret_pieces = []
for piece in new_pieces:
if isinstance(piece, str):
piece = piece.decode('utf-8')
piece = piece.decode("utf-8")
ret_pieces.append(piece)
new_pieces = ret_pieces
......@@ -175,12 +195,12 @@ class XLNetTokenizer(PreTrainedTokenizer):
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
token = self.sp_model.IdToPiece(index)
if six.PY2 and return_unicode and isinstance(token, str):
token = token.decode('utf-8')
token = token.decode("utf-8")
return token
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip()
out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
return out_string
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
......@@ -215,8 +235,10 @@ class XLNetTokenizer(PreTrainedTokenizer):
if already_has_special_tokens:
if token_ids_1 is not None:
raise ValueError("You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model.")
raise ValueError(
"You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
if token_ids_1 is not None:
......@@ -247,7 +269,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)
......
''' Script for downloading all GLUE data.
""" Script for downloading all GLUE data.
Original source: https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e
Note: for legal reasons, we are unable to host MRPC.
......@@ -16,7 +16,7 @@ rm MSRParaphraseCorpus.msi
1/30/19: It looks like SentEval is no longer hosting their extracted and tokenized MRPC data, so you'll need to download the data from the original source for now.
2/11/19: It looks like SentEval actually *is* hosting the extracted data. Hooray!
'''
"""
import os
import sys
......@@ -27,20 +27,23 @@ import urllib.request
import zipfile
TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"]
TASK2PATH = {"CoLA":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4',
"SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',
"MRPC":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc',
"QQP":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5',
"STS":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5',
"MNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce',
"SNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df',
"QNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601',
"RTE":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb',
"WNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf',
"diagnostic":'https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D'}
MRPC_TRAIN = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt'
MRPC_TEST = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt'
TASK2PATH = {
"CoLA": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4",
"SST": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8",
"MRPC": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc",
"QQP": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5",
"STS": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5",
"MNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce",
"SNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df",
"QNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601",
"RTE": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb",
"WNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf",
"diagnostic": "https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D",
}
MRPC_TRAIN = "https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt"
MRPC_TEST = "https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt"
def download_and_extract(task, data_dir):
print("Downloading and extracting %s..." % task)
......@@ -51,6 +54,7 @@ def download_and_extract(task, data_dir):
os.remove(data_file)
print("\tCompleted!")
def format_mrpc(data_dir, path_to_data):
print("Processing MRPC...")
mrpc_dir = os.path.join(data_dir, "MRPC")
......@@ -72,30 +76,32 @@ def format_mrpc(data_dir, path_to_data):
dev_ids = []
with open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf8") as ids_fh:
for row in ids_fh:
dev_ids.append(row.strip().split('\t'))
dev_ids.append(row.strip().split("\t"))
with open(mrpc_train_file, encoding="utf8") as data_fh, \
open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding="utf8") as train_fh, \
open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding="utf8") as dev_fh:
with open(mrpc_train_file, encoding="utf8") as data_fh, open(
os.path.join(mrpc_dir, "train.tsv"), "w", encoding="utf8"
) as train_fh, open(os.path.join(mrpc_dir, "dev.tsv"), "w", encoding="utf8") as dev_fh:
header = data_fh.readline()
train_fh.write(header)
dev_fh.write(header)
for row in data_fh:
label, id1, id2, s1, s2 = row.strip().split('\t')
label, id1, id2, s1, s2 = row.strip().split("\t")
if [id1, id2] in dev_ids:
dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
else:
train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
with open(mrpc_test_file, encoding="utf8") as data_fh, \
open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding="utf8") as test_fh:
with open(mrpc_test_file, encoding="utf8") as data_fh, open(
os.path.join(mrpc_dir, "test.tsv"), "w", encoding="utf8"
) as test_fh:
header = data_fh.readline()
test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n")
for idx, row in enumerate(data_fh):
label, id1, id2, s1, s2 = row.strip().split('\t')
label, id1, id2, s1, s2 = row.strip().split("\t")
test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2))
print("\tCompleted!")
def download_diagnostic(data_dir):
print("Downloading and extracting diagnostic...")
if not os.path.isdir(os.path.join(data_dir, "diagnostic")):
......@@ -105,8 +111,9 @@ def download_diagnostic(data_dir):
print("\tCompleted!")
return
def get_tasks(task_names):
task_names = task_names.split(',')
task_names = task_names.split(",")
if "all" in task_names:
tasks = TASKS
else:
......@@ -116,13 +123,19 @@ def get_tasks(task_names):
tasks.append(task_name)
return tasks
def main(arguments):
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', help='directory to save data to', type=str, default='glue_data')
parser.add_argument('--tasks', help='tasks to download data for as a comma separated string',
type=str, default='all')
parser.add_argument('--path_to_mrpc', help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt',
type=str, default='')
parser.add_argument("--data_dir", help="directory to save data to", type=str, default="glue_data")
parser.add_argument(
"--tasks", help="tasks to download data for as a comma separated string", type=str, default="all"
)
parser.add_argument(
"--path_to_mrpc",
help="path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt",
type=str,
default="",
)
args = parser.parse_args(arguments)
if not os.path.isdir(args.data_dir):
......@@ -130,13 +143,13 @@ def main(arguments):
tasks = get_tasks(args.tasks)
for task in tasks:
if task == 'MRPC':
if task == "MRPC":
format_mrpc(args.data_dir, args.path_to_mrpc)
elif task == 'diagnostic':
elif task == "diagnostic":
download_diagnostic(args.data_dir)
else:
download_and_extract(task, args.data_dir)
if __name__ == '__main__':
if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))
......@@ -43,7 +43,7 @@ def scan_code_for_links(source):
""" Scans the file to find links using a regular expression.
Returns a list of links.
"""
with open(source, 'r') as content:
with open(source, "r") as content:
content = content.read()
raw_links = re.findall(REGEXP_FIND_S3_LINKS, content)
links = [prefix + suffix for _, prefix, suffix in raw_links]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment