Unverified Commit 2818e505 authored by Anthony MOI's avatar Anthony MOI
Browse files

Add tests for fast tokenizers

parent 31c56f2e
...@@ -21,6 +21,7 @@ from transformers.tokenization_bert import ( ...@@ -21,6 +21,7 @@ from transformers.tokenization_bert import (
VOCAB_FILES_NAMES, VOCAB_FILES_NAMES,
BasicTokenizer, BasicTokenizer,
BertTokenizer, BertTokenizer,
BertTokenizerFast,
WordpieceTokenizer, WordpieceTokenizer,
_is_control, _is_control,
_is_punctuation, _is_punctuation,
...@@ -34,6 +35,7 @@ from .utils import slow ...@@ -34,6 +35,7 @@ from .utils import slow
class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase): class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = BertTokenizer tokenizer_class = BertTokenizer
test_rust_tokenizer = True
def setUp(self): def setUp(self):
super(BertTokenizationTest, self).setUp() super(BertTokenizationTest, self).setUp()
...@@ -60,6 +62,9 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -60,6 +62,9 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def get_tokenizer(self, **kwargs): def get_tokenizer(self, **kwargs):
return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs) return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_rust_tokenizer(self, **kwargs):
return BertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = "UNwant\u00E9d,running" input_text = "UNwant\u00E9d,running"
output_text = "unwanted, running" output_text = "unwanted, running"
...@@ -72,6 +77,28 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -72,6 +77,28 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
def test_rust_and_python_full_tokenizers(self):
if not self.test_rust_tokenizer:
return
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer(add_special_tokens=False)
sequence = u"UNwant\u00E9d,running"
tokens = tokenizer.tokenize(sequence)
rust_tokens = rust_tokenizer.tokenize(sequence)
self.assertListEqual(tokens, rust_tokens)
ids = tokenizer.encode(sequence, add_special_tokens=False)
rust_ids = rust_tokenizer.encode(sequence)
self.assertListEqual(ids, rust_ids)
rust_tokenizer = self.get_rust_tokenizer()
ids = tokenizer.encode(sequence)
rust_ids = rust_tokenizer.encode(sequence)
self.assertListEqual(ids, rust_ids)
def test_chinese(self): def test_chinese(self):
tokenizer = BasicTokenizer() tokenizer = BasicTokenizer()
......
...@@ -23,6 +23,7 @@ import tempfile ...@@ -23,6 +23,7 @@ import tempfile
class TokenizerTesterMixin: class TokenizerTesterMixin:
tokenizer_class = None tokenizer_class = None
test_rust_tokenizer = False
def setUp(self): def setUp(self):
self.tmpdirname = tempfile.mkdtemp() self.tmpdirname = tempfile.mkdtemp()
...@@ -33,6 +34,9 @@ class TokenizerTesterMixin: ...@@ -33,6 +34,9 @@ class TokenizerTesterMixin:
def get_tokenizer(self, **kwargs): def get_tokenizer(self, **kwargs):
raise NotImplementedError raise NotImplementedError
def get_rust_tokenizer(self, **kwargs):
raise NotImplementedError
def get_input_output_texts(self): def get_input_output_texts(self):
raise NotImplementedError raise NotImplementedError
......
...@@ -18,7 +18,7 @@ import json ...@@ -18,7 +18,7 @@ import json
import os import os
import unittest import unittest
from transformers.tokenization_gpt2 import VOCAB_FILES_NAMES, GPT2Tokenizer from transformers.tokenization_gpt2 import VOCAB_FILES_NAMES, GPT2Tokenizer, GPT2TokenizerFast
from .test_tokenization_common import TokenizerTesterMixin from .test_tokenization_common import TokenizerTesterMixin
...@@ -26,6 +26,7 @@ from .test_tokenization_common import TokenizerTesterMixin ...@@ -26,6 +26,7 @@ from .test_tokenization_common import TokenizerTesterMixin
class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase): class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = GPT2Tokenizer tokenizer_class = GPT2Tokenizer
test_rust_tokenizer = True
def setUp(self): def setUp(self):
super(GPT2TokenizationTest, self).setUp() super(GPT2TokenizationTest, self).setUp()
...@@ -68,6 +69,10 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -68,6 +69,10 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
kwargs.update(self.special_tokens_map) kwargs.update(self.special_tokens_map)
return GPT2Tokenizer.from_pretrained(self.tmpdirname, **kwargs) return GPT2Tokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_rust_tokenizer(self, **kwargs):
kwargs.update(self.special_tokens_map)
return GPT2TokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = "lower newer" input_text = "lower newer"
output_text = "lower newer" output_text = "lower newer"
...@@ -83,3 +88,33 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -83,3 +88,33 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
input_tokens = tokens + [tokenizer.unk_token] input_tokens = tokens + [tokenizer.unk_token]
input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19] input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19]
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
def test_rust_and_python_full_tokenizers(self):
if not self.test_rust_tokenizer:
return
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer(add_special_tokens=False, add_prefix_space=True)
sequence = u"lower newer"
# Testing tokenization
tokens = tokenizer.tokenize(sequence, add_prefix_space=True)
rust_tokens = rust_tokenizer.tokenize(sequence)
self.assertListEqual(tokens, rust_tokens)
# Testing conversion to ids without special tokens
ids = tokenizer.encode(sequence, add_special_tokens=False, add_prefix_space=True)
rust_ids = rust_tokenizer.encode(sequence)
self.assertListEqual(ids, rust_ids)
# Testing conversion to ids with special tokens
rust_tokenizer = self.get_rust_tokenizer(add_prefix_space=True)
ids = tokenizer.encode(sequence, add_prefix_space=True)
rust_ids = rust_tokenizer.encode(sequence)
self.assertListEqual(ids, rust_ids)
# Testing the unknown token
input_tokens = tokens + [rust_tokenizer.unk_token]
input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19]
self.assertListEqual(rust_tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment