Unverified Commit 32e3466d authored by Alara Dirik's avatar Alara Dirik Committed by GitHub
Browse files

Add AutoModelForZeroShotImageClassification (#22087)

Adds AutoModelForZeroShotImageClassification to transformers
parent b90fbc7e
...@@ -258,6 +258,14 @@ The following auto classes are available for the following computer vision tasks ...@@ -258,6 +258,14 @@ The following auto classes are available for the following computer vision tasks
[[autodoc]] AutoModelForUniversalSegmentation [[autodoc]] AutoModelForUniversalSegmentation
### AutoModelForZeroShotImageClassification
[[autodoc]] AutoModelForZeroShotImageClassification
### TFAutoModelForZeroShotImageClassification
[[autodoc]] TFAutoModelForZeroShotImageClassification
### AutoModelForZeroShotObjectDetection ### AutoModelForZeroShotObjectDetection
[[autodoc]] AutoModelForZeroShotObjectDetection [[autodoc]] AutoModelForZeroShotObjectDetection
......
...@@ -1001,6 +1001,7 @@ else: ...@@ -1001,6 +1001,7 @@ else:
"MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING",
"MODEL_FOR_VISION_2_SEQ_MAPPING", "MODEL_FOR_VISION_2_SEQ_MAPPING",
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING", "MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
"MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING",
"MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING", "MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING",
"MODEL_MAPPING", "MODEL_MAPPING",
"MODEL_WITH_LM_HEAD_MAPPING", "MODEL_WITH_LM_HEAD_MAPPING",
...@@ -1033,6 +1034,7 @@ else: ...@@ -1033,6 +1034,7 @@ else:
"AutoModelForVideoClassification", "AutoModelForVideoClassification",
"AutoModelForVision2Seq", "AutoModelForVision2Seq",
"AutoModelForVisualQuestionAnswering", "AutoModelForVisualQuestionAnswering",
"AutoModelForZeroShotImageClassification",
"AutoModelForZeroShotObjectDetection", "AutoModelForZeroShotObjectDetection",
"AutoModelWithLMHead", "AutoModelWithLMHead",
] ]
...@@ -2785,6 +2787,7 @@ else: ...@@ -2785,6 +2787,7 @@ else:
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING", "TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING", "TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
"TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING",
"TF_MODEL_MAPPING", "TF_MODEL_MAPPING",
"TF_MODEL_WITH_LM_HEAD_MAPPING", "TF_MODEL_WITH_LM_HEAD_MAPPING",
"TFAutoModel", "TFAutoModel",
...@@ -2803,6 +2806,7 @@ else: ...@@ -2803,6 +2806,7 @@ else:
"TFAutoModelForTableQuestionAnswering", "TFAutoModelForTableQuestionAnswering",
"TFAutoModelForTokenClassification", "TFAutoModelForTokenClassification",
"TFAutoModelForVision2Seq", "TFAutoModelForVision2Seq",
"TFAutoModelForZeroShotImageClassification",
"TFAutoModelWithLMHead", "TFAutoModelWithLMHead",
] ]
) )
...@@ -4514,6 +4518,7 @@ if TYPE_CHECKING: ...@@ -4514,6 +4518,7 @@ if TYPE_CHECKING:
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING, MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING, MODEL_FOR_VISION_2_SEQ_MAPPING,
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING, MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING,
MODEL_MAPPING, MODEL_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING, MODEL_WITH_LM_HEAD_MAPPING,
...@@ -4546,6 +4551,7 @@ if TYPE_CHECKING: ...@@ -4546,6 +4551,7 @@ if TYPE_CHECKING:
AutoModelForVideoClassification, AutoModelForVideoClassification,
AutoModelForVision2Seq, AutoModelForVision2Seq,
AutoModelForVisualQuestionAnswering, AutoModelForVisualQuestionAnswering,
AutoModelForZeroShotImageClassification,
AutoModelForZeroShotObjectDetection, AutoModelForZeroShotObjectDetection,
AutoModelWithLMHead, AutoModelWithLMHead,
) )
...@@ -5971,6 +5977,7 @@ if TYPE_CHECKING: ...@@ -5971,6 +5977,7 @@ if TYPE_CHECKING:
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING, TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING,
TF_MODEL_MAPPING, TF_MODEL_MAPPING,
TF_MODEL_WITH_LM_HEAD_MAPPING, TF_MODEL_WITH_LM_HEAD_MAPPING,
TFAutoModel, TFAutoModel,
...@@ -5989,6 +5996,7 @@ if TYPE_CHECKING: ...@@ -5989,6 +5996,7 @@ if TYPE_CHECKING:
TFAutoModelForTableQuestionAnswering, TFAutoModelForTableQuestionAnswering,
TFAutoModelForTokenClassification, TFAutoModelForTokenClassification,
TFAutoModelForVision2Seq, TFAutoModelForVision2Seq,
TFAutoModelForZeroShotImageClassification,
TFAutoModelWithLMHead, TFAutoModelWithLMHead,
) )
from .models.bart import ( from .models.bart import (
......
...@@ -43,6 +43,7 @@ from .models.auto.modeling_auto import ( ...@@ -43,6 +43,7 @@ from .models.auto.modeling_auto import (
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
) )
from .training_args import ParallelMode from .training_args import ParallelMode
from .utils import ( from .utils import (
...@@ -70,6 +71,7 @@ TASK_MAPPING = { ...@@ -70,6 +71,7 @@ TASK_MAPPING = {
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, "token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
"audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, "audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
"automatic-speech-recognition": {**MODEL_FOR_CTC_MAPPING_NAMES, **MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES}, "automatic-speech-recognition": {**MODEL_FOR_CTC_MAPPING_NAMES, **MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES},
"zero-shot-image-classification": MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
} }
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -69,6 +69,7 @@ else: ...@@ -69,6 +69,7 @@ else:
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING", "MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
"MODEL_MAPPING", "MODEL_MAPPING",
"MODEL_WITH_LM_HEAD_MAPPING", "MODEL_WITH_LM_HEAD_MAPPING",
"MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING",
"MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING", "MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING",
"AutoModel", "AutoModel",
"AutoBackbone", "AutoBackbone",
...@@ -100,6 +101,7 @@ else: ...@@ -100,6 +101,7 @@ else:
"AutoModelForVisualQuestionAnswering", "AutoModelForVisualQuestionAnswering",
"AutoModelForDocumentQuestionAnswering", "AutoModelForDocumentQuestionAnswering",
"AutoModelWithLMHead", "AutoModelWithLMHead",
"AutoModelForZeroShotImageClassification",
"AutoModelForZeroShotObjectDetection", "AutoModelForZeroShotObjectDetection",
] ]
...@@ -126,6 +128,7 @@ else: ...@@ -126,6 +128,7 @@ else:
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING", "TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING", "TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
"TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING",
"TF_MODEL_MAPPING", "TF_MODEL_MAPPING",
"TF_MODEL_WITH_LM_HEAD_MAPPING", "TF_MODEL_WITH_LM_HEAD_MAPPING",
"TFAutoModel", "TFAutoModel",
...@@ -144,6 +147,7 @@ else: ...@@ -144,6 +147,7 @@ else:
"TFAutoModelForTableQuestionAnswering", "TFAutoModelForTableQuestionAnswering",
"TFAutoModelForTokenClassification", "TFAutoModelForTokenClassification",
"TFAutoModelForVision2Seq", "TFAutoModelForVision2Seq",
"TFAutoModelForZeroShotImageClassification",
"TFAutoModelWithLMHead", "TFAutoModelWithLMHead",
] ]
...@@ -226,6 +230,7 @@ if TYPE_CHECKING: ...@@ -226,6 +230,7 @@ if TYPE_CHECKING:
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING, MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING, MODEL_FOR_VISION_2_SEQ_MAPPING,
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING, MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING,
MODEL_MAPPING, MODEL_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING, MODEL_WITH_LM_HEAD_MAPPING,
...@@ -258,6 +263,7 @@ if TYPE_CHECKING: ...@@ -258,6 +263,7 @@ if TYPE_CHECKING:
AutoModelForVideoClassification, AutoModelForVideoClassification,
AutoModelForVision2Seq, AutoModelForVision2Seq,
AutoModelForVisualQuestionAnswering, AutoModelForVisualQuestionAnswering,
AutoModelForZeroShotImageClassification,
AutoModelForZeroShotObjectDetection, AutoModelForZeroShotObjectDetection,
AutoModelWithLMHead, AutoModelWithLMHead,
) )
...@@ -285,6 +291,7 @@ if TYPE_CHECKING: ...@@ -285,6 +291,7 @@ if TYPE_CHECKING:
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING, TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING,
TF_MODEL_MAPPING, TF_MODEL_MAPPING,
TF_MODEL_WITH_LM_HEAD_MAPPING, TF_MODEL_WITH_LM_HEAD_MAPPING,
TFAutoModel, TFAutoModel,
...@@ -303,6 +310,7 @@ if TYPE_CHECKING: ...@@ -303,6 +310,7 @@ if TYPE_CHECKING:
TFAutoModelForTableQuestionAnswering, TFAutoModelForTableQuestionAnswering,
TFAutoModelForTokenClassification, TFAutoModelForTokenClassification,
TFAutoModelForVision2Seq, TFAutoModelForVision2Seq,
TFAutoModelForZeroShotImageClassification,
TFAutoModelWithLMHead, TFAutoModelWithLMHead,
) )
......
...@@ -920,7 +920,7 @@ MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict( ...@@ -920,7 +920,7 @@ MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict(
] ]
) )
_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[ [
# Model for Zero Shot Image Classification mapping # Model for Zero Shot Image Classification mapping
("align", "AlignModel"), ("align", "AlignModel"),
...@@ -955,6 +955,9 @@ MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING = _LazyAutoMapping( ...@@ -955,6 +955,9 @@ MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
) )
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping( MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES
) )
...@@ -1142,6 +1145,15 @@ class AutoModelForImageClassification(_BaseAutoModelClass): ...@@ -1142,6 +1145,15 @@ class AutoModelForImageClassification(_BaseAutoModelClass):
AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification") AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification")
class AutoModelForZeroShotImageClassification(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
AutoModelForZeroShotImageClassification = auto_class_update(
AutoModelForZeroShotImageClassification, head_doc="zero-shot image classification"
)
class AutoModelForImageSegmentation(_BaseAutoModelClass): class AutoModelForImageSegmentation(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING _model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING
......
...@@ -209,6 +209,15 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -209,6 +209,15 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
] ]
) )
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Zero Shot Image Classification mapping
("clip", "TFCLIPModel"),
]
)
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict( TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
[ [
# Model for Semantic Segmentation mapping # Model for Semantic Segmentation mapping
...@@ -424,6 +433,9 @@ TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping( ...@@ -424,6 +433,9 @@ TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
) )
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping( TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
) )
...@@ -505,6 +517,15 @@ TFAutoModelForImageClassification = auto_class_update( ...@@ -505,6 +517,15 @@ TFAutoModelForImageClassification = auto_class_update(
) )
class TFAutoModelForZeroShotImageClassification(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
TFAutoModelForZeroShotImageClassification = auto_class_update(
TFAutoModelForZeroShotImageClassification, head_doc="zero-shot image classification"
)
class TFAutoModelForSemanticSegmentation(_BaseAutoModelClass): class TFAutoModelForSemanticSegmentation(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING _model_mapping = TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
......
...@@ -103,6 +103,7 @@ if is_tf_available(): ...@@ -103,6 +103,7 @@ if is_tf_available():
TFAutoModelForTableQuestionAnswering, TFAutoModelForTableQuestionAnswering,
TFAutoModelForTokenClassification, TFAutoModelForTokenClassification,
TFAutoModelForVision2Seq, TFAutoModelForVision2Seq,
TFAutoModelForZeroShotImageClassification,
) )
if is_torch_available(): if is_torch_available():
...@@ -135,6 +136,7 @@ if is_torch_available(): ...@@ -135,6 +136,7 @@ if is_torch_available():
AutoModelForVideoClassification, AutoModelForVideoClassification,
AutoModelForVision2Seq, AutoModelForVision2Seq,
AutoModelForVisualQuestionAnswering, AutoModelForVisualQuestionAnswering,
AutoModelForZeroShotImageClassification,
AutoModelForZeroShotObjectDetection, AutoModelForZeroShotObjectDetection,
) )
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -290,8 +292,8 @@ SUPPORTED_TASKS = { ...@@ -290,8 +292,8 @@ SUPPORTED_TASKS = {
}, },
"zero-shot-image-classification": { "zero-shot-image-classification": {
"impl": ZeroShotImageClassificationPipeline, "impl": ZeroShotImageClassificationPipeline,
"tf": (TFAutoModel,) if is_tf_available() else (), "tf": (TFAutoModelForZeroShotImageClassification,) if is_tf_available() else (),
"pt": (AutoModel,) if is_torch_available() else (), "pt": (AutoModelForZeroShotImageClassification,) if is_torch_available() else (),
"default": { "default": {
"model": { "model": {
"pt": ("openai/clip-vit-base-patch32", "f4881ba"), "pt": ("openai/clip-vit-base-patch32", "f4881ba"),
......
...@@ -18,9 +18,10 @@ if is_vision_available(): ...@@ -18,9 +18,10 @@ if is_vision_available():
from ..image_utils import load_image from ..image_utils import load_image
if is_torch_available(): if is_torch_available():
pass from ..models.auto.modeling_auto import MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
if is_tf_available(): if is_tf_available():
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
from ..tf_utils import stable_softmax from ..tf_utils import stable_softmax
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -64,8 +65,11 @@ class ZeroShotImageClassificationPipeline(Pipeline): ...@@ -64,8 +65,11 @@ class ZeroShotImageClassificationPipeline(Pipeline):
super().__init__(**kwargs) super().__init__(**kwargs)
requires_backends(self, "vision") requires_backends(self, "vision")
# No specific FOR_XXX available yet self.check_model_type(
# self.check_model_type(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING) TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
if self.framework == "tf"
else MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
)
def __call__(self, images: Union[str, List[str], "Image", List["Image"]], **kwargs): def __call__(self, images: Union[str, List[str], "Image", List["Image"]], **kwargs):
""" """
...@@ -137,9 +141,11 @@ class ZeroShotImageClassificationPipeline(Pipeline): ...@@ -137,9 +141,11 @@ class ZeroShotImageClassificationPipeline(Pipeline):
if self.framework == "pt": if self.framework == "pt":
probs = logits.softmax(dim=-1).squeeze(-1) probs = logits.softmax(dim=-1).squeeze(-1)
scores = probs.tolist() scores = probs.tolist()
else: elif self.framework == "tf":
probs = stable_softmax(logits, axis=-1) probs = stable_softmax(logits, axis=-1)
scores = probs.numpy().tolist() scores = probs.numpy().tolist()
else:
raise ValueError(f"Unsupported framework: {self.framework}")
result = [ result = [
{"score": score, "label": candidate_label} {"score": score, "label": candidate_label}
......
...@@ -526,6 +526,9 @@ MODEL_FOR_VISION_2_SEQ_MAPPING = None ...@@ -526,6 +526,9 @@ MODEL_FOR_VISION_2_SEQ_MAPPING = None
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = None MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = None
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = None
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = None MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = None
...@@ -738,6 +741,13 @@ class AutoModelForVisualQuestionAnswering(metaclass=DummyObject): ...@@ -738,6 +741,13 @@ class AutoModelForVisualQuestionAnswering(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class AutoModelForZeroShotImageClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class AutoModelForZeroShotObjectDetection(metaclass=DummyObject): class AutoModelForZeroShotObjectDetection(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -316,6 +316,9 @@ TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None ...@@ -316,6 +316,9 @@ TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
TF_MODEL_FOR_VISION_2_SEQ_MAPPING = None TF_MODEL_FOR_VISION_2_SEQ_MAPPING = None
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = None
TF_MODEL_MAPPING = None TF_MODEL_MAPPING = None
...@@ -434,6 +437,13 @@ class TFAutoModelForVision2Seq(metaclass=DummyObject): ...@@ -434,6 +437,13 @@ class TFAutoModelForVision2Seq(metaclass=DummyObject):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
class TFAutoModelForZeroShotImageClassification(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFAutoModelWithLMHead(metaclass=DummyObject): class TFAutoModelWithLMHead(metaclass=DummyObject):
_backends = ["tf"] _backends = ["tf"]
......
...@@ -50,6 +50,7 @@ from ..models.auto.modeling_auto import ( ...@@ -50,6 +50,7 @@ from ..models.auto.modeling_auto import (
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
MODEL_MAPPING_NAMES, MODEL_MAPPING_NAMES,
) )
from ..utils import ENV_VARS_TRUE_VALUES, TORCH_FX_REQUIRED_VERSION, is_torch_fx_available from ..utils import ENV_VARS_TRUE_VALUES, TORCH_FX_REQUIRED_VERSION, is_torch_fx_available
...@@ -79,6 +80,7 @@ def _generate_supported_model_class_names( ...@@ -79,6 +80,7 @@ def _generate_supported_model_class_names(
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, "token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
"masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES, "masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, "image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
"zero-shot-image-classification": MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
"ctc": MODEL_FOR_CTC_MAPPING_NAMES, "ctc": MODEL_FOR_CTC_MAPPING_NAMES,
"audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, "audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
"semantic-segmentation": MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES, "semantic-segmentation": MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
......
...@@ -93,8 +93,8 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [ ...@@ -93,8 +93,8 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [
("image-to-text", "MODEL_FOR_FOR_VISION_2_SEQ_MAPPING_NAMES", "AutoModelForVision2Seq"), ("image-to-text", "MODEL_FOR_FOR_VISION_2_SEQ_MAPPING_NAMES", "AutoModelForVision2Seq"),
( (
"zero-shot-image-classification", "zero-shot-image-classification",
"_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES", "MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES",
"AutoModel", "AutoModelForZeroShotImageClassification",
), ),
("depth-estimation", "MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES", "AutoModelForDepthEstimation"), ("depth-estimation", "MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES", "AutoModelForDepthEstimation"),
("video-classification", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES", "AutoModelForVideoClassification"), ("video-classification", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES", "AutoModelForVideoClassification"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment