Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
32e3466d
Unverified
Commit
32e3466d
authored
Mar 13, 2023
by
Alara Dirik
Committed by
GitHub
Mar 13, 2023
Browse files
Add AutoModelForZeroShotImageClassification (#22087)
Adds AutoModelForZeroShotImageClassification to transformers
parent
b90fbc7e
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
98 additions
and
9 deletions
+98
-9
docs/source/en/model_doc/auto.mdx
docs/source/en/model_doc/auto.mdx
+8
-0
src/transformers/__init__.py
src/transformers/__init__.py
+8
-0
src/transformers/modelcard.py
src/transformers/modelcard.py
+2
-0
src/transformers/models/auto/__init__.py
src/transformers/models/auto/__init__.py
+8
-0
src/transformers/models/auto/modeling_auto.py
src/transformers/models/auto/modeling_auto.py
+13
-1
src/transformers/models/auto/modeling_tf_auto.py
src/transformers/models/auto/modeling_tf_auto.py
+21
-0
src/transformers/pipelines/__init__.py
src/transformers/pipelines/__init__.py
+4
-2
src/transformers/pipelines/zero_shot_image_classification.py
src/transformers/pipelines/zero_shot_image_classification.py
+10
-4
src/transformers/utils/dummy_pt_objects.py
src/transformers/utils/dummy_pt_objects.py
+10
-0
src/transformers/utils/dummy_tf_objects.py
src/transformers/utils/dummy_tf_objects.py
+10
-0
src/transformers/utils/fx.py
src/transformers/utils/fx.py
+2
-0
utils/update_metadata.py
utils/update_metadata.py
+2
-2
No files found.
docs/source/en/model_doc/auto.mdx
View file @
32e3466d
...
...
@@ -258,6 +258,14 @@ The following auto classes are available for the following computer vision tasks
[[
autodoc
]]
AutoModelForUniversalSegmentation
###
AutoModelForZeroShotImageClassification
[[
autodoc
]]
AutoModelForZeroShotImageClassification
###
TFAutoModelForZeroShotImageClassification
[[
autodoc
]]
TFAutoModelForZeroShotImageClassification
###
AutoModelForZeroShotObjectDetection
[[
autodoc
]]
AutoModelForZeroShotObjectDetection
...
...
src/transformers/__init__.py
View file @
32e3466d
...
...
@@ -1001,6 +1001,7 @@ else:
"MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING"
,
"MODEL_FOR_VISION_2_SEQ_MAPPING"
,
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING"
,
"MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING"
,
"MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING"
,
"MODEL_MAPPING"
,
"MODEL_WITH_LM_HEAD_MAPPING"
,
...
...
@@ -1033,6 +1034,7 @@ else:
"AutoModelForVideoClassification"
,
"AutoModelForVision2Seq"
,
"AutoModelForVisualQuestionAnswering"
,
"AutoModelForZeroShotImageClassification"
,
"AutoModelForZeroShotObjectDetection"
,
"AutoModelWithLMHead"
,
]
...
...
@@ -2785,6 +2787,7 @@ else:
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING"
,
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING"
,
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING"
,
"TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING"
,
"TF_MODEL_MAPPING"
,
"TF_MODEL_WITH_LM_HEAD_MAPPING"
,
"TFAutoModel"
,
...
...
@@ -2803,6 +2806,7 @@ else:
"TFAutoModelForTableQuestionAnswering"
,
"TFAutoModelForTokenClassification"
,
"TFAutoModelForVision2Seq"
,
"TFAutoModelForZeroShotImageClassification"
,
"TFAutoModelWithLMHead"
,
]
)
...
...
@@ -4514,6 +4518,7 @@ if TYPE_CHECKING:
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING
,
MODEL_FOR_VISION_2_SEQ_MAPPING
,
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING
,
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
,
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
,
MODEL_MAPPING
,
MODEL_WITH_LM_HEAD_MAPPING
,
...
...
@@ -4546,6 +4551,7 @@ if TYPE_CHECKING:
AutoModelForVideoClassification
,
AutoModelForVision2Seq
,
AutoModelForVisualQuestionAnswering
,
AutoModelForZeroShotImageClassification
,
AutoModelForZeroShotObjectDetection
,
AutoModelWithLMHead
,
)
...
...
@@ -5971,6 +5977,7 @@ if TYPE_CHECKING:
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING
,
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
,
TF_MODEL_MAPPING
,
TF_MODEL_WITH_LM_HEAD_MAPPING
,
TFAutoModel
,
...
...
@@ -5989,6 +5996,7 @@ if TYPE_CHECKING:
TFAutoModelForTableQuestionAnswering
,
TFAutoModelForTokenClassification
,
TFAutoModelForVision2Seq
,
TFAutoModelForZeroShotImageClassification
,
TFAutoModelWithLMHead
,
)
from
.models.bart
import
(
...
...
src/transformers/modelcard.py
View file @
32e3466d
...
...
@@ -43,6 +43,7 @@ from .models.auto.modeling_auto import (
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
,
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
,
)
from
.training_args
import
ParallelMode
from
.utils
import
(
...
...
@@ -70,6 +71,7 @@ TASK_MAPPING = {
"token-classification"
:
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
,
"audio-classification"
:
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
,
"automatic-speech-recognition"
:
{
**
MODEL_FOR_CTC_MAPPING_NAMES
,
**
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
},
"zero-shot-image-classification"
:
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
,
}
logger
=
logging
.
get_logger
(
__name__
)
...
...
src/transformers/models/auto/__init__.py
View file @
32e3466d
...
...
@@ -69,6 +69,7 @@ else:
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING"
,
"MODEL_MAPPING"
,
"MODEL_WITH_LM_HEAD_MAPPING"
,
"MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING"
,
"MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING"
,
"AutoModel"
,
"AutoBackbone"
,
...
...
@@ -100,6 +101,7 @@ else:
"AutoModelForVisualQuestionAnswering"
,
"AutoModelForDocumentQuestionAnswering"
,
"AutoModelWithLMHead"
,
"AutoModelForZeroShotImageClassification"
,
"AutoModelForZeroShotObjectDetection"
,
]
...
...
@@ -126,6 +128,7 @@ else:
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING"
,
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING"
,
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING"
,
"TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING"
,
"TF_MODEL_MAPPING"
,
"TF_MODEL_WITH_LM_HEAD_MAPPING"
,
"TFAutoModel"
,
...
...
@@ -144,6 +147,7 @@ else:
"TFAutoModelForTableQuestionAnswering"
,
"TFAutoModelForTokenClassification"
,
"TFAutoModelForVision2Seq"
,
"TFAutoModelForZeroShotImageClassification"
,
"TFAutoModelWithLMHead"
,
]
...
...
@@ -226,6 +230,7 @@ if TYPE_CHECKING:
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING
,
MODEL_FOR_VISION_2_SEQ_MAPPING
,
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING
,
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
,
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
,
MODEL_MAPPING
,
MODEL_WITH_LM_HEAD_MAPPING
,
...
...
@@ -258,6 +263,7 @@ if TYPE_CHECKING:
AutoModelForVideoClassification
,
AutoModelForVision2Seq
,
AutoModelForVisualQuestionAnswering
,
AutoModelForZeroShotImageClassification
,
AutoModelForZeroShotObjectDetection
,
AutoModelWithLMHead
,
)
...
...
@@ -285,6 +291,7 @@ if TYPE_CHECKING:
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING
,
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
,
TF_MODEL_MAPPING
,
TF_MODEL_WITH_LM_HEAD_MAPPING
,
TFAutoModel
,
...
...
@@ -303,6 +310,7 @@ if TYPE_CHECKING:
TFAutoModelForTableQuestionAnswering
,
TFAutoModelForTokenClassification
,
TFAutoModelForVision2Seq
,
TFAutoModelForZeroShotImageClassification
,
TFAutoModelWithLMHead
,
)
...
...
src/transformers/models/auto/modeling_auto.py
View file @
32e3466d
...
...
@@ -920,7 +920,7 @@ MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict(
]
)
_
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
=
OrderedDict
(
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
=
OrderedDict
(
[
# Model for Zero Shot Image Classification mapping
(
"align"
,
"AlignModel"
),
...
...
@@ -955,6 +955,9 @@ MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES
)
...
...
@@ -1142,6 +1145,15 @@ class AutoModelForImageClassification(_BaseAutoModelClass):
AutoModelForImageClassification
=
auto_class_update
(
AutoModelForImageClassification
,
head_doc
=
"image classification"
)
class
AutoModelForZeroShotImageClassification
(
_BaseAutoModelClass
):
_model_mapping
=
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
AutoModelForZeroShotImageClassification
=
auto_class_update
(
AutoModelForZeroShotImageClassification
,
head_doc
=
"zero-shot image classification"
)
class
AutoModelForImageSegmentation
(
_BaseAutoModelClass
):
_model_mapping
=
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING
...
...
src/transformers/models/auto/modeling_tf_auto.py
View file @
32e3466d
...
...
@@ -209,6 +209,15 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
]
)
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
=
OrderedDict
(
[
# Model for Zero Shot Image Classification mapping
(
"clip"
,
"TFCLIPModel"
),
]
)
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
=
OrderedDict
(
[
# Model for Semantic Segmentation mapping
...
...
@@ -424,6 +433,9 @@ TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
)
...
...
@@ -505,6 +517,15 @@ TFAutoModelForImageClassification = auto_class_update(
)
class
TFAutoModelForZeroShotImageClassification
(
_BaseAutoModelClass
):
_model_mapping
=
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
TFAutoModelForZeroShotImageClassification
=
auto_class_update
(
TFAutoModelForZeroShotImageClassification
,
head_doc
=
"zero-shot image classification"
)
class
TFAutoModelForSemanticSegmentation
(
_BaseAutoModelClass
):
_model_mapping
=
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
...
...
src/transformers/pipelines/__init__.py
View file @
32e3466d
...
...
@@ -103,6 +103,7 @@ if is_tf_available():
TFAutoModelForTableQuestionAnswering
,
TFAutoModelForTokenClassification
,
TFAutoModelForVision2Seq
,
TFAutoModelForZeroShotImageClassification
,
)
if
is_torch_available
():
...
...
@@ -135,6 +136,7 @@ if is_torch_available():
AutoModelForVideoClassification
,
AutoModelForVision2Seq
,
AutoModelForVisualQuestionAnswering
,
AutoModelForZeroShotImageClassification
,
AutoModelForZeroShotObjectDetection
,
)
if
TYPE_CHECKING
:
...
...
@@ -290,8 +292,8 @@ SUPPORTED_TASKS = {
},
"zero-shot-image-classification"
:
{
"impl"
:
ZeroShotImageClassificationPipeline
,
"tf"
:
(
TFAutoModel
,)
if
is_tf_available
()
else
(),
"pt"
:
(
AutoModel
,)
if
is_torch_available
()
else
(),
"tf"
:
(
TFAutoModel
ForZeroShotImageClassification
,)
if
is_tf_available
()
else
(),
"pt"
:
(
AutoModel
ForZeroShotImageClassification
,)
if
is_torch_available
()
else
(),
"default"
:
{
"model"
:
{
"pt"
:
(
"openai/clip-vit-base-patch32"
,
"f4881ba"
),
...
...
src/transformers/pipelines/zero_shot_image_classification.py
View file @
32e3466d
...
...
@@ -18,9 +18,10 @@ if is_vision_available():
from
..image_utils
import
load_image
if
is_torch_available
():
pass
from
..models.auto.modeling_auto
import
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
if
is_tf_available
():
from
..models.auto.modeling_tf_auto
import
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
from
..tf_utils
import
stable_softmax
logger
=
logging
.
get_logger
(
__name__
)
...
...
@@ -64,8 +65,11 @@ class ZeroShotImageClassificationPipeline(Pipeline):
super
().
__init__
(
**
kwargs
)
requires_backends
(
self
,
"vision"
)
# No specific FOR_XXX available yet
# self.check_model_type(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING)
self
.
check_model_type
(
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
if
self
.
framework
==
"tf"
else
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
)
def
__call__
(
self
,
images
:
Union
[
str
,
List
[
str
],
"Image"
,
List
[
"Image"
]],
**
kwargs
):
"""
...
...
@@ -137,9 +141,11 @@ class ZeroShotImageClassificationPipeline(Pipeline):
if
self
.
framework
==
"pt"
:
probs
=
logits
.
softmax
(
dim
=-
1
).
squeeze
(
-
1
)
scores
=
probs
.
tolist
()
el
se
:
el
if
self
.
framework
==
"tf"
:
probs
=
stable_softmax
(
logits
,
axis
=-
1
)
scores
=
probs
.
numpy
().
tolist
()
else
:
raise
ValueError
(
f
"Unsupported framework:
{
self
.
framework
}
"
)
result
=
[
{
"score"
:
score
,
"label"
:
candidate_label
}
...
...
src/transformers/utils/dummy_pt_objects.py
View file @
32e3466d
...
...
@@ -526,6 +526,9 @@ MODEL_FOR_VISION_2_SEQ_MAPPING = None
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING
=
None
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
=
None
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
=
None
...
...
@@ -738,6 +741,13 @@ class AutoModelForVisualQuestionAnswering(metaclass=DummyObject):
requires_backends
(
self
,
[
"torch"
])
class
AutoModelForZeroShotImageClassification
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
AutoModelForZeroShotObjectDetection
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
...
...
src/transformers/utils/dummy_tf_objects.py
View file @
32e3466d
...
...
@@ -316,6 +316,9 @@ TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
TF_MODEL_FOR_VISION_2_SEQ_MAPPING
=
None
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
=
None
TF_MODEL_MAPPING
=
None
...
...
@@ -434,6 +437,13 @@ class TFAutoModelForVision2Seq(metaclass=DummyObject):
requires_backends
(
self
,
[
"tf"
])
class
TFAutoModelForZeroShotImageClassification
(
metaclass
=
DummyObject
):
_backends
=
[
"tf"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"tf"
])
class
TFAutoModelWithLMHead
(
metaclass
=
DummyObject
):
_backends
=
[
"tf"
]
...
...
src/transformers/utils/fx.py
View file @
32e3466d
...
...
@@ -50,6 +50,7 @@ from ..models.auto.modeling_auto import (
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
,
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
,
MODEL_MAPPING_NAMES
,
)
from
..utils
import
ENV_VARS_TRUE_VALUES
,
TORCH_FX_REQUIRED_VERSION
,
is_torch_fx_available
...
...
@@ -79,6 +80,7 @@ def _generate_supported_model_class_names(
"token-classification"
:
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
,
"masked-image-modeling"
:
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
,
"image-classification"
:
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
,
"zero-shot-image-classification"
:
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
,
"ctc"
:
MODEL_FOR_CTC_MAPPING_NAMES
,
"audio-classification"
:
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
,
"semantic-segmentation"
:
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
,
...
...
utils/update_metadata.py
View file @
32e3466d
...
...
@@ -93,8 +93,8 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [
(
"image-to-text"
,
"MODEL_FOR_FOR_VISION_2_SEQ_MAPPING_NAMES"
,
"AutoModelForVision2Seq"
),
(
"zero-shot-image-classification"
,
"
_
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES"
,
"AutoModel"
,
"MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES"
,
"AutoModel
ForZeroShotImageClassification
"
,
),
(
"depth-estimation"
,
"MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES"
,
"AutoModelForDepthEstimation"
),
(
"video-classification"
,
"MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES"
,
"AutoModelForVideoClassification"
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment