Unverified Commit 11bbb505 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Adds pretrained IDs directly in the tests (#29534)

* Adds pretrained IDs directly in the tests

* Fix tests

* Fix tests

* Review!
parent 38bff8c8
......@@ -25,6 +25,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
class TestTokenizationLED(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "allenai/led-base-16384"
tokenizer_class = LEDTokenizer
rust_tokenizer_class = LEDTokenizerFast
test_rust_tokenizer = True
......
......@@ -52,6 +52,7 @@ if is_torch_available():
@require_sentencepiece
@require_tokenizers
class LlamaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "hf-internal-testing/llama-tokenizer"
tokenizer_class = LlamaTokenizer
rust_tokenizer_class = LlamaTokenizerFast
......
......@@ -30,6 +30,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest with FacebookAI/roberta-base->allenai/longformer-base-4096,Roberta->Longformer,roberta->longformer,
class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "allenai/longformer-base-4096"
# Ignore copy
tokenizer_class = LongformerTokenizer
test_slow_tokenizer = True
......
......@@ -28,6 +28,7 @@ SAMPLE_ENTITY_VOCAB = get_tests_dir("fixtures/test_entity_vocab.json")
class LukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "studio-ousia/luke-base"
tokenizer_class = LukeTokenizer
test_rust_tokenizer = False
from_pretrained_kwargs = {"cls_token": "<s>"}
......
......@@ -26,6 +26,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
class LxmertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "unc-nlp/lxmert-base-uncased"
tokenizer_class = LxmertTokenizer
rust_tokenizer_class = LxmertTokenizerFast
test_rust_tokenizer = True
......
......@@ -48,6 +48,7 @@ FR_CODE = 128028
@require_sentencepiece
class M2M100TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "facebook/m2m100_418M"
tokenizer_class = M2M100Tokenizer
test_rust_tokenizer = False
test_seq2seq = False
......
......@@ -45,6 +45,7 @@ else:
@require_sentencepiece
class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "Helsinki-NLP/opus-mt-en-de"
tokenizer_class = MarianTokenizer
test_rust_tokenizer = False
test_sentencepiece = True
......
......@@ -41,6 +41,7 @@ logger = logging.get_logger(__name__)
@require_tokenizers
class MarkupLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "microsoft/markuplm-base"
tokenizer_class = MarkupLMTokenizer
rust_tokenizer_class = MarkupLMTokenizerFast
test_rust_tokenizer = True
......
......@@ -41,6 +41,7 @@ RO_CODE = 250020
@require_sentencepiece
@require_tokenizers
class MBartTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "facebook/mbart-large-en-ro"
tokenizer_class = MBartTokenizer
rust_tokenizer_class = MBartTokenizerFast
test_rust_tokenizer = True
......
......@@ -41,6 +41,7 @@ RO_CODE = 250020
@require_sentencepiece
@require_tokenizers
class MBart50TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "facebook/mbart-large-50-one-to-many-mmt"
tokenizer_class = MBart50Tokenizer
rust_tokenizer_class = MBart50TokenizerFast
test_rust_tokenizer = True
......
......@@ -27,6 +27,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
class MgpstrTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "alibaba-damo/mgp-str-base"
tokenizer_class = MgpstrTokenizer
test_rust_tokenizer = False
from_pretrained_kwargs = {}
......
......@@ -28,6 +28,7 @@ SAMPLE_ENTITY_VOCAB = get_tests_dir("fixtures/test_entity_vocab.json")
class MLukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "studio-ousia/mluke-base"
tokenizer_class = MLukeTokenizer
test_rust_tokenizer = False
from_pretrained_kwargs = {"cls_token": "<s>"}
......
......@@ -34,6 +34,7 @@ from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english
@require_tokenizers
class MobileBERTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "mobilebert-uncased"
tokenizer_class = MobileBertTokenizer
rust_tokenizer_class = MobileBertTokenizerFast
test_rust_tokenizer = True
......
......@@ -26,6 +26,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
class MPNetTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "microsoft/mpnet-base"
tokenizer_class = MPNetTokenizer
rust_tokenizer_class = MPNetTokenizerFast
test_rust_tokenizer = True
......
......@@ -25,6 +25,7 @@ from ...test_tokenization_common import TokenizerTesterMixin, filter_roberta_det
@require_tokenizers
class TestTokenizationMvp(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "RUCAIBox/mvp"
tokenizer_class = MvpTokenizer
rust_tokenizer_class = MvpTokenizerFast
test_rust_tokenizer = True
......
......@@ -49,6 +49,7 @@ RO_CODE = 256145
@require_sentencepiece
@require_tokenizers
class NllbTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "facebook/nllb-200-distilled-600M"
tokenizer_class = NllbTokenizer
rust_tokenizer_class = NllbTokenizerFast
test_rust_tokenizer = True
......
......@@ -24,6 +24,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
class NougatTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "facebook/nougat-base"
slow_tokenizer_class = None
rust_tokenizer_class = NougatTokenizerFast
tokenizer_class = NougatTokenizerFast
......
......@@ -27,6 +27,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
class OpenAIGPTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "openai-community/openai-gpt"
"""Tests OpenAIGPTTokenizer that uses BERT BasicTokenizer."""
tokenizer_class = OpenAIGPTTokenizer
......
......@@ -27,6 +27,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece_no_bos.model")
@require_sentencepiece
@require_tokenizers
class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "google/pegasus-xsum"
tokenizer_class = PegasusTokenizer
rust_tokenizer_class = PegasusTokenizerFast
test_rust_tokenizer = True
......@@ -135,6 +136,7 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
@require_sentencepiece
@require_tokenizers
class BigBirdPegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "google/pegasus-xsum"
tokenizer_class = PegasusTokenizer
rust_tokenizer_class = PegasusTokenizerFast
test_rust_tokenizer = True
......
......@@ -36,6 +36,7 @@ else:
class PerceiverTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "deepmind/language-perceiver"
tokenizer_class = PerceiverTokenizer
test_rust_tokenizer = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment