Unverified Commit 5b7ffd54 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Avoid importing all models when instantiating a pipeline (#24960)

* Avoid importing all models when instantiating a pipeline

* Remove sums that don't work
parent 640e1b6c
......@@ -650,6 +650,7 @@ class _LazyAutoMapping(OrderedDict):
self._config_mapping = config_mapping
self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
self._model_mapping = model_mapping
self._model_mapping._model_mapping = self
self._extra_content = {}
self._modules = {}
......
......@@ -88,11 +88,6 @@ if is_tf_available():
import tensorflow as tf
from ..models.auto.modeling_tf_auto import (
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
TF_MODEL_WITH_LM_HEAD_MAPPING,
TFAutoModel,
TFAutoModelForCausalLM,
TFAutoModelForImageClassification,
......@@ -110,13 +105,6 @@ if is_torch_available():
import torch
from ..models.auto.modeling_auto import (
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
AutoModel,
AutoModelForAudioClassification,
AutoModelForCausalLM,
......
......@@ -22,7 +22,7 @@ from .base import PIPELINE_INIT_ARGS, Pipeline
if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
from ..models.auto.modeling_auto import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
logger = logging.get_logger(__name__)
......@@ -98,7 +98,7 @@ class AudioClassificationPipeline(Pipeline):
if self.framework != "pt":
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
self.check_model_type(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING)
self.check_model_type(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES)
def __call__(
self,
......
......@@ -30,7 +30,7 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
def rescale_stride(stride, ratio):
......@@ -205,7 +205,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if self.model.config.model_type == "whisper":
self.type = "seq2seq_whisper"
elif self.model.__class__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values():
elif self.model.__class__.__name__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.values():
self.type = "seq2seq"
elif (
feature_extractor._processor_class
......@@ -220,7 +220,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if self.framework == "tf":
raise ValueError("The AutomaticSpeechRecognitionPipeline is only available in PyTorch.")
self.check_model_type(dict(MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.items() + MODEL_FOR_CTC_MAPPING.items()))
mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.copy()
mapping.update(MODEL_FOR_CTC_MAPPING_NAMES)
self.check_model_type(mapping)
def __call__(
self,
......
......@@ -952,12 +952,18 @@ class Pipeline(_ScikitCompat):
"""
if not isinstance(supported_models, list): # Create from a model mapping
supported_models_names = []
for config, model in supported_models.items():
for _, model_name in supported_models.items():
# Mapping can now contain tuples of models for the same configuration.
if isinstance(model, tuple):
supported_models_names.extend([_model.__name__ for _model in model])
if isinstance(model_name, tuple):
supported_models_names.extend(list(model_name))
else:
supported_models_names.append(model.__name__)
supported_models_names.append(model_name)
if hasattr(supported_models, "_model_mapping"):
for _, model in supported_models._model_mapping._extra_content.items():
if isinstance(model_name, tuple):
supported_models_names.extend([m.__name__ for m in model])
else:
supported_models_names.append(model.__name__)
supported_models = supported_models_names
if self.model.__class__.__name__ not in supported_models:
logger.error(
......
......@@ -14,7 +14,7 @@ if is_vision_available():
if is_torch_available():
import torch
from ..models.auto.modeling_auto import MODEL_FOR_DEPTH_ESTIMATION_MAPPING
from ..models.auto.modeling_auto import MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES
logger = logging.get_logger(__name__)
......@@ -48,7 +48,7 @@ class DepthEstimationPipeline(Pipeline):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
requires_backends(self, "vision")
self.check_model_type(MODEL_FOR_DEPTH_ESTIMATION_MAPPING)
self.check_model_type(MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES)
def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs):
"""
......
......@@ -37,7 +37,7 @@ if is_vision_available():
if is_torch_available():
import torch
from ..models.auto.modeling_auto import MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING
from ..models.auto.modeling_auto import MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES
TESSERACT_LOADED = False
if is_pytesseract_available():
......@@ -142,7 +142,7 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
if self.model.config.encoder.model_type != "donut-swin":
raise ValueError("Currently, the only supported VisionEncoderDecoder model is Donut")
else:
self.check_model_type(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING)
self.check_model_type(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES)
if self.model.config.__class__.__name__ == "LayoutLMConfig":
self.model_type = ModelType.LayoutLM
else:
......
......@@ -19,11 +19,11 @@ if is_vision_available():
if is_tf_available():
import tensorflow as tf
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
from ..tf_utils import stable_softmax
if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
logger = logging.get_logger(__name__)
......@@ -57,9 +57,9 @@ class ImageClassificationPipeline(Pipeline):
super().__init__(*args, **kwargs)
requires_backends(self, "vision")
self.check_model_type(
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
if self.framework == "tf"
else MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
else MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
def _sanitize_parameters(self, top_k=None):
......
......@@ -13,10 +13,10 @@ if is_vision_available():
if is_torch_available():
from ..models.auto.modeling_auto import (
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING,
MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING,
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES,
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES,
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES,
)
......@@ -71,14 +71,11 @@ class ImageSegmentationPipeline(Pipeline):
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
requires_backends(self, "vision")
self.check_model_type(
dict(
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING.items()
+ MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING.items()
+ MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING.items()
+ MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING.items()
)
)
mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES.copy()
mapping.update(MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES)
mapping.update(MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES)
mapping.update(MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES)
self.check_model_type(mapping)
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
......
......@@ -17,12 +17,12 @@ if is_vision_available():
from ..image_utils import load_image
if is_tf_available():
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_VISION_2_SEQ_MAPPING
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
if is_torch_available():
import torch
from ..models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING
from ..models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
logger = logging.get_logger(__name__)
......@@ -55,7 +55,7 @@ class ImageToTextPipeline(Pipeline):
super().__init__(*args, **kwargs)
requires_backends(self, "vision")
self.check_model_type(
TF_MODEL_FOR_VISION_2_SEQ_MAPPING if self.framework == "tf" else MODEL_FOR_VISION_2_SEQ_MAPPING
TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES if self.framework == "tf" else MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
)
def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None, prompt=None):
......
......@@ -14,7 +14,7 @@ from .base import PIPELINE_INIT_ARGS, ChunkPipeline
if is_torch_available():
import torch
from ..models.auto.modeling_auto import MODEL_FOR_MASK_GENERATION_MAPPING
from ..models.auto.modeling_auto import MODEL_FOR_MASK_GENERATION_MAPPING_NAMES
logger = logging.get_logger(__name__)
......@@ -96,7 +96,7 @@ class MaskGenerationPipeline(ChunkPipeline):
if self.framework != "pt":
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
self.check_model_type(MODEL_FOR_MASK_GENERATION_MAPPING)
self.check_model_type(MODEL_FOR_MASK_GENERATION_MAPPING_NAMES)
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
......
......@@ -11,7 +11,10 @@ if is_vision_available():
if is_torch_available():
import torch
from ..models.auto.modeling_auto import MODEL_FOR_OBJECT_DETECTION_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
from ..models.auto.modeling_auto import (
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
)
logger = logging.get_logger(__name__)
......@@ -53,9 +56,9 @@ class ObjectDetectionPipeline(Pipeline):
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
requires_backends(self, "vision")
self.check_model_type(
dict(MODEL_FOR_OBJECT_DETECTION_MAPPING.items() + MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items())
)
mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES.copy()
mapping.update(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES)
self.check_model_type(mapping)
def _sanitize_parameters(self, **kwargs):
postprocess_kwargs = {}
......
......@@ -32,7 +32,7 @@ if TYPE_CHECKING:
if is_tf_available():
import tensorflow as tf
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
Dataset = None
......@@ -40,7 +40,7 @@ if is_torch_available():
import torch
from torch.utils.data import Dataset
from ..models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
from ..models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
def decode_spans(
......@@ -270,7 +270,9 @@ class QuestionAnsweringPipeline(ChunkPipeline):
self._args_parser = QuestionAnsweringArgumentHandler()
self.check_model_type(
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING if self.framework == "tf" else MODEL_FOR_QUESTION_ANSWERING_MAPPING
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
if self.framework == "tf"
else MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
)
@staticmethod
......
......@@ -17,8 +17,8 @@ if is_torch_available():
import torch
from ..models.auto.modeling_auto import (
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
)
if is_tf_available() and is_tensorflow_probability_available():
......@@ -26,8 +26,8 @@ if is_tf_available() and is_tensorflow_probability_available():
import tensorflow_probability as tfp
from ..models.auto.modeling_tf_auto import (
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
)
......@@ -122,16 +122,13 @@ class TableQuestionAnsweringPipeline(Pipeline):
super().__init__(*args, **kwargs)
self._args_parser = args_parser
self.check_model_type(
dict(
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING.items()
+ TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items()
)
if self.framework == "tf"
else dict(
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING.items() + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items()
)
)
if self.framework == "tf":
mapping = TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES.copy()
mapping.update(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES)
else:
mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES.copy()
mapping.update(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES)
self.check_model_type(mapping)
self.aggregate = bool(getattr(self.model.config, "aggregation_labels", None)) and bool(
getattr(self.model.config, "num_aggregation_labels", None)
......
......@@ -9,10 +9,10 @@ from .base import PIPELINE_INIT_ARGS, Pipeline
if is_tf_available():
import tensorflow as tf
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
from ..models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
logger = logging.get_logger(__name__)
......@@ -65,9 +65,9 @@ class Text2TextGenerationPipeline(Pipeline):
super().__init__(*args, **kwargs)
self.check_model_type(
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
if self.framework == "tf"
else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
)
def _sanitize_parameters(
......
......@@ -9,10 +9,10 @@ from .base import PIPELINE_INIT_ARGS, GenericTensor, Pipeline
if is_tf_available():
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
from ..models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
def sigmoid(_outputs):
......@@ -84,9 +84,9 @@ class TextClassificationPipeline(Pipeline):
super().__init__(**kwargs)
self.check_model_type(
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
if self.framework == "tf"
else MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
else MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
)
def _sanitize_parameters(self, return_all_scores=None, function_to_apply=None, top_k="", **tokenizer_kwargs):
......
import enum
import warnings
from .. import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING
from ..utils import add_end_docstrings, is_tf_available
from ..utils import add_end_docstrings, is_tf_available, is_torch_available
from .base import PIPELINE_INIT_ARGS, Pipeline
if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
if is_tf_available():
import tensorflow as tf
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
class ReturnType(enum.Enum):
TENSORS = 0
......@@ -62,7 +66,7 @@ class TextGenerationPipeline(Pipeline):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.check_model_type(
TF_MODEL_FOR_CAUSAL_LM_MAPPING if self.framework == "tf" else MODEL_FOR_CAUSAL_LM_MAPPING
TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES if self.framework == "tf" else MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
)
if "prefix" not in self._preprocess_params:
# This is very specific. The logic is quite complex and needs to be done
......
......@@ -17,9 +17,9 @@ from .base import PIPELINE_INIT_ARGS, ArgumentHandler, ChunkPipeline, Dataset
if is_tf_available():
import tensorflow as tf
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
from ..models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
class TokenClassificationArgumentHandler(ArgumentHandler):
......@@ -135,9 +135,9 @@ class TokenClassificationPipeline(ChunkPipeline):
def __init__(self, args_parser=TokenClassificationArgumentHandler(), *args, **kwargs):
super().__init__(*args, **kwargs)
self.check_model_type(
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
if self.framework == "tf"
else MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
else MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
)
self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
......
......@@ -13,7 +13,7 @@ if is_decord_available():
if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING
from ..models.auto.modeling_auto import MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES
logger = logging.get_logger(__name__)
......@@ -34,7 +34,7 @@ class VideoClassificationPipeline(Pipeline):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
requires_backends(self, "decord")
self.check_model_type(MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING)
self.check_model_type(MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES)
def _sanitize_parameters(self, top_k=None, num_frames=None, frame_sampling_rate=None):
preprocess_params = {}
......
......@@ -10,7 +10,7 @@ if is_vision_available():
from ..image_utils import load_image
if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING
from ..models.auto.modeling_auto import MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES
logger = logging.get_logger(__name__)
......@@ -53,7 +53,7 @@ class VisualQuestionAnsweringPipeline(Pipeline):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.check_model_type(MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING)
self.check_model_type(MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES)
def _sanitize_parameters(self, top_k=None, padding=None, truncation=None, **kwargs):
preprocess_params, postprocess_params = {}, {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment