Unverified Commit d2357a01 authored by lewtun's avatar lewtun Committed by GitHub
Browse files

Use tiny models for ONNX tests - text modality (#20333)

* Use tiny ONNX models

* Fix broken tests

* Add tiny perceiver

* Add tiny convbert
parent 3d0c0ae4
...@@ -179,48 +179,48 @@ class OnnxConfigWithPastTestCaseV2(TestCase): ...@@ -179,48 +179,48 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
PYTORCH_EXPORT_MODELS = { PYTORCH_EXPORT_MODELS = {
("albert", "hf-internal-testing/tiny-albert"), ("albert", "hf-internal-testing/tiny-random-AlbertModel"),
("bert", "bert-base-cased"), ("bert", "hf-internal-testing/tiny-random-BertModel"),
("big-bird", "google/bigbird-roberta-base"), ("beit", "microsoft/beit-base-patch16-224"),
("ibert", "kssteven/ibert-roberta-base"), ("big-bird", "hf-internal-testing/tiny-random-BigBirdModel"),
("camembert", "camembert-base"), ("camembert", "camembert-base"),
("clip", "openai/clip-vit-base-patch32"), ("clip", "hf-internal-testing/tiny-random-CLIPModel"),
("convbert", "YituTech/conv-bert-base"), ("convbert", "hf-internal-testing/tiny-random-ConvBertModel"),
("codegen", "Salesforce/codegen-350M-multi"), ("codegen", "hf-internal-testing/tiny-random-CodeGenModel"),
("deberta", "microsoft/deberta-base"), ("data2vec-text", "hf-internal-testing/tiny-random-Data2VecTextModel"),
("deberta-v2", "microsoft/deberta-v2-xlarge"), ("data2vec-vision", "facebook/data2vec-vision-base"),
("deberta", "hf-internal-testing/tiny-random-DebertaModel"),
("deberta-v2", "hf-internal-testing/tiny-random-DebertaV2Model"),
("deit", "facebook/deit-small-patch16-224"),
("convnext", "facebook/convnext-tiny-224"), ("convnext", "facebook/convnext-tiny-224"),
("detr", "facebook/detr-resnet-50"), ("detr", "facebook/detr-resnet-50"),
("distilbert", "distilbert-base-cased"), ("distilbert", "hf-internal-testing/tiny-random-DistilBertModel"),
("electra", "google/electra-base-generator"), ("electra", "hf-internal-testing/tiny-random-ElectraModel"),
("groupvit", "nvidia/groupvit-gcc-yfcc"),
("ibert", "kssteven/ibert-roberta-base"),
("imagegpt", "openai/imagegpt-small"), ("imagegpt", "openai/imagegpt-small"),
("resnet", "microsoft/resnet-50"), ("levit", "facebook/levit-128S"),
("roberta", "roberta-base"), ("layoutlm", "hf-internal-testing/tiny-random-LayoutLMModel"),
("roformer", "junnyu/roformer_chinese_base"), ("layoutlmv3", "microsoft/layoutlmv3-base"),
("squeezebert", "squeezebert/squeezebert-uncased"), ("longformer", "allenai/longformer-base-4096"),
("mobilebert", "google/mobilebert-uncased"), ("mobilebert", "hf-internal-testing/tiny-random-MobileBertModel"),
("mobilenet_v1", "google/mobilenet_v1_0.75_192"), ("mobilenet_v1", "google/mobilenet_v1_0.75_192"),
("mobilenet_v2", "google/mobilenet_v2_0.35_96"), ("mobilenet_v2", "google/mobilenet_v2_0.35_96"),
("mobilevit", "apple/mobilevit-small"), ("mobilevit", "apple/mobilevit-small"),
("xlm", "xlm-clm-ende-1024"),
("xlm-roberta", "xlm-roberta-base"),
("layoutlm", "microsoft/layoutlm-base-uncased"),
("layoutlmv3", "microsoft/layoutlmv3-base"),
("groupvit", "nvidia/groupvit-gcc-yfcc"),
("levit", "facebook/levit-128S"),
("owlvit", "google/owlvit-base-patch32"), ("owlvit", "google/owlvit-base-patch32"),
("vit", "google/vit-base-patch16-224"), ("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("masked-lm", "sequence-classification")),
("deit", "facebook/deit-small-patch16-224"), ("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("image-classification",)),
("beit", "microsoft/beit-base-patch16-224"), ("resnet", "microsoft/resnet-50"),
("data2vec-text", "facebook/data2vec-text-base"), ("roberta", "hf-internal-testing/tiny-random-RobertaModel"),
("data2vec-vision", "facebook/data2vec-vision-base"), ("roformer", "hf-internal-testing/tiny-random-RoFormerModel"),
("perceiver", "deepmind/language-perceiver", ("masked-lm", "sequence-classification")),
("perceiver", "deepmind/vision-perceiver-conv", ("image-classification",)),
("longformer", "allenai/longformer-base-4096"),
("yolos", "hustvl/yolos-tiny"),
("segformer", "nvidia/segformer-b0-finetuned-ade-512-512"), ("segformer", "nvidia/segformer-b0-finetuned-ade-512-512"),
("squeezebert", "hf-internal-testing/tiny-random-SqueezeBertModel"),
("swin", "microsoft/swin-tiny-patch4-window7-224"), ("swin", "microsoft/swin-tiny-patch4-window7-224"),
("vit", "google/vit-base-patch16-224"),
("yolos", "hustvl/yolos-tiny"),
("whisper", "openai/whisper-tiny.en"), ("whisper", "openai/whisper-tiny.en"),
("xlm", "hf-internal-testing/tiny-random-XLMModel"),
("xlm-roberta", "hf-internal-testing/tiny-random-XLMRobertaXLModel"),
} }
PYTORCH_EXPORT_ENCODER_DECODER_MODELS = { PYTORCH_EXPORT_ENCODER_DECODER_MODELS = {
...@@ -228,34 +228,31 @@ PYTORCH_EXPORT_ENCODER_DECODER_MODELS = { ...@@ -228,34 +228,31 @@ PYTORCH_EXPORT_ENCODER_DECODER_MODELS = {
} }
PYTORCH_EXPORT_WITH_PAST_MODELS = { PYTORCH_EXPORT_WITH_PAST_MODELS = {
("bloom", "bigscience/bloom-560m"), ("bloom", "hf-internal-testing/tiny-random-BloomModel"),
("gpt2", "gpt2"), ("gpt2", "hf-internal-testing/tiny-random-GPT2Model"),
("gpt-neo", "EleutherAI/gpt-neo-125M"), ("gpt-neo", "hf-internal-testing/tiny-random-GPTNeoModel"),
} }
PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = { PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
("bart", "facebook/bart-base"), ("bart", "hf-internal-testing/tiny-random-BartModel"),
("mbart", "sshleifer/tiny-mbart"), ("bigbird-pegasus", "hf-internal-testing/tiny-random-BigBirdPegasusModel"),
("t5", "t5-small"), ("blenderbot-small", "facebook/blenderbot_small-90M"),
("blenderbot", "hf-internal-testing/tiny-random-BlenderbotModel"),
("longt5", "hf-internal-testing/tiny-random-LongT5Model"),
("marian", "Helsinki-NLP/opus-mt-en-de"), ("marian", "Helsinki-NLP/opus-mt-en-de"),
("mbart", "sshleifer/tiny-mbart"),
("mt5", "google/mt5-base"), ("mt5", "google/mt5-base"),
("m2m-100", "facebook/m2m100_418M"), ("m2m-100", "hf-internal-testing/tiny-random-M2M100Model"),
("blenderbot-small", "facebook/blenderbot_small-90M"), ("t5", "hf-internal-testing/tiny-random-T5Model"),
("blenderbot", "facebook/blenderbot-400M-distill"),
("bigbird-pegasus", "google/bigbird-pegasus-large-arxiv"),
("longt5", "google/long-t5-local-base"),
# Disable for now as it causes fatal error `Floating point exception (core dumped)` and the subsequential tests are
# not run.
# ("longt5", "google/long-t5-tglobal-base"),
} }
# TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_MODELS` once TensorFlow has parity with the PyTorch model implementations. # TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_MODELS` once TensorFlow has parity with the PyTorch model implementations.
TENSORFLOW_EXPORT_DEFAULT_MODELS = { TENSORFLOW_EXPORT_DEFAULT_MODELS = {
("albert", "hf-internal-testing/tiny-albert"), ("albert", "hf-internal-testing/tiny-albert"),
("bert", "bert-base-cased"), ("bert", "hf-internal-testing/tiny-random-BertModel"),
("camembert", "camembert-base"), ("camembert", "camembert-base"),
("distilbert", "distilbert-base-cased"), ("distilbert", "hf-internal-testing/tiny-random-DistilBertModel"),
("roberta", "roberta-base"), ("roberta", "hf-internal-testing/tiny-random-RobertaModel"),
} }
# TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_WITH_PAST_MODELS` once TensorFlow has parity with the PyTorch model implementations. # TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_WITH_PAST_MODELS` once TensorFlow has parity with the PyTorch model implementations.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment