Unverified Commit 11bbb505 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Adds pretrained IDs directly in the tests (#29534)

* Adds pretrained IDs directly in the tests

* Fix tests

* Fix tests

* Review!
parent 38bff8c8
...@@ -25,6 +25,7 @@ from ...test_tokenization_common import TokenizerTesterMixin ...@@ -25,6 +25,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers @require_tokenizers
class TestTokenizationLED(TokenizerTesterMixin, unittest.TestCase): class TestTokenizationLED(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "allenai/led-base-16384"
tokenizer_class = LEDTokenizer tokenizer_class = LEDTokenizer
rust_tokenizer_class = LEDTokenizerFast rust_tokenizer_class = LEDTokenizerFast
test_rust_tokenizer = True test_rust_tokenizer = True
......
...@@ -52,6 +52,7 @@ if is_torch_available(): ...@@ -52,6 +52,7 @@ if is_torch_available():
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
class LlamaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): class LlamaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "hf-internal-testing/llama-tokenizer"
tokenizer_class = LlamaTokenizer tokenizer_class = LlamaTokenizer
rust_tokenizer_class = LlamaTokenizerFast rust_tokenizer_class = LlamaTokenizerFast
......
...@@ -30,6 +30,7 @@ from ...test_tokenization_common import TokenizerTesterMixin ...@@ -30,6 +30,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers @require_tokenizers
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest with FacebookAI/roberta-base->allenai/longformer-base-4096,Roberta->Longformer,roberta->longformer, # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest with FacebookAI/roberta-base->allenai/longformer-base-4096,Roberta->Longformer,roberta->longformer,
class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase): class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "allenai/longformer-base-4096"
# Ignore copy # Ignore copy
tokenizer_class = LongformerTokenizer tokenizer_class = LongformerTokenizer
test_slow_tokenizer = True test_slow_tokenizer = True
......
...@@ -28,6 +28,7 @@ SAMPLE_ENTITY_VOCAB = get_tests_dir("fixtures/test_entity_vocab.json") ...@@ -28,6 +28,7 @@ SAMPLE_ENTITY_VOCAB = get_tests_dir("fixtures/test_entity_vocab.json")
class LukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase): class LukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "studio-ousia/luke-base"
tokenizer_class = LukeTokenizer tokenizer_class = LukeTokenizer
test_rust_tokenizer = False test_rust_tokenizer = False
from_pretrained_kwargs = {"cls_token": "<s>"} from_pretrained_kwargs = {"cls_token": "<s>"}
......
...@@ -26,6 +26,7 @@ from ...test_tokenization_common import TokenizerTesterMixin ...@@ -26,6 +26,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers @require_tokenizers
class LxmertTokenizationTest(TokenizerTesterMixin, unittest.TestCase): class LxmertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "unc-nlp/lxmert-base-uncased"
tokenizer_class = LxmertTokenizer tokenizer_class = LxmertTokenizer
rust_tokenizer_class = LxmertTokenizerFast rust_tokenizer_class = LxmertTokenizerFast
test_rust_tokenizer = True test_rust_tokenizer = True
......
...@@ -48,6 +48,7 @@ FR_CODE = 128028 ...@@ -48,6 +48,7 @@ FR_CODE = 128028
@require_sentencepiece @require_sentencepiece
class M2M100TokenizationTest(TokenizerTesterMixin, unittest.TestCase): class M2M100TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "facebook/m2m100_418M"
tokenizer_class = M2M100Tokenizer tokenizer_class = M2M100Tokenizer
test_rust_tokenizer = False test_rust_tokenizer = False
test_seq2seq = False test_seq2seq = False
......
...@@ -45,6 +45,7 @@ else: ...@@ -45,6 +45,7 @@ else:
@require_sentencepiece @require_sentencepiece
class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase): class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "Helsinki-NLP/opus-mt-en-de"
tokenizer_class = MarianTokenizer tokenizer_class = MarianTokenizer
test_rust_tokenizer = False test_rust_tokenizer = False
test_sentencepiece = True test_sentencepiece = True
......
...@@ -41,6 +41,7 @@ logger = logging.get_logger(__name__) ...@@ -41,6 +41,7 @@ logger = logging.get_logger(__name__)
@require_tokenizers @require_tokenizers
class MarkupLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase): class MarkupLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "microsoft/markuplm-base"
tokenizer_class = MarkupLMTokenizer tokenizer_class = MarkupLMTokenizer
rust_tokenizer_class = MarkupLMTokenizerFast rust_tokenizer_class = MarkupLMTokenizerFast
test_rust_tokenizer = True test_rust_tokenizer = True
......
...@@ -41,6 +41,7 @@ RO_CODE = 250020 ...@@ -41,6 +41,7 @@ RO_CODE = 250020
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
class MBartTokenizationTest(TokenizerTesterMixin, unittest.TestCase): class MBartTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "facebook/mbart-large-en-ro"
tokenizer_class = MBartTokenizer tokenizer_class = MBartTokenizer
rust_tokenizer_class = MBartTokenizerFast rust_tokenizer_class = MBartTokenizerFast
test_rust_tokenizer = True test_rust_tokenizer = True
......
...@@ -41,6 +41,7 @@ RO_CODE = 250020 ...@@ -41,6 +41,7 @@ RO_CODE = 250020
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
class MBart50TokenizationTest(TokenizerTesterMixin, unittest.TestCase): class MBart50TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "facebook/mbart-large-50-one-to-many-mmt"
tokenizer_class = MBart50Tokenizer tokenizer_class = MBart50Tokenizer
rust_tokenizer_class = MBart50TokenizerFast rust_tokenizer_class = MBart50TokenizerFast
test_rust_tokenizer = True test_rust_tokenizer = True
......
...@@ -27,6 +27,7 @@ from ...test_tokenization_common import TokenizerTesterMixin ...@@ -27,6 +27,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers @require_tokenizers
class MgpstrTokenizationTest(TokenizerTesterMixin, unittest.TestCase): class MgpstrTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "alibaba-damo/mgp-str-base"
tokenizer_class = MgpstrTokenizer tokenizer_class = MgpstrTokenizer
test_rust_tokenizer = False test_rust_tokenizer = False
from_pretrained_kwargs = {} from_pretrained_kwargs = {}
......
...@@ -28,6 +28,7 @@ SAMPLE_ENTITY_VOCAB = get_tests_dir("fixtures/test_entity_vocab.json") ...@@ -28,6 +28,7 @@ SAMPLE_ENTITY_VOCAB = get_tests_dir("fixtures/test_entity_vocab.json")
class MLukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase): class MLukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "studio-ousia/mluke-base"
tokenizer_class = MLukeTokenizer tokenizer_class = MLukeTokenizer
test_rust_tokenizer = False test_rust_tokenizer = False
from_pretrained_kwargs = {"cls_token": "<s>"} from_pretrained_kwargs = {"cls_token": "<s>"}
......
...@@ -34,6 +34,7 @@ from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english ...@@ -34,6 +34,7 @@ from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english
@require_tokenizers @require_tokenizers
class MobileBERTTokenizationTest(TokenizerTesterMixin, unittest.TestCase): class MobileBERTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "mobilebert-uncased"
tokenizer_class = MobileBertTokenizer tokenizer_class = MobileBertTokenizer
rust_tokenizer_class = MobileBertTokenizerFast rust_tokenizer_class = MobileBertTokenizerFast
test_rust_tokenizer = True test_rust_tokenizer = True
......
...@@ -26,6 +26,7 @@ from ...test_tokenization_common import TokenizerTesterMixin ...@@ -26,6 +26,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers @require_tokenizers
class MPNetTokenizerTest(TokenizerTesterMixin, unittest.TestCase): class MPNetTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "microsoft/mpnet-base"
tokenizer_class = MPNetTokenizer tokenizer_class = MPNetTokenizer
rust_tokenizer_class = MPNetTokenizerFast rust_tokenizer_class = MPNetTokenizerFast
test_rust_tokenizer = True test_rust_tokenizer = True
......
...@@ -25,6 +25,7 @@ from ...test_tokenization_common import TokenizerTesterMixin, filter_roberta_det ...@@ -25,6 +25,7 @@ from ...test_tokenization_common import TokenizerTesterMixin, filter_roberta_det
@require_tokenizers @require_tokenizers
class TestTokenizationMvp(TokenizerTesterMixin, unittest.TestCase): class TestTokenizationMvp(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "RUCAIBox/mvp"
tokenizer_class = MvpTokenizer tokenizer_class = MvpTokenizer
rust_tokenizer_class = MvpTokenizerFast rust_tokenizer_class = MvpTokenizerFast
test_rust_tokenizer = True test_rust_tokenizer = True
......
...@@ -49,6 +49,7 @@ RO_CODE = 256145 ...@@ -49,6 +49,7 @@ RO_CODE = 256145
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
class NllbTokenizationTest(TokenizerTesterMixin, unittest.TestCase): class NllbTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "facebook/nllb-200-distilled-600M"
tokenizer_class = NllbTokenizer tokenizer_class = NllbTokenizer
rust_tokenizer_class = NllbTokenizerFast rust_tokenizer_class = NllbTokenizerFast
test_rust_tokenizer = True test_rust_tokenizer = True
......
...@@ -24,6 +24,7 @@ from ...test_tokenization_common import TokenizerTesterMixin ...@@ -24,6 +24,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers @require_tokenizers
class NougatTokenizationTest(TokenizerTesterMixin, unittest.TestCase): class NougatTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "facebook/nougat-base"
slow_tokenizer_class = None slow_tokenizer_class = None
rust_tokenizer_class = NougatTokenizerFast rust_tokenizer_class = NougatTokenizerFast
tokenizer_class = NougatTokenizerFast tokenizer_class = NougatTokenizerFast
......
...@@ -27,6 +27,7 @@ from ...test_tokenization_common import TokenizerTesterMixin ...@@ -27,6 +27,7 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers @require_tokenizers
class OpenAIGPTTokenizationTest(TokenizerTesterMixin, unittest.TestCase): class OpenAIGPTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "openai-community/openai-gpt"
"""Tests OpenAIGPTTokenizer that uses BERT BasicTokenizer.""" """Tests OpenAIGPTTokenizer that uses BERT BasicTokenizer."""
tokenizer_class = OpenAIGPTTokenizer tokenizer_class = OpenAIGPTTokenizer
......
...@@ -27,6 +27,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece_no_bos.model") ...@@ -27,6 +27,7 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece_no_bos.model")
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase): class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "google/pegasus-xsum"
tokenizer_class = PegasusTokenizer tokenizer_class = PegasusTokenizer
rust_tokenizer_class = PegasusTokenizerFast rust_tokenizer_class = PegasusTokenizerFast
test_rust_tokenizer = True test_rust_tokenizer = True
...@@ -135,6 +136,7 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -135,6 +136,7 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
class BigBirdPegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase): class BigBirdPegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "google/pegasus-xsum"
tokenizer_class = PegasusTokenizer tokenizer_class = PegasusTokenizer
rust_tokenizer_class = PegasusTokenizerFast rust_tokenizer_class = PegasusTokenizerFast
test_rust_tokenizer = True test_rust_tokenizer = True
......
...@@ -36,6 +36,7 @@ else: ...@@ -36,6 +36,7 @@ else:
class PerceiverTokenizationTest(TokenizerTesterMixin, unittest.TestCase): class PerceiverTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_id = "deepmind/language-perceiver"
tokenizer_class = PerceiverTokenizer tokenizer_class = PerceiverTokenizer
test_rust_tokenizer = False test_rust_tokenizer = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment