Commit abe734ca authored by thomwolf's avatar thomwolf
Browse files

fix GPT-2 and RoBERTa tests to be clean now

parent 0f5a7994
...@@ -31,17 +31,18 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -31,17 +31,18 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
"lo", "low", "er", "\u0120", "\u0120l", "\u0120n",
"low", "lowest", "newer", "wider", "<unk>"] "\u0120lo", "\u0120low", "er",
"\u0120lowest", "\u0120newer", "\u0120wider", "<unk>"]
vocab_tokens = dict(zip(vocab, range(len(vocab)))) vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "l o", "lo w", "e r", ""] merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
self.special_tokens_map = {"unk_token": "<unk>"} self.special_tokens_map = {"unk_token": "<unk>"}
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
with open(self.vocab_file, "w") as fp: with open(self.vocab_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(vocab_tokens)) fp.write(json.dumps(vocab_tokens))
with open(self.merges_file, "w") as fp: with open(self.merges_file, "w", encoding="utf-8") as fp:
fp.write("\n".join(merges)) fp.write("\n".join(merges))
def get_tokenizer(self): def get_tokenizer(self):
...@@ -49,18 +50,18 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -49,18 +50,18 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"lower newer" input_text = u"lower newer"
output_text = u"lower<unk>newer" output_text = u" lower newer"
return input_text, output_text return input_text, output_text
def test_full_tokenizer(self): def test_full_tokenizer(self):
tokenizer = GPT2Tokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) tokenizer = GPT2Tokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
text = "lower" text = "lower"
bpe_tokens = ["low", "er"] bpe_tokens = ["\u0120low", "er"]
tokens = tokenizer.tokenize(text) tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens) self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + [tokenizer.unk_token] input_tokens = tokens + [tokenizer.unk_token]
input_bpe_tokens = [13, 12, 17] input_bpe_tokens = [14, 15, 19]
self.assertListEqual( self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
......
...@@ -30,17 +30,18 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -30,17 +30,18 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
"lo", "low", "er", "\u0120", "\u0120l", "\u0120n",
"low", "lowest", "newer", "wider", "<unk>"] "\u0120lo", "\u0120low", "er",
"\u0120lowest", "\u0120newer", "\u0120wider", "<unk>"]
vocab_tokens = dict(zip(vocab, range(len(vocab)))) vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "l o", "lo w", "e r", ""] merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
self.special_tokens_map = {"unk_token": "<unk>"} self.special_tokens_map = {"unk_token": "<unk>"}
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
with open(self.vocab_file, "w") as fp: with open(self.vocab_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(vocab_tokens)) fp.write(json.dumps(vocab_tokens))
with open(self.merges_file, "w") as fp: with open(self.merges_file, "w", encoding="utf-8") as fp:
fp.write("\n".join(merges)) fp.write("\n".join(merges))
def get_tokenizer(self): def get_tokenizer(self):
...@@ -48,18 +49,18 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -48,18 +49,18 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"lower newer" input_text = u"lower newer"
output_text = u"lower<unk>newer" output_text = u" lower newer"
return input_text, output_text return input_text, output_text
def test_full_tokenizer(self): def test_full_tokenizer(self):
tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
text = "lower" text = "lower"
bpe_tokens = ["low", "er"] bpe_tokens = ["\u0120low", "er"]
tokens = tokenizer.tokenize(text) tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens) self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + [tokenizer.unk_token] input_tokens = tokens + [tokenizer.unk_token]
input_bpe_tokens = [13, 12, 17] input_bpe_tokens = [14, 15, 19]
self.assertListEqual( self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
......
...@@ -129,7 +129,7 @@ class CommonTestCases: ...@@ -129,7 +129,7 @@ class CommonTestCases:
self.assertGreater(tokens[-2], tokenizer.vocab_size - 1) self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
self.assertGreater(tokens[-2], tokens[-3]) self.assertGreater(tokens[-2], tokens[-3])
self.assertEqual(tokens[0], tokenizer.eos_token_id) self.assertEqual(tokens[0], tokenizer.eos_token_id)
self.assertEqual(tokens[-2], tokenizer.eos_token_id) self.assertEqual(tokens[-2], tokenizer.pad_token_id)
def test_required_methods_tokenizer(self): def test_required_methods_tokenizer(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment