Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
32e3466d
Unverified
Commit
32e3466d
authored
Mar 13, 2023
by
Alara Dirik
Committed by
GitHub
Mar 13, 2023
Browse files
Add AutoModelForZeroShotImageClassification (#22087)
Adds AutoModelForZeroShotImageClassification to transformers
parent
b90fbc7e
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
98 additions
and
9 deletions
+98
-9
docs/source/en/model_doc/auto.mdx
docs/source/en/model_doc/auto.mdx
+8
-0
src/transformers/__init__.py
src/transformers/__init__.py
+8
-0
src/transformers/modelcard.py
src/transformers/modelcard.py
+2
-0
src/transformers/models/auto/__init__.py
src/transformers/models/auto/__init__.py
+8
-0
src/transformers/models/auto/modeling_auto.py
src/transformers/models/auto/modeling_auto.py
+13
-1
src/transformers/models/auto/modeling_tf_auto.py
src/transformers/models/auto/modeling_tf_auto.py
+21
-0
src/transformers/pipelines/__init__.py
src/transformers/pipelines/__init__.py
+4
-2
src/transformers/pipelines/zero_shot_image_classification.py
src/transformers/pipelines/zero_shot_image_classification.py
+10
-4
src/transformers/utils/dummy_pt_objects.py
src/transformers/utils/dummy_pt_objects.py
+10
-0
src/transformers/utils/dummy_tf_objects.py
src/transformers/utils/dummy_tf_objects.py
+10
-0
src/transformers/utils/fx.py
src/transformers/utils/fx.py
+2
-0
utils/update_metadata.py
utils/update_metadata.py
+2
-2
No files found.
docs/source/en/model_doc/auto.mdx
View file @
32e3466d
...
...
@@ -258,6 +258,14 @@ The following auto classes are available for the following computer vision tasks
[[
autodoc
]]
AutoModelForUniversalSegmentation
###
AutoModelForZeroShotImageClassification
[[
autodoc
]]
AutoModelForZeroShotImageClassification
###
TFAutoModelForZeroShotImageClassification
[[
autodoc
]]
TFAutoModelForZeroShotImageClassification
###
AutoModelForZeroShotObjectDetection
[[
autodoc
]]
AutoModelForZeroShotObjectDetection
...
...
src/transformers/__init__.py
View file @
32e3466d
...
...
@@ -1001,6 +1001,7 @@ else:
"MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING"
,
"MODEL_FOR_VISION_2_SEQ_MAPPING"
,
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING"
,
"MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING"
,
"MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING"
,
"MODEL_MAPPING"
,
"MODEL_WITH_LM_HEAD_MAPPING"
,
...
...
@@ -1033,6 +1034,7 @@ else:
"AutoModelForVideoClassification"
,
"AutoModelForVision2Seq"
,
"AutoModelForVisualQuestionAnswering"
,
"AutoModelForZeroShotImageClassification"
,
"AutoModelForZeroShotObjectDetection"
,
"AutoModelWithLMHead"
,
]
...
...
@@ -2785,6 +2787,7 @@ else:
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING"
,
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING"
,
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING"
,
"TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING"
,
"TF_MODEL_MAPPING"
,
"TF_MODEL_WITH_LM_HEAD_MAPPING"
,
"TFAutoModel"
,
...
...
@@ -2803,6 +2806,7 @@ else:
"TFAutoModelForTableQuestionAnswering"
,
"TFAutoModelForTokenClassification"
,
"TFAutoModelForVision2Seq"
,
"TFAutoModelForZeroShotImageClassification"
,
"TFAutoModelWithLMHead"
,
]
)
...
...
@@ -4514,6 +4518,7 @@ if TYPE_CHECKING:
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING
,
MODEL_FOR_VISION_2_SEQ_MAPPING
,
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING
,
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
,
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
,
MODEL_MAPPING
,
MODEL_WITH_LM_HEAD_MAPPING
,
...
...
@@ -4546,6 +4551,7 @@ if TYPE_CHECKING:
AutoModelForVideoClassification
,
AutoModelForVision2Seq
,
AutoModelForVisualQuestionAnswering
,
AutoModelForZeroShotImageClassification
,
AutoModelForZeroShotObjectDetection
,
AutoModelWithLMHead
,
)
...
...
@@ -5971,6 +5977,7 @@ if TYPE_CHECKING:
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING
,
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
,
TF_MODEL_MAPPING
,
TF_MODEL_WITH_LM_HEAD_MAPPING
,
TFAutoModel
,
...
...
@@ -5989,6 +5996,7 @@ if TYPE_CHECKING:
TFAutoModelForTableQuestionAnswering
,
TFAutoModelForTokenClassification
,
TFAutoModelForVision2Seq
,
TFAutoModelForZeroShotImageClassification
,
TFAutoModelWithLMHead
,
)
from
.models.bart
import
(
...
...
src/transformers/modelcard.py
View file @
32e3466d
...
...
@@ -43,6 +43,7 @@ from .models.auto.modeling_auto import (
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
,
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
,
)
from
.training_args
import
ParallelMode
from
.utils
import
(
...
...
@@ -70,6 +71,7 @@ TASK_MAPPING = {
"token-classification"
:
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
,
"audio-classification"
:
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
,
"automatic-speech-recognition"
:
{
**
MODEL_FOR_CTC_MAPPING_NAMES
,
**
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
},
"zero-shot-image-classification"
:
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
,
}
logger
=
logging
.
get_logger
(
__name__
)
...
...
src/transformers/models/auto/__init__.py
View file @
32e3466d
...
...
@@ -69,6 +69,7 @@ else:
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING"
,
"MODEL_MAPPING"
,
"MODEL_WITH_LM_HEAD_MAPPING"
,
"MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING"
,
"MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING"
,
"AutoModel"
,
"AutoBackbone"
,
...
...
@@ -100,6 +101,7 @@ else:
"AutoModelForVisualQuestionAnswering"
,
"AutoModelForDocumentQuestionAnswering"
,
"AutoModelWithLMHead"
,
"AutoModelForZeroShotImageClassification"
,
"AutoModelForZeroShotObjectDetection"
,
]
...
...
@@ -126,6 +128,7 @@ else:
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING"
,
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING"
,
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING"
,
"TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING"
,
"TF_MODEL_MAPPING"
,
"TF_MODEL_WITH_LM_HEAD_MAPPING"
,
"TFAutoModel"
,
...
...
@@ -144,6 +147,7 @@ else:
"TFAutoModelForTableQuestionAnswering"
,
"TFAutoModelForTokenClassification"
,
"TFAutoModelForVision2Seq"
,
"TFAutoModelForZeroShotImageClassification"
,
"TFAutoModelWithLMHead"
,
]
...
...
@@ -226,6 +230,7 @@ if TYPE_CHECKING:
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING
,
MODEL_FOR_VISION_2_SEQ_MAPPING
,
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING
,
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
,
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
,
MODEL_MAPPING
,
MODEL_WITH_LM_HEAD_MAPPING
,
...
...
@@ -258,6 +263,7 @@ if TYPE_CHECKING:
AutoModelForVideoClassification
,
AutoModelForVision2Seq
,
AutoModelForVisualQuestionAnswering
,
AutoModelForZeroShotImageClassification
,
AutoModelForZeroShotObjectDetection
,
AutoModelWithLMHead
,
)
...
...
@@ -285,6 +291,7 @@ if TYPE_CHECKING:
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING
,
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
,
TF_MODEL_MAPPING
,
TF_MODEL_WITH_LM_HEAD_MAPPING
,
TFAutoModel
,
...
...
@@ -303,6 +310,7 @@ if TYPE_CHECKING:
TFAutoModelForTableQuestionAnswering
,
TFAutoModelForTokenClassification
,
TFAutoModelForVision2Seq
,
TFAutoModelForZeroShotImageClassification
,
TFAutoModelWithLMHead
,
)
...
...
src/transformers/models/auto/modeling_auto.py
View file @
32e3466d
...
...
@@ -920,7 +920,7 @@ MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict(
]
)
_
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
=
OrderedDict
(
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
=
OrderedDict
(
[
# Model for Zero Shot Image Classification mapping
(
"align"
,
"AlignModel"
),
...
...
@@ -955,6 +955,9 @@ MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES
)
...
...
@@ -1142,6 +1145,15 @@ class AutoModelForImageClassification(_BaseAutoModelClass):
AutoModelForImageClassification
=
auto_class_update
(
AutoModelForImageClassification
,
head_doc
=
"image classification"
)
class
AutoModelForZeroShotImageClassification
(
_BaseAutoModelClass
):
_model_mapping
=
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
AutoModelForZeroShotImageClassification
=
auto_class_update
(
AutoModelForZeroShotImageClassification
,
head_doc
=
"zero-shot image classification"
)
class
AutoModelForImageSegmentation
(
_BaseAutoModelClass
):
_model_mapping
=
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING
...
...
src/transformers/models/auto/modeling_tf_auto.py
View file @
32e3466d
...
...
@@ -209,6 +209,15 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
]
)
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
=
OrderedDict
(
[
# Model for Zero Shot Image Classification mapping
(
"clip"
,
"TFCLIPModel"
),
]
)
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
=
OrderedDict
(
[
# Model for Semantic Segmentation mapping
...
...
@@ -424,6 +433,9 @@ TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
)
...
...
@@ -505,6 +517,15 @@ TFAutoModelForImageClassification = auto_class_update(
)
class
TFAutoModelForZeroShotImageClassification
(
_BaseAutoModelClass
):
_model_mapping
=
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
TFAutoModelForZeroShotImageClassification
=
auto_class_update
(
TFAutoModelForZeroShotImageClassification
,
head_doc
=
"zero-shot image classification"
)
class
TFAutoModelForSemanticSegmentation
(
_BaseAutoModelClass
):
_model_mapping
=
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
...
...
src/transformers/pipelines/__init__.py
View file @
32e3466d
...
...
@@ -103,6 +103,7 @@ if is_tf_available():
TFAutoModelForTableQuestionAnswering
,
TFAutoModelForTokenClassification
,
TFAutoModelForVision2Seq
,
TFAutoModelForZeroShotImageClassification
,
)
if
is_torch_available
():
...
...
@@ -135,6 +136,7 @@ if is_torch_available():
AutoModelForVideoClassification
,
AutoModelForVision2Seq
,
AutoModelForVisualQuestionAnswering
,
AutoModelForZeroShotImageClassification
,
AutoModelForZeroShotObjectDetection
,
)
if
TYPE_CHECKING
:
...
...
@@ -290,8 +292,8 @@ SUPPORTED_TASKS = {
},
"zero-shot-image-classification"
:
{
"impl"
:
ZeroShotImageClassificationPipeline
,
"tf"
:
(
TFAutoModel
,)
if
is_tf_available
()
else
(),
"pt"
:
(
AutoModel
,)
if
is_torch_available
()
else
(),
"tf"
:
(
TFAutoModel
ForZeroShotImageClassification
,)
if
is_tf_available
()
else
(),
"pt"
:
(
AutoModel
ForZeroShotImageClassification
,)
if
is_torch_available
()
else
(),
"default"
:
{
"model"
:
{
"pt"
:
(
"openai/clip-vit-base-patch32"
,
"f4881ba"
),
...
...
src/transformers/pipelines/zero_shot_image_classification.py
View file @
32e3466d
...
...
@@ -18,9 +18,10 @@ if is_vision_available():
from
..image_utils
import
load_image
if
is_torch_available
():
pass
from
..models.auto.modeling_auto
import
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
if
is_tf_available
():
from
..models.auto.modeling_tf_auto
import
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
from
..tf_utils
import
stable_softmax
logger
=
logging
.
get_logger
(
__name__
)
...
...
@@ -64,8 +65,11 @@ class ZeroShotImageClassificationPipeline(Pipeline):
super
().
__init__
(
**
kwargs
)
requires_backends
(
self
,
"vision"
)
# No specific FOR_XXX available yet
# self.check_model_type(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING)
self
.
check_model_type
(
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
if
self
.
framework
==
"tf"
else
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
)
def
__call__
(
self
,
images
:
Union
[
str
,
List
[
str
],
"Image"
,
List
[
"Image"
]],
**
kwargs
):
"""
...
...
@@ -137,9 +141,11 @@ class ZeroShotImageClassificationPipeline(Pipeline):
if
self
.
framework
==
"pt"
:
probs
=
logits
.
softmax
(
dim
=-
1
).
squeeze
(
-
1
)
scores
=
probs
.
tolist
()
el
se
:
el
if
self
.
framework
==
"tf"
:
probs
=
stable_softmax
(
logits
,
axis
=-
1
)
scores
=
probs
.
numpy
().
tolist
()
else
:
raise
ValueError
(
f
"Unsupported framework:
{
self
.
framework
}
"
)
result
=
[
{
"score"
:
score
,
"label"
:
candidate_label
}
...
...
src/transformers/utils/dummy_pt_objects.py
View file @
32e3466d
...
...
@@ -526,6 +526,9 @@ MODEL_FOR_VISION_2_SEQ_MAPPING = None
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING
=
None
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
=
None
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
=
None
...
...
@@ -738,6 +741,13 @@ class AutoModelForVisualQuestionAnswering(metaclass=DummyObject):
requires_backends
(
self
,
[
"torch"
])
class
AutoModelForZeroShotImageClassification
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
AutoModelForZeroShotObjectDetection
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
...
...
src/transformers/utils/dummy_tf_objects.py
View file @
32e3466d
...
...
@@ -316,6 +316,9 @@ TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
TF_MODEL_FOR_VISION_2_SEQ_MAPPING
=
None
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
=
None
TF_MODEL_MAPPING
=
None
...
...
@@ -434,6 +437,13 @@ class TFAutoModelForVision2Seq(metaclass=DummyObject):
requires_backends
(
self
,
[
"tf"
])
class
TFAutoModelForZeroShotImageClassification
(
metaclass
=
DummyObject
):
_backends
=
[
"tf"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"tf"
])
class
TFAutoModelWithLMHead
(
metaclass
=
DummyObject
):
_backends
=
[
"tf"
]
...
...
src/transformers/utils/fx.py
View file @
32e3466d
...
...
@@ -50,6 +50,7 @@ from ..models.auto.modeling_auto import (
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
,
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
,
MODEL_MAPPING_NAMES
,
)
from
..utils
import
ENV_VARS_TRUE_VALUES
,
TORCH_FX_REQUIRED_VERSION
,
is_torch_fx_available
...
...
@@ -79,6 +80,7 @@ def _generate_supported_model_class_names(
"token-classification"
:
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
,
"masked-image-modeling"
:
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
,
"image-classification"
:
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
,
"zero-shot-image-classification"
:
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
,
"ctc"
:
MODEL_FOR_CTC_MAPPING_NAMES
,
"audio-classification"
:
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
,
"semantic-segmentation"
:
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
,
...
...
utils/update_metadata.py
View file @
32e3466d
...
...
@@ -93,8 +93,8 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [
(
"image-to-text"
,
"MODEL_FOR_FOR_VISION_2_SEQ_MAPPING_NAMES"
,
"AutoModelForVision2Seq"
),
(
"zero-shot-image-classification"
,
"
_
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES"
,
"AutoModel"
,
"MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES"
,
"AutoModel
ForZeroShotImageClassification
"
,
),
(
"depth-estimation"
,
"MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES"
,
"AutoModelForDepthEstimation"
),
(
"video-classification"
,
"MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES"
,
"AutoModelForVideoClassification"
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment