Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9c391277
"...googletest/xcode/Config/ReleaseProject.xcconfig" did not exist on "e3f120b99de7bad9801b51c7e1fffea82d3c4f41"
Commit
9c391277
authored
Dec 16, 2019
by
Morgan Funtowicz
Browse files
Allow tensors placement on specific device through CLI and pipeline.
parent
955d7ecb
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
54 additions
and
26 deletions
+54
-26
transformers/commands/run.py
transformers/commands/run.py
+2
-1
transformers/pipelines.py
transformers/pipelines.py
+52
-25
No files found.
transformers/commands/run.py
View file @
9c391277
...
@@ -16,7 +16,7 @@ def try_infer_format_from_ext(path: str):
...
@@ -16,7 +16,7 @@ def try_infer_format_from_ext(path: str):
def
run_command_factory
(
args
):
def
run_command_factory
(
args
):
nlp
=
pipeline
(
task
=
args
.
task
,
model
=
args
.
model
,
tokenizer
=
args
.
tokenizer
)
nlp
=
pipeline
(
task
=
args
.
task
,
model
=
args
.
model
,
tokenizer
=
args
.
tokenizer
,
device
=
args
.
device
)
format
=
try_infer_format_from_ext
(
args
.
input
)
if
args
.
format
==
'infer'
else
args
.
format
format
=
try_infer_format_from_ext
(
args
.
input
)
if
args
.
format
==
'infer'
else
args
.
format
reader
=
PipelineDataFormat
.
from_str
(
format
,
args
.
output
,
args
.
input
,
args
.
column
)
reader
=
PipelineDataFormat
.
from_str
(
format
,
args
.
output
,
args
.
input
,
args
.
column
)
return
RunCommand
(
nlp
,
reader
)
return
RunCommand
(
nlp
,
reader
)
...
@@ -31,6 +31,7 @@ class RunCommand(BaseTransformersCLICommand):
...
@@ -31,6 +31,7 @@ class RunCommand(BaseTransformersCLICommand):
@
staticmethod
@
staticmethod
def
register_subcommand
(
parser
:
ArgumentParser
):
def
register_subcommand
(
parser
:
ArgumentParser
):
run_parser
=
parser
.
add_parser
(
'run'
,
help
=
"Run a pipeline through the CLI"
)
run_parser
=
parser
.
add_parser
(
'run'
,
help
=
"Run a pipeline through the CLI"
)
run_parser
.
add_argument
(
'--device'
,
type
=
int
,
default
=-
1
,
help
=
'Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU'
)
run_parser
.
add_argument
(
'--task'
,
choices
=
SUPPORTED_TASKS
.
keys
(),
help
=
'Task to run'
)
run_parser
.
add_argument
(
'--task'
,
choices
=
SUPPORTED_TASKS
.
keys
(),
help
=
'Task to run'
)
run_parser
.
add_argument
(
'--model'
,
type
=
str
,
required
=
True
,
help
=
'Name or path to the model to instantiate.'
)
run_parser
.
add_argument
(
'--model'
,
type
=
str
,
required
=
True
,
help
=
'Name or path to the model to instantiate.'
)
run_parser
.
add_argument
(
'--tokenizer'
,
type
=
str
,
help
=
'Name of the tokenizer to use. (default: same as the model name)'
)
run_parser
.
add_argument
(
'--tokenizer'
,
type
=
str
,
help
=
'Name of the tokenizer to use. (default: same as the model name)'
)
...
...
transformers/pipelines.py
View file @
9c391277
...
@@ -18,6 +18,7 @@ import csv
...
@@ -18,6 +18,7 @@ import csv
import
json
import
json
import
os
import
os
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
contextmanager
from
itertools
import
groupby
from
itertools
import
groupby
from
typing
import
Union
,
Optional
,
Tuple
,
List
,
Dict
from
typing
import
Union
,
Optional
,
Tuple
,
List
,
Dict
...
@@ -152,11 +153,18 @@ class JsonPipelineDataFormat(PipelineDataFormat):
...
@@ -152,11 +153,18 @@ class JsonPipelineDataFormat(PipelineDataFormat):
class
Pipeline
(
_ScikitCompat
):
class
Pipeline
(
_ScikitCompat
):
def
__init__
(
self
,
model
,
tokenizer
:
PreTrainedTokenizer
=
None
,
args_parser
:
ArgumentHandler
=
None
,
**
kwargs
):
def
__init__
(
self
,
model
,
tokenizer
:
PreTrainedTokenizer
=
None
,
args_parser
:
ArgumentHandler
=
None
,
device
:
int
=
-
1
,
**
kwargs
):
self
.
model
=
model
self
.
model
=
model
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
self
.
device
=
device
self
.
_args_parser
=
args_parser
or
DefaultArgumentHandler
()
self
.
_args_parser
=
args_parser
or
DefaultArgumentHandler
()
# Special handling
if
self
.
device
>=
0
and
not
is_tf_available
():
self
.
model
=
self
.
model
.
to
(
'cuda:{}'
.
format
(
self
.
device
))
def
save_pretrained
(
self
,
save_directory
):
def
save_pretrained
(
self
,
save_directory
):
if
not
os
.
path
.
isdir
(
save_directory
):
if
not
os
.
path
.
isdir
(
save_directory
):
logger
.
error
(
"Provided path ({}) should be a directory"
.
format
(
save_directory
))
logger
.
error
(
"Provided path ({}) should be a directory"
.
format
(
save_directory
))
...
@@ -176,12 +184,26 @@ class Pipeline(_ScikitCompat):
...
@@ -176,12 +184,26 @@ class Pipeline(_ScikitCompat):
inputs
=
self
.
_args_parser
(
*
texts
,
**
kwargs
)
inputs
=
self
.
_args_parser
(
*
texts
,
**
kwargs
)
# Encode for forward
# Encode for forward
with
self
.
device_placement
():
inputs
=
self
.
tokenizer
.
batch_encode_plus
(
inputs
=
self
.
tokenizer
.
batch_encode_plus
(
inputs
,
add_special_tokens
=
True
,
return_tensors
=
'tf'
if
is_tf_available
()
else
'pt'
inputs
,
add_special_tokens
=
True
,
return_tensors
=
'tf'
if
is_tf_available
()
else
'pt'
)
)
return
self
.
_forward
(
inputs
)
return
self
.
_forward
(
inputs
)
@
contextmanager
def
device_placement
(
self
):
if
is_tf_available
():
import
tensorflow
as
tf
with
tf
.
device
(
'/CPU:0'
if
self
.
device
==
-
1
else
'/device:GPU:{}'
.
format
(
self
.
device
)):
yield
else
:
import
torch
if
self
.
device
>=
0
:
torch
.
cuda
.
set_device
(
self
.
device
)
yield
def
_forward
(
self
,
inputs
):
def
_forward
(
self
,
inputs
):
if
is_tf_available
():
if
is_tf_available
():
# TODO trace model
# TODO trace model
...
@@ -225,6 +247,9 @@ class NerPipeline(Pipeline):
...
@@ -225,6 +247,9 @@ class NerPipeline(Pipeline):
for
i
,
w
in
enumerate
(
words
):
for
i
,
w
in
enumerate
(
words
):
tokens
=
self
.
tokenizer
.
tokenize
(
w
)
tokens
=
self
.
tokenizer
.
tokenize
(
w
)
token_to_word
+=
[
i
]
*
len
(
tokens
)
token_to_word
+=
[
i
]
*
len
(
tokens
)
# Manage correct placement of the tensors
with
self
.
device_placement
():
tokens
=
self
.
tokenizer
.
encode_plus
(
sentence
,
return_attention_mask
=
False
,
return_tensors
=
'tf'
if
is_tf_available
()
else
'pt'
)
tokens
=
self
.
tokenizer
.
encode_plus
(
sentence
,
return_attention_mask
=
False
,
return_tensors
=
'tf'
if
is_tf_available
()
else
'pt'
)
# Forward
# Forward
...
@@ -352,6 +377,8 @@ class QuestionAnsweringPipeline(Pipeline):
...
@@ -352,6 +377,8 @@ class QuestionAnsweringPipeline(Pipeline):
features
=
squad_convert_examples_to_features
(
examples
,
self
.
tokenizer
,
kwargs
[
'max_seq_len'
],
kwargs
[
'doc_stride'
],
kwargs
[
'max_question_len'
],
False
)
features
=
squad_convert_examples_to_features
(
examples
,
self
.
tokenizer
,
kwargs
[
'max_seq_len'
],
kwargs
[
'doc_stride'
],
kwargs
[
'max_question_len'
],
False
)
fw_args
=
self
.
inputs_for_model
(
features
)
fw_args
=
self
.
inputs_for_model
(
features
)
# Manage tensor allocation on correct device
with
self
.
device_placement
():
if
is_tf_available
():
if
is_tf_available
():
import
tensorflow
as
tf
import
tensorflow
as
tf
fw_args
=
{
k
:
tf
.
constant
(
v
)
for
(
k
,
v
)
in
fw_args
.
items
()}
fw_args
=
{
k
:
tf
.
constant
(
v
)
for
(
k
,
v
)
in
fw_args
.
items
()}
...
@@ -374,7 +401,7 @@ class QuestionAnsweringPipeline(Pipeline):
...
@@ -374,7 +401,7 @@ class QuestionAnsweringPipeline(Pipeline):
# Mask padding and question
# Mask padding and question
start_
,
end_
=
start_
*
np
.
abs
(
np
.
array
(
feature
.
p_mask
)
-
1
),
end_
*
np
.
abs
(
np
.
array
(
feature
.
p_mask
)
-
1
)
start_
,
end_
=
start_
*
np
.
abs
(
np
.
array
(
feature
.
p_mask
)
-
1
),
end_
*
np
.
abs
(
np
.
array
(
feature
.
p_mask
)
-
1
)
# TODO : What happen
d
if not possible
# TODO : What happen
s
if not possible
# Mask CLS
# Mask CLS
start_
[
0
]
=
end_
[
0
]
=
0
start_
[
0
]
=
end_
[
0
]
=
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment