Unverified Commit e4d25885 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Pipelines] Add revision tag to all default pipelines (#17667)



* trigger test failure

* upload revision poc

* Update src/transformers/pipelines/base.py
Co-authored-by: default avatarJulien Chaumond <julien@huggingface.co>

* up

* add test

* correct some stuff

* Update src/transformers/pipelines/__init__.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* correct require flag
Co-authored-by: default avatarJulien Chaumond <julien@huggingface.co>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 4f8361af
...@@ -30,7 +30,7 @@ from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, Aut ...@@ -30,7 +30,7 @@ from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, Aut
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
from ..tokenization_utils import PreTrainedTokenizer from ..tokenization_utils import PreTrainedTokenizer
from ..tokenization_utils_fast import PreTrainedTokenizerFast from ..tokenization_utils_fast import PreTrainedTokenizerFast
from ..utils import http_get, is_tf_available, is_torch_available, logging from ..utils import HUGGINGFACE_CO_RESOLVE_ENDPOINT, http_get, is_tf_available, is_torch_available, logging
from .audio_classification import AudioClassificationPipeline from .audio_classification import AudioClassificationPipeline
from .automatic_speech_recognition import AutomaticSpeechRecognitionPipeline from .automatic_speech_recognition import AutomaticSpeechRecognitionPipeline
from .base import ( from .base import (
...@@ -41,7 +41,7 @@ from .base import ( ...@@ -41,7 +41,7 @@ from .base import (
Pipeline, Pipeline,
PipelineDataFormat, PipelineDataFormat,
PipelineException, PipelineException,
get_default_model, get_default_model_and_revision,
infer_framework_load_model, infer_framework_load_model,
) )
from .conversational import Conversation, ConversationalPipeline from .conversational import Conversation, ConversationalPipeline
...@@ -131,21 +131,21 @@ SUPPORTED_TASKS = { ...@@ -131,21 +131,21 @@ SUPPORTED_TASKS = {
"impl": AudioClassificationPipeline, "impl": AudioClassificationPipeline,
"tf": (), "tf": (),
"pt": (AutoModelForAudioClassification,) if is_torch_available() else (), "pt": (AutoModelForAudioClassification,) if is_torch_available() else (),
"default": {"model": {"pt": "superb/wav2vec2-base-superb-ks"}}, "default": {"model": {"pt": ("superb/wav2vec2-base-superb-ks", "372e048")}},
"type": "audio", "type": "audio",
}, },
"automatic-speech-recognition": { "automatic-speech-recognition": {
"impl": AutomaticSpeechRecognitionPipeline, "impl": AutomaticSpeechRecognitionPipeline,
"tf": (), "tf": (),
"pt": (AutoModelForCTC, AutoModelForSpeechSeq2Seq) if is_torch_available() else (), "pt": (AutoModelForCTC, AutoModelForSpeechSeq2Seq) if is_torch_available() else (),
"default": {"model": {"pt": "facebook/wav2vec2-base-960h"}}, "default": {"model": {"pt": ("facebook/wav2vec2-base-960h", "55bb623")}},
"type": "multimodal", "type": "multimodal",
}, },
"feature-extraction": { "feature-extraction": {
"impl": FeatureExtractionPipeline, "impl": FeatureExtractionPipeline,
"tf": (TFAutoModel,) if is_tf_available() else (), "tf": (TFAutoModel,) if is_tf_available() else (),
"pt": (AutoModel,) if is_torch_available() else (), "pt": (AutoModel,) if is_torch_available() else (),
"default": {"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"}}, "default": {"model": {"pt": ("distilbert-base-cased", "935ac13"), "tf": ("distilbert-base-cased", "935ac13")}},
"type": "multimodal", "type": "multimodal",
}, },
"text-classification": { "text-classification": {
...@@ -154,8 +154,8 @@ SUPPORTED_TASKS = { ...@@ -154,8 +154,8 @@ SUPPORTED_TASKS = {
"pt": (AutoModelForSequenceClassification,) if is_torch_available() else (), "pt": (AutoModelForSequenceClassification,) if is_torch_available() else (),
"default": { "default": {
"model": { "model": {
"pt": "distilbert-base-uncased-finetuned-sst-2-english", "pt": ("distilbert-base-uncased-finetuned-sst-2-english", "af0f99b"),
"tf": "distilbert-base-uncased-finetuned-sst-2-english", "tf": ("distilbert-base-uncased-finetuned-sst-2-english", "af0f99b"),
}, },
}, },
"type": "text", "type": "text",
...@@ -166,8 +166,8 @@ SUPPORTED_TASKS = { ...@@ -166,8 +166,8 @@ SUPPORTED_TASKS = {
"pt": (AutoModelForTokenClassification,) if is_torch_available() else (), "pt": (AutoModelForTokenClassification,) if is_torch_available() else (),
"default": { "default": {
"model": { "model": {
"pt": "dbmdz/bert-large-cased-finetuned-conll03-english", "pt": ("dbmdz/bert-large-cased-finetuned-conll03-english", "f2482bf"),
"tf": "dbmdz/bert-large-cased-finetuned-conll03-english", "tf": ("dbmdz/bert-large-cased-finetuned-conll03-english", "f2482bf"),
}, },
}, },
"type": "text", "type": "text",
...@@ -177,7 +177,10 @@ SUPPORTED_TASKS = { ...@@ -177,7 +177,10 @@ SUPPORTED_TASKS = {
"tf": (TFAutoModelForQuestionAnswering,) if is_tf_available() else (), "tf": (TFAutoModelForQuestionAnswering,) if is_tf_available() else (),
"pt": (AutoModelForQuestionAnswering,) if is_torch_available() else (), "pt": (AutoModelForQuestionAnswering,) if is_torch_available() else (),
"default": { "default": {
"model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"}, "model": {
"pt": ("distilbert-base-cased-distilled-squad", "626af31"),
"tf": ("distilbert-base-cased-distilled-squad", "626af31"),
},
}, },
"type": "text", "type": "text",
}, },
...@@ -187,9 +190,8 @@ SUPPORTED_TASKS = { ...@@ -187,9 +190,8 @@ SUPPORTED_TASKS = {
"tf": (TFAutoModelForTableQuestionAnswering,) if is_tf_available() else (), "tf": (TFAutoModelForTableQuestionAnswering,) if is_tf_available() else (),
"default": { "default": {
"model": { "model": {
"pt": "google/tapas-base-finetuned-wtq", "pt": ("google/tapas-base-finetuned-wtq", "69ceee2"),
"tokenizer": "google/tapas-base-finetuned-wtq", "tf": ("google/tapas-base-finetuned-wtq", "69ceee2"),
"tf": "google/tapas-base-finetuned-wtq",
}, },
}, },
"type": "text", "type": "text",
...@@ -199,11 +201,7 @@ SUPPORTED_TASKS = { ...@@ -199,11 +201,7 @@ SUPPORTED_TASKS = {
"pt": (AutoModelForVisualQuestionAnswering,) if is_torch_available() else (), "pt": (AutoModelForVisualQuestionAnswering,) if is_torch_available() else (),
"tf": (), "tf": (),
"default": { "default": {
"model": { "model": {"pt": ("dandelin/vilt-b32-finetuned-vqa", "4355f59")},
"pt": "dandelin/vilt-b32-finetuned-vqa",
"tokenizer": "dandelin/vilt-b32-finetuned-vqa",
"feature_extractor": "dandelin/vilt-b32-finetuned-vqa",
},
}, },
"type": "multimodal", "type": "multimodal",
}, },
...@@ -211,14 +209,14 @@ SUPPORTED_TASKS = { ...@@ -211,14 +209,14 @@ SUPPORTED_TASKS = {
"impl": FillMaskPipeline, "impl": FillMaskPipeline,
"tf": (TFAutoModelForMaskedLM,) if is_tf_available() else (), "tf": (TFAutoModelForMaskedLM,) if is_tf_available() else (),
"pt": (AutoModelForMaskedLM,) if is_torch_available() else (), "pt": (AutoModelForMaskedLM,) if is_torch_available() else (),
"default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}}, "default": {"model": {"pt": ("distilroberta-base", "ec58a5b"), "tf": ("distilroberta-base", "ec58a5b")}},
"type": "text", "type": "text",
}, },
"summarization": { "summarization": {
"impl": SummarizationPipeline, "impl": SummarizationPipeline,
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (), "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (), "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
"default": {"model": {"pt": "sshleifer/distilbart-cnn-12-6", "tf": "t5-small"}}, "default": {"model": {"pt": ("sshleifer/distilbart-cnn-12-6", "a4f8f3e"), "tf": ("t5-small", "d769bba")}},
"type": "text", "type": "text",
}, },
# This task is a special case as it's parametrized by SRC, TGT languages. # This task is a special case as it's parametrized by SRC, TGT languages.
...@@ -227,9 +225,9 @@ SUPPORTED_TASKS = { ...@@ -227,9 +225,9 @@ SUPPORTED_TASKS = {
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (), "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (), "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
"default": { "default": {
("en", "fr"): {"model": {"pt": "t5-base", "tf": "t5-base"}}, ("en", "fr"): {"model": {"pt": ("t5-base", "686f1db"), "tf": ("t5-base", "686f1db")}},
("en", "de"): {"model": {"pt": "t5-base", "tf": "t5-base"}}, ("en", "de"): {"model": {"pt": ("t5-base", "686f1db"), "tf": ("t5-base", "686f1db")}},
("en", "ro"): {"model": {"pt": "t5-base", "tf": "t5-base"}}, ("en", "ro"): {"model": {"pt": ("t5-base", "686f1db"), "tf": ("t5-base", "686f1db")}},
}, },
"type": "text", "type": "text",
}, },
...@@ -237,14 +235,14 @@ SUPPORTED_TASKS = { ...@@ -237,14 +235,14 @@ SUPPORTED_TASKS = {
"impl": Text2TextGenerationPipeline, "impl": Text2TextGenerationPipeline,
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (), "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (), "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}}, "default": {"model": {"pt": ("t5-base", "686f1db"), "tf": ("t5-base", "686f1db")}},
"type": "text", "type": "text",
}, },
"text-generation": { "text-generation": {
"impl": TextGenerationPipeline, "impl": TextGenerationPipeline,
"tf": (TFAutoModelForCausalLM,) if is_tf_available() else (), "tf": (TFAutoModelForCausalLM,) if is_tf_available() else (),
"pt": (AutoModelForCausalLM,) if is_torch_available() else (), "pt": (AutoModelForCausalLM,) if is_torch_available() else (),
"default": {"model": {"pt": "gpt2", "tf": "gpt2"}}, "default": {"model": {"pt": ("gpt2", "6c0e608"), "tf": ("gpt2", "6c0e608")}},
"type": "text", "type": "text",
}, },
"zero-shot-classification": { "zero-shot-classification": {
...@@ -252,9 +250,8 @@ SUPPORTED_TASKS = { ...@@ -252,9 +250,8 @@ SUPPORTED_TASKS = {
"tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (), "tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (),
"pt": (AutoModelForSequenceClassification,) if is_torch_available() else (), "pt": (AutoModelForSequenceClassification,) if is_torch_available() else (),
"default": { "default": {
"model": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"}, "model": {"pt": ("facebook/bart-large-mnli", "c626438"), "tf": ("roberta-large-mnli", "130fb28")},
"config": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"}, "config": {"pt": ("facebook/bart-large-mnli", "c626438"), "tf": ("roberta-large-mnli", "130fb28")},
"tokenizer": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
}, },
"type": "text", "type": "text",
}, },
...@@ -262,35 +259,42 @@ SUPPORTED_TASKS = { ...@@ -262,35 +259,42 @@ SUPPORTED_TASKS = {
"impl": ZeroShotImageClassificationPipeline, "impl": ZeroShotImageClassificationPipeline,
"tf": (TFAutoModel,) if is_tf_available() else (), "tf": (TFAutoModel,) if is_tf_available() else (),
"pt": (AutoModel,) if is_torch_available() else (), "pt": (AutoModel,) if is_torch_available() else (),
"default": {"model": {"pt": "openai/clip-vit-base-patch32", "tf": "openai/clip-vit-base-patch32"}}, "default": {
"model": {
"pt": ("openai/clip-vit-base-patch32", "f4881ba"),
"tf": ("openai/clip-vit-base-patch32", "f4881ba"),
}
},
"type": "multimodal", "type": "multimodal",
}, },
"conversational": { "conversational": {
"impl": ConversationalPipeline, "impl": ConversationalPipeline,
"tf": (TFAutoModelForSeq2SeqLM, TFAutoModelForCausalLM) if is_tf_available() else (), "tf": (TFAutoModelForSeq2SeqLM, TFAutoModelForCausalLM) if is_tf_available() else (),
"pt": (AutoModelForSeq2SeqLM, AutoModelForCausalLM) if is_torch_available() else (), "pt": (AutoModelForSeq2SeqLM, AutoModelForCausalLM) if is_torch_available() else (),
"default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}}, "default": {
"model": {"pt": ("microsoft/DialoGPT-medium", "8bada3b"), "tf": ("microsoft/DialoGPT-medium", "8bada3b")}
},
"type": "text", "type": "text",
}, },
"image-classification": { "image-classification": {
"impl": ImageClassificationPipeline, "impl": ImageClassificationPipeline,
"tf": (), "tf": (),
"pt": (AutoModelForImageClassification,) if is_torch_available() else (), "pt": (AutoModelForImageClassification,) if is_torch_available() else (),
"default": {"model": {"pt": "google/vit-base-patch16-224"}}, "default": {"model": {"pt": ("google/vit-base-patch16-224", "5dca96d")}},
"type": "image", "type": "image",
}, },
"image-segmentation": { "image-segmentation": {
"impl": ImageSegmentationPipeline, "impl": ImageSegmentationPipeline,
"tf": (), "tf": (),
"pt": (AutoModelForImageSegmentation, AutoModelForSemanticSegmentation) if is_torch_available() else (), "pt": (AutoModelForImageSegmentation, AutoModelForSemanticSegmentation) if is_torch_available() else (),
"default": {"model": {"pt": "facebook/detr-resnet-50-panoptic"}}, "default": {"model": {"pt": ("facebook/detr-resnet-50-panoptic", "fc15262")}},
"type": "image", "type": "image",
}, },
"object-detection": { "object-detection": {
"impl": ObjectDetectionPipeline, "impl": ObjectDetectionPipeline,
"tf": (), "tf": (),
"pt": (AutoModelForObjectDetection,) if is_torch_available() else (), "pt": (AutoModelForObjectDetection,) if is_torch_available() else (),
"default": {"model": {"pt": "facebook/detr-resnet-50"}}, "default": {"model": {"pt": ("facebook/detr-resnet-50", "2729413")}},
"type": "image", "type": "image",
}, },
} }
...@@ -545,8 +549,13 @@ def pipeline( ...@@ -545,8 +549,13 @@ def pipeline(
# Use default model/config/tokenizer for the task if no model is provided # Use default model/config/tokenizer for the task if no model is provided
if model is None: if model is None:
# At that point framework might still be undetermined # At that point framework might still be undetermined
model = get_default_model(targeted_task, framework, task_options) model, default_revision = get_default_model_and_revision(targeted_task, framework, task_options)
logger.warning(f"No model was supplied, defaulted to {model} (https://huggingface.co/{model})") revision = revision if revision is not None else default_revision
logger.warning(
f"No model was supplied, defaulted to {model} and revision"
f" {revision} ({HUGGINGFACE_CO_RESOLVE_ENDPOINT}/{model}).\n"
"Using a pipeline without specifying a model name and revision in production is not recommended."
)
# Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained # Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained
model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token) model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token)
......
...@@ -341,7 +341,9 @@ def get_framework(model, revision: Optional[str] = None): ...@@ -341,7 +341,9 @@ def get_framework(model, revision: Optional[str] = None):
return framework return framework
def get_default_model(targeted_task: Dict, framework: Optional[str], task_options: Optional[Any]) -> str: def get_default_model_and_revision(
targeted_task: Dict, framework: Optional[str], task_options: Optional[Any]
) -> Union[str, Tuple[str, str]]:
""" """
Select a default model to use for a given task. Defaults to pytorch if ambiguous. Select a default model to use for a given task. Defaults to pytorch if ambiguous.
......
...@@ -22,6 +22,8 @@ from abc import abstractmethod ...@@ -22,6 +22,8 @@ from abc import abstractmethod
from functools import lru_cache from functools import lru_cache
from unittest import skipIf from unittest import skipIf
import numpy as np
from transformers import ( from transformers import (
FEATURE_EXTRACTOR_MAPPING, FEATURE_EXTRACTOR_MAPPING,
TOKENIZER_MAPPING, TOKENIZER_MAPPING,
...@@ -35,7 +37,15 @@ from transformers import ( ...@@ -35,7 +37,15 @@ from transformers import (
) )
from transformers.pipelines import get_task from transformers.pipelines import get_task
from transformers.pipelines.base import _pad from transformers.pipelines.base import _pad
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch from transformers.testing_utils import (
is_pipeline_test,
nested_simplify,
require_scatter,
require_tensorflow_probability,
require_tf,
require_torch,
slow,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -461,8 +471,8 @@ class PipelinePadTest(unittest.TestCase): ...@@ -461,8 +471,8 @@ class PipelinePadTest(unittest.TestCase):
@is_pipeline_test @is_pipeline_test
@require_torch
class PipelineUtilsTest(unittest.TestCase): class PipelineUtilsTest(unittest.TestCase):
@require_torch
def test_pipeline_dataset(self): def test_pipeline_dataset(self):
from transformers.pipelines.pt_utils import PipelineDataset from transformers.pipelines.pt_utils import PipelineDataset
...@@ -476,6 +486,7 @@ class PipelineUtilsTest(unittest.TestCase): ...@@ -476,6 +486,7 @@ class PipelineUtilsTest(unittest.TestCase):
outputs = [dataset[i] for i in range(4)] outputs = [dataset[i] for i in range(4)]
self.assertEqual(outputs, [2, 3, 4, 5]) self.assertEqual(outputs, [2, 3, 4, 5])
@require_torch
def test_pipeline_iterator(self): def test_pipeline_iterator(self):
from transformers.pipelines.pt_utils import PipelineIterator from transformers.pipelines.pt_utils import PipelineIterator
...@@ -490,6 +501,7 @@ class PipelineUtilsTest(unittest.TestCase): ...@@ -490,6 +501,7 @@ class PipelineUtilsTest(unittest.TestCase):
outputs = [item for item in dataset] outputs = [item for item in dataset]
self.assertEqual(outputs, [2, 3, 4, 5]) self.assertEqual(outputs, [2, 3, 4, 5])
@require_torch
def test_pipeline_iterator_no_len(self): def test_pipeline_iterator_no_len(self):
from transformers.pipelines.pt_utils import PipelineIterator from transformers.pipelines.pt_utils import PipelineIterator
...@@ -507,6 +519,7 @@ class PipelineUtilsTest(unittest.TestCase): ...@@ -507,6 +519,7 @@ class PipelineUtilsTest(unittest.TestCase):
outputs = [item for item in dataset] outputs = [item for item in dataset]
self.assertEqual(outputs, [2, 3, 4, 5]) self.assertEqual(outputs, [2, 3, 4, 5])
@require_torch
def test_pipeline_batch_unbatch_iterator(self): def test_pipeline_batch_unbatch_iterator(self):
from transformers.pipelines.pt_utils import PipelineIterator from transformers.pipelines.pt_utils import PipelineIterator
...@@ -520,6 +533,7 @@ class PipelineUtilsTest(unittest.TestCase): ...@@ -520,6 +533,7 @@ class PipelineUtilsTest(unittest.TestCase):
outputs = [item for item in dataset] outputs = [item for item in dataset]
self.assertEqual(outputs, [{"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}]) self.assertEqual(outputs, [{"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}])
@require_torch
def test_pipeline_batch_unbatch_iterator_tensors(self): def test_pipeline_batch_unbatch_iterator_tensors(self):
import torch import torch
...@@ -537,6 +551,7 @@ class PipelineUtilsTest(unittest.TestCase): ...@@ -537,6 +551,7 @@ class PipelineUtilsTest(unittest.TestCase):
nested_simplify(outputs), [{"id": [[12, 22]]}, {"id": [[2, 3]]}, {"id": [[2, 4]]}, {"id": [[5]]}] nested_simplify(outputs), [{"id": [[12, 22]]}, {"id": [[2, 3]]}, {"id": [[2, 4]]}, {"id": [[5]]}]
) )
@require_torch
def test_pipeline_chunk_iterator(self): def test_pipeline_chunk_iterator(self):
from transformers.pipelines.pt_utils import PipelineChunkIterator from transformers.pipelines.pt_utils import PipelineChunkIterator
...@@ -552,6 +567,7 @@ class PipelineUtilsTest(unittest.TestCase): ...@@ -552,6 +567,7 @@ class PipelineUtilsTest(unittest.TestCase):
self.assertEqual(outputs, [0, 1, 0, 1, 2]) self.assertEqual(outputs, [0, 1, 0, 1, 2])
@require_torch
def test_pipeline_pack_iterator(self): def test_pipeline_pack_iterator(self):
from transformers.pipelines.pt_utils import PipelinePackIterator from transformers.pipelines.pt_utils import PipelinePackIterator
...@@ -584,6 +600,7 @@ class PipelineUtilsTest(unittest.TestCase): ...@@ -584,6 +600,7 @@ class PipelineUtilsTest(unittest.TestCase):
], ],
) )
@require_torch
def test_pipeline_pack_unbatch_iterator(self): def test_pipeline_pack_unbatch_iterator(self):
from transformers.pipelines.pt_utils import PipelinePackIterator from transformers.pipelines.pt_utils import PipelinePackIterator
...@@ -607,3 +624,125 @@ class PipelineUtilsTest(unittest.TestCase): ...@@ -607,3 +624,125 @@ class PipelineUtilsTest(unittest.TestCase):
outputs = [item for item in dataset] outputs = [item for item in dataset]
self.assertEqual(outputs, [[{"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}]]) self.assertEqual(outputs, [[{"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}]])
@slow
@require_torch
def test_load_default_pipelines_pt(self):
import torch
from transformers.pipelines import SUPPORTED_TASKS
set_seed_fn = lambda: torch.manual_seed(0) # noqa: E731
for task in SUPPORTED_TASKS.keys():
if task == "table-question-answering":
# test table in seperate test due to more dependencies
continue
self.check_default_pipeline(task, "pt", set_seed_fn, self.check_models_equal_pt)
@slow
@require_tf
def test_load_default_pipelines_tf(self):
import tensorflow as tf
from transformers.pipelines import SUPPORTED_TASKS
set_seed_fn = lambda: tf.random.set_seed(0) # noqa: E731
for task in SUPPORTED_TASKS.keys():
if task == "table-question-answering":
# test table in seperate test due to more dependencies
continue
self.check_default_pipeline(task, "tf", set_seed_fn, self.check_models_equal_tf)
@slow
@require_torch
@require_scatter
def test_load_default_pipelines_pt_table_qa(self):
import torch
set_seed_fn = lambda: torch.manual_seed(0) # noqa: E731
self.check_default_pipeline("table-question-answering", "pt", set_seed_fn, self.check_models_equal_pt)
@slow
@require_tf
@require_tensorflow_probability
def test_load_default_pipelines_tf_table_qa(self):
import tensorflow as tf
set_seed_fn = lambda: tf.random.set_seed(0) # noqa: E731
self.check_default_pipeline("table-question-answering", "tf", set_seed_fn, self.check_models_equal_tf)
def check_default_pipeline(self, task, framework, set_seed_fn, check_models_equal_fn):
from transformers.pipelines import SUPPORTED_TASKS, pipeline
task_dict = SUPPORTED_TASKS[task]
# test to compare pipeline to manually loading the respective model
model = None
relevant_auto_classes = task_dict[framework]
if len(relevant_auto_classes) == 0:
# task has no default
logger.debug(f"{task} in {framework} has no default")
return
# by default use first class
auto_model_cls = relevant_auto_classes[0]
# retrieve correct model ids
if task == "translation":
# special case for translation pipeline which has multiple languages
model_ids = []
revisions = []
tasks = []
for translation_pair in task_dict["default"].keys():
model_id, revision = task_dict["default"][translation_pair]["model"][framework]
model_ids.append(model_id)
revisions.append(revision)
tasks.append(task + f"_{'_to_'.join(translation_pair)}")
else:
# normal case - non-translation pipeline
model_id, revision = task_dict["default"]["model"][framework]
model_ids = [model_id]
revisions = [revision]
tasks = [task]
# check for equality
for model_id, revision, task in zip(model_ids, revisions, tasks):
# load default model
try:
set_seed_fn()
model = auto_model_cls.from_pretrained(model_id, revision=revision)
except ValueError:
# first auto class is possible not compatible with model, go to next model class
auto_model_cls = relevant_auto_classes[1]
set_seed_fn()
model = auto_model_cls.from_pretrained(model_id, revision=revision)
# load default pipeline
set_seed_fn()
default_pipeline = pipeline(task, framework=framework)
# compare pipeline model with default model
models_are_equal = check_models_equal_fn(default_pipeline.model, model)
self.assertTrue(models_are_equal, f"{task} model doesn't match pipeline.")
logger.debug(f"{task} in {framework} succeeded with {model_id}.")
def check_models_equal_pt(self, model1, model2):
models_are_equal = True
for model1_p, model2_p in zip(model1.parameters(), model2.parameters()):
if model1_p.data.ne(model2_p.data).sum() > 0:
models_are_equal = False
return models_are_equal
def check_models_equal_tf(self, model1, model2):
models_are_equal = True
for model1_p, model2_p in zip(model1.weights, model2.weights):
if np.abs(model1_p.numpy() - model2_p.numpy()).sum() > 1e-5:
models_are_equal = False
return models_are_equal
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment