Unverified Commit 5b7ffd54 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Avoid importing all models when instantiating a pipeline (#24960)

* Avoid importing all models when instantiating a pipeline

* Remove sums that don't work
parent 640e1b6c
...@@ -650,6 +650,7 @@ class _LazyAutoMapping(OrderedDict): ...@@ -650,6 +650,7 @@ class _LazyAutoMapping(OrderedDict):
self._config_mapping = config_mapping self._config_mapping = config_mapping
self._reverse_config_mapping = {v: k for k, v in config_mapping.items()} self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
self._model_mapping = model_mapping self._model_mapping = model_mapping
self._model_mapping._model_mapping = self
self._extra_content = {} self._extra_content = {}
self._modules = {} self._modules = {}
......
...@@ -88,11 +88,6 @@ if is_tf_available(): ...@@ -88,11 +88,6 @@ if is_tf_available():
import tensorflow as tf import tensorflow as tf
from ..models.auto.modeling_tf_auto import ( from ..models.auto.modeling_tf_auto import (
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
TF_MODEL_WITH_LM_HEAD_MAPPING,
TFAutoModel, TFAutoModel,
TFAutoModelForCausalLM, TFAutoModelForCausalLM,
TFAutoModelForImageClassification, TFAutoModelForImageClassification,
...@@ -110,13 +105,6 @@ if is_torch_available(): ...@@ -110,13 +105,6 @@ if is_torch_available():
import torch import torch
from ..models.auto.modeling_auto import ( from ..models.auto.modeling_auto import (
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
AutoModel, AutoModel,
AutoModelForAudioClassification, AutoModelForAudioClassification,
AutoModelForCausalLM, AutoModelForCausalLM,
......
...@@ -22,7 +22,7 @@ from .base import PIPELINE_INIT_ARGS, Pipeline ...@@ -22,7 +22,7 @@ from .base import PIPELINE_INIT_ARGS, Pipeline
if is_torch_available(): if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING from ..models.auto.modeling_auto import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -98,7 +98,7 @@ class AudioClassificationPipeline(Pipeline): ...@@ -98,7 +98,7 @@ class AudioClassificationPipeline(Pipeline):
if self.framework != "pt": if self.framework != "pt":
raise ValueError(f"The {self.__class__} is only available in PyTorch.") raise ValueError(f"The {self.__class__} is only available in PyTorch.")
self.check_model_type(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING) self.check_model_type(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES)
def __call__( def __call__(
self, self,
......
...@@ -30,7 +30,7 @@ if TYPE_CHECKING: ...@@ -30,7 +30,7 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
if is_torch_available(): if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
def rescale_stride(stride, ratio): def rescale_stride(stride, ratio):
...@@ -205,7 +205,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -205,7 +205,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if self.model.config.model_type == "whisper": if self.model.config.model_type == "whisper":
self.type = "seq2seq_whisper" self.type = "seq2seq_whisper"
elif self.model.__class__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values(): elif self.model.__class__.__name__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.values():
self.type = "seq2seq" self.type = "seq2seq"
elif ( elif (
feature_extractor._processor_class feature_extractor._processor_class
...@@ -220,7 +220,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -220,7 +220,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if self.framework == "tf": if self.framework == "tf":
raise ValueError("The AutomaticSpeechRecognitionPipeline is only available in PyTorch.") raise ValueError("The AutomaticSpeechRecognitionPipeline is only available in PyTorch.")
self.check_model_type(dict(MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.items() + MODEL_FOR_CTC_MAPPING.items())) mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.copy()
mapping.update(MODEL_FOR_CTC_MAPPING_NAMES)
self.check_model_type(mapping)
def __call__( def __call__(
self, self,
......
...@@ -952,10 +952,16 @@ class Pipeline(_ScikitCompat): ...@@ -952,10 +952,16 @@ class Pipeline(_ScikitCompat):
""" """
if not isinstance(supported_models, list): # Create from a model mapping if not isinstance(supported_models, list): # Create from a model mapping
supported_models_names = [] supported_models_names = []
for config, model in supported_models.items(): for _, model_name in supported_models.items():
# Mapping can now contain tuples of models for the same configuration. # Mapping can now contain tuples of models for the same configuration.
if isinstance(model, tuple): if isinstance(model_name, tuple):
supported_models_names.extend([_model.__name__ for _model in model]) supported_models_names.extend(list(model_name))
else:
supported_models_names.append(model_name)
if hasattr(supported_models, "_model_mapping"):
for _, model in supported_models._model_mapping._extra_content.items():
if isinstance(model_name, tuple):
supported_models_names.extend([m.__name__ for m in model])
else: else:
supported_models_names.append(model.__name__) supported_models_names.append(model.__name__)
supported_models = supported_models_names supported_models = supported_models_names
......
...@@ -14,7 +14,7 @@ if is_vision_available(): ...@@ -14,7 +14,7 @@ if is_vision_available():
if is_torch_available(): if is_torch_available():
import torch import torch
from ..models.auto.modeling_auto import MODEL_FOR_DEPTH_ESTIMATION_MAPPING from ..models.auto.modeling_auto import MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -48,7 +48,7 @@ class DepthEstimationPipeline(Pipeline): ...@@ -48,7 +48,7 @@ class DepthEstimationPipeline(Pipeline):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
requires_backends(self, "vision") requires_backends(self, "vision")
self.check_model_type(MODEL_FOR_DEPTH_ESTIMATION_MAPPING) self.check_model_type(MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES)
def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs): def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs):
""" """
......
...@@ -37,7 +37,7 @@ if is_vision_available(): ...@@ -37,7 +37,7 @@ if is_vision_available():
if is_torch_available(): if is_torch_available():
import torch import torch
from ..models.auto.modeling_auto import MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING from ..models.auto.modeling_auto import MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES
TESSERACT_LOADED = False TESSERACT_LOADED = False
if is_pytesseract_available(): if is_pytesseract_available():
...@@ -142,7 +142,7 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline): ...@@ -142,7 +142,7 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
if self.model.config.encoder.model_type != "donut-swin": if self.model.config.encoder.model_type != "donut-swin":
raise ValueError("Currently, the only supported VisionEncoderDecoder model is Donut") raise ValueError("Currently, the only supported VisionEncoderDecoder model is Donut")
else: else:
self.check_model_type(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING) self.check_model_type(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES)
if self.model.config.__class__.__name__ == "LayoutLMConfig": if self.model.config.__class__.__name__ == "LayoutLMConfig":
self.model_type = ModelType.LayoutLM self.model_type = ModelType.LayoutLM
else: else:
......
...@@ -19,11 +19,11 @@ if is_vision_available(): ...@@ -19,11 +19,11 @@ if is_vision_available():
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
from ..tf_utils import stable_softmax from ..tf_utils import stable_softmax
if is_torch_available(): if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -57,9 +57,9 @@ class ImageClassificationPipeline(Pipeline): ...@@ -57,9 +57,9 @@ class ImageClassificationPipeline(Pipeline):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
requires_backends(self, "vision") requires_backends(self, "vision")
self.check_model_type( self.check_model_type(
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
if self.framework == "tf" if self.framework == "tf"
else MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING else MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
) )
def _sanitize_parameters(self, top_k=None): def _sanitize_parameters(self, top_k=None):
......
...@@ -13,10 +13,10 @@ if is_vision_available(): ...@@ -13,10 +13,10 @@ if is_vision_available():
if is_torch_available(): if is_torch_available():
from ..models.auto.modeling_auto import ( from ..models.auto.modeling_auto import (
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES,
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES,
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING, MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING, MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES,
) )
...@@ -71,14 +71,11 @@ class ImageSegmentationPipeline(Pipeline): ...@@ -71,14 +71,11 @@ class ImageSegmentationPipeline(Pipeline):
raise ValueError(f"The {self.__class__} is only available in PyTorch.") raise ValueError(f"The {self.__class__} is only available in PyTorch.")
requires_backends(self, "vision") requires_backends(self, "vision")
self.check_model_type( mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES.copy()
dict( mapping.update(MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES)
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING.items() mapping.update(MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES)
+ MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING.items() mapping.update(MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES)
+ MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING.items() self.check_model_type(mapping)
+ MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING.items()
)
)
def _sanitize_parameters(self, **kwargs): def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {} preprocess_kwargs = {}
......
...@@ -17,12 +17,12 @@ if is_vision_available(): ...@@ -17,12 +17,12 @@ if is_vision_available():
from ..image_utils import load_image from ..image_utils import load_image
if is_tf_available(): if is_tf_available():
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_VISION_2_SEQ_MAPPING from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
if is_torch_available(): if is_torch_available():
import torch import torch
from ..models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING from ..models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -55,7 +55,7 @@ class ImageToTextPipeline(Pipeline): ...@@ -55,7 +55,7 @@ class ImageToTextPipeline(Pipeline):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
requires_backends(self, "vision") requires_backends(self, "vision")
self.check_model_type( self.check_model_type(
TF_MODEL_FOR_VISION_2_SEQ_MAPPING if self.framework == "tf" else MODEL_FOR_VISION_2_SEQ_MAPPING TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES if self.framework == "tf" else MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
) )
def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None, prompt=None): def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None, prompt=None):
......
...@@ -14,7 +14,7 @@ from .base import PIPELINE_INIT_ARGS, ChunkPipeline ...@@ -14,7 +14,7 @@ from .base import PIPELINE_INIT_ARGS, ChunkPipeline
if is_torch_available(): if is_torch_available():
import torch import torch
from ..models.auto.modeling_auto import MODEL_FOR_MASK_GENERATION_MAPPING from ..models.auto.modeling_auto import MODEL_FOR_MASK_GENERATION_MAPPING_NAMES
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -96,7 +96,7 @@ class MaskGenerationPipeline(ChunkPipeline): ...@@ -96,7 +96,7 @@ class MaskGenerationPipeline(ChunkPipeline):
if self.framework != "pt": if self.framework != "pt":
raise ValueError(f"The {self.__class__} is only available in PyTorch.") raise ValueError(f"The {self.__class__} is only available in PyTorch.")
self.check_model_type(MODEL_FOR_MASK_GENERATION_MAPPING) self.check_model_type(MODEL_FOR_MASK_GENERATION_MAPPING_NAMES)
def _sanitize_parameters(self, **kwargs): def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {} preprocess_kwargs = {}
......
...@@ -11,7 +11,10 @@ if is_vision_available(): ...@@ -11,7 +11,10 @@ if is_vision_available():
if is_torch_available(): if is_torch_available():
import torch import torch
from ..models.auto.modeling_auto import MODEL_FOR_OBJECT_DETECTION_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING from ..models.auto.modeling_auto import (
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
)
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -53,9 +56,9 @@ class ObjectDetectionPipeline(Pipeline): ...@@ -53,9 +56,9 @@ class ObjectDetectionPipeline(Pipeline):
raise ValueError(f"The {self.__class__} is only available in PyTorch.") raise ValueError(f"The {self.__class__} is only available in PyTorch.")
requires_backends(self, "vision") requires_backends(self, "vision")
self.check_model_type( mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES.copy()
dict(MODEL_FOR_OBJECT_DETECTION_MAPPING.items() + MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items()) mapping.update(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES)
) self.check_model_type(mapping)
def _sanitize_parameters(self, **kwargs): def _sanitize_parameters(self, **kwargs):
postprocess_kwargs = {} postprocess_kwargs = {}
......
...@@ -32,7 +32,7 @@ if TYPE_CHECKING: ...@@ -32,7 +32,7 @@ if TYPE_CHECKING:
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
Dataset = None Dataset = None
...@@ -40,7 +40,7 @@ if is_torch_available(): ...@@ -40,7 +40,7 @@ if is_torch_available():
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from ..models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING from ..models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
def decode_spans( def decode_spans(
...@@ -270,7 +270,9 @@ class QuestionAnsweringPipeline(ChunkPipeline): ...@@ -270,7 +270,9 @@ class QuestionAnsweringPipeline(ChunkPipeline):
self._args_parser = QuestionAnsweringArgumentHandler() self._args_parser = QuestionAnsweringArgumentHandler()
self.check_model_type( self.check_model_type(
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING if self.framework == "tf" else MODEL_FOR_QUESTION_ANSWERING_MAPPING TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
if self.framework == "tf"
else MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
) )
@staticmethod @staticmethod
......
...@@ -17,8 +17,8 @@ if is_torch_available(): ...@@ -17,8 +17,8 @@ if is_torch_available():
import torch import torch
from ..models.auto.modeling_auto import ( from ..models.auto.modeling_auto import (
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
) )
if is_tf_available() and is_tensorflow_probability_available(): if is_tf_available() and is_tensorflow_probability_available():
...@@ -26,8 +26,8 @@ if is_tf_available() and is_tensorflow_probability_available(): ...@@ -26,8 +26,8 @@ if is_tf_available() and is_tensorflow_probability_available():
import tensorflow_probability as tfp import tensorflow_probability as tfp
from ..models.auto.modeling_tf_auto import ( from ..models.auto.modeling_tf_auto import (
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
) )
...@@ -122,16 +122,13 @@ class TableQuestionAnsweringPipeline(Pipeline): ...@@ -122,16 +122,13 @@ class TableQuestionAnsweringPipeline(Pipeline):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._args_parser = args_parser self._args_parser = args_parser
self.check_model_type( if self.framework == "tf":
dict( mapping = TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES.copy()
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING.items() mapping.update(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES)
+ TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items() else:
) mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES.copy()
if self.framework == "tf" mapping.update(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES)
else dict( self.check_model_type(mapping)
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING.items() + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items()
)
)
self.aggregate = bool(getattr(self.model.config, "aggregation_labels", None)) and bool( self.aggregate = bool(getattr(self.model.config, "aggregation_labels", None)) and bool(
getattr(self.model.config, "num_aggregation_labels", None) getattr(self.model.config, "num_aggregation_labels", None)
......
...@@ -9,10 +9,10 @@ from .base import PIPELINE_INIT_ARGS, Pipeline ...@@ -9,10 +9,10 @@ from .base import PIPELINE_INIT_ARGS, Pipeline
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
if is_torch_available(): if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING from ..models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -65,9 +65,9 @@ class Text2TextGenerationPipeline(Pipeline): ...@@ -65,9 +65,9 @@ class Text2TextGenerationPipeline(Pipeline):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.check_model_type( self.check_model_type(
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
if self.framework == "tf" if self.framework == "tf"
else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
) )
def _sanitize_parameters( def _sanitize_parameters(
......
...@@ -9,10 +9,10 @@ from .base import PIPELINE_INIT_ARGS, GenericTensor, Pipeline ...@@ -9,10 +9,10 @@ from .base import PIPELINE_INIT_ARGS, GenericTensor, Pipeline
if is_tf_available(): if is_tf_available():
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
if is_torch_available(): if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING from ..models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
def sigmoid(_outputs): def sigmoid(_outputs):
...@@ -84,9 +84,9 @@ class TextClassificationPipeline(Pipeline): ...@@ -84,9 +84,9 @@ class TextClassificationPipeline(Pipeline):
super().__init__(**kwargs) super().__init__(**kwargs)
self.check_model_type( self.check_model_type(
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
if self.framework == "tf" if self.framework == "tf"
else MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING else MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
) )
def _sanitize_parameters(self, return_all_scores=None, function_to_apply=None, top_k="", **tokenizer_kwargs): def _sanitize_parameters(self, return_all_scores=None, function_to_apply=None, top_k="", **tokenizer_kwargs):
......
import enum import enum
import warnings import warnings
from .. import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING from ..utils import add_end_docstrings, is_tf_available, is_torch_available
from ..utils import add_end_docstrings, is_tf_available
from .base import PIPELINE_INIT_ARGS, Pipeline from .base import PIPELINE_INIT_ARGS, Pipeline
if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
class ReturnType(enum.Enum): class ReturnType(enum.Enum):
TENSORS = 0 TENSORS = 0
...@@ -62,7 +66,7 @@ class TextGenerationPipeline(Pipeline): ...@@ -62,7 +66,7 @@ class TextGenerationPipeline(Pipeline):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.check_model_type( self.check_model_type(
TF_MODEL_FOR_CAUSAL_LM_MAPPING if self.framework == "tf" else MODEL_FOR_CAUSAL_LM_MAPPING TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES if self.framework == "tf" else MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
) )
if "prefix" not in self._preprocess_params: if "prefix" not in self._preprocess_params:
# This is very specific. The logic is quite complex and needs to be done # This is very specific. The logic is quite complex and needs to be done
......
...@@ -17,9 +17,9 @@ from .base import PIPELINE_INIT_ARGS, ArgumentHandler, ChunkPipeline, Dataset ...@@ -17,9 +17,9 @@ from .base import PIPELINE_INIT_ARGS, ArgumentHandler, ChunkPipeline, Dataset
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
if is_torch_available(): if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING from ..models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
class TokenClassificationArgumentHandler(ArgumentHandler): class TokenClassificationArgumentHandler(ArgumentHandler):
...@@ -135,9 +135,9 @@ class TokenClassificationPipeline(ChunkPipeline): ...@@ -135,9 +135,9 @@ class TokenClassificationPipeline(ChunkPipeline):
def __init__(self, args_parser=TokenClassificationArgumentHandler(), *args, **kwargs): def __init__(self, args_parser=TokenClassificationArgumentHandler(), *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.check_model_type( self.check_model_type(
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
if self.framework == "tf" if self.framework == "tf"
else MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING else MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
) )
self._basic_tokenizer = BasicTokenizer(do_lower_case=False) self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
......
...@@ -13,7 +13,7 @@ if is_decord_available(): ...@@ -13,7 +13,7 @@ if is_decord_available():
if is_torch_available(): if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING from ..models.auto.modeling_auto import MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -34,7 +34,7 @@ class VideoClassificationPipeline(Pipeline): ...@@ -34,7 +34,7 @@ class VideoClassificationPipeline(Pipeline):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
requires_backends(self, "decord") requires_backends(self, "decord")
self.check_model_type(MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING) self.check_model_type(MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES)
def _sanitize_parameters(self, top_k=None, num_frames=None, frame_sampling_rate=None): def _sanitize_parameters(self, top_k=None, num_frames=None, frame_sampling_rate=None):
preprocess_params = {} preprocess_params = {}
......
...@@ -10,7 +10,7 @@ if is_vision_available(): ...@@ -10,7 +10,7 @@ if is_vision_available():
from ..image_utils import load_image from ..image_utils import load_image
if is_torch_available(): if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING from ..models.auto.modeling_auto import MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -53,7 +53,7 @@ class VisualQuestionAnsweringPipeline(Pipeline): ...@@ -53,7 +53,7 @@ class VisualQuestionAnsweringPipeline(Pipeline):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.check_model_type(MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING) self.check_model_type(MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES)
def _sanitize_parameters(self, top_k=None, padding=None, truncation=None, **kwargs): def _sanitize_parameters(self, top_k=None, padding=None, truncation=None, **kwargs):
preprocess_params, postprocess_params = {}, {} preprocess_params, postprocess_params = {}, {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment