Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2cb2ea3f
Unverified
Commit
2cb2ea3f
authored
May 18, 2022
by
Nicolas Patry
Committed by
GitHub
May 18, 2022
Browse files
Accepting real pytorch device as arguments. (#17318)
* Accepting real pytorch device as arguments. * is_torch_available.
parent
1c9d1f4c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
2 deletions
+19
-2
src/transformers/pipelines/base.py
src/transformers/pipelines/base.py
+5
-2
tests/pipelines/test_pipelines_text_classification.py
tests/pipelines/test_pipelines_text_classification.py
+14
-0
No files found.
src/transformers/pipelines/base.py
View file @
2cb2ea3f
...
@@ -693,7 +693,7 @@ PIPELINE_INIT_ARGS = r"""
...
@@ -693,7 +693,7 @@ PIPELINE_INIT_ARGS = r"""
Reference to the object in charge of parsing supplied pipeline parameters.
Reference to the object in charge of parsing supplied pipeline parameters.
device (`int`, *optional*, defaults to -1):
device (`int`, *optional*, defaults to -1):
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model on
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model on
the associated CUDA device id.
the associated CUDA device id.
You can pass native `torch.device` too.
binary_output (`bool`, *optional*, defaults to `False`):
binary_output (`bool`, *optional*, defaults to `False`):
Flag indicating if the output the pipeline should happen in a binary format (i.e., pickle) or as raw text.
Flag indicating if the output the pipeline should happen in a binary format (i.e., pickle) or as raw text.
"""
"""
...
@@ -750,7 +750,10 @@ class Pipeline(_ScikitCompat):
...
@@ -750,7 +750,10 @@ class Pipeline(_ScikitCompat):
self
.
feature_extractor
=
feature_extractor
self
.
feature_extractor
=
feature_extractor
self
.
modelcard
=
modelcard
self
.
modelcard
=
modelcard
self
.
framework
=
framework
self
.
framework
=
framework
self
.
device
=
device
if
framework
==
"tf"
else
torch
.
device
(
"cpu"
if
device
<
0
else
f
"cuda:
{
device
}
"
)
if
is_torch_available
()
and
isinstance
(
device
,
torch
.
device
):
self
.
device
=
device
else
:
self
.
device
=
device
if
framework
==
"tf"
else
torch
.
device
(
"cpu"
if
device
<
0
else
f
"cuda:
{
device
}
"
)
self
.
binary_output
=
binary_output
self
.
binary_output
=
binary_output
# Special handling
# Special handling
...
...
tests/pipelines/test_pipelines_text_classification.py
View file @
2cb2ea3f
...
@@ -39,6 +39,20 @@ class TextClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestC
...
@@ -39,6 +39,20 @@ class TextClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestC
outputs
=
text_classifier
(
"This is great !"
)
outputs
=
text_classifier
(
"This is great !"
)
self
.
assertEqual
(
nested_simplify
(
outputs
),
[{
"label"
:
"LABEL_0"
,
"score"
:
0.504
}])
self
.
assertEqual
(
nested_simplify
(
outputs
),
[{
"label"
:
"LABEL_0"
,
"score"
:
0.504
}])
@
require_torch
def
test_accepts_torch_device
(
self
):
import
torch
text_classifier
=
pipeline
(
task
=
"text-classification"
,
model
=
"hf-internal-testing/tiny-random-distilbert"
,
framework
=
"pt"
,
device
=
torch
.
device
(
"cpu"
),
)
outputs
=
text_classifier
(
"This is great !"
)
self
.
assertEqual
(
nested_simplify
(
outputs
),
[{
"label"
:
"LABEL_0"
,
"score"
:
0.504
}])
@
require_tf
@
require_tf
def
test_small_model_tf
(
self
):
def
test_small_model_tf
(
self
):
text_classifier
=
pipeline
(
text_classifier
=
pipeline
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment