Commit f2a337b3 authored by thomwolf's avatar thomwolf
Browse files

fix tokenization tests for gpt2 roberta

parent 7a99e4b1
...@@ -52,14 +52,14 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -52,14 +52,14 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"lower newer" input_text = u"lower newer"
output_text = u" lower newer" output_text = u"lower newer"
return input_text, output_text return input_text, output_text
def test_full_tokenizer(self): def test_full_tokenizer(self):
tokenizer = GPT2Tokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) tokenizer = GPT2Tokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
text = "lower newer" text = "lower newer"
bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"] bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"]
tokens = tokenizer.tokenize(text) tokens = tokenizer.tokenize(text, add_prefix_space=True)
self.assertListEqual(tokens, bpe_tokens) self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + [tokenizer.unk_token] input_tokens = tokens + [tokenizer.unk_token]
......
...@@ -51,14 +51,14 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -51,14 +51,14 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"lower newer" input_text = u"lower newer"
output_text = u" lower newer" output_text = u"lower newer"
return input_text, output_text return input_text, output_text
def test_full_tokenizer(self): def test_full_tokenizer(self):
tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
text = "lower newer" text = "lower newer"
bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"] bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"]
tokens = tokenizer.tokenize(text) tokens = tokenizer.tokenize(text, add_prefix_space=True)
self.assertListEqual(tokens, bpe_tokens) self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + [tokenizer.unk_token] input_tokens = tokens + [tokenizer.unk_token]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment