Unverified Commit 5b7ffd54 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Avoid importing all models when instantiating a pipeline (#24960)

* Avoid importing all models when instantiating a pipeline

* Remove sums that don't work
parent 640e1b6c
...@@ -18,10 +18,10 @@ if is_vision_available(): ...@@ -18,10 +18,10 @@ if is_vision_available():
from ..image_utils import load_image from ..image_utils import load_image
if is_torch_available(): if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING from ..models.auto.modeling_auto import MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
if is_tf_available(): if is_tf_available():
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
from ..tf_utils import stable_softmax from ..tf_utils import stable_softmax
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -66,9 +66,9 @@ class ZeroShotImageClassificationPipeline(Pipeline): ...@@ -66,9 +66,9 @@ class ZeroShotImageClassificationPipeline(Pipeline):
requires_backends(self, "vision") requires_backends(self, "vision")
self.check_model_type( self.check_model_type(
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
if self.framework == "tf" if self.framework == "tf"
else MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING else MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
) )
def __call__(self, images: Union[str, List[str], "Image", List["Image"]], **kwargs): def __call__(self, images: Union[str, List[str], "Image", List["Image"]], **kwargs):
......
...@@ -14,7 +14,7 @@ if is_torch_available(): ...@@ -14,7 +14,7 @@ if is_torch_available():
from transformers.modeling_outputs import BaseModelOutput from transformers.modeling_outputs import BaseModelOutput
from ..models.auto.modeling_auto import MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING from ..models.auto.modeling_auto import MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -60,7 +60,7 @@ class ZeroShotObjectDetectionPipeline(ChunkPipeline): ...@@ -60,7 +60,7 @@ class ZeroShotObjectDetectionPipeline(ChunkPipeline):
raise ValueError(f"The {self.__class__} is only available in PyTorch.") raise ValueError(f"The {self.__class__} is only available in PyTorch.")
requires_backends(self, "vision") requires_backends(self, "vision")
self.check_model_type(MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING) self.check_model_type(MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES)
def __call__( def __call__(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment