Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b56848c8
"vscode:/vscode.git/clone" did not exist on "aee11fe427b2f2fd66c3ef3cd91757ec00420ac9"
Unverified
Commit
b56848c8
authored
Jun 17, 2021
by
Lysandre Debut
Committed by
GitHub
Jun 17, 2021
Browse files
Pipeline update & tests (#12207)
parent
700cee34
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
36 additions
and
1 deletion
+36
-1
src/transformers/pipelines/image_classification.py
src/transformers/pipelines/image_classification.py
+5
-1
tests/test_pipelines_image_classification.py
tests/test_pipelines_image_classification.py
+31
-0
No files found.
src/transformers/pipelines/image_classification.py
View file @
b56848c8
...
...
@@ -87,7 +87,8 @@ class ImageClassificationPipeline(Pipeline):
Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
images.
top_k (:obj:`int`, `optional`, defaults to 5):
The number of top labels that will be returned by the pipeline.
The number of top labels that will be returned by the pipeline. If the provided number is higher than
the number of labels available in the model configuration, it will default to the number of labels.
Return:
A dictionary or a list of dictionaries containing result. If the input is a single image, will return a
...
...
@@ -106,6 +107,9 @@ class ImageClassificationPipeline(Pipeline):
images
=
[
self
.
load_image
(
image
)
for
image
in
images
]
if
top_k
>
self
.
model
.
config
.
num_labels
:
top_k
=
self
.
model
.
config
.
num_labels
with
torch
.
no_grad
():
inputs
=
self
.
feature_extractor
(
images
=
images
,
return_tensors
=
"pt"
)
outputs
=
self
.
model
(
**
inputs
)
...
...
tests/test_pipelines_image_classification.py
View file @
b56848c8
...
...
@@ -15,6 +15,7 @@
import
unittest
from
transformers
import
(
AutoConfig
,
AutoFeatureExtractor
,
AutoModelForImageClassification
,
PreTrainedTokenizer
,
...
...
@@ -128,3 +129,33 @@ class ImageClassificationPipelineTests(unittest.TestCase):
image_classifier
=
pipeline
(
"image-classification"
,
model
=
self
.
small_models
[
0
],
tokenizer
=
tokenizer
)
self
.
assertIs
(
image_classifier
.
tokenizer
,
tokenizer
)
def
test_num_labels_inferior_to_topk
(
self
):
for
small_model
in
self
.
small_models
:
num_labels
=
2
model
=
AutoModelForImageClassification
.
from_config
(
AutoConfig
.
from_pretrained
(
small_model
,
num_labels
=
num_labels
)
)
feature_extractor
=
AutoFeatureExtractor
.
from_pretrained
(
small_model
)
image_classifier
=
ImageClassificationPipeline
(
model
=
model
,
feature_extractor
=
feature_extractor
)
for
valid_input
in
self
.
valid_inputs
:
output
=
image_classifier
(
**
valid_input
)
def
assert_valid_pipeline_output
(
pipeline_output
):
self
.
assertTrue
(
isinstance
(
pipeline_output
,
list
))
self
.
assertEqual
(
len
(
pipeline_output
),
num_labels
)
for
label_result
in
pipeline_output
:
self
.
assertTrue
(
isinstance
(
label_result
,
dict
))
self
.
assertIn
(
"label"
,
label_result
)
self
.
assertIn
(
"score"
,
label_result
)
if
isinstance
(
valid_input
[
"images"
],
list
):
# When images are batched, pipeline output is a list of lists of dictionaries
self
.
assertEqual
(
len
(
valid_input
[
"images"
]),
len
(
output
))
for
individual_output
in
output
:
assert_valid_pipeline_output
(
individual_output
)
else
:
# When images are batched, pipeline output is a list of dictionaries
assert_valid_pipeline_output
(
output
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment