Commit 00204f2b authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Replace CommonTestCases for tokenizers with a mixin.

This is the same change as for (TF)CommonTestCases for modeling.
parent a3c5883f
...@@ -15,14 +15,15 @@ ...@@ -15,14 +15,15 @@
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import os import os
import unittest
from io import open from io import open
from transformers.tokenization_bert import VOCAB_FILES_NAMES, XxxTokenizer from transformers.tokenization_bert import VOCAB_FILES_NAMES, XxxTokenizer
from .test_tokenization_commo import CommonTestCases from .test_tokenization_common import TokenizerTesterMixin
class XxxTokenizationTest(CommonTestCases.CommonTokenizerTester): class XxxTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = XxxTokenizer tokenizer_class = XxxTokenizer
......
...@@ -17,7 +17,7 @@ from __future__ import absolute_import, division, print_function ...@@ -17,7 +17,7 @@ from __future__ import absolute_import, division, print_function
import json import json
import os import os
from .test_tokenization_commo import TemporaryDirectory from .test_tokenization_common import TemporaryDirectory
class ConfigTester(object): class ConfigTester(object):
......
...@@ -20,7 +20,7 @@ import unittest ...@@ -20,7 +20,7 @@ import unittest
from transformers.modelcard import ModelCard from transformers.modelcard import ModelCard
from .test_tokenization_commo import TemporaryDirectory from .test_tokenization_common import TemporaryDirectory
class ModelCardTester(unittest.TestCase): class ModelCardTester(unittest.TestCase):
......
...@@ -19,7 +19,7 @@ import unittest ...@@ -19,7 +19,7 @@ import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .test_tokenization_commo import TemporaryDirectory from .test_tokenization_common import TemporaryDirectory
from .utils import require_torch from .utils import require_torch
......
...@@ -15,16 +15,17 @@ ...@@ -15,16 +15,17 @@
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import os import os
import unittest
from transformers.tokenization_albert import AlbertTokenizer from transformers.tokenization_albert import AlbertTokenizer
from .test_tokenization_commo import CommonTestCases from .test_tokenization_common import TokenizerTesterMixin
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/spiece.model") SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/spiece.model")
class AlbertTokenizationTest(CommonTestCases.CommonTokenizerTester): class AlbertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = AlbertTokenizer tokenizer_class = AlbertTokenizer
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import os import os
import unittest
from io import open from io import open
from transformers.tokenization_bert import ( from transformers.tokenization_bert import (
...@@ -27,11 +28,11 @@ from transformers.tokenization_bert import ( ...@@ -27,11 +28,11 @@ from transformers.tokenization_bert import (
_is_whitespace, _is_whitespace,
) )
from .test_tokenization_commo import CommonTestCases from .test_tokenization_common import TokenizerTesterMixin
from .utils import slow from .utils import slow
class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = BertTokenizer tokenizer_class = BertTokenizer
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import os import os
import unittest
from io import open from io import open
from transformers.tokenization_bert import WordpieceTokenizer from transformers.tokenization_bert import WordpieceTokenizer
...@@ -25,12 +26,12 @@ from transformers.tokenization_bert_japanese import ( ...@@ -25,12 +26,12 @@ from transformers.tokenization_bert_japanese import (
MecabTokenizer, MecabTokenizer,
) )
from .test_tokenization_commo import CommonTestCases from .test_tokenization_common import TokenizerTesterMixin
from .utils import custom_tokenizers, slow from .utils import custom_tokenizers, slow
@custom_tokenizers @custom_tokenizers
class BertJapaneseTokenizationTest(CommonTestCases.CommonTokenizerTester): class BertJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = BertJapaneseTokenizer tokenizer_class = BertJapaneseTokenizer
...@@ -130,7 +131,7 @@ class BertJapaneseTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -130,7 +131,7 @@ class BertJapaneseTokenizationTest(CommonTestCases.CommonTokenizerTester):
assert encoded_pair == [2] + text + [3] + text_2 + [3] assert encoded_pair == [2] + text + [3] + text_2 + [3]
class BertJapaneseCharacterTokenizationTest(CommonTestCases.CommonTokenizerTester): class BertJapaneseCharacterTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = BertJapaneseTokenizer tokenizer_class = BertJapaneseTokenizer
......
...@@ -18,7 +18,6 @@ import os ...@@ -18,7 +18,6 @@ import os
import shutil import shutil
import sys import sys
import tempfile import tempfile
import unittest
from io import open from io import open
...@@ -43,8 +42,7 @@ else: ...@@ -43,8 +42,7 @@ else:
unicode = str unicode = str
class CommonTestCases: class TokenizerTesterMixin:
class CommonTokenizerTester(unittest.TestCase):
tokenizer_class = None tokenizer_class = None
...@@ -305,11 +303,7 @@ class CommonTestCases: ...@@ -305,11 +303,7 @@ class CommonTestCases:
num_added_tokens = tokenizer.num_added_tokens() num_added_tokens = tokenizer.num_added_tokens()
total_length = len(sequence) + num_added_tokens total_length = len(sequence) + num_added_tokens
information = tokenizer.encode_plus( information = tokenizer.encode_plus(
seq_0, seq_0, max_length=total_length - 2, add_special_tokens=True, stride=stride, return_overflowing_tokens=True,
max_length=total_length - 2,
add_special_tokens=True,
stride=stride,
return_overflowing_tokens=True,
) )
truncated_sequence = information["input_ids"] truncated_sequence = information["input_ids"]
...@@ -332,8 +326,7 @@ class CommonTestCases: ...@@ -332,8 +326,7 @@ class CommonTestCases:
sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=True) sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=True)
truncated_second_sequence = tokenizer.build_inputs_with_special_tokens( truncated_second_sequence = tokenizer.build_inputs_with_special_tokens(
tokenizer.encode(seq_0, add_special_tokens=False), tokenizer.encode(seq_0, add_special_tokens=False), tokenizer.encode(seq_1, add_special_tokens=False)[:-2],
tokenizer.encode(seq_1, add_special_tokens=False)[:-2],
) )
information = tokenizer.encode_plus( information = tokenizer.encode_plus(
...@@ -440,9 +433,7 @@ class CommonTestCases: ...@@ -440,9 +433,7 @@ class CommonTestCases:
tokenizer.padding_side = "right" tokenizer.padding_side = "right"
encoded_sequence = tokenizer.encode(sequence) encoded_sequence = tokenizer.encode(sequence)
sequence_length = len(encoded_sequence) sequence_length = len(encoded_sequence)
padded_sequence = tokenizer.encode( padded_sequence = tokenizer.encode(sequence, max_length=sequence_length + padding_size, pad_to_max_length=True)
sequence, max_length=sequence_length + padding_size, pad_to_max_length=True
)
padded_sequence_length = len(padded_sequence) padded_sequence_length = len(padded_sequence)
assert sequence_length + padding_size == padded_sequence_length assert sequence_length + padding_size == padded_sequence_length
assert encoded_sequence + [padding_idx] * padding_size == padded_sequence assert encoded_sequence + [padding_idx] * padding_size == padded_sequence
...@@ -451,9 +442,7 @@ class CommonTestCases: ...@@ -451,9 +442,7 @@ class CommonTestCases:
tokenizer.padding_side = "left" tokenizer.padding_side = "left"
encoded_sequence = tokenizer.encode(sequence) encoded_sequence = tokenizer.encode(sequence)
sequence_length = len(encoded_sequence) sequence_length = len(encoded_sequence)
padded_sequence = tokenizer.encode( padded_sequence = tokenizer.encode(sequence, max_length=sequence_length + padding_size, pad_to_max_length=True)
sequence, max_length=sequence_length + padding_size, pad_to_max_length=True
)
padded_sequence_length = len(padded_sequence) padded_sequence_length = len(padded_sequence)
assert sequence_length + padding_size == padded_sequence_length assert sequence_length + padding_size == padded_sequence_length
assert [padding_idx] * padding_size + encoded_sequence == padded_sequence assert [padding_idx] * padding_size + encoded_sequence == padded_sequence
......
...@@ -15,14 +15,15 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -15,14 +15,15 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import json import json
import os import os
import unittest
from io import open from io import open
from transformers.tokenization_ctrl import VOCAB_FILES_NAMES, CTRLTokenizer from transformers.tokenization_ctrl import VOCAB_FILES_NAMES, CTRLTokenizer
from .test_tokenization_commo import CommonTestCases from .test_tokenization_common import TokenizerTesterMixin
class CTRLTokenizationTest(CommonTestCases.CommonTokenizerTester): class CTRLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = CTRLTokenizer tokenizer_class = CTRLTokenizer
......
...@@ -16,14 +16,15 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -16,14 +16,15 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import json import json
import os import os
import unittest
from io import open from io import open
from transformers.tokenization_gpt2 import VOCAB_FILES_NAMES, GPT2Tokenizer from transformers.tokenization_gpt2 import VOCAB_FILES_NAMES, GPT2Tokenizer
from .test_tokenization_commo import CommonTestCases from .test_tokenization_common import TokenizerTesterMixin
class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester): class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = GPT2Tokenizer tokenizer_class = GPT2Tokenizer
......
...@@ -16,13 +16,14 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -16,13 +16,14 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import json import json
import os import os
import unittest
from transformers.tokenization_openai import VOCAB_FILES_NAMES, OpenAIGPTTokenizer from transformers.tokenization_openai import VOCAB_FILES_NAMES, OpenAIGPTTokenizer
from .test_tokenization_commo import CommonTestCases from .test_tokenization_common import TokenizerTesterMixin
class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester): class OpenAIGPTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = OpenAIGPTTokenizer tokenizer_class = OpenAIGPTTokenizer
......
...@@ -16,15 +16,16 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -16,15 +16,16 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import json import json
import os import os
import unittest
from io import open from io import open
from transformers.tokenization_roberta import VOCAB_FILES_NAMES, RobertaTokenizer from transformers.tokenization_roberta import VOCAB_FILES_NAMES, RobertaTokenizer
from .test_tokenization_commo import CommonTestCases from .test_tokenization_common import TokenizerTesterMixin
from .utils import slow from .utils import slow
class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = RobertaTokenizer tokenizer_class = RobertaTokenizer
def setUp(self): def setUp(self):
......
...@@ -15,17 +15,18 @@ ...@@ -15,17 +15,18 @@
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import os import os
import unittest
from transformers.tokenization_t5 import T5Tokenizer from transformers.tokenization_t5 import T5Tokenizer
from transformers.tokenization_xlnet import SPIECE_UNDERLINE from transformers.tokenization_xlnet import SPIECE_UNDERLINE
from .test_tokenization_commo import CommonTestCases from .test_tokenization_common import TokenizerTesterMixin
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model") SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
class T5TokenizationTest(CommonTestCases.CommonTokenizerTester): class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = T5Tokenizer tokenizer_class = T5Tokenizer
......
...@@ -15,11 +15,12 @@ ...@@ -15,11 +15,12 @@
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import os import os
import unittest
from io import open from io import open
from transformers import is_torch_available from transformers import is_torch_available
from .test_tokenization_commo import CommonTestCases from .test_tokenization_common import TokenizerTesterMixin
from .utils import require_torch from .utils import require_torch
...@@ -28,7 +29,7 @@ if is_torch_available(): ...@@ -28,7 +29,7 @@ if is_torch_available():
@require_torch @require_torch
class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester): class TransfoXLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = TransfoXLTokenizer if is_torch_available() else None tokenizer_class = TransfoXLTokenizer if is_torch_available() else None
......
...@@ -16,14 +16,15 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -16,14 +16,15 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import json import json
import os import os
import unittest
from transformers.tokenization_xlm import VOCAB_FILES_NAMES, XLMTokenizer from transformers.tokenization_xlm import VOCAB_FILES_NAMES, XLMTokenizer
from .test_tokenization_commo import CommonTestCases from .test_tokenization_common import TokenizerTesterMixin
from .utils import slow from .utils import slow
class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester): class XLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = XLMTokenizer tokenizer_class = XLMTokenizer
......
...@@ -15,17 +15,18 @@ ...@@ -15,17 +15,18 @@
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import os import os
import unittest
from transformers.tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer from transformers.tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
from .test_tokenization_commo import CommonTestCases from .test_tokenization_common import TokenizerTesterMixin
from .utils import slow from .utils import slow
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model") SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester): class XLNetTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = XLNetTokenizer tokenizer_class = XLNetTokenizer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment