Unverified Commit 3d78e226 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #489 from huggingface/tokenization_serialization

Better serialization for Tokenizers and Configuration classes - Also fix #466
parents 64b6ef4d 3571187e
...@@ -52,5 +52,21 @@ class OpenAIGPTTokenizationTest(unittest.TestCase): ...@@ -52,5 +52,21 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
self.assertListEqual( self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
vocab_file, merges_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
tokenizer.from_pretrained("/tmp/")
os.remove(vocab_file)
os.remove(merges_file)
text = "lower"
bpe_tokens = ["low", "er</w>"]
tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + ["<unk>"]
input_bpe_tokens = [14, 15, 20]
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -46,6 +46,17 @@ class TokenizationTest(unittest.TestCase): ...@@ -46,6 +46,17 @@ class TokenizationTest(unittest.TestCase):
self.assertListEqual( self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
vocab_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
tokenizer.from_pretrained(vocab_file)
os.remove(vocab_file)
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
def test_chinese(self): def test_chinese(self):
tokenizer = BasicTokenizer() tokenizer = BasicTokenizer()
......
...@@ -18,9 +18,7 @@ import os ...@@ -18,9 +18,7 @@ import os
import unittest import unittest
from io import open from io import open
from pytorch_pretrained_bert.tokenization_transfo_xl import (TransfoXLTokenizer, from pytorch_pretrained_bert.tokenization_transfo_xl import TransfoXLTokenizer
_is_control, _is_punctuation,
_is_whitespace)
class TransfoXLTokenizationTest(unittest.TestCase): class TransfoXLTokenizationTest(unittest.TestCase):
...@@ -37,54 +35,37 @@ class TransfoXLTokenizationTest(unittest.TestCase): ...@@ -37,54 +35,37 @@ class TransfoXLTokenizationTest(unittest.TestCase):
tokenizer.build_vocab() tokenizer.build_vocab()
os.remove(vocab_file) os.remove(vocab_file)
tokens = tokenizer.tokenize(u"<unk> UNwant\u00E9d,running") tokens = tokenizer.tokenize(u"<unk> UNwanted , running")
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"]) self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
self.assertListEqual( self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7]) tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
vocab_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
tokenizer.from_pretrained(vocab_file)
os.remove(vocab_file)
tokens = tokenizer.tokenize(u"<unk> UNwanted , running")
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
def test_full_tokenizer_lower(self): def test_full_tokenizer_lower(self):
tokenizer = TransfoXLTokenizer(lower_case=True) tokenizer = TransfoXLTokenizer(lower_case=True)
self.assertListEqual( self.assertListEqual(
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "),
["hello", "!", "how", "are", "you", "?"]) ["hello", "!", "how", "are", "you", "?"])
self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
def test_full_tokenizer_no_lower(self): def test_full_tokenizer_no_lower(self):
tokenizer = TransfoXLTokenizer(lower_case=False) tokenizer = TransfoXLTokenizer(lower_case=False)
self.assertListEqual( self.assertListEqual(
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "),
["HeLLo", "!", "how", "Are", "yoU", "?"]) ["HeLLo", "!", "how", "Are", "yoU", "?"])
def test_is_whitespace(self):
self.assertTrue(_is_whitespace(u" "))
self.assertTrue(_is_whitespace(u"\t"))
self.assertTrue(_is_whitespace(u"\r"))
self.assertTrue(_is_whitespace(u"\n"))
self.assertTrue(_is_whitespace(u"\u00A0"))
self.assertFalse(_is_whitespace(u"A"))
self.assertFalse(_is_whitespace(u"-"))
def test_is_control(self):
self.assertTrue(_is_control(u"\u0005"))
self.assertFalse(_is_control(u"A"))
self.assertFalse(_is_control(u" "))
self.assertFalse(_is_control(u"\t"))
self.assertFalse(_is_control(u"\r"))
def test_is_punctuation(self):
self.assertTrue(_is_punctuation(u"-"))
self.assertTrue(_is_punctuation(u"$"))
self.assertTrue(_is_punctuation(u"`"))
self.assertTrue(_is_punctuation(u"."))
self.assertFalse(_is_punctuation(u"A"))
self.assertFalse(_is_punctuation(u" "))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment