Commit abe734ca authored by thomwolf's avatar thomwolf
Browse files

fix GPT-2 and RoBERTa tests to be clean now

parent 0f5a7994
......@@ -31,17 +31,18 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
"lo", "low", "er",
"low", "lowest", "newer", "wider", "<unk>"]
"\u0120", "\u0120l", "\u0120n",
"\u0120lo", "\u0120low", "er",
"\u0120lowest", "\u0120newer", "\u0120wider", "<unk>"]
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "l o", "lo w", "e r", ""]
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
self.special_tokens_map = {"unk_token": "<unk>"}
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
with open(self.vocab_file, "w") as fp:
with open(self.vocab_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(vocab_tokens))
with open(self.merges_file, "w") as fp:
with open(self.merges_file, "w", encoding="utf-8") as fp:
fp.write("\n".join(merges))
def get_tokenizer(self):
......@@ -49,18 +50,18 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
def get_input_output_texts(self):
input_text = u"lower newer"
output_text = u"lower<unk>newer"
output_text = u" lower newer"
return input_text, output_text
def test_full_tokenizer(self):
tokenizer = GPT2Tokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
text = "lower"
bpe_tokens = ["low", "er"]
bpe_tokens = ["\u0120low", "er"]
tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + [tokenizer.unk_token]
input_bpe_tokens = [13, 12, 17]
input_bpe_tokens = [14, 15, 19]
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
......
......@@ -30,17 +30,18 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
"lo", "low", "er",
"low", "lowest", "newer", "wider", "<unk>"]
"\u0120", "\u0120l", "\u0120n",
"\u0120lo", "\u0120low", "er",
"\u0120lowest", "\u0120newer", "\u0120wider", "<unk>"]
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "l o", "lo w", "e r", ""]
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
self.special_tokens_map = {"unk_token": "<unk>"}
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
with open(self.vocab_file, "w") as fp:
with open(self.vocab_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(vocab_tokens))
with open(self.merges_file, "w") as fp:
with open(self.merges_file, "w", encoding="utf-8") as fp:
fp.write("\n".join(merges))
def get_tokenizer(self):
......@@ -48,18 +49,18 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
def get_input_output_texts(self):
input_text = u"lower newer"
output_text = u"lower<unk>newer"
output_text = u" lower newer"
return input_text, output_text
def test_full_tokenizer(self):
tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
text = "lower"
bpe_tokens = ["low", "er"]
bpe_tokens = ["\u0120low", "er"]
tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + [tokenizer.unk_token]
input_bpe_tokens = [13, 12, 17]
input_bpe_tokens = [14, 15, 19]
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
......
......@@ -111,7 +111,7 @@ class CommonTestCases:
self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
new_toks_2 = {'eos_token': ">>>>|||<||<<|<<",
'pad_token': "<<<<<|||>|>>>>|>"}
'pad_token': "<<<<<|||>|>>>>|>"}
added_toks_2 = tokenizer.add_special_tokens(new_toks_2)
vocab_size_3 = tokenizer.vocab_size
all_size_3 = len(tokenizer)
......@@ -129,7 +129,7 @@ class CommonTestCases:
self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
self.assertGreater(tokens[-2], tokens[-3])
self.assertEqual(tokens[0], tokenizer.eos_token_id)
self.assertEqual(tokens[-2], tokenizer.eos_token_id)
self.assertEqual(tokens[-2], tokenizer.pad_token_id)
def test_required_methods_tokenizer(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment