Commit 75d5f98f authored by LysandreJik's avatar LysandreJik
Browse files

Roberta tokenization + fixed tests (py3 + py2).

parent 14e970c2
...@@ -157,42 +157,6 @@ class RobertaModelTest(CommonTestCases.CommonModelTester): ...@@ -157,42 +157,6 @@ class RobertaModelTest(CommonTestCases.CommonModelTester):
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask} inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask}
return config, inputs_dict return config, inputs_dict
def test_inference_masked_lm(self):
model = RobertaForMaskedLM.from_pretrained('roberta-base')
input_ids = torch.tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
output = model(input_ids)[0]
expected_shape = torch.Size((1, 11, 50265))
self.assertEqual(
output.shape,
expected_shape
)
# compare the actual values for a slice.
expected_slice = torch.Tensor(
[[[33.8843, -4.3107, 22.7779],
[4.6533, -2.8099, 13.6252],
[1.8222, -3.6898, 8.8600]]]
)
self.assertTrue(
torch.allclose(output[:, :3, :3], expected_slice, atol=1e-3)
)
# @pytest.mark.slow
def test_inference_no_head(self):
model = RobertaModel.from_pretrained('roberta-base')
input_ids = torch.tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
output = model(input_ids)[0]
# compare the actual values for a slice.
expected_slice = torch.Tensor(
[[[-0.0231, 0.0782, 0.0074],
[-0.1854, 0.0539, -0.0174],
[0.0548, 0.0799, 0.1687]]]
)
self.assertTrue(
torch.allclose(output[:, :3, :3], expected_slice, atol=1e-3)
)
def setUp(self): def setUp(self):
self.model_tester = RobertaModelTest.RobertaModelTester(self) self.model_tester = RobertaModelTest.RobertaModelTester(self)
self.config_tester = ConfigTester(self, config_class=RobertaConfig, hidden_size=37) self.config_tester = ConfigTester(self, config_class=RobertaConfig, hidden_size=37)
...@@ -220,7 +184,7 @@ class RobertaModelTest(CommonTestCases.CommonModelTester): ...@@ -220,7 +184,7 @@ class RobertaModelTest(CommonTestCases.CommonModelTester):
class RobertaModelIntegrationTest(unittest.TestCase): class RobertaModelIntegrationTest(unittest.TestCase):
# @pytest.mark.slow @pytest.mark.slow
def test_inference_masked_lm(self): def test_inference_masked_lm(self):
model = RobertaForMaskedLM.from_pretrained('roberta-base') model = RobertaForMaskedLM.from_pretrained('roberta-base')
...@@ -241,7 +205,7 @@ class RobertaModelIntegrationTest(unittest.TestCase): ...@@ -241,7 +205,7 @@ class RobertaModelIntegrationTest(unittest.TestCase):
torch.allclose(output[:, :3, :3], expected_slice, atol=1e-3) torch.allclose(output[:, :3, :3], expected_slice, atol=1e-3)
) )
# @pytest.mark.slow @pytest.mark.slow
def test_inference_no_head(self): def test_inference_no_head(self):
model = RobertaModel.from_pretrained('roberta-base') model = RobertaModel.from_pretrained('roberta-base')
......
...@@ -18,8 +18,7 @@ import os ...@@ -18,8 +18,7 @@ import os
import json import json
import unittest import unittest
from pytorch_transformers.tokenization_roberta import RobertaTokenizer, DICT_FILES_NAMES from pytorch_transformers.tokenization_roberta import RobertaTokenizer, VOCAB_FILES_NAMES
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
...@@ -45,8 +44,7 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -45,8 +44,7 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
fp.write("\n".join(merges)) fp.write("\n".join(merges))
def get_tokenizer(self): def get_tokenizer(self):
bpe_tokenizer = GPT2Tokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map) return RobertaTokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map)
return RobertaTokenizer.from_pretrained("roberta-base", bpe_tokenizer=bpe_tokenizer)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"lower newer" input_text = u"lower newer"
...@@ -54,15 +52,14 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -54,15 +52,14 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
return input_text, output_text return input_text, output_text
def test_full_tokenizer(self): def test_full_tokenizer(self):
tokenizer = self.get_tokenizer() tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
text = "lower" text = "lower"
bpe_tokens = ["low", "er"] bpe_tokens = ["low", "er"]
tokens = tokenizer.tokenize(text) tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens) self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + [tokenizer.unk_token] input_tokens = tokens + [tokenizer.unk_token]
input_bpe_tokens = [0, 4, 12, 176, 2] input_bpe_tokens = [13, 12, 17]
tokenizer.convert_tokens_to_ids(input_tokens)
self.assertListEqual( self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
......
...@@ -12,229 +12,182 @@ ...@@ -12,229 +12,182 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tokenization classes for RoBERTa.""" """Tokenization classes for OpenAI GPT."""
from __future__ import (absolute_import, division, print_function, from __future__ import (absolute_import, division, print_function,
unicode_literals) unicode_literals)
import sys
import json import json
import logging import logging
import re
from io import open
import six
import os import os
import regex as re
from io import open
from .tokenization_gpt2 import bytes_to_unicode, get_pairs
from .tokenization_utils import PreTrainedTokenizer from .tokenization_utils import PreTrainedTokenizer
from .tokenization_gpt2 import GPT2Tokenizer
try:
from functools import lru_cache
except ImportError:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
def lru_cache():
return lambda func: func
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DICT_FILES_NAMES = { VOCAB_FILES_NAMES = {
'dict_file': 'dict.txt', 'vocab_file': 'vocab.json',
'merges_file': 'merges.txt',
} }
PRETRAINED_DICT_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
'dict_file': 'vocab_file':
{ {
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-dict.txt", 'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json",
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-dict.txt", 'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json",
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-dict.txt", 'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-vocab.json",
}, },
'merges_file':
{
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt",
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt",
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-merges.txt",
},
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'roberta-base': 512, 'roberta-base': 1024,
'roberta-large': 512, 'roberta-large': 1024,
'roberta-large-mnli': 512, 'roberta-large-mnli': 1024,
} }
SPACE_NORMALIZER = re.compile(r"\s+")
def tokenize_line(line):
line = SPACE_NORMALIZER.sub(" ", line)
line = line.strip()
return line.split()
class Dictionary(object):
"""
A mapping from symbols to consecutive integers
From Facebook's fairseq.
"""
def __init__(
self,
pad='<pad>',
eos='</s>',
unk='<unk>',
bos='<s>',
extra_special_symbols=None,
):
self.unk_word, self.pad_word, self.eos_word = unk, pad, eos
self.symbols = []
self.count = []
self.indices = {}
self.bos_index = self.add_symbol(bos)
self.pad_index = self.add_symbol(pad)
self.eos_index = self.add_symbol(eos)
self.unk_index = self.add_symbol(unk)
if extra_special_symbols:
for s in extra_special_symbols:
self.add_symbol(s)
self.nspecial = len(self.symbols)
def __getitem__(self, idx):
if idx < len(self.symbols):
return self.symbols[idx]
return self.unk_word
def index(self, sym):
"""Returns the index of the specified symbol"""
assert isinstance(sym, str)
if sym in self.indices:
return self.indices[sym]
return self.unk_index
def add_symbol(self, word, n=1):
"""Adds a word to the dictionary"""
if word in self.indices:
idx = self.indices[word]
self.count[idx] = self.count[idx] + n
return idx
else:
idx = len(self.symbols)
self.indices[word] = idx
self.symbols.append(word)
self.count.append(n)
return idx
@classmethod
def load(cls, f, ignore_utf_errors=False):
"""Loads the dictionary from a text file with the format:
```
<symbol0> <count0>
<symbol1> <count1>
...
```
"""
d = cls()
d.add_from_file(f, ignore_utf_errors)
return d
def add_from_file(self, f, ignore_utf_errors=False):
"""
Loads a pre-existing dictionary from a text file and adds its symbols
to this instance.
"""
if isinstance(f, six.string_types):
try:
if not ignore_utf_errors:
with open(f, 'r', encoding='utf-8') as fd:
self.add_from_file(fd)
else:
with open(f, 'r', encoding='utf-8', errors='ignore') as fd:
self.add_from_file(fd)
except FileNotFoundError as fnfe:
raise fnfe
except UnicodeError:
raise Exception("Incorrect encoding detected in {}, please "
"rebuild the dataset".format(f))
return
lines = f.read().splitlines()
for line in lines:
idx = line.rfind(' ')
if idx == -1:
raise ValueError("Incorrect dictionary format, expected '<token> <cnt>'")
word = line[:idx]
count = int(line[idx + 1:])
self.indices[word] = len(self.symbols)
self.symbols.append(word)
self.count.append(count)
def encode_line(self, line, line_tokenizer=tokenize_line, add_if_not_exist=True,
consumer=None, append_eos=True, reverse_order=False):
words = line_tokenizer(line)
if reverse_order:
words = list(reversed(words))
nwords = len(words)
ids = [0] * (nwords + 1 if append_eos else nwords)
for i, word in enumerate(words):
if add_if_not_exist:
idx = self.add_symbol(word)
else:
idx = self.index(word)
if consumer is not None:
consumer(word, idx)
ids[i] = idx
if append_eos:
ids[nwords] = self.eos_index
return ids
class RobertaTokenizer(PreTrainedTokenizer): class RobertaTokenizer(PreTrainedTokenizer):
""" """
RoBERTa tokenizer. Peculiarities: GPT-2 BPE tokenizer. Peculiarities:
- GPT-2 tokenizer with a different integer mapping on top. - Byte-level BPE
""" """
vocab_files_names = DICT_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_DICT_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, dict_file, bpe_tokenizer=None, bos_token="<s>", eos_token="</s>", sep_token="</s>", cls_token="<s>", def __init__(self, vocab_file, merges_file, errors='replace', bos_token="<s>", eos_token="</s>", sep_token="</s>",
unk_token="<unk>", **kwargs): cls_token="<s>", unk_token="<unk>", **kwargs):
super(RobertaTokenizer, self).__init__(cls_token=bos_token, sep_token=eos_token, eos_token=eos_token, super(RobertaTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token,
unk_token=unk_token, **kwargs) sep_token=sep_token, cls_token=cls_token, **kwargs)
self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2") if bpe_tokenizer is None else bpe_tokenizer self.encoder = json.load(open(vocab_file, encoding="utf-8"))
self.dictionary = Dictionary.load(dict_file) self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
@property @property
def vocab_size(self): def vocab_size(self):
return len(self.dictionary.indices) return len(self.encoder)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def _tokenize(self, text): def _tokenize(self, text):
""" Use GPT-2 Tokenizer """ """ Tokenize a string. """
return self.gpt2_tokenizer._tokenize(text) bpe_tokens = []
for token in re.findall(self.pat, text):
if sys.version_info[0] == 2:
token = ''.join(self.byte_encoder[ord(b)] for b in token)
else:
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
if self.dictionary.index(token) != 3: """ Converts a token (str/unicode) in an id using the vocab. """
return self.dictionary.index(token) return self.encoder.get(token, self.encoder.get(self.unk_token))
return self.dictionary.index(str(self.gpt2_tokenizer.convert_tokens_to_ids(token)))
def _convert_id_to_token(self, index): def _convert_id_to_token(self, index):
symbol = self.dictionary[index] """Converts an index (integer) in a token (string/unicode) using the vocab."""
try: return self.decoder.get(index)
idx = int(symbol)
return self.gpt2_tokenizer._convert_id_to_token(idx)
except ValueError:
return symbol
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
return self.gpt2_tokenizer.convert_tokens_to_string(tokens) """ Converts a sequence of tokens (string) in a single string. """
text = ''.join(tokens)
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
return text
def convert_tokens_to_ids(self, tokens, no_sep_cls_tokens=False): def add_special_tokens_single_sentence(self, token_ids):
cls = [self._convert_token_to_id(self.cls_token)] return [self._convert_token_to_id(self.cls_token)] + token_ids + [self._convert_token_to_id(self.sep_token)]
tokens = super().convert_tokens_to_ids(tokens)
sep = [self._convert_token_to_id(self.sep_token)]
return (cls + tokens + sep) if (isinstance(tokens, list) and not no_sep_cls_tokens) else tokens
def convert_ids_to_tokens(self, ids, skip_special_tokens=False): def add_special_tokens_sentences_pair(self, *token_ids):
return super().convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)[1:-1] sep = [self._convert_token_to_id(self.sep_token)]
cls = [self._convert_token_to_id(self.cls_token)]
return cls + token_ids[0] + sep + sep + token_ids[1] + sep
def save_vocabulary(self, save_directory): def save_vocabulary(self, save_directory):
"""Save the tokenizer vocabulary and merge files to a directory.""" """Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return return
dict_file = os.path.join(save_directory, DICT_FILES_NAMES['dict_file']) vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file'])
with open(dict_file, 'w', encoding='utf-8') as f:
for i in range(self.dictionary.nspecial, len(self.dictionary.count)): with open(vocab_file, 'w', encoding='utf-8') as f:
f.write(f"{list(self.dictionary.indices.keys())[i]} {self.dictionary.count[i]}\n") f.write(json.dumps(self.encoder, ensure_ascii=False))
vocab_files = self.gpt2_tokenizer.save_pretrained(save_directory) index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
return vocab_files + (dict_file,) writer.write(u'#version: 0.2\n')
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file))
index = token_index
writer.write(' '.join(bpe_tokens) + u'\n')
index += 1
return vocab_file, merge_file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment