Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5a06118b
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e20d8895bdc926babc45e6bfa7ec9047b012aa77"
Unverified
Commit
5a06118b
authored
Jan 06, 2022
by
Nicolas Patry
Committed by
GitHub
Jan 06, 2022
Browse files
Enabling `TF` on `image-classification` pipeline. (#15030)
parent
9f89fa02
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
70 additions
and
13 deletions
+70
-13
src/transformers/pipelines/image_classification.py
src/transformers/pipelines/image_classification.py
+28
-9
tests/test_pipelines_image_classification.py
tests/test_pipelines_image_classification.py
+42
-4
No files found.
src/transformers/pipelines/image_classification.py
View file @
5a06118b
from
typing
import
List
,
Union
from
typing
import
List
,
Union
from
..file_utils
import
add_end_docstrings
,
is_torch_available
,
is_vision_available
,
requires_backends
from
..file_utils
import
(
add_end_docstrings
,
is_tf_available
,
is_torch_available
,
is_vision_available
,
requires_backends
,
)
from
..utils
import
logging
from
..utils
import
logging
from
.base
import
PIPELINE_INIT_ARGS
,
Pipeline
from
.base
import
PIPELINE_INIT_ARGS
,
Pipeline
...
@@ -10,6 +16,11 @@ if is_vision_available():
...
@@ -10,6 +16,11 @@ if is_vision_available():
from
..image_utils
import
load_image
from
..image_utils
import
load_image
if
is_tf_available
():
import
tensorflow
as
tf
from
..models.auto.modeling_tf_auto
import
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
if
is_torch_available
():
if
is_torch_available
():
from
..models.auto.modeling_auto
import
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
from
..models.auto.modeling_auto
import
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
...
@@ -31,12 +42,12 @@ class ImageClassificationPipeline(Pipeline):
...
@@ -31,12 +42,12 @@ class ImageClassificationPipeline(Pipeline):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
if
self
.
framework
==
"tf"
:
raise
ValueError
(
f
"The
{
self
.
__class__
}
is only available in PyTorch."
)
requires_backends
(
self
,
"vision"
)
requires_backends
(
self
,
"vision"
)
self
.
check_model_type
(
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
)
self
.
check_model_type
(
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
if
self
.
framework
==
"tf"
else
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
)
def
_sanitize_parameters
(
self
,
top_k
=
None
):
def
_sanitize_parameters
(
self
,
top_k
=
None
):
postprocess_params
=
{}
postprocess_params
=
{}
...
@@ -77,7 +88,7 @@ class ImageClassificationPipeline(Pipeline):
...
@@ -77,7 +88,7 @@ class ImageClassificationPipeline(Pipeline):
def
preprocess
(
self
,
image
):
def
preprocess
(
self
,
image
):
image
=
load_image
(
image
)
image
=
load_image
(
image
)
model_inputs
=
self
.
feature_extractor
(
images
=
image
,
return_tensors
=
"pt"
)
model_inputs
=
self
.
feature_extractor
(
images
=
image
,
return_tensors
=
self
.
framework
)
return
model_inputs
return
model_inputs
def
_forward
(
self
,
model_inputs
):
def
_forward
(
self
,
model_inputs
):
...
@@ -87,8 +98,16 @@ class ImageClassificationPipeline(Pipeline):
...
@@ -87,8 +98,16 @@ class ImageClassificationPipeline(Pipeline):
def
postprocess
(
self
,
model_outputs
,
top_k
=
5
):
def
postprocess
(
self
,
model_outputs
,
top_k
=
5
):
if
top_k
>
self
.
model
.
config
.
num_labels
:
if
top_k
>
self
.
model
.
config
.
num_labels
:
top_k
=
self
.
model
.
config
.
num_labels
top_k
=
self
.
model
.
config
.
num_labels
probs
=
model_outputs
.
logits
.
softmax
(
-
1
)[
0
]
scores
,
ids
=
probs
.
topk
(
top_k
)
if
self
.
framework
==
"pt"
:
probs
=
model_outputs
.
logits
.
softmax
(
-
1
)[
0
]
scores
,
ids
=
probs
.
topk
(
top_k
)
elif
self
.
framework
==
"tf"
:
probs
=
tf
.
nn
.
softmax
(
model_outputs
.
logits
,
axis
=-
1
)[
0
]
topk
=
tf
.
math
.
top_k
(
probs
,
k
=
top_k
)
scores
,
ids
=
topk
.
values
.
numpy
(),
topk
.
indices
.
numpy
()
else
:
raise
ValueError
(
f
"Unsupported framework:
{
self
.
framework
}
"
)
scores
=
scores
.
tolist
()
scores
=
scores
.
tolist
()
ids
=
ids
.
tolist
()
ids
=
ids
.
tolist
()
...
...
tests/test_pipelines_image_classification.py
View file @
5a06118b
...
@@ -14,7 +14,12 @@
...
@@ -14,7 +14,12 @@
import
unittest
import
unittest
from
transformers
import
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
,
PreTrainedTokenizer
,
is_vision_available
from
transformers
import
(
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
,
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
,
PreTrainedTokenizer
,
is_vision_available
,
)
from
transformers.pipelines
import
ImageClassificationPipeline
,
pipeline
from
transformers.pipelines
import
ImageClassificationPipeline
,
pipeline
from
transformers.testing_utils
import
(
from
transformers.testing_utils
import
(
is_pipeline_test
,
is_pipeline_test
,
...
@@ -40,9 +45,9 @@ else:
...
@@ -40,9 +45,9 @@ else:
@
is_pipeline_test
@
is_pipeline_test
@
require_vision
@
require_vision
@
require_torch
class
ImageClassificationPipelineTests
(
unittest
.
TestCase
,
metaclass
=
PipelineTestCaseMeta
):
class
ImageClassificationPipelineTests
(
unittest
.
TestCase
,
metaclass
=
PipelineTestCaseMeta
):
model_mapping
=
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
model_mapping
=
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
tf_model_mapping
=
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
def
get_test_pipeline
(
self
,
model
,
tokenizer
,
feature_extractor
):
def
get_test_pipeline
(
self
,
model
,
tokenizer
,
feature_extractor
):
image_classifier
=
ImageClassificationPipeline
(
model
=
model
,
feature_extractor
=
feature_extractor
,
top_k
=
2
)
image_classifier
=
ImageClassificationPipeline
(
model
=
model
,
feature_extractor
=
feature_extractor
,
top_k
=
2
)
...
@@ -145,9 +150,42 @@ class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
...
@@ -145,9 +150,42 @@ class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
)
)
@
require_tf
@
require_tf
@
unittest
.
skip
(
"Image classification is not implemented for TF"
)
def
test_small_model_tf
(
self
):
def
test_small_model_tf
(
self
):
pass
small_model
=
"lysandre/tiny-vit-random"
image_classifier
=
pipeline
(
"image-classification"
,
model
=
small_model
)
outputs
=
image_classifier
(
"http://images.cocodataset.org/val2017/000000039769.jpg"
)
self
.
assertEqual
(
nested_simplify
(
outputs
,
decimals
=
4
),
[
{
"score"
:
0.0015
,
"label"
:
"chambered nautilus, pearly nautilus, nautilus"
},
{
"score"
:
0.0015
,
"label"
:
"pajama, pyjama, pj's, jammies"
},
{
"score"
:
0.0014
,
"label"
:
"trench coat"
},
{
"score"
:
0.0014
,
"label"
:
"handkerchief, hankie, hanky, hankey"
},
{
"score"
:
0.0014
,
"label"
:
"baboon"
},
],
)
outputs
=
image_classifier
(
[
"http://images.cocodataset.org/val2017/000000039769.jpg"
,
"http://images.cocodataset.org/val2017/000000039769.jpg"
,
],
top_k
=
2
,
)
self
.
assertEqual
(
nested_simplify
(
outputs
,
decimals
=
4
),
[
[
{
"score"
:
0.0015
,
"label"
:
"chambered nautilus, pearly nautilus, nautilus"
},
{
"score"
:
0.0015
,
"label"
:
"pajama, pyjama, pj's, jammies"
},
],
[
{
"score"
:
0.0015
,
"label"
:
"chambered nautilus, pearly nautilus, nautilus"
},
{
"score"
:
0.0015
,
"label"
:
"pajama, pyjama, pj's, jammies"
},
],
],
)
def
test_custom_tokenizer
(
self
):
def
test_custom_tokenizer
(
self
):
tokenizer
=
PreTrainedTokenizer
()
tokenizer
=
PreTrainedTokenizer
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment