Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9c391277
Commit
9c391277
authored
Dec 16, 2019
by
Morgan Funtowicz
Browse files
Allow tensors placement on specific device through CLI and pipeline.
parent
955d7ecb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
54 additions
and
26 deletions
+54
-26
transformers/commands/run.py
transformers/commands/run.py
+2
-1
transformers/pipelines.py
transformers/pipelines.py
+52
-25
No files found.
transformers/commands/run.py
View file @
9c391277
...
@@ -16,7 +16,7 @@ def try_infer_format_from_ext(path: str):
...
@@ -16,7 +16,7 @@ def try_infer_format_from_ext(path: str):
def
run_command_factory
(
args
):
def
run_command_factory
(
args
):
nlp
=
pipeline
(
task
=
args
.
task
,
model
=
args
.
model
,
tokenizer
=
args
.
tokenizer
)
nlp
=
pipeline
(
task
=
args
.
task
,
model
=
args
.
model
,
tokenizer
=
args
.
tokenizer
,
device
=
args
.
device
)
format
=
try_infer_format_from_ext
(
args
.
input
)
if
args
.
format
==
'infer'
else
args
.
format
format
=
try_infer_format_from_ext
(
args
.
input
)
if
args
.
format
==
'infer'
else
args
.
format
reader
=
PipelineDataFormat
.
from_str
(
format
,
args
.
output
,
args
.
input
,
args
.
column
)
reader
=
PipelineDataFormat
.
from_str
(
format
,
args
.
output
,
args
.
input
,
args
.
column
)
return
RunCommand
(
nlp
,
reader
)
return
RunCommand
(
nlp
,
reader
)
...
@@ -31,6 +31,7 @@ class RunCommand(BaseTransformersCLICommand):
...
@@ -31,6 +31,7 @@ class RunCommand(BaseTransformersCLICommand):
@
staticmethod
@
staticmethod
def
register_subcommand
(
parser
:
ArgumentParser
):
def
register_subcommand
(
parser
:
ArgumentParser
):
run_parser
=
parser
.
add_parser
(
'run'
,
help
=
"Run a pipeline through the CLI"
)
run_parser
=
parser
.
add_parser
(
'run'
,
help
=
"Run a pipeline through the CLI"
)
run_parser
.
add_argument
(
'--device'
,
type
=
int
,
default
=-
1
,
help
=
'Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU'
)
run_parser
.
add_argument
(
'--task'
,
choices
=
SUPPORTED_TASKS
.
keys
(),
help
=
'Task to run'
)
run_parser
.
add_argument
(
'--task'
,
choices
=
SUPPORTED_TASKS
.
keys
(),
help
=
'Task to run'
)
run_parser
.
add_argument
(
'--model'
,
type
=
str
,
required
=
True
,
help
=
'Name or path to the model to instantiate.'
)
run_parser
.
add_argument
(
'--model'
,
type
=
str
,
required
=
True
,
help
=
'Name or path to the model to instantiate.'
)
run_parser
.
add_argument
(
'--tokenizer'
,
type
=
str
,
help
=
'Name of the tokenizer to use. (default: same as the model name)'
)
run_parser
.
add_argument
(
'--tokenizer'
,
type
=
str
,
help
=
'Name of the tokenizer to use. (default: same as the model name)'
)
...
...
transformers/pipelines.py
View file @
9c391277
...
@@ -18,6 +18,7 @@ import csv
...
@@ -18,6 +18,7 @@ import csv
import
json
import
json
import
os
import
os
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
contextmanager
from
itertools
import
groupby
from
itertools
import
groupby
from
typing
import
Union
,
Optional
,
Tuple
,
List
,
Dict
from
typing
import
Union
,
Optional
,
Tuple
,
List
,
Dict
...
@@ -152,11 +153,18 @@ class JsonPipelineDataFormat(PipelineDataFormat):
...
@@ -152,11 +153,18 @@ class JsonPipelineDataFormat(PipelineDataFormat):
class
Pipeline
(
_ScikitCompat
):
class
Pipeline
(
_ScikitCompat
):
def
__init__
(
self
,
model
,
tokenizer
:
PreTrainedTokenizer
=
None
,
args_parser
:
ArgumentHandler
=
None
,
**
kwargs
):
def
__init__
(
self
,
model
,
tokenizer
:
PreTrainedTokenizer
=
None
,
args_parser
:
ArgumentHandler
=
None
,
device
:
int
=
-
1
,
**
kwargs
):
self
.
model
=
model
self
.
model
=
model
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
self
.
device
=
device
self
.
_args_parser
=
args_parser
or
DefaultArgumentHandler
()
self
.
_args_parser
=
args_parser
or
DefaultArgumentHandler
()
# Special handling
if
self
.
device
>=
0
and
not
is_tf_available
():
self
.
model
=
self
.
model
.
to
(
'cuda:{}'
.
format
(
self
.
device
))
def
save_pretrained
(
self
,
save_directory
):
def
save_pretrained
(
self
,
save_directory
):
if
not
os
.
path
.
isdir
(
save_directory
):
if
not
os
.
path
.
isdir
(
save_directory
):
logger
.
error
(
"Provided path ({}) should be a directory"
.
format
(
save_directory
))
logger
.
error
(
"Provided path ({}) should be a directory"
.
format
(
save_directory
))
...
@@ -176,11 +184,25 @@ class Pipeline(_ScikitCompat):
...
@@ -176,11 +184,25 @@ class Pipeline(_ScikitCompat):
inputs
=
self
.
_args_parser
(
*
texts
,
**
kwargs
)
inputs
=
self
.
_args_parser
(
*
texts
,
**
kwargs
)
# Encode for forward
# Encode for forward
inputs
=
self
.
tokenizer
.
batch_encode_plus
(
with
self
.
device_placement
():
inputs
,
add_special_tokens
=
True
,
return_tensors
=
'tf'
if
is_tf_available
()
else
'pt'
inputs
=
self
.
tokenizer
.
batch_encode_plus
(
)
inputs
,
add_special_tokens
=
True
,
return_tensors
=
'tf'
if
is_tf_available
()
else
'pt'
)
return
self
.
_forward
(
inputs
)
@
contextmanager
def
device_placement
(
self
):
if
is_tf_available
():
import
tensorflow
as
tf
with
tf
.
device
(
'/CPU:0'
if
self
.
device
==
-
1
else
'/device:GPU:{}'
.
format
(
self
.
device
)):
yield
else
:
import
torch
if
self
.
device
>=
0
:
torch
.
cuda
.
set_device
(
self
.
device
)
return
self
.
_forward
(
inputs
)
yield
def
_forward
(
self
,
inputs
):
def
_forward
(
self
,
inputs
):
if
is_tf_available
():
if
is_tf_available
():
...
@@ -225,14 +247,17 @@ class NerPipeline(Pipeline):
...
@@ -225,14 +247,17 @@ class NerPipeline(Pipeline):
for
i
,
w
in
enumerate
(
words
):
for
i
,
w
in
enumerate
(
words
):
tokens
=
self
.
tokenizer
.
tokenize
(
w
)
tokens
=
self
.
tokenizer
.
tokenize
(
w
)
token_to_word
+=
[
i
]
*
len
(
tokens
)
token_to_word
+=
[
i
]
*
len
(
tokens
)
tokens
=
self
.
tokenizer
.
encode_plus
(
sentence
,
return_attention_mask
=
False
,
return_tensors
=
'tf'
if
is_tf_available
()
else
'pt'
)
# Forward
# Manage correct placement of the tensors
if
is_torch_available
():
with
self
.
device_placement
():
with
torch
.
no_grad
():
tokens
=
self
.
tokenizer
.
encode_plus
(
sentence
,
return_attention_mask
=
False
,
return_tensors
=
'tf'
if
is_tf_available
()
else
'pt'
)
entities
=
self
.
model
(
**
tokens
)[
0
][
0
].
cpu
().
numpy
()
else
:
# Forward
entities
=
self
.
model
(
tokens
)[
0
][
0
].
numpy
()
if
is_torch_available
():
with
torch
.
no_grad
():
entities
=
self
.
model
(
**
tokens
)[
0
][
0
].
cpu
().
numpy
()
else
:
entities
=
self
.
model
(
tokens
)[
0
][
0
].
numpy
()
# Normalize scores
# Normalize scores
answer
,
token_start
=
[],
1
answer
,
token_start
=
[],
1
...
@@ -352,18 +377,20 @@ class QuestionAnsweringPipeline(Pipeline):
...
@@ -352,18 +377,20 @@ class QuestionAnsweringPipeline(Pipeline):
features
=
squad_convert_examples_to_features
(
examples
,
self
.
tokenizer
,
kwargs
[
'max_seq_len'
],
kwargs
[
'doc_stride'
],
kwargs
[
'max_question_len'
],
False
)
features
=
squad_convert_examples_to_features
(
examples
,
self
.
tokenizer
,
kwargs
[
'max_seq_len'
],
kwargs
[
'doc_stride'
],
kwargs
[
'max_question_len'
],
False
)
fw_args
=
self
.
inputs_for_model
(
features
)
fw_args
=
self
.
inputs_for_model
(
features
)
if
is_tf_available
():
# Manage tensor allocation on correct device
import
tensorflow
as
tf
with
self
.
device_placement
():
fw_args
=
{
k
:
tf
.
constant
(
v
)
for
(
k
,
v
)
in
fw_args
.
items
()}
if
is_tf_available
():
start
,
end
=
self
.
model
(
fw_args
)
import
tensorflow
as
tf
start
,
end
=
start
.
numpy
(),
end
.
numpy
()
fw_args
=
{
k
:
tf
.
constant
(
v
)
for
(
k
,
v
)
in
fw_args
.
items
()}
else
:
start
,
end
=
self
.
model
(
fw_args
)
import
torch
start
,
end
=
start
.
numpy
(),
end
.
numpy
()
with
torch
.
no_grad
():
else
:
# Retrieve the score for the context tokens only (removing question tokens)
import
torch
fw_args
=
{
k
:
torch
.
tensor
(
v
)
for
(
k
,
v
)
in
fw_args
.
items
()}
with
torch
.
no_grad
():
start
,
end
=
self
.
model
(
**
fw_args
)
# Retrieve the score for the context tokens only (removing question tokens)
start
,
end
=
start
.
cpu
().
numpy
(),
end
.
cpu
().
numpy
()
fw_args
=
{
k
:
torch
.
tensor
(
v
)
for
(
k
,
v
)
in
fw_args
.
items
()}
start
,
end
=
self
.
model
(
**
fw_args
)
start
,
end
=
start
.
cpu
().
numpy
(),
end
.
cpu
().
numpy
()
answers
=
[]
answers
=
[]
for
(
example
,
feature
,
start_
,
end_
)
in
zip
(
examples
,
features
,
start
,
end
):
for
(
example
,
feature
,
start_
,
end_
)
in
zip
(
examples
,
features
,
start
,
end
):
...
@@ -374,7 +401,7 @@ class QuestionAnsweringPipeline(Pipeline):
...
@@ -374,7 +401,7 @@ class QuestionAnsweringPipeline(Pipeline):
# Mask padding and question
# Mask padding and question
start_
,
end_
=
start_
*
np
.
abs
(
np
.
array
(
feature
.
p_mask
)
-
1
),
end_
*
np
.
abs
(
np
.
array
(
feature
.
p_mask
)
-
1
)
start_
,
end_
=
start_
*
np
.
abs
(
np
.
array
(
feature
.
p_mask
)
-
1
),
end_
*
np
.
abs
(
np
.
array
(
feature
.
p_mask
)
-
1
)
# TODO : What happen
d
if not possible
# TODO : What happen
s
if not possible
# Mask CLS
# Mask CLS
start_
[
0
]
=
end_
[
0
]
=
0
start_
[
0
]
=
end_
[
0
]
=
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment