Commit 870b734b authored by thomwolf's avatar thomwolf
Browse files

added tokenizers serialization tests

parent 3e65f255
...@@ -146,6 +146,7 @@ class BertTokenizer(object): ...@@ -146,6 +146,7 @@ class BertTokenizer(object):
index = token_index index = token_index
writer.write(token + u'\n') writer.write(token + u'\n')
index += 1 index += 1
return vocab_file
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
......
...@@ -188,7 +188,10 @@ class GPT2Tokenizer(object): ...@@ -188,7 +188,10 @@ class GPT2Tokenizer(object):
return word return word
def save_vocabulary(self, vocab_path): def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary to a path.""" """Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(vocab_path):
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
return
vocab_file = os.path.join(vocab_path, VOCAB_NAME) vocab_file = os.path.join(vocab_path, VOCAB_NAME)
merge_file = os.path.join(vocab_path, MERGES_NAME) merge_file = os.path.join(vocab_path, MERGES_NAME)
json.dump(self.encoder, vocab_file) json.dump(self.encoder, vocab_file)
...@@ -202,6 +205,7 @@ class GPT2Tokenizer(object): ...@@ -202,6 +205,7 @@ class GPT2Tokenizer(object):
index = token_index index = token_index
writer.write(bpe_tokens + u'\n') writer.write(bpe_tokens + u'\n')
index += 1 index += 1
return vocab_file, merge_file
def encode(self, text): def encode(self, text):
bpe_tokens = [] bpe_tokens = []
......
...@@ -263,7 +263,10 @@ class OpenAIGPTTokenizer(object): ...@@ -263,7 +263,10 @@ class OpenAIGPTTokenizer(object):
return out_string return out_string
def save_vocabulary(self, vocab_path): def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary to a path.""" """Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(vocab_path):
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
return
vocab_file = os.path.join(vocab_path, VOCAB_NAME) vocab_file = os.path.join(vocab_path, VOCAB_NAME)
merge_file = os.path.join(vocab_path, MERGES_NAME) merge_file = os.path.join(vocab_path, MERGES_NAME)
json.dump(self.encoder, vocab_file) json.dump(self.encoder, vocab_file)
...@@ -277,3 +280,4 @@ class OpenAIGPTTokenizer(object): ...@@ -277,3 +280,4 @@ class OpenAIGPTTokenizer(object):
index = token_index index = token_index
writer.write(bpe_tokens + u'\n') writer.write(bpe_tokens + u'\n')
index += 1 index += 1
return vocab_file, merge_file
...@@ -148,6 +148,7 @@ class TransfoXLTokenizer(object): ...@@ -148,6 +148,7 @@ class TransfoXLTokenizer(object):
index = 0 index = 0
vocab_file = os.path.join(vocab_path, VOCAB_NAME) vocab_file = os.path.join(vocab_path, VOCAB_NAME)
torch.save(self.__dict__, vocab_file) torch.save(self.__dict__, vocab_file)
return vocab_file
def build_vocab(self): def build_vocab(self):
if self.vocab_file: if self.vocab_file:
......
...@@ -52,5 +52,21 @@ class OpenAIGPTTokenizationTest(unittest.TestCase): ...@@ -52,5 +52,21 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
self.assertListEqual( self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
vocab_file, merges_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
tokenizer.from_pretrained("/tmp/")
os.remove(vocab_file)
os.remove(merges_file)
text = "lower"
bpe_tokens = ["low", "er</w>"]
tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + ["<unk>"]
input_bpe_tokens = [14, 15, 20]
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -46,6 +46,17 @@ class TokenizationTest(unittest.TestCase): ...@@ -46,6 +46,17 @@ class TokenizationTest(unittest.TestCase):
self.assertListEqual( self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
vocab_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
tokenizer.from_pretrained(vocab_file)
os.remove(vocab_file)
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
def test_chinese(self): def test_chinese(self):
tokenizer = BasicTokenizer() tokenizer = BasicTokenizer()
......
...@@ -18,9 +18,7 @@ import os ...@@ -18,9 +18,7 @@ import os
import unittest import unittest
from io import open from io import open
from pytorch_pretrained_bert.tokenization_transfo_xl import (TransfoXLTokenizer, from pytorch_pretrained_bert.tokenization_transfo_xl import TransfoXLTokenizer
_is_control, _is_punctuation,
_is_whitespace)
class TransfoXLTokenizationTest(unittest.TestCase): class TransfoXLTokenizationTest(unittest.TestCase):
...@@ -43,6 +41,17 @@ class TransfoXLTokenizationTest(unittest.TestCase): ...@@ -43,6 +41,17 @@ class TransfoXLTokenizationTest(unittest.TestCase):
self.assertListEqual( self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7]) tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
vocab_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
tokenizer.from_pretrained(vocab_file)
os.remove(vocab_file)
tokens = tokenizer.tokenize(u"<unk> UNwant\u00E9d,running")
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
def test_full_tokenizer_lower(self): def test_full_tokenizer_lower(self):
tokenizer = TransfoXLTokenizer(lower_case=True) tokenizer = TransfoXLTokenizer(lower_case=True)
...@@ -58,33 +67,6 @@ class TransfoXLTokenizationTest(unittest.TestCase): ...@@ -58,33 +67,6 @@ class TransfoXLTokenizationTest(unittest.TestCase):
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
["HeLLo", "!", "how", "Are", "yoU", "?"]) ["HeLLo", "!", "how", "Are", "yoU", "?"])
def test_is_whitespace(self):
self.assertTrue(_is_whitespace(u" "))
self.assertTrue(_is_whitespace(u"\t"))
self.assertTrue(_is_whitespace(u"\r"))
self.assertTrue(_is_whitespace(u"\n"))
self.assertTrue(_is_whitespace(u"\u00A0"))
self.assertFalse(_is_whitespace(u"A"))
self.assertFalse(_is_whitespace(u"-"))
def test_is_control(self):
self.assertTrue(_is_control(u"\u0005"))
self.assertFalse(_is_control(u"A"))
self.assertFalse(_is_control(u" "))
self.assertFalse(_is_control(u"\t"))
self.assertFalse(_is_control(u"\r"))
def test_is_punctuation(self):
self.assertTrue(_is_punctuation(u"-"))
self.assertTrue(_is_punctuation(u"$"))
self.assertTrue(_is_punctuation(u"`"))
self.assertTrue(_is_punctuation(u"."))
self.assertFalse(_is_punctuation(u"A"))
self.assertFalse(_is_punctuation(u" "))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment