Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5a06118b
Unverified
Commit
5a06118b
authored
Jan 06, 2022
by
Nicolas Patry
Committed by
GitHub
Jan 06, 2022
Browse files
Enabling `TF` on `image-classification` pipeline. (#15030)
parent
9f89fa02
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
70 additions
and
13 deletions
+70
-13
src/transformers/pipelines/image_classification.py
src/transformers/pipelines/image_classification.py
+28
-9
tests/test_pipelines_image_classification.py
tests/test_pipelines_image_classification.py
+42
-4
No files found.
src/transformers/pipelines/image_classification.py
View file @
5a06118b
from
typing
import
List
,
Union
from
..file_utils
import
add_end_docstrings
,
is_torch_available
,
is_vision_available
,
requires_backends
from
..file_utils
import
(
add_end_docstrings
,
is_tf_available
,
is_torch_available
,
is_vision_available
,
requires_backends
,
)
from
..utils
import
logging
from
.base
import
PIPELINE_INIT_ARGS
,
Pipeline
...
...
@@ -10,6 +16,11 @@ if is_vision_available():
from
..image_utils
import
load_image
if
is_tf_available
():
import
tensorflow
as
tf
from
..models.auto.modeling_tf_auto
import
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
if
is_torch_available
():
from
..models.auto.modeling_auto
import
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
...
...
@@ -31,12 +42,12 @@ class ImageClassificationPipeline(Pipeline):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
if
self
.
framework
==
"tf"
:
raise
ValueError
(
f
"The
{
self
.
__class__
}
is only available in PyTorch."
)
requires_backends
(
self
,
"vision"
)
self
.
check_model_type
(
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
)
self
.
check_model_type
(
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
if
self
.
framework
==
"tf"
else
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
)
def
_sanitize_parameters
(
self
,
top_k
=
None
):
postprocess_params
=
{}
...
...
@@ -77,7 +88,7 @@ class ImageClassificationPipeline(Pipeline):
def
preprocess
(
self
,
image
):
image
=
load_image
(
image
)
model_inputs
=
self
.
feature_extractor
(
images
=
image
,
return_tensors
=
"pt"
)
model_inputs
=
self
.
feature_extractor
(
images
=
image
,
return_tensors
=
self
.
framework
)
return
model_inputs
def
_forward
(
self
,
model_inputs
):
...
...
@@ -87,8 +98,16 @@ class ImageClassificationPipeline(Pipeline):
def
postprocess
(
self
,
model_outputs
,
top_k
=
5
):
if
top_k
>
self
.
model
.
config
.
num_labels
:
top_k
=
self
.
model
.
config
.
num_labels
probs
=
model_outputs
.
logits
.
softmax
(
-
1
)[
0
]
scores
,
ids
=
probs
.
topk
(
top_k
)
if
self
.
framework
==
"pt"
:
probs
=
model_outputs
.
logits
.
softmax
(
-
1
)[
0
]
scores
,
ids
=
probs
.
topk
(
top_k
)
elif
self
.
framework
==
"tf"
:
probs
=
tf
.
nn
.
softmax
(
model_outputs
.
logits
,
axis
=-
1
)[
0
]
topk
=
tf
.
math
.
top_k
(
probs
,
k
=
top_k
)
scores
,
ids
=
topk
.
values
.
numpy
(),
topk
.
indices
.
numpy
()
else
:
raise
ValueError
(
f
"Unsupported framework:
{
self
.
framework
}
"
)
scores
=
scores
.
tolist
()
ids
=
ids
.
tolist
()
...
...
tests/test_pipelines_image_classification.py
View file @
5a06118b
...
...
@@ -14,7 +14,12 @@
import
unittest
from
transformers
import
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
,
PreTrainedTokenizer
,
is_vision_available
from
transformers
import
(
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
,
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
,
PreTrainedTokenizer
,
is_vision_available
,
)
from
transformers.pipelines
import
ImageClassificationPipeline
,
pipeline
from
transformers.testing_utils
import
(
is_pipeline_test
,
...
...
@@ -40,9 +45,9 @@ else:
@
is_pipeline_test
@
require_vision
@
require_torch
class
ImageClassificationPipelineTests
(
unittest
.
TestCase
,
metaclass
=
PipelineTestCaseMeta
):
model_mapping
=
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
tf_model_mapping
=
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
def
get_test_pipeline
(
self
,
model
,
tokenizer
,
feature_extractor
):
image_classifier
=
ImageClassificationPipeline
(
model
=
model
,
feature_extractor
=
feature_extractor
,
top_k
=
2
)
...
...
@@ -145,9 +150,42 @@ class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
)
@
require_tf
@
unittest
.
skip
(
"Image classification is not implemented for TF"
)
def
test_small_model_tf
(
self
):
pass
small_model
=
"lysandre/tiny-vit-random"
image_classifier
=
pipeline
(
"image-classification"
,
model
=
small_model
)
outputs
=
image_classifier
(
"http://images.cocodataset.org/val2017/000000039769.jpg"
)
self
.
assertEqual
(
nested_simplify
(
outputs
,
decimals
=
4
),
[
{
"score"
:
0.0015
,
"label"
:
"chambered nautilus, pearly nautilus, nautilus"
},
{
"score"
:
0.0015
,
"label"
:
"pajama, pyjama, pj's, jammies"
},
{
"score"
:
0.0014
,
"label"
:
"trench coat"
},
{
"score"
:
0.0014
,
"label"
:
"handkerchief, hankie, hanky, hankey"
},
{
"score"
:
0.0014
,
"label"
:
"baboon"
},
],
)
outputs
=
image_classifier
(
[
"http://images.cocodataset.org/val2017/000000039769.jpg"
,
"http://images.cocodataset.org/val2017/000000039769.jpg"
,
],
top_k
=
2
,
)
self
.
assertEqual
(
nested_simplify
(
outputs
,
decimals
=
4
),
[
[
{
"score"
:
0.0015
,
"label"
:
"chambered nautilus, pearly nautilus, nautilus"
},
{
"score"
:
0.0015
,
"label"
:
"pajama, pyjama, pj's, jammies"
},
],
[
{
"score"
:
0.0015
,
"label"
:
"chambered nautilus, pearly nautilus, nautilus"
},
{
"score"
:
0.0015
,
"label"
:
"pajama, pyjama, pj's, jammies"
},
],
],
)
def
test_custom_tokenizer
(
self
):
tokenizer
=
PreTrainedTokenizer
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment