Unverified Commit 54abc67a authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #2255 from aaugustin/implement-best-practices

Implement some Python best practices
parents 645713e2 c11b3e29
...@@ -15,30 +15,35 @@ ...@@ -15,30 +15,35 @@
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import os import os
import shutil
import sys import sys
from io import open
import tempfile import tempfile
import shutil
import unittest import unittest
from io import open
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
import cPickle as pickle import cPickle as pickle
class TemporaryDirectory(object): class TemporaryDirectory(object):
"""Context manager for tempfile.mkdtemp() so it's usable with "with" statement.""" """Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
def __enter__(self): def __enter__(self):
self.name = tempfile.mkdtemp() self.name = tempfile.mkdtemp()
return self.name return self.name
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
shutil.rmtree(self.name) shutil.rmtree(self.name)
else: else:
import pickle import pickle
TemporaryDirectory = tempfile.TemporaryDirectory TemporaryDirectory = tempfile.TemporaryDirectory
unicode = str unicode = str
class CommonTestCases: class CommonTestCases:
class CommonTokenizerTester(unittest.TestCase): class CommonTokenizerTester(unittest.TestCase):
tokenizer_class = None tokenizer_class = None
...@@ -57,17 +62,23 @@ class CommonTestCases: ...@@ -57,17 +62,23 @@ class CommonTestCases:
def test_tokenizers_common_properties(self): def test_tokenizers_common_properties(self):
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
attributes_list = ["bos_token", "eos_token", "unk_token", "sep_token", attributes_list = [
"pad_token", "cls_token", "mask_token"] "bos_token",
"eos_token",
"unk_token",
"sep_token",
"pad_token",
"cls_token",
"mask_token",
]
for attr in attributes_list: for attr in attributes_list:
self.assertTrue(hasattr(tokenizer, attr)) self.assertTrue(hasattr(tokenizer, attr))
self.assertTrue(hasattr(tokenizer, attr + "_id")) self.assertTrue(hasattr(tokenizer, attr + "_id"))
self.assertTrue(hasattr(tokenizer, "additional_special_tokens")) self.assertTrue(hasattr(tokenizer, "additional_special_tokens"))
self.assertTrue(hasattr(tokenizer, 'additional_special_tokens_ids')) self.assertTrue(hasattr(tokenizer, "additional_special_tokens_ids"))
attributes_list = ["max_len", "init_inputs", "init_kwargs", "added_tokens_encoder", attributes_list = ["max_len", "init_inputs", "init_kwargs", "added_tokens_encoder", "added_tokens_decoder"]
"added_tokens_decoder"]
for attr in attributes_list: for attr in attributes_list:
self.assertTrue(hasattr(tokenizer, attr)) self.assertTrue(hasattr(tokenizer, attr))
...@@ -79,13 +90,13 @@ class CommonTestCases: ...@@ -79,13 +90,13 @@ class CommonTestCases:
# Now let's start the test # Now let's start the test
tokenizer = self.get_tokenizer(max_len=42) tokenizer = self.get_tokenizer(max_len=42)
before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running", add_special_tokens=False) before_tokens = tokenizer.encode("He is very happy, UNwant\u00E9d,running", add_special_tokens=False)
with TemporaryDirectory() as tmpdirname: with TemporaryDirectory() as tmpdirname:
tokenizer.save_pretrained(tmpdirname) tokenizer.save_pretrained(tmpdirname)
tokenizer = self.tokenizer_class.from_pretrained(tmpdirname) tokenizer = self.tokenizer_class.from_pretrained(tmpdirname)
after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running", add_special_tokens=False) after_tokens = tokenizer.encode("He is very happy, UNwant\u00E9d,running", add_special_tokens=False)
self.assertListEqual(before_tokens, after_tokens) self.assertListEqual(before_tokens, after_tokens)
self.assertEqual(tokenizer.max_len, 42) self.assertEqual(tokenizer.max_len, 42)
...@@ -96,12 +107,12 @@ class CommonTestCases: ...@@ -96,12 +107,12 @@ class CommonTestCases:
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
self.assertIsNotNone(tokenizer) self.assertIsNotNone(tokenizer)
text = u"Munich and Berlin are nice cities" text = "Munich and Berlin are nice cities"
subwords = tokenizer.tokenize(text) subwords = tokenizer.tokenize(text)
with TemporaryDirectory() as tmpdirname: with TemporaryDirectory() as tmpdirname:
filename = os.path.join(tmpdirname, u"tokenizer.bin") filename = os.path.join(tmpdirname, "tokenizer.bin")
with open(filename, "wb") as handle: with open(filename, "wb") as handle:
pickle.dump(tokenizer, handle) pickle.dump(tokenizer, handle)
...@@ -122,7 +133,7 @@ class CommonTestCases: ...@@ -122,7 +133,7 @@ class CommonTestCases:
toks0 = tokenizer.tokenize(text) # toks before adding new_toks toks0 = tokenizer.tokenize(text) # toks before adding new_toks
new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd", 'AAAAA BBBBBB', 'CCCCCCCCCDDDDDDDD'] new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd", "AAAAA BBBBBB", "CCCCCCCCCDDDDDDDD"]
added = tokenizer.add_tokens(new_toks) added = tokenizer.add_tokens(new_toks)
self.assertEqual(added, 2) self.assertEqual(added, 2)
...@@ -178,8 +189,7 @@ class CommonTestCases: ...@@ -178,8 +189,7 @@ class CommonTestCases:
self.assertGreater(tokens[0], tokenizer.vocab_size - 1) self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
self.assertGreater(tokens[-2], tokenizer.vocab_size - 1) self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
new_toks_2 = {'eos_token': ">>>>|||<||<<|<<", new_toks_2 = {"eos_token": ">>>>|||<||<<|<<", "pad_token": "<<<<<|||>|>>>>|>"}
'pad_token': "<<<<<|||>|>>>>|>"}
added_toks_2 = tokenizer.add_special_tokens(new_toks_2) added_toks_2 = tokenizer.add_special_tokens(new_toks_2)
vocab_size_3 = tokenizer.vocab_size vocab_size_3 = tokenizer.vocab_size
all_size_3 = len(tokenizer) all_size_3 = len(tokenizer)
...@@ -189,8 +199,9 @@ class CommonTestCases: ...@@ -189,8 +199,9 @@ class CommonTestCases:
self.assertEqual(added_toks_2, len(new_toks_2)) self.assertEqual(added_toks_2, len(new_toks_2))
self.assertEqual(all_size_3, all_size_2 + len(new_toks_2)) self.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l", tokens = tokenizer.encode(
add_special_tokens=False) ">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l", add_special_tokens=False
)
out_string = tokenizer.decode(tokens) out_string = tokenizer.decode(tokens)
self.assertGreaterEqual(len(tokens), 6) self.assertGreaterEqual(len(tokens), 6)
...@@ -242,7 +253,7 @@ class CommonTestCases: ...@@ -242,7 +253,7 @@ class CommonTestCases:
def test_encode_decode_with_spaces(self): def test_encode_decode_with_spaces(self):
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
new_toks = ['[ABC]', '[DEF]', 'GHI IHG'] new_toks = ["[ABC]", "[DEF]", "GHI IHG"]
tokenizer.add_tokens(new_toks) tokenizer.add_tokens(new_toks)
input = "[ABC] [DEF] [ABC] GHI IHG [DEF]" input = "[ABC] [DEF] [ABC] GHI IHG [DEF]"
encoded = tokenizer.encode(input, add_special_tokens=False) encoded = tokenizer.encode(input, add_special_tokens=False)
...@@ -264,7 +275,7 @@ class CommonTestCases: ...@@ -264,7 +275,7 @@ class CommonTestCases:
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
if tokenizer.build_inputs_with_special_tokens.__qualname__.split('.')[0] != "PreTrainedTokenizer": if tokenizer.build_inputs_with_special_tokens.__qualname__.split(".")[0] != "PreTrainedTokenizer":
seq_0 = "Test this method." seq_0 = "Test this method."
seq_1 = "With these inputs." seq_1 = "With these inputs."
information = tokenizer.encode_plus(seq_0, seq_1, add_special_tokens=True) information = tokenizer.encode_plus(seq_0, seq_1, add_special_tokens=True)
...@@ -293,17 +304,19 @@ class CommonTestCases: ...@@ -293,17 +304,19 @@ class CommonTestCases:
sequence = tokenizer.encode(seq_0, add_special_tokens=False) sequence = tokenizer.encode(seq_0, add_special_tokens=False)
num_added_tokens = tokenizer.num_added_tokens() num_added_tokens = tokenizer.num_added_tokens()
total_length = len(sequence) + num_added_tokens total_length = len(sequence) + num_added_tokens
information = tokenizer.encode_plus(seq_0, information = tokenizer.encode_plus(
max_length=total_length - 2, seq_0,
add_special_tokens=True, max_length=total_length - 2,
stride=stride, add_special_tokens=True,
return_overflowing_tokens=True) stride=stride,
return_overflowing_tokens=True,
)
truncated_sequence = information["input_ids"] truncated_sequence = information["input_ids"]
overflowing_tokens = information["overflowing_tokens"] overflowing_tokens = information["overflowing_tokens"]
self.assertEqual(len(overflowing_tokens), 2 + stride) self.assertEqual(len(overflowing_tokens), 2 + stride)
self.assertEqual(overflowing_tokens, sequence[-(2 + stride):]) self.assertEqual(overflowing_tokens, sequence[-(2 + stride) :])
self.assertEqual(len(truncated_sequence), total_length - 2) self.assertEqual(len(truncated_sequence), total_length - 2)
self.assertEqual(truncated_sequence, tokenizer.build_inputs_with_special_tokens(sequence[:-2])) self.assertEqual(truncated_sequence, tokenizer.build_inputs_with_special_tokens(sequence[:-2]))
...@@ -320,24 +333,35 @@ class CommonTestCases: ...@@ -320,24 +333,35 @@ class CommonTestCases:
sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=True) sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=True)
truncated_second_sequence = tokenizer.build_inputs_with_special_tokens( truncated_second_sequence = tokenizer.build_inputs_with_special_tokens(
tokenizer.encode(seq_0, add_special_tokens=False), tokenizer.encode(seq_0, add_special_tokens=False),
tokenizer.encode(seq_1, add_special_tokens=False)[:-2] tokenizer.encode(seq_1, add_special_tokens=False)[:-2],
) )
information = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True, information = tokenizer.encode_plus(
stride=stride, truncation_strategy='only_second', seq_0,
return_overflowing_tokens=True) seq_1,
information_first_truncated = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, max_length=len(sequence) - 2,
add_special_tokens=True, stride=stride, add_special_tokens=True,
truncation_strategy='only_first', stride=stride,
return_overflowing_tokens=True) truncation_strategy="only_second",
return_overflowing_tokens=True,
)
information_first_truncated = tokenizer.encode_plus(
seq_0,
seq_1,
max_length=len(sequence) - 2,
add_special_tokens=True,
stride=stride,
truncation_strategy="only_first",
return_overflowing_tokens=True,
)
truncated_sequence = information["input_ids"] truncated_sequence = information["input_ids"]
overflowing_tokens = information["overflowing_tokens"] overflowing_tokens = information["overflowing_tokens"]
overflowing_tokens_first_truncated = information_first_truncated["overflowing_tokens"] overflowing_tokens_first_truncated = information_first_truncated["overflowing_tokens"]
self.assertEqual(len(overflowing_tokens), 2 + stride) self.assertEqual(len(overflowing_tokens), 2 + stride)
self.assertEqual(overflowing_tokens, sequence_1_no_special_tokens[-(2 + stride):]) self.assertEqual(overflowing_tokens, sequence_1_no_special_tokens[-(2 + stride) :])
self.assertEqual(overflowing_tokens_first_truncated, sequence_0_no_special_tokens[-(2 + stride):]) self.assertEqual(overflowing_tokens_first_truncated, sequence_0_no_special_tokens[-(2 + stride) :])
self.assertEqual(len(truncated_sequence), len(sequence) - 2) self.assertEqual(len(truncated_sequence), len(sequence) - 2)
self.assertEqual(truncated_sequence, truncated_second_sequence) self.assertEqual(truncated_sequence, truncated_second_sequence)
...@@ -361,37 +385,47 @@ class CommonTestCases: ...@@ -361,37 +385,47 @@ class CommonTestCases:
# Testing single inputs # Testing single inputs
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False) encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False)
encoded_sequence_dict = tokenizer.encode_plus(sequence_0, add_special_tokens=True, return_special_tokens_mask=True) encoded_sequence_dict = tokenizer.encode_plus(
sequence_0, add_special_tokens=True, return_special_tokens_mask=True
)
encoded_sequence_w_special = encoded_sequence_dict["input_ids"] encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"] special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special)) self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
filtered_sequence = [(x if not special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)] filtered_sequence = [
(x if not special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)
]
filtered_sequence = [x for x in filtered_sequence if x is not None] filtered_sequence = [x for x in filtered_sequence if x is not None]
self.assertEqual(encoded_sequence, filtered_sequence) self.assertEqual(encoded_sequence, filtered_sequence)
# Testing inputs pairs # Testing inputs pairs
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False) + tokenizer.encode(sequence_1, encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False) + tokenizer.encode(
add_special_tokens=False) sequence_1, add_special_tokens=False
encoded_sequence_dict = tokenizer.encode_plus(sequence_0, sequence_1, add_special_tokens=True, )
return_special_tokens_mask=True) encoded_sequence_dict = tokenizer.encode_plus(
sequence_0, sequence_1, add_special_tokens=True, return_special_tokens_mask=True
)
encoded_sequence_w_special = encoded_sequence_dict["input_ids"] encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"] special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special)) self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
filtered_sequence = [(x if not special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)] filtered_sequence = [
(x if not special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)
]
filtered_sequence = [x for x in filtered_sequence if x is not None] filtered_sequence = [x for x in filtered_sequence if x is not None]
self.assertEqual(encoded_sequence, filtered_sequence) self.assertEqual(encoded_sequence, filtered_sequence)
# Testing with already existing special tokens # Testing with already existing special tokens
if tokenizer.cls_token_id == tokenizer.unk_token_id and tokenizer.cls_token_id == tokenizer.unk_token_id: if tokenizer.cls_token_id == tokenizer.unk_token_id and tokenizer.cls_token_id == tokenizer.unk_token_id:
tokenizer.add_special_tokens({'cls_token': '</s>', 'sep_token': '<s>'}) tokenizer.add_special_tokens({"cls_token": "</s>", "sep_token": "<s>"})
encoded_sequence_dict = tokenizer.encode_plus(sequence_0, encoded_sequence_dict = tokenizer.encode_plus(
add_special_tokens=True, sequence_0, add_special_tokens=True, return_special_tokens_mask=True
return_special_tokens_mask=True) )
encoded_sequence_w_special = encoded_sequence_dict["input_ids"] encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
special_tokens_mask_orig = encoded_sequence_dict["special_tokens_mask"] special_tokens_mask_orig = encoded_sequence_dict["special_tokens_mask"]
special_tokens_mask = tokenizer.get_special_tokens_mask(encoded_sequence_w_special, already_has_special_tokens=True) special_tokens_mask = tokenizer.get_special_tokens_mask(
encoded_sequence_w_special, already_has_special_tokens=True
)
self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special)) self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
self.assertEqual(special_tokens_mask_orig, special_tokens_mask) self.assertEqual(special_tokens_mask_orig, special_tokens_mask)
...@@ -406,7 +440,9 @@ class CommonTestCases: ...@@ -406,7 +440,9 @@ class CommonTestCases:
tokenizer.padding_side = "right" tokenizer.padding_side = "right"
encoded_sequence = tokenizer.encode(sequence) encoded_sequence = tokenizer.encode(sequence)
sequence_length = len(encoded_sequence) sequence_length = len(encoded_sequence)
padded_sequence = tokenizer.encode(sequence, max_length=sequence_length + padding_size, pad_to_max_length=True) padded_sequence = tokenizer.encode(
sequence, max_length=sequence_length + padding_size, pad_to_max_length=True
)
padded_sequence_length = len(padded_sequence) padded_sequence_length = len(padded_sequence)
assert sequence_length + padding_size == padded_sequence_length assert sequence_length + padding_size == padded_sequence_length
assert encoded_sequence + [padding_idx] * padding_size == padded_sequence assert encoded_sequence + [padding_idx] * padding_size == padded_sequence
...@@ -415,7 +451,9 @@ class CommonTestCases: ...@@ -415,7 +451,9 @@ class CommonTestCases:
tokenizer.padding_side = "left" tokenizer.padding_side = "left"
encoded_sequence = tokenizer.encode(sequence) encoded_sequence = tokenizer.encode(sequence)
sequence_length = len(encoded_sequence) sequence_length = len(encoded_sequence)
padded_sequence = tokenizer.encode(sequence, max_length=sequence_length + padding_size, pad_to_max_length=True) padded_sequence = tokenizer.encode(
sequence, max_length=sequence_length + padding_size, pad_to_max_length=True
)
padded_sequence_length = len(padded_sequence) padded_sequence_length = len(padded_sequence)
assert sequence_length + padding_size == padded_sequence_length assert sequence_length + padding_size == padded_sequence_length
assert [padding_idx] * padding_size + encoded_sequence == padded_sequence assert [padding_idx] * padding_size + encoded_sequence == padded_sequence
...@@ -446,38 +484,48 @@ class CommonTestCases: ...@@ -446,38 +484,48 @@ class CommonTestCases:
token_type_padding_idx = tokenizer.pad_token_type_id token_type_padding_idx = tokenizer.pad_token_type_id
encoded_sequence = tokenizer.encode_plus(sequence, return_special_tokens_mask=True) encoded_sequence = tokenizer.encode_plus(sequence, return_special_tokens_mask=True)
input_ids = encoded_sequence['input_ids'] input_ids = encoded_sequence["input_ids"]
token_type_ids = encoded_sequence['token_type_ids'] token_type_ids = encoded_sequence["token_type_ids"]
attention_mask = encoded_sequence['attention_mask'] attention_mask = encoded_sequence["attention_mask"]
special_tokens_mask = encoded_sequence['special_tokens_mask'] special_tokens_mask = encoded_sequence["special_tokens_mask"]
sequence_length = len(input_ids) sequence_length = len(input_ids)
# Test right padding # Test right padding
tokenizer.padding_side = "right" tokenizer.padding_side = "right"
padded_sequence = tokenizer.encode_plus(sequence, max_length=sequence_length + padding_size, pad_to_max_length=True, return_special_tokens_mask=True) padded_sequence = tokenizer.encode_plus(
padded_input_ids = padded_sequence['input_ids'] sequence,
padded_token_type_ids = padded_sequence['token_type_ids'] max_length=sequence_length + padding_size,
padded_attention_mask = padded_sequence['attention_mask'] pad_to_max_length=True,
padded_special_tokens_mask = padded_sequence['special_tokens_mask'] return_special_tokens_mask=True,
)
padded_input_ids = padded_sequence["input_ids"]
padded_token_type_ids = padded_sequence["token_type_ids"]
padded_attention_mask = padded_sequence["attention_mask"]
padded_special_tokens_mask = padded_sequence["special_tokens_mask"]
padded_sequence_length = len(padded_input_ids) padded_sequence_length = len(padded_input_ids)
assert sequence_length + padding_size == padded_sequence_length assert sequence_length + padding_size == padded_sequence_length
assert input_ids + [padding_idx] * padding_size == padded_input_ids assert input_ids + [padding_idx] * padding_size == padded_input_ids
assert token_type_ids + [token_type_padding_idx] * padding_size == padded_token_type_ids assert token_type_ids + [token_type_padding_idx] * padding_size == padded_token_type_ids
assert attention_mask + [0] * padding_size == padded_attention_mask assert attention_mask + [0] * padding_size == padded_attention_mask
assert special_tokens_mask + [1] * padding_size == padded_special_tokens_mask assert special_tokens_mask + [1] * padding_size == padded_special_tokens_mask
# Test left padding # Test left padding
tokenizer.padding_side = "left" tokenizer.padding_side = "left"
padded_sequence = tokenizer.encode_plus(sequence, max_length=sequence_length + padding_size, pad_to_max_length=True, return_special_tokens_mask=True) padded_sequence = tokenizer.encode_plus(
padded_input_ids = padded_sequence['input_ids'] sequence,
padded_token_type_ids = padded_sequence['token_type_ids'] max_length=sequence_length + padding_size,
padded_attention_mask = padded_sequence['attention_mask'] pad_to_max_length=True,
padded_special_tokens_mask = padded_sequence['special_tokens_mask'] return_special_tokens_mask=True,
)
padded_input_ids = padded_sequence["input_ids"]
padded_token_type_ids = padded_sequence["token_type_ids"]
padded_attention_mask = padded_sequence["attention_mask"]
padded_special_tokens_mask = padded_sequence["special_tokens_mask"]
padded_sequence_length = len(padded_input_ids) padded_sequence_length = len(padded_input_ids)
assert sequence_length + padding_size == padded_sequence_length assert sequence_length + padding_size == padded_sequence_length
assert [padding_idx] * padding_size + input_ids == padded_input_ids assert [padding_idx] * padding_size + input_ids == padded_input_ids
assert [token_type_padding_idx] * padding_size + token_type_ids == padded_token_type_ids assert [token_type_padding_idx] * padding_size + token_type_ids == padded_token_type_ids
assert [0] * padding_size + attention_mask == padded_attention_mask assert [0] * padding_size + attention_mask == padded_attention_mask
assert [1] * padding_size + special_tokens_mask == padded_special_tokens_mask assert [1] * padding_size + special_tokens_mask == padded_special_tokens_mask
\ No newline at end of file
...@@ -20,14 +20,14 @@ from io import open ...@@ -20,14 +20,14 @@ from io import open
from transformers import is_torch_available from transformers import is_torch_available
if is_torch_available():
import torch
from transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
from .utils import require_torch from .utils import require_torch
if is_torch_available():
from transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES
@require_torch @require_torch
class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester): class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester):
...@@ -37,45 +37,53 @@ class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -37,45 +37,53 @@ class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester):
super(TransfoXLTokenizationTest, self).setUp() super(TransfoXLTokenizationTest, self).setUp()
vocab_tokens = [ vocab_tokens = [
"<unk>", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un", "<unk>",
"running", ",", "low", "l", "[CLS]",
"[SEP]",
"want",
"unwanted",
"wa",
"un",
"running",
",",
"low",
"l",
] ]
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
def get_tokenizer(self, **kwargs): def get_tokenizer(self, **kwargs):
kwargs['lower_case'] = True kwargs["lower_case"] = True
return TransfoXLTokenizer.from_pretrained(self.tmpdirname, **kwargs) return TransfoXLTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"<unk> UNwanted , running" input_text = "<unk> UNwanted , running"
output_text = u"<unk> unwanted, running" output_text = "<unk> unwanted, running"
return input_text, output_text return input_text, output_text
def test_full_tokenizer(self): def test_full_tokenizer(self):
tokenizer = TransfoXLTokenizer(vocab_file=self.vocab_file, lower_case=True) tokenizer = TransfoXLTokenizer(vocab_file=self.vocab_file, lower_case=True)
tokens = tokenizer.tokenize(u"<unk> UNwanted , running") tokens = tokenizer.tokenize("<unk> UNwanted , running")
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"]) self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
self.assertListEqual( self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
def test_full_tokenizer_lower(self): def test_full_tokenizer_lower(self):
tokenizer = TransfoXLTokenizer(lower_case=True) tokenizer = TransfoXLTokenizer(lower_case=True)
self.assertListEqual( self.assertListEqual(
tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "), tokenizer.tokenize(" \tHeLLo ! how \n Are yoU ? "), ["hello", "!", "how", "are", "you", "?"]
["hello", "!", "how", "are", "you", "?"]) )
def test_full_tokenizer_no_lower(self): def test_full_tokenizer_no_lower(self):
tokenizer = TransfoXLTokenizer(lower_case=False) tokenizer = TransfoXLTokenizer(lower_case=False)
self.assertListEqual( self.assertListEqual(
tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "), tokenizer.tokenize(" \tHeLLo ! how \n Are yoU ? "), ["HeLLo", "!", "how", "Are", "yoU", "?"]
["HeLLo", "!", "how", "Are", "yoU", "?"]) )
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -12,11 +12,10 @@ ...@@ -12,11 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import, division, print_function
from __future__ import division
from __future__ import print_function
import unittest import unittest
import six import six
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
...@@ -24,8 +23,8 @@ from transformers.tokenization_gpt2 import GPT2Tokenizer ...@@ -24,8 +23,8 @@ from transformers.tokenization_gpt2 import GPT2Tokenizer
from .utils import slow from .utils import slow
class TokenizerUtilsTest(unittest.TestCase):
class TokenizerUtilsTest(unittest.TestCase):
def check_tokenizer_from_pretrained(self, tokenizer_class): def check_tokenizer_from_pretrained(self, tokenizer_class):
s3_models = list(tokenizer_class.max_model_input_sizes.keys()) s3_models = list(tokenizer_class.max_model_input_sizes.keys())
for model_name in s3_models[:1]: for model_name in s3_models[:1]:
...@@ -36,7 +35,7 @@ class TokenizerUtilsTest(unittest.TestCase): ...@@ -36,7 +35,7 @@ class TokenizerUtilsTest(unittest.TestCase):
for special_tok in tokenizer.all_special_tokens: for special_tok in tokenizer.all_special_tokens:
if six.PY2: if six.PY2:
self.assertIsInstance(special_tok, unicode) self.assertIsInstance(special_tok, unicode) # noqa: F821
else: else:
self.assertIsInstance(special_tok, str) self.assertIsInstance(special_tok, str)
special_tok_id = tokenizer.convert_tokens_to_ids(special_tok) special_tok_id = tokenizer.convert_tokens_to_ids(special_tok)
...@@ -46,5 +45,6 @@ class TokenizerUtilsTest(unittest.TestCase): ...@@ -46,5 +45,6 @@ class TokenizerUtilsTest(unittest.TestCase):
def test_pretrained_tokenizers(self): def test_pretrained_tokenizers(self):
self.check_tokenizer_from_pretrained(GPT2Tokenizer) self.check_tokenizer_from_pretrained(GPT2Tokenizer)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -14,15 +14,16 @@ ...@@ -14,15 +14,16 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import json
import os import os
import unittest import unittest
import json
from transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES from transformers.tokenization_xlm import VOCAB_FILES_NAMES, XLMTokenizer
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
from .utils import slow from .utils import slow
class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester): class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
tokenizer_class = XLMTokenizer tokenizer_class = XLMTokenizer
...@@ -31,15 +32,34 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -31,15 +32,34 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
super(XLMTokenizationTest, self).setUp() super(XLMTokenizationTest, self).setUp()
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", vocab = [
"w</w>", "r</w>", "t</w>", "l",
"lo", "low", "er</w>", "o",
"low</w>", "lowest</w>", "newer</w>", "wider</w>", "<unk>"] "w",
"e",
"r",
"s",
"t",
"i",
"d",
"n",
"w</w>",
"r</w>",
"t</w>",
"lo",
"low",
"er</w>",
"low</w>",
"lowest</w>",
"newer</w>",
"wider</w>",
"<unk>",
]
vocab_tokens = dict(zip(vocab, range(len(vocab)))) vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["l o 123", "lo w 1456", "e r</w> 1789", ""] merges = ["l o 123", "lo w 1456", "e r</w> 1789", ""]
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
with open(self.vocab_file, "w") as fp: with open(self.vocab_file, "w") as fp:
fp.write(json.dumps(vocab_tokens)) fp.write(json.dumps(vocab_tokens))
with open(self.merges_file, "w") as fp: with open(self.merges_file, "w") as fp:
...@@ -49,8 +69,8 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -49,8 +69,8 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
return XLMTokenizer.from_pretrained(self.tmpdirname, **kwargs) return XLMTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"lower newer" input_text = "lower newer"
output_text = u"lower newer" output_text = "lower newer"
return input_text, output_text return input_text, output_text
def test_full_tokenizer(self): def test_full_tokenizer(self):
...@@ -64,8 +84,7 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -64,8 +84,7 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
input_tokens = tokens + ["<unk>"] input_tokens = tokens + ["<unk>"]
input_bpe_tokens = [14, 15, 20] input_bpe_tokens = [14, 15, 20]
self.assertListEqual( self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
@slow @slow
def test_sequence_builders(self): def test_sequence_builders(self):
...@@ -80,5 +99,6 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -80,5 +99,6 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
assert encoded_sentence == [1] + text + [1] assert encoded_sentence == [1] + text + [1]
assert encoded_pair == [1] + text + [1] + text_2 + [1] assert encoded_pair == [1] + text + [1] + text_2 + [1]
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -17,13 +17,14 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -17,13 +17,14 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
from transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE) from transformers.tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
from .utils import slow from .utils import slow
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'fixtures/test_sentencepiece.model') SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester): class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
...@@ -40,55 +41,135 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -40,55 +41,135 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
return XLNetTokenizer.from_pretrained(self.tmpdirname, **kwargs) return XLNetTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"This is a test" input_text = "This is a test"
output_text = u"This is a test" output_text = "This is a test"
return input_text, output_text return input_text, output_text
def test_full_tokenizer(self): def test_full_tokenizer(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
tokens = tokenizer.tokenize(u'This is a test') tokens = tokenizer.tokenize("This is a test")
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est']) self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"])
self.assertListEqual( self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.'])
ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual( self.assertListEqual(
ids, [8, 21, 84, 55, 24, 19, 7, 0, tokens,
602, 347, 347, 347, 3, 12, 66, [
46, 72, 80, 6, 0, 4]) SPIECE_UNDERLINE + "I",
SPIECE_UNDERLINE + "was",
SPIECE_UNDERLINE + "b",
"or",
"n",
SPIECE_UNDERLINE + "in",
SPIECE_UNDERLINE + "",
"9",
"2",
"0",
"0",
"0",
",",
SPIECE_UNDERLINE + "and",
SPIECE_UNDERLINE + "this",
SPIECE_UNDERLINE + "is",
SPIECE_UNDERLINE + "f",
"al",
"s",
"é",
".",
],
)
ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(ids, [8, 21, 84, 55, 24, 19, 7, 0, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 0, 4])
back_tokens = tokenizer.convert_ids_to_tokens(ids) back_tokens = tokenizer.convert_ids_to_tokens(ids)
self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', self.assertListEqual(
u'or', u'n', SPIECE_UNDERLINE + u'in', back_tokens,
SPIECE_UNDERLINE + u'', u'<unk>', u'2', u'0', u'0', u'0', u',', [
SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', SPIECE_UNDERLINE + "I",
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', SPIECE_UNDERLINE + "was",
u'<unk>', u'.']) SPIECE_UNDERLINE + "b",
"or",
"n",
SPIECE_UNDERLINE + "in",
SPIECE_UNDERLINE + "",
"<unk>",
"2",
"0",
"0",
"0",
",",
SPIECE_UNDERLINE + "and",
SPIECE_UNDERLINE + "this",
SPIECE_UNDERLINE + "is",
SPIECE_UNDERLINE + "f",
"al",
"s",
"<unk>",
".",
],
)
def test_tokenizer_lower(self): def test_tokenizer_lower(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True) tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True)
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'', u'i', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', self.assertListEqual(
u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', tokens,
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', [
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.']) SPIECE_UNDERLINE + "",
self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), [u"▁he", u"ll", u"o"]) "i",
SPIECE_UNDERLINE + "was",
SPIECE_UNDERLINE + "b",
"or",
"n",
SPIECE_UNDERLINE + "in",
SPIECE_UNDERLINE + "",
"9",
"2",
"0",
"0",
"0",
",",
SPIECE_UNDERLINE + "and",
SPIECE_UNDERLINE + "this",
SPIECE_UNDERLINE + "is",
SPIECE_UNDERLINE + "f",
"al",
"se",
".",
],
)
self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["▁he", "ll", "o"])
def test_tokenizer_no_lower(self): def test_tokenizer_no_lower(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=False) tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=False)
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', u'or', self.assertListEqual(
u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', tokens,
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', [
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.']) SPIECE_UNDERLINE + "I",
SPIECE_UNDERLINE + "was",
SPIECE_UNDERLINE + "b",
"or",
"n",
SPIECE_UNDERLINE + "in",
SPIECE_UNDERLINE + "",
"9",
"2",
"0",
"0",
"0",
",",
SPIECE_UNDERLINE + "and",
SPIECE_UNDERLINE + "this",
SPIECE_UNDERLINE + "is",
SPIECE_UNDERLINE + "f",
"al",
"se",
".",
],
)
@slow @slow
def test_sequence_builders(self): def test_sequence_builders(self):
...@@ -104,5 +185,5 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -104,5 +185,5 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
assert encoded_pair == text + [4] + text_2 + [4, 3] assert encoded_pair == text + [4] + text_2 + [4, 3]
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()
import os import os
import unittest
import tempfile import tempfile
import unittest
from distutils.util import strtobool from distutils.util import strtobool
from transformers.file_utils import _tf_available, _torch_available from transformers.file_utils import _tf_available, _torch_available
...@@ -27,6 +26,7 @@ def parse_flag_from_env(key, default=False): ...@@ -27,6 +26,7 @@ def parse_flag_from_env(key, default=False):
raise ValueError("If set, {} must be yes or no.".format(key)) raise ValueError("If set, {} must be yes or no.".format(key))
return _value return _value
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) _run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
_run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False) _run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False)
......
...@@ -13,45 +13,47 @@ ...@@ -13,45 +13,47 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Tokenization classes for ALBERT model.""" """ Tokenization classes for ALBERT model."""
from __future__ import (absolute_import, division, print_function, from __future__ import absolute_import, division, print_function, unicode_literals
unicode_literals)
from .tokenization_utils import PreTrainedTokenizer
import logging import logging
import unicodedata
import six
import os import os
import unicodedata
from shutil import copyfile from shutil import copyfile
import six
from .tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {'vocab_file': 'spiece.model'} VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file': "vocab_file": {
{ "albert-base-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-spiece.model",
'albert-base-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-spiece.model", "albert-large-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-spiece.model",
'albert-large-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-spiece.model", "albert-xlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-spiece.model",
'albert-xlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-spiece.model", "albert-xxlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-spiece.model",
'albert-xxlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-spiece.model", "albert-base-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-spiece.model",
'albert-base-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-spiece.model", "albert-large-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-spiece.model",
'albert-large-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-spiece.model", "albert-xlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-spiece.model",
'albert-xlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-spiece.model", "albert-xxlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-spiece.model",
'albert-xxlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-spiece.model",
} }
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'albert-base-v1': 512, "albert-base-v1": 512,
'albert-large-v1': 512, "albert-large-v1": 512,
'albert-xlarge-v1': 512, "albert-xlarge-v1": 512,
'albert-xxlarge-v1': 512, "albert-xxlarge-v1": 512,
'albert-base-v2': 512, "albert-base-v2": 512,
'albert-large-v2': 512, "albert-large-v2": 512,
'albert-xlarge-v2': 512, "albert-xlarge-v2": 512,
'albert-xxlarge-v2': 512, "albert-xxlarge-v2": 512,
} }
SPIECE_UNDERLINE = u'▁' SPIECE_UNDERLINE = "▁"
class AlbertTokenizer(PreTrainedTokenizer): class AlbertTokenizer(PreTrainedTokenizer):
""" """
...@@ -59,18 +61,36 @@ class AlbertTokenizer(PreTrainedTokenizer): ...@@ -59,18 +61,36 @@ class AlbertTokenizer(PreTrainedTokenizer):
- requires `SentencePiece <https://github.com/google/sentencepiece>`_ - requires `SentencePiece <https://github.com/google/sentencepiece>`_
""" """
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, def __init__(
do_lower_case=True, remove_space=True, keep_accents=False, self,
bos_token="[CLS]", eos_token="[SEP]", unk_token="<unk>", sep_token="[SEP]", vocab_file,
pad_token="<pad>", cls_token="[CLS]", mask_token="[MASK]", **kwargs): do_lower_case=True,
super(AlbertTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, remove_space=True,
unk_token=unk_token, sep_token=sep_token, keep_accents=False,
pad_token=pad_token, cls_token=cls_token, bos_token="[CLS]",
mask_token=mask_token, **kwargs) eos_token="[SEP]",
unk_token="<unk>",
sep_token="[SEP]",
pad_token="<pad>",
cls_token="[CLS]",
mask_token="[MASK]",
**kwargs
):
super(AlbertTokenizer, self).__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
**kwargs
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
...@@ -78,8 +98,10 @@ class AlbertTokenizer(PreTrainedTokenizer): ...@@ -78,8 +98,10 @@ class AlbertTokenizer(PreTrainedTokenizer):
try: try:
import sentencepiece as spm import sentencepiece as spm
except ImportError: except ImportError:
logger.warning("You need to install SentencePiece to use AlbertTokenizer: https://github.com/google/sentencepiece" logger.warning(
"pip install sentencepiece") "You need to install SentencePiece to use AlbertTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
self.do_lower_case = do_lower_case self.do_lower_case = do_lower_case
self.remove_space = remove_space self.remove_space = remove_space
...@@ -103,24 +125,26 @@ class AlbertTokenizer(PreTrainedTokenizer): ...@@ -103,24 +125,26 @@ class AlbertTokenizer(PreTrainedTokenizer):
try: try:
import sentencepiece as spm import sentencepiece as spm
except ImportError: except ImportError:
logger.warning("You need to install SentencePiece to use AlbertTokenizer: https://github.com/google/sentencepiece" logger.warning(
"pip install sentencepiece") "You need to install SentencePiece to use AlbertTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
self.sp_model = spm.SentencePieceProcessor() self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(self.vocab_file) self.sp_model.Load(self.vocab_file)
def preprocess_text(self, inputs): def preprocess_text(self, inputs):
if self.remove_space: if self.remove_space:
outputs = ' '.join(inputs.strip().split()) outputs = " ".join(inputs.strip().split())
else: else:
outputs = inputs outputs = inputs
outputs = outputs.replace("``", '"').replace("''", '"') outputs = outputs.replace("``", '"').replace("''", '"')
if six.PY2 and isinstance(outputs, str): if six.PY2 and isinstance(outputs, str):
outputs = outputs.decode('utf-8') outputs = outputs.decode("utf-8")
if not self.keep_accents: if not self.keep_accents:
outputs = unicodedata.normalize('NFKD', outputs) outputs = unicodedata.normalize("NFKD", outputs)
outputs = ''.join([c for c in outputs if not unicodedata.combining(c)]) outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
if self.do_lower_case: if self.do_lower_case:
outputs = outputs.lower() outputs = outputs.lower()
...@@ -132,8 +156,8 @@ class AlbertTokenizer(PreTrainedTokenizer): ...@@ -132,8 +156,8 @@ class AlbertTokenizer(PreTrainedTokenizer):
""" """
text = self.preprocess_text(text) text = self.preprocess_text(text)
# note(zhiliny): in some systems, sentencepiece only accepts str for py2 # note(zhiliny): in some systems, sentencepiece only accepts str for py2
if six.PY2 and isinstance(text, unicode): if six.PY2 and isinstance(text, unicode): # noqa: F821
text = text.encode('utf-8') text = text.encode("utf-8")
if not sample: if not sample:
pieces = self.sp_model.EncodeAsPieces(text) pieces = self.sp_model.EncodeAsPieces(text)
...@@ -141,9 +165,8 @@ class AlbertTokenizer(PreTrainedTokenizer): ...@@ -141,9 +165,8 @@ class AlbertTokenizer(PreTrainedTokenizer):
pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1) pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1)
new_pieces = [] new_pieces = []
for piece in pieces: for piece in pieces:
if len(piece) > 1 and piece[-1] == str(',') and piece[-2].isdigit(): if len(piece) > 1 and piece[-1] == str(",") and piece[-2].isdigit():
cur_pieces = self.sp_model.EncodeAsPieces( cur_pieces = self.sp_model.EncodeAsPieces(piece[:-1].replace(SPIECE_UNDERLINE, ""))
piece[:-1].replace(SPIECE_UNDERLINE, ''))
if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE: if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
if len(cur_pieces[0]) == 1: if len(cur_pieces[0]) == 1:
cur_pieces = cur_pieces[1:] cur_pieces = cur_pieces[1:]
...@@ -159,7 +182,7 @@ class AlbertTokenizer(PreTrainedTokenizer): ...@@ -159,7 +182,7 @@ class AlbertTokenizer(PreTrainedTokenizer):
ret_pieces = [] ret_pieces = []
for piece in new_pieces: for piece in new_pieces:
if isinstance(piece, str): if isinstance(piece, str):
piece = piece.decode('utf-8') piece = piece.decode("utf-8")
ret_pieces.append(piece) ret_pieces.append(piece)
new_pieces = ret_pieces new_pieces = ret_pieces
...@@ -173,12 +196,12 @@ class AlbertTokenizer(PreTrainedTokenizer): ...@@ -173,12 +196,12 @@ class AlbertTokenizer(PreTrainedTokenizer):
"""Converts an index (integer) in a token (string/unicode) using the vocab.""" """Converts an index (integer) in a token (string/unicode) using the vocab."""
token = self.sp_model.IdToPiece(index) token = self.sp_model.IdToPiece(index)
if six.PY2 and return_unicode and isinstance(token, str): if six.PY2 and return_unicode and isinstance(token, str):
token = token.decode('utf-8') token = token.decode("utf-8")
return token return token
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string.""" """Converts a sequence of tokens (strings for sub-words) in a single string."""
out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip() out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
return out_string return out_string
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
...@@ -213,8 +236,10 @@ class AlbertTokenizer(PreTrainedTokenizer): ...@@ -213,8 +236,10 @@ class AlbertTokenizer(PreTrainedTokenizer):
if already_has_special_tokens: if already_has_special_tokens:
if token_ids_1 is not None: if token_ids_1 is not None:
raise ValueError("You should not supply a second sequence if the provided sequence of " raise ValueError(
"ids is already formated with special tokens for the model.") "You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
if token_ids_1 is not None: if token_ids_1 is not None:
...@@ -244,7 +269,7 @@ class AlbertTokenizer(PreTrainedTokenizer): ...@@ -244,7 +269,7 @@ class AlbertTokenizer(PreTrainedTokenizer):
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return return
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
......
...@@ -18,23 +18,25 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -18,23 +18,25 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import logging import logging
from .tokenization_albert import AlbertTokenizer
from .tokenization_bert import BertTokenizer from .tokenization_bert import BertTokenizer
from .tokenization_bert_japanese import BertJapaneseTokenizer from .tokenization_bert_japanese import BertJapaneseTokenizer
from .tokenization_openai import OpenAIGPTTokenizer from .tokenization_camembert import CamembertTokenizer
from .tokenization_gpt2 import GPT2Tokenizer
from .tokenization_ctrl import CTRLTokenizer from .tokenization_ctrl import CTRLTokenizer
from .tokenization_transfo_xl import TransfoXLTokenizer
from .tokenization_xlnet import XLNetTokenizer
from .tokenization_xlm import XLMTokenizer
from .tokenization_roberta import RobertaTokenizer
from .tokenization_distilbert import DistilBertTokenizer from .tokenization_distilbert import DistilBertTokenizer
from .tokenization_camembert import CamembertTokenizer from .tokenization_gpt2 import GPT2Tokenizer
from .tokenization_albert import AlbertTokenizer from .tokenization_openai import OpenAIGPTTokenizer
from .tokenization_roberta import RobertaTokenizer
from .tokenization_t5 import T5Tokenizer from .tokenization_t5 import T5Tokenizer
from .tokenization_transfo_xl import TransfoXLTokenizer
from .tokenization_xlm import XLMTokenizer
from .tokenization_xlm_roberta import XLMRobertaTokenizer from .tokenization_xlm_roberta import XLMRobertaTokenizer
from .tokenization_xlnet import XLNetTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AutoTokenizer(object): class AutoTokenizer(object):
r""":class:`~transformers.AutoTokenizer` is a generic tokenizer class r""":class:`~transformers.AutoTokenizer` is a generic tokenizer class
that will be instantiated as one of the tokenizer classes of the library that will be instantiated as one of the tokenizer classes of the library
...@@ -62,9 +64,12 @@ class AutoTokenizer(object): ...@@ -62,9 +64,12 @@ class AutoTokenizer(object):
This class cannot be instantiated using `__init__()` (throw an error). This class cannot be instantiated using `__init__()` (throw an error).
""" """
def __init__(self): def __init__(self):
raise EnvironmentError("AutoTokenizer is designed to be instantiated " raise EnvironmentError(
"using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method.") "AutoTokenizer is designed to be instantiated "
"using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method."
)
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
...@@ -125,34 +130,38 @@ class AutoTokenizer(object): ...@@ -125,34 +130,38 @@ class AutoTokenizer(object):
tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/') tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/')
""" """
if 't5' in pretrained_model_name_or_path: if "t5" in pretrained_model_name_or_path:
return T5Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return T5Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'distilbert' in pretrained_model_name_or_path: elif "distilbert" in pretrained_model_name_or_path:
return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'albert' in pretrained_model_name_or_path: elif "albert" in pretrained_model_name_or_path:
return AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'camembert' in pretrained_model_name_or_path: elif "camembert" in pretrained_model_name_or_path:
return CamembertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return CamembertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'xlm-roberta' in pretrained_model_name_or_path: elif "xlm-roberta" in pretrained_model_name_or_path:
return XLMRobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return XLMRobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'roberta' in pretrained_model_name_or_path: elif "roberta" in pretrained_model_name_or_path:
return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'bert-base-japanese' in pretrained_model_name_or_path: elif "bert-base-japanese" in pretrained_model_name_or_path:
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'bert' in pretrained_model_name_or_path: elif "bert" in pretrained_model_name_or_path:
return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'openai-gpt' in pretrained_model_name_or_path: elif "openai-gpt" in pretrained_model_name_or_path:
return OpenAIGPTTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return OpenAIGPTTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'gpt2' in pretrained_model_name_or_path: elif "gpt2" in pretrained_model_name_or_path:
return GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'transfo-xl' in pretrained_model_name_or_path: elif "transfo-xl" in pretrained_model_name_or_path:
return TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'xlnet' in pretrained_model_name_or_path: elif "xlnet" in pretrained_model_name_or_path:
return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'xlm' in pretrained_model_name_or_path: elif "xlm" in pretrained_model_name_or_path:
return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'ctrl' in pretrained_model_name_or_path: elif "ctrl" in pretrained_model_name_or_path:
return CTRLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return CTRLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of " raise ValueError(
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " "Unrecognized model identifier in {}. Should contains one of "
"'xlm-roberta', 'xlm', 'roberta', 'distilbert,' 'camembert', 'ctrl', 'albert'".format(pretrained_model_name_or_path)) "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm-roberta', 'xlm', 'roberta', 'distilbert,' 'camembert', 'ctrl', 'albert'".format(
pretrained_model_name_or_path
)
)
...@@ -24,71 +24,71 @@ from io import open ...@@ -24,71 +24,71 @@ from io import open
from .tokenization_utils import PreTrainedTokenizer from .tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'} VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file': "vocab_file": {
{ "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", "bert-base-german-cased": "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt", "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt",
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt", "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt",
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt", "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt",
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt", "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt", "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt", "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-vocab.txt",
'bert-base-german-dbmdz-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-vocab.txt", "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-vocab.txt",
'bert-base-german-dbmdz-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-vocab.txt", "bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/vocab.txt",
'bert-base-finnish-cased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/vocab.txt", "bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/vocab.txt",
'bert-base-finnish-uncased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/vocab.txt",
} }
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'bert-base-uncased': 512, "bert-base-uncased": 512,
'bert-large-uncased': 512, "bert-large-uncased": 512,
'bert-base-cased': 512, "bert-base-cased": 512,
'bert-large-cased': 512, "bert-large-cased": 512,
'bert-base-multilingual-uncased': 512, "bert-base-multilingual-uncased": 512,
'bert-base-multilingual-cased': 512, "bert-base-multilingual-cased": 512,
'bert-base-chinese': 512, "bert-base-chinese": 512,
'bert-base-german-cased': 512, "bert-base-german-cased": 512,
'bert-large-uncased-whole-word-masking': 512, "bert-large-uncased-whole-word-masking": 512,
'bert-large-cased-whole-word-masking': 512, "bert-large-cased-whole-word-masking": 512,
'bert-large-uncased-whole-word-masking-finetuned-squad': 512, "bert-large-uncased-whole-word-masking-finetuned-squad": 512,
'bert-large-cased-whole-word-masking-finetuned-squad': 512, "bert-large-cased-whole-word-masking-finetuned-squad": 512,
'bert-base-cased-finetuned-mrpc': 512, "bert-base-cased-finetuned-mrpc": 512,
'bert-base-german-dbmdz-cased': 512, "bert-base-german-dbmdz-cased": 512,
'bert-base-german-dbmdz-uncased': 512, "bert-base-german-dbmdz-uncased": 512,
'bert-base-finnish-cased-v1': 512, "bert-base-finnish-cased-v1": 512,
'bert-base-finnish-uncased-v1': 512, "bert-base-finnish-uncased-v1": 512,
} }
PRETRAINED_INIT_CONFIGURATION = { PRETRAINED_INIT_CONFIGURATION = {
'bert-base-uncased': {'do_lower_case': True}, "bert-base-uncased": {"do_lower_case": True},
'bert-large-uncased': {'do_lower_case': True}, "bert-large-uncased": {"do_lower_case": True},
'bert-base-cased': {'do_lower_case': False}, "bert-base-cased": {"do_lower_case": False},
'bert-large-cased': {'do_lower_case': False}, "bert-large-cased": {"do_lower_case": False},
'bert-base-multilingual-uncased': {'do_lower_case': True}, "bert-base-multilingual-uncased": {"do_lower_case": True},
'bert-base-multilingual-cased': {'do_lower_case': False}, "bert-base-multilingual-cased": {"do_lower_case": False},
'bert-base-chinese': {'do_lower_case': False}, "bert-base-chinese": {"do_lower_case": False},
'bert-base-german-cased': {'do_lower_case': False}, "bert-base-german-cased": {"do_lower_case": False},
'bert-large-uncased-whole-word-masking': {'do_lower_case': True}, "bert-large-uncased-whole-word-masking": {"do_lower_case": True},
'bert-large-cased-whole-word-masking': {'do_lower_case': False}, "bert-large-cased-whole-word-masking": {"do_lower_case": False},
'bert-large-uncased-whole-word-masking-finetuned-squad': {'do_lower_case': True}, "bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True},
'bert-large-cased-whole-word-masking-finetuned-squad': {'do_lower_case': False}, "bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False},
'bert-base-cased-finetuned-mrpc': {'do_lower_case': False}, "bert-base-cased-finetuned-mrpc": {"do_lower_case": False},
'bert-base-german-dbmdz-cased': {'do_lower_case': False}, "bert-base-german-dbmdz-cased": {"do_lower_case": False},
'bert-base-german-dbmdz-uncased': {'do_lower_case': True}, "bert-base-german-dbmdz-uncased": {"do_lower_case": True},
'bert-base-finnish-cased-v1': {'do_lower_case': False}, "bert-base-finnish-cased-v1": {"do_lower_case": False},
'bert-base-finnish-uncased-v1': {'do_lower_case': True}, "bert-base-finnish-uncased-v1": {"do_lower_case": True},
} }
...@@ -98,7 +98,7 @@ def load_vocab(vocab_file): ...@@ -98,7 +98,7 @@ def load_vocab(vocab_file):
with open(vocab_file, "r", encoding="utf-8") as reader: with open(vocab_file, "r", encoding="utf-8") as reader:
tokens = reader.readlines() tokens = reader.readlines()
for index, token in enumerate(tokens): for index, token in enumerate(tokens):
token = token.rstrip('\n') token = token.rstrip("\n")
vocab[token] = index vocab[token] = index
return vocab return vocab
...@@ -132,9 +132,20 @@ class BertTokenizer(PreTrainedTokenizer): ...@@ -132,9 +132,20 @@ class BertTokenizer(PreTrainedTokenizer):
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None, def __init__(
unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]", self,
mask_token="[MASK]", tokenize_chinese_chars=True, **kwargs): vocab_file,
do_lower_case=True,
do_basic_tokenize=True,
never_split=None,
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]",
tokenize_chinese_chars=True,
**kwargs
):
"""Constructs a BertTokenizer. """Constructs a BertTokenizer.
Args: Args:
...@@ -152,24 +163,29 @@ class BertTokenizer(PreTrainedTokenizer): ...@@ -152,24 +163,29 @@ class BertTokenizer(PreTrainedTokenizer):
This should likely be deactivated for Japanese: This should likely be deactivated for Japanese:
see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328 see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
""" """
super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token, super(BertTokenizer, self).__init__(
pad_token=pad_token, cls_token=cls_token, unk_token=unk_token,
mask_token=mask_token, **kwargs) sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
**kwargs
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
if not os.path.isfile(vocab_file): if not os.path.isfile(vocab_file):
raise ValueError( raise ValueError(
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)
)
self.vocab = load_vocab(vocab_file) self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict( self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
[(ids, tok) for tok, ids in self.vocab.items()])
self.do_basic_tokenize = do_basic_tokenize self.do_basic_tokenize = do_basic_tokenize
if do_basic_tokenize: if do_basic_tokenize:
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, self.basic_tokenizer = BasicTokenizer(
never_split=never_split, do_lower_case=do_lower_case, never_split=never_split, tokenize_chinese_chars=tokenize_chinese_chars
tokenize_chinese_chars=tokenize_chinese_chars) )
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token) self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
@property @property
...@@ -196,7 +212,7 @@ class BertTokenizer(PreTrainedTokenizer): ...@@ -196,7 +212,7 @@ class BertTokenizer(PreTrainedTokenizer):
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """ """ Converts a sequence of tokens (string) in a single string. """
out_string = ' '.join(tokens).replace(' ##', '').strip() out_string = " ".join(tokens).replace(" ##", "").strip()
return out_string return out_string
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
...@@ -231,8 +247,10 @@ class BertTokenizer(PreTrainedTokenizer): ...@@ -231,8 +247,10 @@ class BertTokenizer(PreTrainedTokenizer):
if already_has_special_tokens: if already_has_special_tokens:
if token_ids_1 is not None: if token_ids_1 is not None:
raise ValueError("You should not supply a second sequence if the provided sequence of " raise ValueError(
"ids is already formated with special tokens for the model.") "You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
if token_ids_1 is not None: if token_ids_1 is not None:
...@@ -258,16 +276,18 @@ class BertTokenizer(PreTrainedTokenizer): ...@@ -258,16 +276,18 @@ class BertTokenizer(PreTrainedTokenizer):
"""Save the tokenizer vocabulary to a directory or file.""" """Save the tokenizer vocabulary to a directory or file."""
index = 0 index = 0
if os.path.isdir(vocab_path): if os.path.isdir(vocab_path):
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file']) vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
else: else:
vocab_file = vocab_path vocab_file = vocab_path
with open(vocab_file, "w", encoding="utf-8") as writer: with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
if index != token_index: if index != token_index:
logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive." logger.warning(
" Please check that the vocabulary is not corrupted!".format(vocab_file)) "Saving vocabulary to {}: vocabulary indices are not consecutive."
" Please check that the vocabulary is not corrupted!".format(vocab_file)
)
index = token_index index = token_index
writer.write(token + u'\n') writer.write(token + "\n")
index += 1 index += 1
return (vocab_file,) return (vocab_file,)
...@@ -382,14 +402,16 @@ class BasicTokenizer(object): ...@@ -382,14 +402,16 @@ class BasicTokenizer(object):
# as is Japanese Hiragana and Katakana. Those alphabets are used to write # as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled # space-separated words, so they are not treated specially and handled
# like the all of the other languages. # like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or # if (
(cp >= 0x3400 and cp <= 0x4DBF) or # (cp >= 0x4E00 and cp <= 0x9FFF)
(cp >= 0x20000 and cp <= 0x2A6DF) or # or (cp >= 0x3400 and cp <= 0x4DBF) #
(cp >= 0x2A700 and cp <= 0x2B73F) or # or (cp >= 0x20000 and cp <= 0x2A6DF) #
(cp >= 0x2B740 and cp <= 0x2B81F) or # or (cp >= 0x2A700 and cp <= 0x2B73F) #
(cp >= 0x2B820 and cp <= 0x2CEAF) or or (cp >= 0x2B740 and cp <= 0x2B81F) #
(cp >= 0xF900 and cp <= 0xFAFF) or # or (cp >= 0x2B820 and cp <= 0x2CEAF) #
(cp >= 0x2F800 and cp <= 0x2FA1F)): # or (cp >= 0xF900 and cp <= 0xFAFF)
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
): #
return True return True
return False return False
...@@ -399,7 +421,7 @@ class BasicTokenizer(object): ...@@ -399,7 +421,7 @@ class BasicTokenizer(object):
output = [] output = []
for char in text: for char in text:
cp = ord(char) cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char): if cp == 0 or cp == 0xFFFD or _is_control(char):
continue continue
if _is_whitespace(char): if _is_whitespace(char):
output.append(" ") output.append(" ")
...@@ -499,8 +521,7 @@ def _is_punctuation(char): ...@@ -499,8 +521,7 @@ def _is_punctuation(char):
# Characters such as "^", "$", and "`" are not in the Unicode # Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for # Punctuation class but we treat them as punctuation anyways, for
# consistency. # consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True return True
cat = unicodedata.category(char) cat = unicodedata.category(char)
if cat.startswith("P"): if cat.startswith("P"):
......
...@@ -19,55 +19,54 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -19,55 +19,54 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import collections import collections
import logging import logging
import os import os
import six
import unicodedata import unicodedata
from io import open
from .tokenization_bert import BertTokenizer, BasicTokenizer, WordpieceTokenizer, load_vocab import six
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_bert import BasicTokenizer, BertTokenizer, WordpieceTokenizer, load_vocab
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'} VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file': "vocab_file": {
{ "bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-vocab.txt",
'bert-base-japanese': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-vocab.txt", "bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-vocab.txt",
'bert-base-japanese-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-vocab.txt", "bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-vocab.txt",
'bert-base-japanese-char': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-vocab.txt", "bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-vocab.txt",
'bert-base-japanese-char-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-vocab.txt"
} }
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'bert-base-japanese': 512, "bert-base-japanese": 512,
'bert-base-japanese-whole-word-masking': 512, "bert-base-japanese-whole-word-masking": 512,
'bert-base-japanese-char': 512, "bert-base-japanese-char": 512,
'bert-base-japanese-char-whole-word-masking': 512 "bert-base-japanese-char-whole-word-masking": 512,
} }
PRETRAINED_INIT_CONFIGURATION = { PRETRAINED_INIT_CONFIGURATION = {
'bert-base-japanese': { "bert-base-japanese": {
'do_lower_case': False, "do_lower_case": False,
'word_tokenizer_type': 'mecab', "word_tokenizer_type": "mecab",
'subword_tokenizer_type': 'wordpiece' "subword_tokenizer_type": "wordpiece",
}, },
'bert-base-japanese-whole-word-masking':{ "bert-base-japanese-whole-word-masking": {
'do_lower_case': False, "do_lower_case": False,
'word_tokenizer_type': 'mecab', "word_tokenizer_type": "mecab",
'subword_tokenizer_type': 'wordpiece' "subword_tokenizer_type": "wordpiece",
}, },
'bert-base-japanese-char': { "bert-base-japanese-char": {
'do_lower_case': False, "do_lower_case": False,
'word_tokenizer_type': 'mecab', "word_tokenizer_type": "mecab",
'subword_tokenizer_type': 'character' "subword_tokenizer_type": "character",
},
"bert-base-japanese-char-whole-word-masking": {
"do_lower_case": False,
"word_tokenizer_type": "mecab",
"subword_tokenizer_type": "character",
}, },
'bert-base-japanese-char-whole-word-masking': {
'do_lower_case': False,
'word_tokenizer_type': 'mecab',
'subword_tokenizer_type': 'character'
}
} }
...@@ -79,11 +78,22 @@ class BertJapaneseTokenizer(BertTokenizer): ...@@ -79,11 +78,22 @@ class BertJapaneseTokenizer(BertTokenizer):
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, do_lower_case=False, def __init__(
do_word_tokenize=True, do_subword_tokenize=True, self,
word_tokenizer_type='basic', subword_tokenizer_type='wordpiece', vocab_file,
never_split=None, unk_token='[UNK]', sep_token='[SEP]', do_lower_case=False,
pad_token='[PAD]', cls_token='[CLS]', mask_token='[MASK]', **kwargs): do_word_tokenize=True,
do_subword_tokenize=True,
word_tokenizer_type="basic",
subword_tokenizer_type="wordpiece",
never_split=None,
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]",
**kwargs
):
"""Constructs a MecabBertTokenizer. """Constructs a MecabBertTokenizer.
Args: Args:
...@@ -100,56 +110,53 @@ class BertJapaneseTokenizer(BertTokenizer): ...@@ -100,56 +110,53 @@ class BertJapaneseTokenizer(BertTokenizer):
**subword_tokenizer_type**: (`optional`) string (default "wordpiece") **subword_tokenizer_type**: (`optional`) string (default "wordpiece")
Type of subword tokenizer. Type of subword tokenizer.
""" """
super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token, super(BertTokenizer, self).__init__(
pad_token=pad_token, cls_token=cls_token, unk_token=unk_token,
mask_token=mask_token, **kwargs) sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
**kwargs
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
if not os.path.isfile(vocab_file): if not os.path.isfile(vocab_file):
raise ValueError( raise ValueError(
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)
)
self.vocab = load_vocab(vocab_file) self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict( self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
[(ids, tok) for tok, ids in self.vocab.items()])
self.do_word_tokenize = do_word_tokenize self.do_word_tokenize = do_word_tokenize
if do_word_tokenize: if do_word_tokenize:
if word_tokenizer_type == 'basic': if word_tokenizer_type == "basic":
self.word_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, self.word_tokenizer = BasicTokenizer(
never_split=never_split, do_lower_case=do_lower_case, never_split=never_split, tokenize_chinese_chars=False
tokenize_chinese_chars=False) )
elif word_tokenizer_type == 'mecab': elif word_tokenizer_type == "mecab":
self.word_tokenizer = MecabTokenizer(do_lower_case=do_lower_case, self.word_tokenizer = MecabTokenizer(do_lower_case=do_lower_case, never_split=never_split)
never_split=never_split)
else: else:
raise ValueError( raise ValueError("Invalid word_tokenizer_type '{}' is specified.".format(word_tokenizer_type))
"Invalid word_tokenizer_type '{}' is specified.".format(word_tokenizer_type))
self.do_subword_tokenize = do_subword_tokenize self.do_subword_tokenize = do_subword_tokenize
if do_subword_tokenize: if do_subword_tokenize:
if subword_tokenizer_type == 'wordpiece': if subword_tokenizer_type == "wordpiece":
self.subword_tokenizer = WordpieceTokenizer(vocab=self.vocab, self.subword_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
unk_token=self.unk_token) elif subword_tokenizer_type == "character":
elif subword_tokenizer_type == 'character': self.subword_tokenizer = CharacterTokenizer(vocab=self.vocab, unk_token=self.unk_token)
self.subword_tokenizer = CharacterTokenizer(vocab=self.vocab,
unk_token=self.unk_token)
else: else:
raise ValueError( raise ValueError("Invalid subword_tokenizer_type '{}' is specified.".format(subword_tokenizer_type))
"Invalid subword_tokenizer_type '{}' is specified.".format(subword_tokenizer_type))
def _tokenize(self, text): def _tokenize(self, text):
if self.do_word_tokenize: if self.do_word_tokenize:
tokens = self.word_tokenizer.tokenize(text, tokens = self.word_tokenizer.tokenize(text, never_split=self.all_special_tokens)
never_split=self.all_special_tokens)
else: else:
tokens = [text] tokens = [text]
if self.do_subword_tokenize: if self.do_subword_tokenize:
split_tokens = [sub_token for token in tokens split_tokens = [sub_token for token in tokens for sub_token in self.subword_tokenizer.tokenize(token)]
for sub_token in self.subword_tokenizer.tokenize(token)]
else: else:
split_tokens = tokens split_tokens = tokens
...@@ -177,27 +184,28 @@ class MecabTokenizer(object): ...@@ -177,27 +184,28 @@ class MecabTokenizer(object):
self.normalize_text = normalize_text self.normalize_text = normalize_text
import MeCab import MeCab
self.mecab = MeCab.Tagger() self.mecab = MeCab.Tagger()
def tokenize(self, text, never_split=None, **kwargs): def tokenize(self, text, never_split=None, **kwargs):
"""Tokenizes a piece of text.""" """Tokenizes a piece of text."""
if self.normalize_text: if self.normalize_text:
text = unicodedata.normalize('NFKC', text) text = unicodedata.normalize("NFKC", text)
never_split = self.never_split + (never_split if never_split is not None else []) never_split = self.never_split + (never_split if never_split is not None else [])
tokens = [] tokens = []
if six.PY2: if six.PY2:
mecab_output = self.mecab.parse(text.encode('utf-8')).decode('utf-8') mecab_output = self.mecab.parse(text.encode("utf-8")).decode("utf-8")
else: else:
mecab_output = self.mecab.parse(text) mecab_output = self.mecab.parse(text)
cursor = 0 cursor = 0
for line in mecab_output.split('\n'): for line in mecab_output.split("\n"):
if line == 'EOS': if line == "EOS":
break break
token, _ = line.split('\t') token, _ = line.split("\t")
token_start = text.index(token, cursor) token_start = text.index(token, cursor)
token_end = token_start + len(token) token_end = token_start + len(token)
if self.do_lower_case and token not in never_split: if self.do_lower_case and token not in never_split:
...@@ -240,7 +248,7 @@ class CharacterTokenizer(object): ...@@ -240,7 +248,7 @@ class CharacterTokenizer(object):
A list of characters. A list of characters.
""" """
if self.normalize_text: if self.normalize_text:
text = unicodedata.normalize('NFKC', text) text = unicodedata.normalize("NFKC", text)
output_tokens = [] output_tokens = []
for i, char in enumerate(text): for i, char in enumerate(text):
......
...@@ -13,32 +13,34 @@ ...@@ -13,32 +13,34 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
""" Tokenization classes for Camembert model.""" """ Tokenization classes for Camembert model."""
from __future__ import (absolute_import, division, print_function, from __future__ import absolute_import, division, print_function, unicode_literals
unicode_literals)
import logging import logging
import os import os
from shutil import copyfile from shutil import copyfile
import sentencepiece as spm import sentencepiece as spm
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizer
from .tokenization_xlnet import SPIECE_UNDERLINE from .tokenization_xlnet import SPIECE_UNDERLINE
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {'vocab_file': 'sentencepiece.bpe.model'} VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"}
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file': "vocab_file": {
{ "camembert-base": "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-sentencepiece.bpe.model",
'camembert-base': "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-sentencepiece.bpe.model",
} }
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'camembert-base': None, "camembert-base": None,
} }
class CamembertTokenizer(PreTrainedTokenizer): class CamembertTokenizer(PreTrainedTokenizer):
""" """
Adapted from RobertaTokenizer and XLNetTokenizer Adapted from RobertaTokenizer and XLNetTokenizer
...@@ -46,17 +48,36 @@ class CamembertTokenizer(PreTrainedTokenizer): ...@@ -46,17 +48,36 @@ class CamembertTokenizer(PreTrainedTokenizer):
- requires `SentencePiece <https://github.com/google/sentencepiece>`_ - requires `SentencePiece <https://github.com/google/sentencepiece>`_
""" """
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, bos_token="<s>", eos_token="</s>", sep_token="</s>", def __init__(
cls_token="<s>", unk_token="<unk>", pad_token='<pad>', mask_token='<mask>', self,
additional_special_tokens=['<s>NOTUSED', '</s>NOTUSED'], **kwargs): vocab_file,
super(CamembertTokenizer, self).__init__(max_len=512, bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, bos_token="<s>",
sep_token=sep_token, cls_token=cls_token, pad_token=pad_token, eos_token="</s>",
mask_token=mask_token, additional_special_tokens=additional_special_tokens, sep_token="</s>",
**kwargs) cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
additional_special_tokens=["<s>NOTUSED", "</s>NOTUSED"],
**kwargs
):
super(CamembertTokenizer, self).__init__(
max_len=512,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
additional_special_tokens=additional_special_tokens,
**kwargs
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens
self.sp_model = spm.SentencePieceProcessor() self.sp_model = spm.SentencePieceProcessor()
...@@ -64,9 +85,9 @@ class CamembertTokenizer(PreTrainedTokenizer): ...@@ -64,9 +85,9 @@ class CamembertTokenizer(PreTrainedTokenizer):
self.vocab_file = vocab_file self.vocab_file = vocab_file
# HACK: These tokens were added by fairseq but don't seem to be actually used when duplicated in the actual # HACK: These tokens were added by fairseq but don't seem to be actually used when duplicated in the actual
# sentencepiece vocabulary (this is the case for <s> and </s> # sentencepiece vocabulary (this is the case for <s> and </s>
self.fairseq_tokens_to_ids = {'<s>NOTUSED': 0, '<pad>': 1, '</s>NOTUSED': 2, '<unk>': 3} self.fairseq_tokens_to_ids = {"<s>NOTUSED": 0, "<pad>": 1, "</s>NOTUSED": 2, "<unk>": 3}
self.fairseq_offset = len(self.fairseq_tokens_to_ids) self.fairseq_offset = len(self.fairseq_tokens_to_ids)
self.fairseq_tokens_to_ids['<mask>'] = len(self.sp_model) + len(self.fairseq_tokens_to_ids) self.fairseq_tokens_to_ids["<mask>"] = len(self.sp_model) + len(self.fairseq_tokens_to_ids)
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()} self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
...@@ -100,8 +121,10 @@ class CamembertTokenizer(PreTrainedTokenizer): ...@@ -100,8 +121,10 @@ class CamembertTokenizer(PreTrainedTokenizer):
""" """
if already_has_special_tokens: if already_has_special_tokens:
if token_ids_1 is not None: if token_ids_1 is not None:
raise ValueError("You should not supply a second sequence if the provided sequence of " raise ValueError(
"ids is already formated with special tokens for the model.") "You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
if token_ids_1 is None: if token_ids_1 is None:
...@@ -148,7 +171,7 @@ class CamembertTokenizer(PreTrainedTokenizer): ...@@ -148,7 +171,7 @@ class CamembertTokenizer(PreTrainedTokenizer):
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string.""" """Converts a sequence of tokens (strings for sub-words) in a single string."""
out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip() out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
return out_string return out_string
def save_vocabulary(self, save_directory): def save_vocabulary(self, save_directory):
...@@ -158,7 +181,7 @@ class CamembertTokenizer(PreTrainedTokenizer): ...@@ -158,7 +181,7 @@ class CamembertTokenizer(PreTrainedTokenizer):
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return return
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
......
...@@ -13,37 +13,32 @@ ...@@ -13,37 +13,32 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tokenization classes for Salesforce CTRL.""" """Tokenization classes for Salesforce CTRL."""
from __future__ import (absolute_import, division, print_function, from __future__ import absolute_import, division, print_function, unicode_literals
unicode_literals)
import json import json
import logging import logging
import os import os
import regex as re
from io import open from io import open
import regex as re
from .tokenization_utils import PreTrainedTokenizer from .tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = { VOCAB_FILES_NAMES = {
'vocab_file': 'vocab.json', "vocab_file": "vocab.json",
'merges_file': 'merges.txt', "merges_file": "merges.txt",
} }
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file': "vocab_file": {"ctrl": "https://raw.githubusercontent.com/salesforce/ctrl/master/ctrl-vocab.json"},
{ "merges_file": {"ctrl": "https://raw.githubusercontent.com/salesforce/ctrl/master/ctrl-merges.txt"},
'ctrl': "https://raw.githubusercontent.com/salesforce/ctrl/master/ctrl-vocab.json",
},
'merges_file':
{
'ctrl': "https://raw.githubusercontent.com/salesforce/ctrl/master/ctrl-merges.txt",
},
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'ctrl': 256, "ctrl": 256,
} }
CONTROL_CODES = { CONTROL_CODES = {
...@@ -104,6 +99,7 @@ CONTROL_CODES = { ...@@ -104,6 +99,7 @@ CONTROL_CODES = {
"multilingual": 128406, "multilingual": 128406,
} }
def get_pairs(word): def get_pairs(word):
"""Return set of symbol pairs in a word. """Return set of symbol pairs in a word.
...@@ -118,11 +114,13 @@ def get_pairs(word): ...@@ -118,11 +114,13 @@ def get_pairs(word):
pairs = set(pairs) pairs = set(pairs)
return pairs return pairs
class CTRLTokenizer(PreTrainedTokenizer): class CTRLTokenizer(PreTrainedTokenizer):
""" """
CTRL BPE tokenizer. Peculiarities: CTRL BPE tokenizer. Peculiarities:
- Byte-Pair-Encoding - Byte-Pair-Encoding
""" """
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
...@@ -130,14 +128,18 @@ class CTRLTokenizer(PreTrainedTokenizer): ...@@ -130,14 +128,18 @@ class CTRLTokenizer(PreTrainedTokenizer):
def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs): def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs):
super(CTRLTokenizer, self).__init__(unk_token=unk_token, **kwargs) super(CTRLTokenizer, self).__init__(unk_token=unk_token, **kwargs)
self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens self.max_len_single_sentence = (
self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens self.max_len
) # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = (
self.max_len
) # no default special tokens - you can update this value if you add special tokens
with open(vocab_file, encoding="utf-8") as vocab_handle: with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle) self.encoder = json.load(vocab_handle)
self.decoder = {v:k for k,v in self.encoder.items()} self.decoder = {v: k for k, v in self.encoder.items()}
with open(merges_file, encoding='utf-8') as merges_handle: with open(merges_file, encoding="utf-8") as merges_handle:
merges = merges_handle.read().split('\n')[1:-1] merges = merges_handle.read().split("\n")[1:-1]
merges = [tuple(merge.split()) for merge in merges] merges = [tuple(merge.split()) for merge in merges]
self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {} self.cache = {}
...@@ -150,14 +152,14 @@ class CTRLTokenizer(PreTrainedTokenizer): ...@@ -150,14 +152,14 @@ class CTRLTokenizer(PreTrainedTokenizer):
if token in self.cache: if token in self.cache:
return self.cache[token] return self.cache[token]
word = tuple(token) word = tuple(token)
word = tuple(list(word[:-1]) + [word[-1]+'</w>']) word = tuple(list(word[:-1]) + [word[-1] + "</w>"])
pairs = get_pairs(word) pairs = get_pairs(word)
if not pairs: if not pairs:
return token return token
while True: while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks: if bigram not in self.bpe_ranks:
break break
first, second = bigram first, second = bigram
...@@ -166,14 +168,15 @@ class CTRLTokenizer(PreTrainedTokenizer): ...@@ -166,14 +168,15 @@ class CTRLTokenizer(PreTrainedTokenizer):
while i < len(word): while i < len(word):
try: try:
j = word.index(first, i) j = word.index(first, i)
new_word.extend(word[i:j]) except ValueError:
i = j
except:
new_word.extend(word[i:]) new_word.extend(word[i:])
break break
else:
new_word.extend(word[i:j])
i = j
if word[i] == first and i < len(word)-1 and word[i+1] == second: if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first+second) new_word.append(first + second)
i += 2 i += 2
else: else:
new_word.append(word[i]) new_word.append(word[i])
...@@ -184,7 +187,7 @@ class CTRLTokenizer(PreTrainedTokenizer): ...@@ -184,7 +187,7 @@ class CTRLTokenizer(PreTrainedTokenizer):
break break
else: else:
pairs = get_pairs(word) pairs = get_pairs(word)
word = '@@ '.join(word) word = "@@ ".join(word)
word = word[:-4] word = word[:-4]
self.cache[token] = word self.cache[token] = word
return word return word
...@@ -194,10 +197,10 @@ class CTRLTokenizer(PreTrainedTokenizer): ...@@ -194,10 +197,10 @@ class CTRLTokenizer(PreTrainedTokenizer):
""" """
split_tokens = [] split_tokens = []
words = re.findall(r'\S+\n?', text) words = re.findall(r"\S+\n?", text)
for token in words: for token in words:
split_tokens.extend([t for t in self.bpe(token).split(' ')]) split_tokens.extend([t for t in self.bpe(token).split(" ")])
return split_tokens return split_tokens
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
...@@ -210,7 +213,7 @@ class CTRLTokenizer(PreTrainedTokenizer): ...@@ -210,7 +213,7 @@ class CTRLTokenizer(PreTrainedTokenizer):
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """ """ Converts a sequence of tokens (string) in a single string. """
out_string = ' '.join(tokens).replace('@@ ', '').strip() out_string = " ".join(tokens).replace("@@ ", "").strip()
return out_string return out_string
def save_vocabulary(self, save_directory): def save_vocabulary(self, save_directory):
...@@ -218,21 +221,23 @@ class CTRLTokenizer(PreTrainedTokenizer): ...@@ -218,21 +221,23 @@ class CTRLTokenizer(PreTrainedTokenizer):
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return return
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file']) merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES["merges_file"])
with open(vocab_file, 'w', encoding='utf-8') as f: with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False)) f.write(json.dumps(self.encoder, ensure_ascii=False))
index = 0 index = 0
with open(merge_file, "w", encoding="utf-8") as writer: with open(merge_file, "w", encoding="utf-8") as writer:
writer.write(u'#version: 0.2\n') writer.write("#version: 0.2\n")
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index: if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." logger.warning(
" Please check that the tokenizer is not corrupted!".format(merge_file)) "Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file)
)
index = token_index index = token_index
writer.write(' '.join(bpe_tokens) + u'\n') writer.write(" ".join(bpe_tokens) + "\n")
index += 1 index += 1
return vocab_file, merge_file return vocab_file, merge_file
......
...@@ -16,33 +16,29 @@ ...@@ -16,33 +16,29 @@
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import collections
import logging import logging
import os
import unicodedata
from io import open
from .tokenization_bert import BertTokenizer from .tokenization_bert import BertTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'} VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file': "vocab_file": {
{ "distilbert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", "distilbert-base-uncased-distilled-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", "distilbert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-vocab.txt",
'distilbert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-vocab.txt", "distilbert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
'distilbert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
} }
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'distilbert-base-uncased': 512, "distilbert-base-uncased": 512,
'distilbert-base-uncased-distilled-squad': 512, "distilbert-base-uncased-distilled-squad": 512,
'distilbert-base-german-cased': 512, "distilbert-base-german-cased": 512,
'distilbert-base-multilingual-cased': 512, "distilbert-base-multilingual-cased": 512,
} }
......
...@@ -13,16 +13,19 @@ ...@@ -13,16 +13,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tokenization classes for OpenAI GPT.""" """Tokenization classes for OpenAI GPT."""
from __future__ import (absolute_import, division, print_function, from __future__ import absolute_import, division, print_function, unicode_literals
unicode_literals)
import sys
import json import json
import logging import logging
import os import os
import regex as re import sys
from io import open from io import open
import regex as re
from .tokenization_utils import PreTrainedTokenizer
try: try:
from functools import lru_cache from functools import lru_cache
except ImportError: except ImportError:
...@@ -31,42 +34,40 @@ except ImportError: ...@@ -31,42 +34,40 @@ except ImportError:
def lru_cache(): def lru_cache():
return lambda func: func return lambda func: func
from .tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = { VOCAB_FILES_NAMES = {
'vocab_file': 'vocab.json', "vocab_file": "vocab.json",
'merges_file': 'merges.txt', "merges_file": "merges.txt",
} }
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file': "vocab_file": {
{ "gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json", "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-vocab.json",
'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-vocab.json", "gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-vocab.json",
'gpt2-xl': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-vocab.json", "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-vocab.json",
'distilgpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-vocab.json",
}, },
'merges_file': "merges_file": {
{ "gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt", "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-merges.txt",
'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-merges.txt", "gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-merges.txt",
'gpt2-xl': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-merges.txt", "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-merges.txt",
'distilgpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-merges.txt",
}, },
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'gpt2': 1024, "gpt2": 1024,
'gpt2-medium': 1024, "gpt2-medium": 1024,
'gpt2-large': 1024, "gpt2-large": 1024,
'gpt2-xl': 1024, "gpt2-xl": 1024,
'distilgpt2': 1024, "distilgpt2": 1024,
} }
@lru_cache() @lru_cache()
def bytes_to_unicode(): def bytes_to_unicode():
""" """
...@@ -79,18 +80,21 @@ def bytes_to_unicode(): ...@@ -79,18 +80,21 @@ def bytes_to_unicode():
This is a signficant percentage of your normal, say, 32K bpe vocab. This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings. To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
""" """
_chr = unichr if sys.version_info[0] == 2 else chr _chr = unichr if sys.version_info[0] == 2 else chr # noqa: F821
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) bs = (
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:] cs = bs[:]
n = 0 n = 0
for b in range(2**8): for b in range(2 ** 8):
if b not in bs: if b not in bs:
bs.append(b) bs.append(b)
cs.append(2**8+n) cs.append(2 ** 8 + n)
n += 1 n += 1
cs = [_chr(n) for n in cs] cs = [_chr(n) for n in cs]
return dict(zip(bs, cs)) return dict(zip(bs, cs))
def get_pairs(word): def get_pairs(word):
"""Return set of symbol pairs in a word. """Return set of symbol pairs in a word.
...@@ -103,6 +107,7 @@ def get_pairs(word): ...@@ -103,6 +107,7 @@ def get_pairs(word):
prev_char = char prev_char = char
return pairs return pairs
class GPT2Tokenizer(PreTrainedTokenizer): class GPT2Tokenizer(PreTrainedTokenizer):
""" """
GPT-2 BPE tokenizer. Peculiarities: GPT-2 BPE tokenizer. Peculiarities:
...@@ -112,15 +117,28 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -112,15 +117,28 @@ class GPT2Tokenizer(PreTrainedTokenizer):
Otherwise, this tokenizer's ``encode``, ``decode``, and ``tokenize`` methods will not conserve Otherwise, this tokenizer's ``encode``, ``decode``, and ``tokenize`` methods will not conserve
the spaces at the beginning of a string: `tokenizer.decode(tokenizer.encode(" Hello")) = "Hello"` the spaces at the beginning of a string: `tokenizer.decode(tokenizer.encode(" Hello")) = "Hello"`
""" """
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, merges_file, errors='replace', unk_token="<|endoftext|>", def __init__(
bos_token="<|endoftext|>", eos_token="<|endoftext|>", **kwargs): self,
vocab_file,
merges_file,
errors="replace",
unk_token="<|endoftext|>",
bos_token="<|endoftext|>",
eos_token="<|endoftext|>",
**kwargs
):
super(GPT2Tokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs) super(GPT2Tokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens self.max_len_single_sentence = (
self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens self.max_len
) # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = (
self.max_len
) # no default special tokens - you can update this value if you add special tokens
with open(vocab_file, encoding="utf-8") as vocab_handle: with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle) self.encoder = json.load(vocab_handle)
...@@ -128,8 +146,8 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -128,8 +146,8 @@ class GPT2Tokenizer(PreTrainedTokenizer):
self.errors = errors # how to handle errors in decoding self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode() self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
with open(merges_file, encoding='utf-8') as merges_handle: with open(merges_file, encoding="utf-8") as merges_handle:
bpe_merges = merges_handle.read().split('\n')[1:-1] bpe_merges = merges_handle.read().split("\n")[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_merges] bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {} self.cache = {}
...@@ -151,7 +169,7 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -151,7 +169,7 @@ class GPT2Tokenizer(PreTrainedTokenizer):
return token return token
while True: while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks: if bigram not in self.bpe_ranks:
break break
first, second = bigram first, second = bigram
...@@ -160,14 +178,15 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -160,14 +178,15 @@ class GPT2Tokenizer(PreTrainedTokenizer):
while i < len(word): while i < len(word):
try: try:
j = word.index(first, i) j = word.index(first, i)
new_word.extend(word[i:j]) except ValueError:
i = j
except:
new_word.extend(word[i:]) new_word.extend(word[i:])
break break
else:
new_word.extend(word[i:j])
i = j
if word[i] == first and i < len(word)-1 and word[i+1] == second: if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first+second) new_word.append(first + second)
i += 2 i += 2
else: else:
new_word.append(word[i]) new_word.append(word[i])
...@@ -178,7 +197,7 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -178,7 +197,7 @@ class GPT2Tokenizer(PreTrainedTokenizer):
break break
else: else:
pairs = get_pairs(word) pairs = get_pairs(word)
word = ' '.join(word) word = " ".join(word)
self.cache[token] = word self.cache[token] = word
return word return word
...@@ -189,15 +208,19 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -189,15 +208,19 @@ class GPT2Tokenizer(PreTrainedTokenizer):
Begin the sentence with at least one space to get invariance to word order in GPT-2 (and RoBERTa) tokenizers. Begin the sentence with at least one space to get invariance to word order in GPT-2 (and RoBERTa) tokenizers.
""" """
if add_prefix_space: if add_prefix_space:
text = ' ' + text text = " " + text
bpe_tokens = [] bpe_tokens = []
for token in re.findall(self.pat, text): for token in re.findall(self.pat, text):
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
token = ''.join(self.byte_encoder[ord(b)] for b in token) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case) token = "".join(
self.byte_encoder[ord(b)] for b in token
) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
else: else:
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case) token = "".join(
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) self.byte_encoder[b] for b in token.encode("utf-8")
) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
return bpe_tokens return bpe_tokens
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
...@@ -210,8 +233,8 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -210,8 +233,8 @@ class GPT2Tokenizer(PreTrainedTokenizer):
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """ """ Converts a sequence of tokens (string) in a single string. """
text = ''.join(tokens) text = "".join(tokens)
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
return text return text
def save_vocabulary(self, save_directory): def save_vocabulary(self, save_directory):
...@@ -219,21 +242,23 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -219,21 +242,23 @@ class GPT2Tokenizer(PreTrainedTokenizer):
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return return
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file']) merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES["merges_file"])
with open(vocab_file, 'w', encoding='utf-8') as f: with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False)) f.write(json.dumps(self.encoder, ensure_ascii=False))
index = 0 index = 0
with open(merge_file, "w", encoding="utf-8") as writer: with open(merge_file, "w", encoding="utf-8") as writer:
writer.write(u'#version: 0.2\n') writer.write("#version: 0.2\n")
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index: if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." logger.warning(
" Please check that the tokenizer is not corrupted!".format(merge_file)) "Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file)
)
index = token_index index = token_index
writer.write(' '.join(bpe_tokens) + u'\n') writer.write(" ".join(bpe_tokens) + "\n")
index += 1 index += 1
return vocab_file, merge_file return vocab_file, merge_file
...@@ -13,8 +13,7 @@ ...@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tokenization classes for OpenAI GPT.""" """Tokenization classes for OpenAI GPT."""
from __future__ import (absolute_import, division, print_function, from __future__ import absolute_import, division, print_function, unicode_literals
unicode_literals)
import json import json
import logging import logging
...@@ -22,31 +21,27 @@ import os ...@@ -22,31 +21,27 @@ import os
import re import re
from io import open from io import open
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_bert import BasicTokenizer from .tokenization_bert import BasicTokenizer
from .tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = { VOCAB_FILES_NAMES = {
'vocab_file': 'vocab.json', "vocab_file": "vocab.json",
'merges_file': 'merges.txt', "merges_file": "merges.txt",
} }
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file': "vocab_file": {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json"},
{ "merges_file": {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt"},
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json",
},
'merges_file':
{
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt",
},
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'openai-gpt': 512, "openai-gpt": 512,
} }
def get_pairs(word): def get_pairs(word):
""" """
Return set of symbol pairs in a word. Return set of symbol pairs in a word.
...@@ -59,27 +54,30 @@ def get_pairs(word): ...@@ -59,27 +54,30 @@ def get_pairs(word):
prev_char = char prev_char = char
return pairs return pairs
def text_standardize(text): def text_standardize(text):
""" """
fixes some issues the spacy tokenizer had on books corpus fixes some issues the spacy tokenizer had on books corpus
also does some whitespace standardization also does some whitespace standardization
""" """
text = text.replace('—', '-') text = text.replace("—", "-")
text = text.replace('–', '-') text = text.replace("–", "-")
text = text.replace('―', '-') text = text.replace("―", "-")
text = text.replace('…', '...') text = text.replace("…", "...")
text = text.replace('´', "'") text = text.replace("´", "'")
text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text) text = re.sub(r"""(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)""", r" \1 ", text)
text = re.sub(r'\s*\n\s*', ' \n ', text) text = re.sub(r"\s*\n\s*", " \n ", text)
text = re.sub(r'[^\S\n]+', ' ', text) text = re.sub(r"[^\S\n]+", " ", text)
return text.strip() return text.strip()
class OpenAIGPTTokenizer(PreTrainedTokenizer): class OpenAIGPTTokenizer(PreTrainedTokenizer):
""" """
BPE tokenizer. Peculiarities: BPE tokenizer. Peculiarities:
- lower case all inputs - lower case all inputs
- uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not. - uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not.
""" """
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
...@@ -87,12 +85,17 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): ...@@ -87,12 +85,17 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs): def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs):
super(OpenAIGPTTokenizer, self).__init__(unk_token=unk_token, **kwargs) super(OpenAIGPTTokenizer, self).__init__(unk_token=unk_token, **kwargs)
self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens self.max_len_single_sentence = (
self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens self.max_len
) # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = (
self.max_len
) # no default special tokens - you can update this value if you add special tokens
try: try:
import ftfy import ftfy
from spacy.lang.en import English from spacy.lang.en import English
_nlp = English() _nlp = English()
self.nlp = _nlp.Defaults.create_tokenizer(_nlp) self.nlp = _nlp.Defaults.create_tokenizer(_nlp)
self.fix_text = ftfy.fix_text self.fix_text = ftfy.fix_text
...@@ -103,9 +106,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): ...@@ -103,9 +106,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
with open(vocab_file, encoding="utf-8") as vocab_handle: with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle) self.encoder = json.load(vocab_handle)
self.decoder = {v:k for k,v in self.encoder.items()} self.decoder = {v: k for k, v in self.encoder.items()}
with open(merges_file, encoding='utf-8') as merges_handle: with open(merges_file, encoding="utf-8") as merges_handle:
merges = merges_handle.read().split('\n')[1:-1] merges = merges_handle.read().split("\n")[1:-1]
merges = [tuple(merge.split()) for merge in merges] merges = [tuple(merge.split()) for merge in merges]
self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {} self.cache = {}
...@@ -115,16 +118,16 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): ...@@ -115,16 +118,16 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
return len(self.encoder) return len(self.encoder)
def bpe(self, token): def bpe(self, token):
word = tuple(token[:-1]) + (token[-1] + '</w>',) word = tuple(token[:-1]) + (token[-1] + "</w>",)
if token in self.cache: if token in self.cache:
return self.cache[token] return self.cache[token]
pairs = get_pairs(word) pairs = get_pairs(word)
if not pairs: if not pairs:
return token+'</w>' return token + "</w>"
while True: while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks: if bigram not in self.bpe_ranks:
break break
first, second = bigram first, second = bigram
...@@ -133,14 +136,15 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): ...@@ -133,14 +136,15 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
while i < len(word): while i < len(word):
try: try:
j = word.index(first, i) j = word.index(first, i)
new_word.extend(word[i:j]) except ValueError:
i = j
except:
new_word.extend(word[i:]) new_word.extend(word[i:])
break break
else:
new_word.extend(word[i:j])
i = j
if word[i] == first and i < len(word)-1 and word[i+1] == second: if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first+second) new_word.append(first + second)
i += 2 i += 2
else: else:
new_word.append(word[i]) new_word.append(word[i])
...@@ -151,9 +155,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): ...@@ -151,9 +155,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
break break
else: else:
pairs = get_pairs(word) pairs = get_pairs(word)
word = ' '.join(word) word = " ".join(word)
if word == '\n </w>': if word == "\n </w>":
word = '\n</w>' word = "\n</w>"
self.cache[token] = word self.cache[token] = word
return word return word
...@@ -164,12 +168,12 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): ...@@ -164,12 +168,12 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
# Using BERT's BasicTokenizer # Using BERT's BasicTokenizer
text = self.nlp.tokenize(text) text = self.nlp.tokenize(text)
for token in text: for token in text:
split_tokens.extend([t for t in self.bpe(token).split(' ')]) split_tokens.extend([t for t in self.bpe(token).split(" ")])
else: else:
# Using SpaCy & ftfy (original tokenization process of OpenAI GPT) # Using SpaCy & ftfy (original tokenization process of OpenAI GPT)
text = self.nlp(text_standardize(self.fix_text(text))) text = self.nlp(text_standardize(self.fix_text(text)))
for token in text: for token in text:
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')]) split_tokens.extend([t for t in self.bpe(token.text.lower()).split(" ")])
return split_tokens return split_tokens
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
...@@ -182,7 +186,7 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): ...@@ -182,7 +186,7 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """ """ Converts a sequence of tokens (string) in a single string. """
out_string = ''.join(tokens).replace('</w>', ' ').strip() out_string = "".join(tokens).replace("</w>", " ").strip()
return out_string return out_string
def save_vocabulary(self, save_directory): def save_vocabulary(self, save_directory):
...@@ -190,21 +194,23 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): ...@@ -190,21 +194,23 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return return
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file']) merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES["merges_file"])
with open(vocab_file, 'w', encoding='utf-8') as f: with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False)) f.write(json.dumps(self.encoder, ensure_ascii=False))
index = 0 index = 0
with open(merge_file, "w", encoding="utf-8") as writer: with open(merge_file, "w", encoding="utf-8") as writer:
writer.write(u'#version: 0.2\n') writer.write("#version: 0.2\n")
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index: if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." logger.warning(
" Please check that the tokenizer is not corrupted!".format(merge_file)) "Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file)
)
index = token_index index = token_index
writer.write(' '.join(bpe_tokens) + u'\n') writer.write(" ".join(bpe_tokens) + "\n")
index += 1 index += 1
return vocab_file, merge_file return vocab_file, merge_file
...@@ -13,18 +13,13 @@ ...@@ -13,18 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tokenization classes for RoBERTa.""" """Tokenization classes for RoBERTa."""
from __future__ import (absolute_import, division, print_function, from __future__ import absolute_import, division, print_function, unicode_literals
unicode_literals)
import sys
import json
import logging import logging
import os
import regex as re
from io import open
from .tokenization_gpt2 import GPT2Tokenizer from .tokenization_gpt2 import GPT2Tokenizer
try: try:
from functools import lru_cache from functools import lru_cache
except ImportError: except ImportError:
...@@ -33,41 +28,40 @@ except ImportError: ...@@ -33,41 +28,40 @@ except ImportError:
def lru_cache(): def lru_cache():
return lambda func: func return lambda func: func
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = { VOCAB_FILES_NAMES = {
'vocab_file': 'vocab.json', "vocab_file": "vocab.json",
'merges_file': 'merges.txt', "merges_file": "merges.txt",
} }
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file': "vocab_file": {
{ "roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json",
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json", "roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json",
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json", "roberta-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-vocab.json",
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-vocab.json", "distilroberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-vocab.json",
'distilroberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-vocab.json", "roberta-base-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json",
'roberta-base-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json", "roberta-large-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json",
'roberta-large-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json",
}, },
'merges_file': "merges_file": {
{ "roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt",
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt", "roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt",
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt", "roberta-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-merges.txt",
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-merges.txt", "distilroberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-merges.txt",
'distilroberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-merges.txt", "roberta-base-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt",
'roberta-base-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt", "roberta-large-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt",
'roberta-large-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt",
}, },
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'roberta-base': 512, "roberta-base": 512,
'roberta-large': 512, "roberta-large": 512,
'roberta-large-mnli': 512, "roberta-large-mnli": 512,
'distilroberta-base': 512, "distilroberta-base": 512,
'roberta-base-openai-detector': 512, "roberta-base-openai-detector": 512,
'roberta-large-openai-detector': 512, "roberta-large-openai-detector": 512,
} }
...@@ -80,16 +74,38 @@ class RobertaTokenizer(GPT2Tokenizer): ...@@ -80,16 +74,38 @@ class RobertaTokenizer(GPT2Tokenizer):
Otherwise, this tokenizer ``encode`` and ``decode`` method will not conserve Otherwise, this tokenizer ``encode`` and ``decode`` method will not conserve
the absence of a space at the beginning of a string: `tokenizer.decode(tokenizer.encode("Hello")) = " Hello"` the absence of a space at the beginning of a string: `tokenizer.decode(tokenizer.encode("Hello")) = " Hello"`
""" """
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, merges_file, errors='replace', bos_token="<s>", eos_token="</s>", sep_token="</s>", def __init__(
cls_token="<s>", unk_token="<unk>", pad_token='<pad>', mask_token='<mask>', **kwargs): self,
super(RobertaTokenizer, self).__init__(vocab_file=vocab_file, merges_file=merges_file, errors=errors, vocab_file,
bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, merges_file,
sep_token=sep_token, cls_token=cls_token, pad_token=pad_token, errors="replace",
mask_token=mask_token, **kwargs) bos_token="<s>",
eos_token="</s>",
sep_token="</s>",
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
**kwargs
):
super(RobertaTokenizer, self).__init__(
vocab_file=vocab_file,
merges_file=merges_file,
errors=errors,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
**kwargs
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens
...@@ -124,8 +140,10 @@ class RobertaTokenizer(GPT2Tokenizer): ...@@ -124,8 +140,10 @@ class RobertaTokenizer(GPT2Tokenizer):
""" """
if already_has_special_tokens: if already_has_special_tokens:
if token_ids_1 is not None: if token_ids_1 is not None:
raise ValueError("You should not supply a second sequence if the provided sequence of " raise ValueError(
"ids is already formated with special tokens for the model.") "You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
if token_ids_1 is None: if token_ids_1 is None:
......
...@@ -19,33 +19,34 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -19,33 +19,34 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import logging import logging
import os import os
import re import re
import six
from shutil import copyfile from shutil import copyfile
import six
from .tokenization_utils import PreTrainedTokenizer from .tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SPIECE_UNDERLINE = u'▁' SPIECE_UNDERLINE = "▁"
#################################################### ####################################################
# Mapping from the keyword arguments names of Tokenizer `__init__` # Mapping from the keyword arguments names of Tokenizer `__init__`
# to file names for serializing Tokenizer instances # to file names for serializing Tokenizer instances
#################################################### ####################################################
VOCAB_FILES_NAMES = {'vocab_file': 'spiece.model'} VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
#################################################### ####################################################
# Mapping from the keyword arguments names of Tokenizer `__init__` # Mapping from the keyword arguments names of Tokenizer `__init__`
# to pretrained vocabulary URL for all the model shortcut names. # to pretrained vocabulary URL for all the model shortcut names.
#################################################### ####################################################
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file': "vocab_file": {
{ "t5-small": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
't5-small': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model", "t5-base": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
't5-base': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model", "t5-large": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
't5-large': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model", "t5-3b": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
't5-3b': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model", "t5-11b": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
't5-11b': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
} }
} }
...@@ -53,13 +54,14 @@ PRETRAINED_VOCAB_FILES_MAP = { ...@@ -53,13 +54,14 @@ PRETRAINED_VOCAB_FILES_MAP = {
# Mapping from model shortcut names to max length of inputs # Mapping from model shortcut names to max length of inputs
#################################################### ####################################################
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
't5-small': 512, "t5-small": 512,
't5-base': 512, "t5-base": 512,
't5-large': 512, "t5-large": 512,
't5-3b': 512, "t5-3b": 512,
't5-11b': 512, "t5-11b": 512,
} }
class T5Tokenizer(PreTrainedTokenizer): class T5Tokenizer(PreTrainedTokenizer):
""" """
SentencePiece based tokenizer. Peculiarities: SentencePiece based tokenizer. Peculiarities:
...@@ -71,28 +73,43 @@ class T5Tokenizer(PreTrainedTokenizer): ...@@ -71,28 +73,43 @@ class T5Tokenizer(PreTrainedTokenizer):
(like in T5 preprocessing (like in T5 preprocessing
see: https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117) see: https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117)
""" """
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, eos_token="</s>", unk_token="<unk>", def __init__(
pad_token="<pad>", extra_ids=100, additional_special_tokens=None, **kwargs): self,
vocab_file,
eos_token="</s>",
unk_token="<unk>",
pad_token="<pad>",
extra_ids=100,
additional_special_tokens=None,
**kwargs
):
# Add extra_ids to the special token list # Add extra_ids to the special token list
if extra_ids > 0: if extra_ids > 0:
if additional_special_tokens is None: if additional_special_tokens is None:
additional_special_tokens = [] additional_special_tokens = []
additional_special_tokens.extend([u"<extra_id_{}>".format(i) for i in range(extra_ids)]) additional_special_tokens.extend(["<extra_id_{}>".format(i) for i in range(extra_ids)])
super(T5Tokenizer, self).__init__(eos_token=eos_token, unk_token=unk_token, super(T5Tokenizer, self).__init__(
pad_token=pad_token, additional_special_tokens=additional_special_tokens, eos_token=eos_token,
**kwargs) unk_token=unk_token,
pad_token=pad_token,
additional_special_tokens=additional_special_tokens,
**kwargs
)
try: try:
import sentencepiece as spm import sentencepiece as spm
except ImportError: except ImportError:
logger.warning("You need to install SentencePiece to use T5Tokenizer:" logger.warning(
"https://github.com/google/sentencepiece" "You need to install SentencePiece to use T5Tokenizer:"
"pip install sentencepiece") "https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
self.vocab_file = vocab_file self.vocab_file = vocab_file
self._extra_ids = extra_ids self._extra_ids = extra_ids
...@@ -114,8 +131,10 @@ class T5Tokenizer(PreTrainedTokenizer): ...@@ -114,8 +131,10 @@ class T5Tokenizer(PreTrainedTokenizer):
try: try:
import sentencepiece as spm import sentencepiece as spm
except ImportError: except ImportError:
logger.warning("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece" logger.warning(
"pip install sentencepiece") "You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
self.sp_model = spm.SentencePieceProcessor() self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(self.vocab_file) self.sp_model.Load(self.vocab_file)
...@@ -132,7 +151,7 @@ class T5Tokenizer(PreTrainedTokenizer): ...@@ -132,7 +151,7 @@ class T5Tokenizer(PreTrainedTokenizer):
ret_pieces = [] ret_pieces = []
for piece in pieces: for piece in pieces:
if isinstance(piece, str): if isinstance(piece, str):
piece = piece.decode('utf-8') piece = piece.decode("utf-8")
ret_pieces.append(piece) ret_pieces.append(piece)
pieces = ret_pieces pieces = ret_pieces
...@@ -140,9 +159,9 @@ class T5Tokenizer(PreTrainedTokenizer): ...@@ -140,9 +159,9 @@ class T5Tokenizer(PreTrainedTokenizer):
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
""" Converts a token (str/unicode) in an id using the vocab. """ """ Converts a token (str/unicode) in an id using the vocab. """
if token.startswith(u"<extra_id_"): if token.startswith("<extra_id_"):
l = re.match(r'<extra_id_(\d+)>', token) match = re.match(r"<extra_id_(\d+)>", token)
num = int(l.group(1)) num = int(match.group(1))
return self.vocab_size - num - 1 return self.vocab_size - num - 1
return self.sp_model.piece_to_id(token) return self.sp_model.piece_to_id(token)
...@@ -151,9 +170,9 @@ class T5Tokenizer(PreTrainedTokenizer): ...@@ -151,9 +170,9 @@ class T5Tokenizer(PreTrainedTokenizer):
if index < self.sp_model.get_piece_size(): if index < self.sp_model.get_piece_size():
token = self.sp_model.IdToPiece(index) token = self.sp_model.IdToPiece(index)
else: else:
token = u"<extra_id_{}>".format(self.vocab_size - 1 - index) token = "<extra_id_{}>".format(self.vocab_size - 1 - index)
if six.PY2 and return_unicode and isinstance(token, str): if six.PY2 and return_unicode and isinstance(token, str):
token = token.decode('utf-8') token = token.decode("utf-8")
return token return token
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
...@@ -168,7 +187,7 @@ class T5Tokenizer(PreTrainedTokenizer): ...@@ -168,7 +187,7 @@ class T5Tokenizer(PreTrainedTokenizer):
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return return
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment