Unverified Commit 11bbb505 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Adds pretrained IDs directly in the tests (#29534)

* Adds pretrained IDs directly in the tests

* Fix tests

* Fix tests

* Review!
parent 38bff8c8
......@@ -26,6 +26,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
class DebertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "microsoft/deberta-base"
tokenizer_class = DebertaTokenizer
test_rust_tokenizer = True
rust_tokenizer_class = DebertaTokenizerFast
......
......@@ -27,6 +27,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/spiece.model")
@require_sentencepiece
@require_tokenizers
class DebertaV2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "microsoft/deberta-v2-xlarge"
tokenizer_class = DebertaV2Tokenizer
rust_tokenizer_class = DebertaV2TokenizerFast
test_sentencepiece = True
......
......@@ -25,6 +25,7 @@ class DistilBertTokenizationTest(BertTokenizationTest):
tokenizer_class = DistilBertTokenizer
rust_tokenizer_class = DistilBertTokenizerFast
test_rust_tokenizer = True
from_pretrained_id = "distilbert/distilbert-base-uncased"
@slow
def test_sequence_builders(self):
......
......@@ -33,6 +33,7 @@ class DPRContextEncoderTokenizationTest(BertTokenizationTest):
tokenizer_class = DPRContextEncoderTokenizer
rust_tokenizer_class = DPRContextEncoderTokenizerFast
test_rust_tokenizer = True
from_pretrained_id = "facebook/dpr-ctx_encoder-single-nq-base"
@require_tokenizers
......@@ -40,6 +41,7 @@ class DPRQuestionEncoderTokenizationTest(BertTokenizationTest):
tokenizer_class = DPRQuestionEncoderTokenizer
rust_tokenizer_class = DPRQuestionEncoderTokenizerFast
test_rust_tokenizer = True
from_pretrained_id = "facebook/dpr-ctx_encoder-single-nq-base"
@require_tokenizers
......@@ -47,6 +49,7 @@ class DPRReaderTokenizationTest(BertTokenizationTest):
tokenizer_class = DPRReaderTokenizer
rust_tokenizer_class = DPRReaderTokenizerFast
test_rust_tokenizer = True
from_pretrained_id = "facebook/dpr-ctx_encoder-single-nq-base"
@slow
def test_decode_best_spans(self):
......
......@@ -33,6 +33,7 @@ from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english
@require_tokenizers
class ElectraTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "google/electra-small-generator"
tokenizer_class = ElectraTokenizer
rust_tokenizer_class = ElectraTokenizerFast
test_rust_tokenizer = True
......
......@@ -28,6 +28,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/spiece.model")
@require_sentencepiece
@require_tokenizers
class ErnieMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "susnato/ernie-m-base_pytorch"
tokenizer_class = ErnieMTokenizer
test_seq2seq = False
test_sentencepiece = True
......
......@@ -24,6 +24,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_g2p_en
class FastSpeech2ConformerTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "espnet/fastspeech2_conformer"
tokenizer_class = FastSpeech2ConformerTokenizer
test_rust_tokenizer = False
......
......@@ -28,6 +28,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/spiece.model")
@require_sentencepiece
@require_tokenizers
class FNetTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "google/fnet-base"
tokenizer_class = FNetTokenizer
rust_tokenizer_class = FNetTokenizerFast
test_rust_tokenizer = True
......
......@@ -30,6 +30,7 @@ FSMT_TINY2 = "stas/tiny-wmt19-en-ru"
class FSMTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "stas/tiny-wmt19-en-de"
tokenizer_class = FSMTTokenizer
test_rust_tokenizer = False
......
......@@ -26,6 +26,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
class FunnelTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "funnel-transformer/small"
tokenizer_class = FunnelTokenizer
rust_tokenizer_class = FunnelTokenizerFast
test_rust_tokenizer = True
......
......@@ -49,6 +49,7 @@ if is_torch_available():
@require_sentencepiece
@require_tokenizers
class GemmaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "google/gemma-7b"
tokenizer_class = GemmaTokenizer
rust_tokenizer_class = GemmaTokenizerFast
......
......@@ -27,6 +27,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "openai-community/gpt2"
tokenizer_class = GPT2Tokenizer
rust_tokenizer_class = GPT2TokenizerFast
test_rust_tokenizer = True
......
......@@ -29,6 +29,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
class GPTNeoXJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "abeja/gpt-neox-japanese-2.7b"
tokenizer_class = GPTNeoXJapaneseTokenizer
test_rust_tokenizer = False
from_pretrained_kwargs = {"do_clean_text": False, "add_prefix_space": False}
......
......@@ -27,6 +27,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece_with_bytefallback.mode
@require_sentencepiece
@require_tokenizers
class GPTSw3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "AI-Sweden-Models/gpt-sw3-126m"
tokenizer_class = GPTSw3Tokenizer
test_rust_tokenizer = False
test_sentencepiece = True
......
......@@ -29,6 +29,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
class GPTSanJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "Tanrei/GPTSAN-japanese"
tokenizer_class = GPTSanJapaneseTokenizer
test_rust_tokenizer = False
from_pretrained_kwargs = {"do_clean_text": False, "add_prefix_space": False}
......
......@@ -28,6 +28,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_sacremoses
@require_tokenizers
class HerbertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "allegro/herbert-base-cased"
tokenizer_class = HerbertTokenizer
rust_tokenizer_class = HerbertTokenizerFast
test_rust_tokenizer = True
......
......@@ -26,6 +26,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
class LayoutLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "microsoft/layoutlm-base-uncased"
tokenizer_class = LayoutLMTokenizer
rust_tokenizer_class = LayoutLMTokenizerFast
test_rust_tokenizer = True
......
......@@ -61,6 +61,7 @@ logger = logging.get_logger(__name__)
@require_tokenizers
@require_pandas
class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "microsoft/layoutlmv2-base-uncased"
tokenizer_class = LayoutLMv2Tokenizer
rust_tokenizer_class = LayoutLMv2TokenizerFast
test_rust_tokenizer = True
......
......@@ -49,6 +49,7 @@ logger = logging.get_logger(__name__)
@require_tokenizers
@require_pandas
class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "microsoft/layoutlmv3-base"
tokenizer_class = LayoutLMv3Tokenizer
rust_tokenizer_class = LayoutLMv3TokenizerFast
test_rust_tokenizer = True
......
......@@ -54,6 +54,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
@require_tokenizers
@require_pandas
class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "FacebookAI/xlm-roberta-base"
tokenizer_class = LayoutXLMTokenizer
rust_tokenizer_class = LayoutXLMTokenizerFast
test_rust_tokenizer = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment