Unverified Commit 11bbb505 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Adds pretrained IDs directly in the tests (#29534)

* Adds pretrained IDs directly in the tests

* Fix tests

* Fix tests

* Review!
parent 38bff8c8
...@@ -26,6 +26,7 @@ from ...test_tokenization_common import TokenizerTesterMixin ...@@ -26,6 +26,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
class DebertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): class DebertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "microsoft/deberta-base"
tokenizer_class = DebertaTokenizer tokenizer_class = DebertaTokenizer
test_rust_tokenizer = True test_rust_tokenizer = True
rust_tokenizer_class = DebertaTokenizerFast rust_tokenizer_class = DebertaTokenizerFast
......
...@@ -27,6 +27,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/spiece.model") ...@@ -27,6 +27,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/spiece.model")
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
class DebertaV2TokenizationTest(TokenizerTesterMixin, unittest.TestCase): class DebertaV2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "microsoft/deberta-v2-xlarge"
tokenizer_class = DebertaV2Tokenizer tokenizer_class = DebertaV2Tokenizer
rust_tokenizer_class = DebertaV2TokenizerFast rust_tokenizer_class = DebertaV2TokenizerFast
test_sentencepiece = True test_sentencepiece = True
......
...@@ -25,6 +25,7 @@ class DistilBertTokenizationTest(BertTokenizationTest): ...@@ -25,6 +25,7 @@ class DistilBertTokenizationTest(BertTokenizationTest):
tokenizer_class = DistilBertTokenizer tokenizer_class = DistilBertTokenizer
rust_tokenizer_class = DistilBertTokenizerFast rust_tokenizer_class = DistilBertTokenizerFast
test_rust_tokenizer = True test_rust_tokenizer = True
from_pretrained_id = "distilbert/distilbert-base-uncased"
@slow @slow
def test_sequence_builders(self): def test_sequence_builders(self):
......
...@@ -33,6 +33,7 @@ class DPRContextEncoderTokenizationTest(BertTokenizationTest): ...@@ -33,6 +33,7 @@ class DPRContextEncoderTokenizationTest(BertTokenizationTest):
tokenizer_class = DPRContextEncoderTokenizer tokenizer_class = DPRContextEncoderTokenizer
rust_tokenizer_class = DPRContextEncoderTokenizerFast rust_tokenizer_class = DPRContextEncoderTokenizerFast
test_rust_tokenizer = True test_rust_tokenizer = True
from_pretrained_id = "facebook/dpr-ctx_encoder-single-nq-base"
@require_tokenizers @require_tokenizers
...@@ -40,6 +41,7 @@ class DPRQuestionEncoderTokenizationTest(BertTokenizationTest): ...@@ -40,6 +41,7 @@ class DPRQuestionEncoderTokenizationTest(BertTokenizationTest):
tokenizer_class = DPRQuestionEncoderTokenizer tokenizer_class = DPRQuestionEncoderTokenizer
rust_tokenizer_class = DPRQuestionEncoderTokenizerFast rust_tokenizer_class = DPRQuestionEncoderTokenizerFast
test_rust_tokenizer = True test_rust_tokenizer = True
from_pretrained_id = "facebook/dpr-ctx_encoder-single-nq-base"
@require_tokenizers @require_tokenizers
...@@ -47,6 +49,7 @@ class DPRReaderTokenizationTest(BertTokenizationTest): ...@@ -47,6 +49,7 @@ class DPRReaderTokenizationTest(BertTokenizationTest):
tokenizer_class = DPRReaderTokenizer tokenizer_class = DPRReaderTokenizer
rust_tokenizer_class = DPRReaderTokenizerFast rust_tokenizer_class = DPRReaderTokenizerFast
test_rust_tokenizer = True test_rust_tokenizer = True
from_pretrained_id = "facebook/dpr-ctx_encoder-single-nq-base"
@slow @slow
def test_decode_best_spans(self): def test_decode_best_spans(self):
......
...@@ -33,6 +33,7 @@ from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english ...@@ -33,6 +33,7 @@ from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english
@require_tokenizers @require_tokenizers
class ElectraTokenizationTest(TokenizerTesterMixin, unittest.TestCase): class ElectraTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "google/electra-small-generator"
tokenizer_class = ElectraTokenizer tokenizer_class = ElectraTokenizer
rust_tokenizer_class = ElectraTokenizerFast rust_tokenizer_class = ElectraTokenizerFast
test_rust_tokenizer = True test_rust_tokenizer = True
......
...@@ -28,6 +28,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/spiece.model") ...@@ -28,6 +28,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/spiece.model")
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
class ErnieMTokenizationTest(TokenizerTesterMixin, unittest.TestCase): class ErnieMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "susnato/ernie-m-base_pytorch"
tokenizer_class = ErnieMTokenizer tokenizer_class = ErnieMTokenizer
test_seq2seq = False test_seq2seq = False
test_sentencepiece = True test_sentencepiece = True
......
...@@ -24,6 +24,7 @@ from ...test_tokenization_common import TokenizerTesterMixin ...@@ -24,6 +24,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_g2p_en @require_g2p_en
class FastSpeech2ConformerTokenizerTest(TokenizerTesterMixin, unittest.TestCase): class FastSpeech2ConformerTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "espnet/fastspeech2_conformer"
tokenizer_class = FastSpeech2ConformerTokenizer tokenizer_class = FastSpeech2ConformerTokenizer
test_rust_tokenizer = False test_rust_tokenizer = False
......
...@@ -28,6 +28,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/spiece.model") ...@@ -28,6 +28,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/spiece.model")
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
class FNetTokenizationTest(TokenizerTesterMixin, unittest.TestCase): class FNetTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "google/fnet-base"
tokenizer_class = FNetTokenizer tokenizer_class = FNetTokenizer
rust_tokenizer_class = FNetTokenizerFast rust_tokenizer_class = FNetTokenizerFast
test_rust_tokenizer = True test_rust_tokenizer = True
......
...@@ -30,6 +30,7 @@ FSMT_TINY2 = "stas/tiny-wmt19-en-ru" ...@@ -30,6 +30,7 @@ FSMT_TINY2 = "stas/tiny-wmt19-en-ru"
class FSMTTokenizationTest(TokenizerTesterMixin, unittest.TestCase): class FSMTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "stas/tiny-wmt19-en-de"
tokenizer_class = FSMTTokenizer tokenizer_class = FSMTTokenizer
test_rust_tokenizer = False test_rust_tokenizer = False
......
...@@ -26,6 +26,7 @@ from ...test_tokenization_common import TokenizerTesterMixin ...@@ -26,6 +26,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers @require_tokenizers
class FunnelTokenizationTest(TokenizerTesterMixin, unittest.TestCase): class FunnelTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "funnel-transformer/small"
tokenizer_class = FunnelTokenizer tokenizer_class = FunnelTokenizer
rust_tokenizer_class = FunnelTokenizerFast rust_tokenizer_class = FunnelTokenizerFast
test_rust_tokenizer = True test_rust_tokenizer = True
......
...@@ -49,6 +49,7 @@ if is_torch_available(): ...@@ -49,6 +49,7 @@ if is_torch_available():
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
class GemmaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): class GemmaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "google/gemma-7b"
tokenizer_class = GemmaTokenizer tokenizer_class = GemmaTokenizer
rust_tokenizer_class = GemmaTokenizerFast rust_tokenizer_class = GemmaTokenizerFast
......
...@@ -27,6 +27,7 @@ from ...test_tokenization_common import TokenizerTesterMixin ...@@ -27,6 +27,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers @require_tokenizers
class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase): class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "openai-community/gpt2"
tokenizer_class = GPT2Tokenizer tokenizer_class = GPT2Tokenizer
rust_tokenizer_class = GPT2TokenizerFast rust_tokenizer_class = GPT2TokenizerFast
test_rust_tokenizer = True test_rust_tokenizer = True
......
...@@ -29,6 +29,7 @@ from ...test_tokenization_common import TokenizerTesterMixin ...@@ -29,6 +29,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers @require_tokenizers
class GPTNeoXJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase): class GPTNeoXJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "abeja/gpt-neox-japanese-2.7b"
tokenizer_class = GPTNeoXJapaneseTokenizer tokenizer_class = GPTNeoXJapaneseTokenizer
test_rust_tokenizer = False test_rust_tokenizer = False
from_pretrained_kwargs = {"do_clean_text": False, "add_prefix_space": False} from_pretrained_kwargs = {"do_clean_text": False, "add_prefix_space": False}
......
...@@ -27,6 +27,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece_with_bytefallback.mode ...@@ -27,6 +27,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece_with_bytefallback.mode
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
class GPTSw3TokenizationTest(TokenizerTesterMixin, unittest.TestCase): class GPTSw3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "AI-Sweden-Models/gpt-sw3-126m"
tokenizer_class = GPTSw3Tokenizer tokenizer_class = GPTSw3Tokenizer
test_rust_tokenizer = False test_rust_tokenizer = False
test_sentencepiece = True test_sentencepiece = True
......
...@@ -29,6 +29,7 @@ from ...test_tokenization_common import TokenizerTesterMixin ...@@ -29,6 +29,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers @require_tokenizers
class GPTSanJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase): class GPTSanJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "Tanrei/GPTSAN-japanese"
tokenizer_class = GPTSanJapaneseTokenizer tokenizer_class = GPTSanJapaneseTokenizer
test_rust_tokenizer = False test_rust_tokenizer = False
from_pretrained_kwargs = {"do_clean_text": False, "add_prefix_space": False} from_pretrained_kwargs = {"do_clean_text": False, "add_prefix_space": False}
......
...@@ -28,6 +28,7 @@ from ...test_tokenization_common import TokenizerTesterMixin ...@@ -28,6 +28,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_sacremoses @require_sacremoses
@require_tokenizers @require_tokenizers
class HerbertTokenizationTest(TokenizerTesterMixin, unittest.TestCase): class HerbertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "allegro/herbert-base-cased"
tokenizer_class = HerbertTokenizer tokenizer_class = HerbertTokenizer
rust_tokenizer_class = HerbertTokenizerFast rust_tokenizer_class = HerbertTokenizerFast
test_rust_tokenizer = True test_rust_tokenizer = True
......
...@@ -26,6 +26,7 @@ from ...test_tokenization_common import TokenizerTesterMixin ...@@ -26,6 +26,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers @require_tokenizers
class LayoutLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase): class LayoutLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "microsoft/layoutlm-base-uncased"
tokenizer_class = LayoutLMTokenizer tokenizer_class = LayoutLMTokenizer
rust_tokenizer_class = LayoutLMTokenizerFast rust_tokenizer_class = LayoutLMTokenizerFast
test_rust_tokenizer = True test_rust_tokenizer = True
......
...@@ -61,6 +61,7 @@ logger = logging.get_logger(__name__) ...@@ -61,6 +61,7 @@ logger = logging.get_logger(__name__)
@require_tokenizers @require_tokenizers
@require_pandas @require_pandas
class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase): class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "microsoft/layoutlmv2-base-uncased"
tokenizer_class = LayoutLMv2Tokenizer tokenizer_class = LayoutLMv2Tokenizer
rust_tokenizer_class = LayoutLMv2TokenizerFast rust_tokenizer_class = LayoutLMv2TokenizerFast
test_rust_tokenizer = True test_rust_tokenizer = True
......
...@@ -49,6 +49,7 @@ logger = logging.get_logger(__name__) ...@@ -49,6 +49,7 @@ logger = logging.get_logger(__name__)
@require_tokenizers @require_tokenizers
@require_pandas @require_pandas
class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase): class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "microsoft/layoutlmv3-base"
tokenizer_class = LayoutLMv3Tokenizer tokenizer_class = LayoutLMv3Tokenizer
rust_tokenizer_class = LayoutLMv3TokenizerFast rust_tokenizer_class = LayoutLMv3TokenizerFast
test_rust_tokenizer = True test_rust_tokenizer = True
......
...@@ -54,6 +54,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model") ...@@ -54,6 +54,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
@require_tokenizers @require_tokenizers
@require_pandas @require_pandas
class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase): class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "FacebookAI/xlm-roberta-base"
tokenizer_class = LayoutXLMTokenizer tokenizer_class = LayoutXLMTokenizer
rust_tokenizer_class = LayoutXLMTokenizerFast rust_tokenizer_class = LayoutXLMTokenizerFast
test_rust_tokenizer = True test_rust_tokenizer = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment