Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
66964c00
Unverified
Commit
66964c00
authored
Jan 11, 2024
by
amyeroberts
Committed by
GitHub
Jan 11, 2024
Browse files
Enable multi-label image classification in pipeline (#28433)
Enable multi-label image classification
parent
8205b264
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
132 additions
and
17 deletions
+132
-17
src/transformers/pipelines/image_classification.py
src/transformers/pipelines/image_classification.py
+86
-17
tests/pipelines/test_pipelines_image_classification.py
tests/pipelines/test_pipelines_image_classification.py
+46
-0
No files found.
src/transformers/pipelines/image_classification.py
View file @
66964c00
from
typing
import
List
,
Union
from
typing
import
List
,
Union
import
numpy
as
np
from
..utils
import
(
from
..utils
import
(
ExplicitEnum
,
add_end_docstrings
,
add_end_docstrings
,
is_tf_available
,
is_tf_available
,
is_torch_available
,
is_torch_available
,
...
@@ -17,10 +20,7 @@ if is_vision_available():
...
@@ -17,10 +20,7 @@ if is_vision_available():
from
..image_utils
import
load_image
from
..image_utils
import
load_image
if
is_tf_available
():
if
is_tf_available
():
import
tensorflow
as
tf
from
..models.auto.modeling_tf_auto
import
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
from
..models.auto.modeling_tf_auto
import
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
from
..tf_utils
import
stable_softmax
if
is_torch_available
():
if
is_torch_available
():
from
..models.auto.modeling_auto
import
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
from
..models.auto.modeling_auto
import
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
...
@@ -28,7 +28,38 @@ if is_torch_available():
...
@@ -28,7 +28,38 @@ if is_torch_available():
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
@
add_end_docstrings
(
PIPELINE_INIT_ARGS
)
# Copied from transformers.pipelines.text_classification.sigmoid
def
sigmoid
(
_outputs
):
return
1.0
/
(
1.0
+
np
.
exp
(
-
_outputs
))
# Copied from transformers.pipelines.text_classification.softmax
def
softmax
(
_outputs
):
maxes
=
np
.
max
(
_outputs
,
axis
=-
1
,
keepdims
=
True
)
shifted_exp
=
np
.
exp
(
_outputs
-
maxes
)
return
shifted_exp
/
shifted_exp
.
sum
(
axis
=-
1
,
keepdims
=
True
)
# Copied from transformers.pipelines.text_classification.ClassificationFunction
class
ClassificationFunction
(
ExplicitEnum
):
SIGMOID
=
"sigmoid"
SOFTMAX
=
"softmax"
NONE
=
"none"
@
add_end_docstrings
(
PIPELINE_INIT_ARGS
,
r
"""
function_to_apply (`str`, *optional*, defaults to `"default"`):
The function to apply to the model outputs in order to retrieve the scores. Accepts four different values:
- `"default"`: if the model has a single label, will apply the sigmoid function on the output. If the model
has several labels, will apply the softmax function on the output.
- `"sigmoid"`: Applies the sigmoid function on the output.
- `"softmax"`: Applies the softmax function on the output.
- `"none"`: Does not apply any function on the output.
"""
,
)
class
ImageClassificationPipeline
(
Pipeline
):
class
ImageClassificationPipeline
(
Pipeline
):
"""
"""
Image classification pipeline using any `AutoModelForImageClassification`. This pipeline predicts the class of an
Image classification pipeline using any `AutoModelForImageClassification`. This pipeline predicts the class of an
...
@@ -53,6 +84,8 @@ class ImageClassificationPipeline(Pipeline):
...
@@ -53,6 +84,8 @@ class ImageClassificationPipeline(Pipeline):
[huggingface.co/models](https://huggingface.co/models?filter=image-classification).
[huggingface.co/models](https://huggingface.co/models?filter=image-classification).
"""
"""
function_to_apply
:
ClassificationFunction
=
ClassificationFunction
.
NONE
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
requires_backends
(
self
,
"vision"
)
requires_backends
(
self
,
"vision"
)
...
@@ -62,13 +95,17 @@ class ImageClassificationPipeline(Pipeline):
...
@@ -62,13 +95,17 @@ class ImageClassificationPipeline(Pipeline):
else
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
else
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
)
def
_sanitize_parameters
(
self
,
top_k
=
None
,
timeout
=
None
):
def
_sanitize_parameters
(
self
,
top_k
=
None
,
function_to_apply
=
None
,
timeout
=
None
):
preprocess_params
=
{}
preprocess_params
=
{}
if
timeout
is
not
None
:
if
timeout
is
not
None
:
preprocess_params
[
"timeout"
]
=
timeout
preprocess_params
[
"timeout"
]
=
timeout
postprocess_params
=
{}
postprocess_params
=
{}
if
top_k
is
not
None
:
if
top_k
is
not
None
:
postprocess_params
[
"top_k"
]
=
top_k
postprocess_params
[
"top_k"
]
=
top_k
if
isinstance
(
function_to_apply
,
str
):
function_to_apply
=
ClassificationFunction
(
function_to_apply
.
lower
())
if
function_to_apply
is
not
None
:
postprocess_params
[
"function_to_apply"
]
=
function_to_apply
return
preprocess_params
,
{},
postprocess_params
return
preprocess_params
,
{},
postprocess_params
def
__call__
(
self
,
images
:
Union
[
str
,
List
[
str
],
"Image.Image"
,
List
[
"Image.Image"
]],
**
kwargs
):
def
__call__
(
self
,
images
:
Union
[
str
,
List
[
str
],
"Image.Image"
,
List
[
"Image.Image"
]],
**
kwargs
):
...
@@ -86,6 +123,21 @@ class ImageClassificationPipeline(Pipeline):
...
@@ -86,6 +123,21 @@ class ImageClassificationPipeline(Pipeline):
The pipeline accepts either a single image or a batch of images, which must then be passed as a string.
The pipeline accepts either a single image or a batch of images, which must then be passed as a string.
Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
images.
images.
function_to_apply (`str`, *optional*, defaults to `"default"`):
The function to apply to the model outputs in order to retrieve the scores. Accepts four different
values:
If this argument is not specified, then it will apply the following functions according to the number
of labels:
- If the model has a single label, will apply the sigmoid function on the output.
- If the model has several labels, will apply the softmax function on the output.
Possible values are:
- `"sigmoid"`: Applies the sigmoid function on the output.
- `"softmax"`: Applies the softmax function on the output.
- `"none"`: Does not apply any function on the output.
top_k (`int`, *optional*, defaults to 5):
top_k (`int`, *optional*, defaults to 5):
The number of top labels that will be returned by the pipeline. If the provided number is higher than
The number of top labels that will be returned by the pipeline. If the provided number is higher than
the number of labels available in the model configuration, it will default to the number of labels.
the number of labels available in the model configuration, it will default to the number of labels.
...
@@ -114,20 +166,37 @@ class ImageClassificationPipeline(Pipeline):
...
@@ -114,20 +166,37 @@ class ImageClassificationPipeline(Pipeline):
model_outputs
=
self
.
model
(
**
model_inputs
)
model_outputs
=
self
.
model
(
**
model_inputs
)
return
model_outputs
return
model_outputs
def
postprocess
(
self
,
model_outputs
,
top_k
=
5
):
def
postprocess
(
self
,
model_outputs
,
function_to_apply
=
None
,
top_k
=
5
):
if
function_to_apply
is
None
:
if
self
.
model
.
config
.
problem_type
==
"multi_label_classification"
or
self
.
model
.
config
.
num_labels
==
1
:
function_to_apply
=
ClassificationFunction
.
SIGMOID
elif
self
.
model
.
config
.
problem_type
==
"single_label_classification"
or
self
.
model
.
config
.
num_labels
>
1
:
function_to_apply
=
ClassificationFunction
.
SOFTMAX
elif
hasattr
(
self
.
model
.
config
,
"function_to_apply"
)
and
function_to_apply
is
None
:
function_to_apply
=
self
.
model
.
config
.
function_to_apply
else
:
function_to_apply
=
ClassificationFunction
.
NONE
if
top_k
>
self
.
model
.
config
.
num_labels
:
if
top_k
>
self
.
model
.
config
.
num_labels
:
top_k
=
self
.
model
.
config
.
num_labels
top_k
=
self
.
model
.
config
.
num_labels
if
self
.
framework
==
"pt"
:
outputs
=
model_outputs
[
"logits"
][
0
]
probs
=
model_outputs
.
logits
.
softmax
(
-
1
)[
0
]
outputs
=
outputs
.
numpy
()
scores
,
ids
=
probs
.
topk
(
top_k
)
elif
self
.
framework
==
"tf"
:
if
function_to_apply
==
ClassificationFunction
.
SIGMOID
:
probs
=
stable_softmax
(
model_outputs
.
logits
,
axis
=-
1
)[
0
]
scores
=
sigmoid
(
outputs
)
topk
=
tf
.
math
.
top_k
(
probs
,
k
=
top_k
)
elif
function_to_apply
==
ClassificationFunction
.
SOFTMAX
:
scores
,
ids
=
topk
.
values
.
numpy
(),
topk
.
indices
.
numpy
()
scores
=
softmax
(
outputs
)
elif
function_to_apply
==
ClassificationFunction
.
NONE
:
scores
=
outputs
else
:
else
:
raise
ValueError
(
f
"Unsupported framework:
{
self
.
framework
}
"
)
raise
ValueError
(
f
"Unrecognized `function_to_apply` argument:
{
function_to_apply
}
"
)
dict_scores
=
[
{
"label"
:
self
.
model
.
config
.
id2label
[
i
],
"score"
:
score
.
item
()}
for
i
,
score
in
enumerate
(
scores
)
]
dict_scores
.
sort
(
key
=
lambda
x
:
x
[
"score"
],
reverse
=
True
)
if
top_k
is
not
None
:
dict_scores
=
dict_scores
[:
top_k
]
scores
=
scores
.
tolist
()
return
dict_scores
ids
=
ids
.
tolist
()
return
[{
"score"
:
score
,
"label"
:
self
.
model
.
config
.
id2label
[
_id
]}
for
score
,
_id
in
zip
(
scores
,
ids
)]
tests/pipelines/test_pipelines_image_classification.py
View file @
66964c00
...
@@ -221,3 +221,49 @@ class ImageClassificationPipelineTests(unittest.TestCase):
...
@@ -221,3 +221,49 @@ class ImageClassificationPipelineTests(unittest.TestCase):
{
"score"
:
0.0096
,
"label"
:
"quilt, comforter, comfort, puff"
},
{
"score"
:
0.0096
,
"label"
:
"quilt, comforter, comfort, puff"
},
],
],
)
)
@
slow
@
require_torch
def
test_multilabel_classification
(
self
):
small_model
=
"hf-internal-testing/tiny-random-vit"
# Sigmoid is applied for multi-label classification
image_classifier
=
pipeline
(
"image-classification"
,
model
=
small_model
)
image_classifier
.
model
.
config
.
problem_type
=
"multi_label_classification"
outputs
=
image_classifier
(
"http://images.cocodataset.org/val2017/000000039769.jpg"
)
self
.
assertEqual
(
nested_simplify
(
outputs
,
decimals
=
4
),
[{
"label"
:
"LABEL_1"
,
"score"
:
0.5356
},
{
"label"
:
"LABEL_0"
,
"score"
:
0.4612
}],
)
outputs
=
image_classifier
(
[
"http://images.cocodataset.org/val2017/000000039769.jpg"
,
"http://images.cocodataset.org/val2017/000000039769.jpg"
,
]
)
self
.
assertEqual
(
nested_simplify
(
outputs
,
decimals
=
4
),
[
[{
"label"
:
"LABEL_1"
,
"score"
:
0.5356
},
{
"label"
:
"LABEL_0"
,
"score"
:
0.4612
}],
[{
"label"
:
"LABEL_1"
,
"score"
:
0.5356
},
{
"label"
:
"LABEL_0"
,
"score"
:
0.4612
}],
],
)
@
slow
@
require_torch
def
test_function_to_apply
(
self
):
small_model
=
"hf-internal-testing/tiny-random-vit"
# Sigmoid is applied for multi-label classification
image_classifier
=
pipeline
(
"image-classification"
,
model
=
small_model
)
outputs
=
image_classifier
(
"http://images.cocodataset.org/val2017/000000039769.jpg"
,
function_to_apply
=
"sigmoid"
,
)
self
.
assertEqual
(
nested_simplify
(
outputs
,
decimals
=
4
),
[{
"label"
:
"LABEL_1"
,
"score"
:
0.5356
},
{
"label"
:
"LABEL_0"
,
"score"
:
0.4612
}],
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment