Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
bbaaec04
Commit
bbaaec04
authored
Dec 20, 2019
by
thomwolf
Browse files
fixing CLI pipeline
parent
1c12ee0e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
58 additions
and
41 deletions
+58
-41
transformers/commands/run.py
transformers/commands/run.py
+22
-10
transformers/pipelines.py
transformers/pipelines.py
+36
-31
No files found.
transformers/commands/run.py
View file @
bbaaec04
...
@@ -9,6 +9,9 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
...
@@ -9,6 +9,9 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def
try_infer_format_from_ext
(
path
:
str
):
def
try_infer_format_from_ext
(
path
:
str
):
if
not
path
:
return
'pipe'
for
ext
in
PipelineDataFormat
.
SUPPORTED_FORMATS
:
for
ext
in
PipelineDataFormat
.
SUPPORTED_FORMATS
:
if
path
.
endswith
(
ext
):
if
path
.
endswith
(
ext
):
return
ext
return
ext
...
@@ -20,9 +23,16 @@ def try_infer_format_from_ext(path: str):
...
@@ -20,9 +23,16 @@ def try_infer_format_from_ext(path: str):
def
run_command_factory
(
args
):
def
run_command_factory
(
args
):
nlp
=
pipeline
(
task
=
args
.
task
,
model
=
args
.
model
,
config
=
args
.
config
,
tokenizer
=
args
.
tokenizer
,
device
=
args
.
device
)
nlp
=
pipeline
(
task
=
args
.
task
,
model
=
args
.
model
if
args
.
model
else
None
,
config
=
args
.
config
,
tokenizer
=
args
.
tokenizer
,
device
=
args
.
device
)
format
=
try_infer_format_from_ext
(
args
.
input
)
if
args
.
format
==
'infer'
else
args
.
format
format
=
try_infer_format_from_ext
(
args
.
input
)
if
args
.
format
==
'infer'
else
args
.
format
reader
=
PipelineDataFormat
.
from_str
(
format
,
args
.
output
,
args
.
input
,
args
.
column
)
reader
=
PipelineDataFormat
.
from_str
(
format
=
format
,
output_path
=
args
.
output
,
input_path
=
args
.
input
,
column
=
args
.
column
if
args
.
column
else
nlp
.
default_input_names
)
return
RunCommand
(
nlp
,
reader
)
return
RunCommand
(
nlp
,
reader
)
...
@@ -35,24 +45,26 @@ class RunCommand(BaseTransformersCLICommand):
...
@@ -35,24 +45,26 @@ class RunCommand(BaseTransformersCLICommand):
@
staticmethod
@
staticmethod
def
register_subcommand
(
parser
:
ArgumentParser
):
def
register_subcommand
(
parser
:
ArgumentParser
):
run_parser
=
parser
.
add_parser
(
'run'
,
help
=
"Run a pipeline through the CLI"
)
run_parser
=
parser
.
add_parser
(
'run'
,
help
=
"Run a pipeline through the CLI"
)
run_parser
.
add_argument
(
'--device'
,
type
=
int
,
default
=-
1
,
help
=
'Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)'
)
run_parser
.
add_argument
(
'--task'
,
choices
=
SUPPORTED_TASKS
.
keys
(),
help
=
'Task to run'
)
run_parser
.
add_argument
(
'--task'
,
choices
=
SUPPORTED_TASKS
.
keys
(),
help
=
'Task to run'
)
run_parser
.
add_argument
(
'--model'
,
type
=
str
,
required
=
True
,
help
=
'Name or path to the model to instantiate.'
)
run_parser
.
add_argument
(
'--input'
,
type
=
str
,
help
=
'Path to the file to use for inference'
)
run_parser
.
add_argument
(
'--output'
,
type
=
str
,
help
=
'Path to the file that will be used post to write results.'
)
run_parser
.
add_argument
(
'--model'
,
type
=
str
,
help
=
'Name or path to the model to instantiate.'
)
run_parser
.
add_argument
(
'--config'
,
type
=
str
,
help
=
'Name or path to the model
\'
s config to instantiate.'
)
run_parser
.
add_argument
(
'--config'
,
type
=
str
,
help
=
'Name or path to the model
\'
s config to instantiate.'
)
run_parser
.
add_argument
(
'--tokenizer'
,
type
=
str
,
help
=
'Name of the tokenizer to use. (default: same as the model name)'
)
run_parser
.
add_argument
(
'--tokenizer'
,
type
=
str
,
help
=
'Name of the tokenizer to use. (default: same as the model name)'
)
run_parser
.
add_argument
(
'--column'
,
type
=
str
,
help
=
'Name of the column to use as input. (For multi columns input as QA use column1,columns2)'
)
run_parser
.
add_argument
(
'--column'
,
type
=
str
,
help
=
'Name of the column to use as input. (For multi columns input as QA use column1,columns2)'
)
run_parser
.
add_argument
(
'--format'
,
type
=
str
,
default
=
'infer'
,
choices
=
PipelineDataFormat
.
SUPPORTED_FORMATS
,
help
=
'Input format to read from'
)
run_parser
.
add_argument
(
'--format'
,
type
=
str
,
default
=
'infer'
,
choices
=
PipelineDataFormat
.
SUPPORTED_FORMATS
,
help
=
'Input format to read from'
)
run_parser
.
add_argument
(
'--input'
,
type
=
str
,
help
=
'Path to the file to use for inference'
)
run_parser
.
add_argument
(
'--device'
,
type
=
int
,
default
=-
1
,
help
=
'Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)'
)
run_parser
.
add_argument
(
'--output'
,
type
=
str
,
help
=
'Path to the file that will be used post to write results.'
)
run_parser
.
set_defaults
(
func
=
run_command_factory
)
run_parser
.
set_defaults
(
func
=
run_command_factory
)
def
run
(
self
):
def
run
(
self
):
nlp
,
output
=
self
.
_nlp
,
[]
nlp
,
outputs
=
self
.
_nlp
,
[]
for
entry
in
self
.
_reader
:
for
entry
in
self
.
_reader
:
if
self
.
_reader
.
is_multi_columns
:
output
=
nlp
(
**
entry
)
if
self
.
_reader
.
is_multi_columns
else
nlp
(
entry
)
output
+=
nlp
(
**
entry
)
if
isinstance
(
output
,
dict
):
outputs
.
append
(
output
)
else
:
else
:
output
+=
nlp
(
entry
)
output
s
+=
output
# Saving data
# Saving data
if
self
.
_nlp
.
binary_output
:
if
self
.
_nlp
.
binary_output
:
...
...
transformers/pipelines.py
View file @
bbaaec04
...
@@ -14,12 +14,14 @@
...
@@ -14,12 +14,14 @@
# limitations under the License.
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
sys
import
csv
import
csv
import
json
import
json
import
os
import
os
import
pickle
import
pickle
import
logging
import
logging
import
six
import
six
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
itertools
import
groupby
from
itertools
import
groupby
...
@@ -98,28 +100,29 @@ class PipelineDataFormat:
...
@@ -98,28 +100,29 @@ class PipelineDataFormat:
Supported data formats currently includes:
Supported data formats currently includes:
- JSON
- JSON
- CSV
- CSV
- stdin/stdout (pipe)
PipelineDataFormat also includes some utilities to work with multi-columns like mapping from datasets columns
PipelineDataFormat also includes some utilities to work with multi-columns like mapping from datasets columns
to pipelines keyword arguments through the `dataset_kwarg_1=dataset_column_1` format.
to pipelines keyword arguments through the `dataset_kwarg_1=dataset_column_1` format.
"""
"""
SUPPORTED_FORMATS
=
[
'json'
,
'csv'
,
'pipe'
]
SUPPORTED_FORMATS
=
[
'json'
,
'csv'
,
'pipe'
]
def
__init__
(
self
,
output
:
Optional
[
str
],
input
:
Optional
[
str
],
column
:
Optional
[
str
]):
def
__init__
(
self
,
output
_path
:
Optional
[
str
],
input
_path
:
Optional
[
str
],
column
:
Optional
[
str
]):
self
.
output
=
output
self
.
output
_path
=
output
_path
self
.
path
=
input
self
.
input_
path
=
input
_path
self
.
column
=
column
.
split
(
','
)
if
column
else
[
''
]
self
.
column
=
column
.
split
(
','
)
if
column
is
not
None
else
[
''
]
self
.
is_multi_columns
=
len
(
self
.
column
)
>
1
self
.
is_multi_columns
=
len
(
self
.
column
)
>
1
if
self
.
is_multi_columns
:
if
self
.
is_multi_columns
:
self
.
column
=
[
tuple
(
c
.
split
(
'='
))
if
'='
in
c
else
(
c
,
c
)
for
c
in
self
.
column
]
self
.
column
=
[
tuple
(
c
.
split
(
'='
))
if
'='
in
c
else
(
c
,
c
)
for
c
in
self
.
column
]
if
output
is
not
None
:
if
output
_path
is
not
None
:
if
exists
(
abspath
(
self
.
output
)):
if
exists
(
abspath
(
self
.
output
_path
)):
raise
OSError
(
'{} already exists on disk'
.
format
(
self
.
output
))
raise
OSError
(
'{} already exists on disk'
.
format
(
self
.
output
_path
))
if
input
is
not
None
:
if
input
_path
is
not
None
:
if
not
exists
(
abspath
(
self
.
path
)):
if
not
exists
(
abspath
(
self
.
input_
path
)):
raise
OSError
(
'{} doesnt exist on disk'
.
format
(
self
.
path
))
raise
OSError
(
'{} doesnt exist on disk'
.
format
(
self
.
input_
path
))
@
abstractmethod
@
abstractmethod
def
__iter__
(
self
):
def
__iter__
(
self
):
...
@@ -140,7 +143,7 @@ class PipelineDataFormat:
...
@@ -140,7 +143,7 @@ class PipelineDataFormat:
:param data: data to store
:param data: data to store
:return: (str) Path where the data has been saved
:return: (str) Path where the data has been saved
"""
"""
path
,
_
=
os
.
path
.
splitext
(
self
.
output
)
path
,
_
=
os
.
path
.
splitext
(
self
.
output
_path
)
binary_path
=
os
.
path
.
extsep
.
join
((
path
,
'pickle'
))
binary_path
=
os
.
path
.
extsep
.
join
((
path
,
'pickle'
))
with
open
(
binary_path
,
'wb+'
)
as
f_output
:
with
open
(
binary_path
,
'wb+'
)
as
f_output
:
...
@@ -149,23 +152,23 @@ class PipelineDataFormat:
...
@@ -149,23 +152,23 @@ class PipelineDataFormat:
return
binary_path
return
binary_path
@
staticmethod
@
staticmethod
def
from_str
(
name
:
str
,
output
:
Optional
[
str
],
path
:
Optional
[
str
],
column
:
Optional
[
str
]):
def
from_str
(
format
:
str
,
output
_path
:
Optional
[
str
],
input_
path
:
Optional
[
str
],
column
:
Optional
[
str
]):
if
name
==
'json'
:
if
format
==
'json'
:
return
JsonPipelineDataFormat
(
output
,
path
,
column
)
return
JsonPipelineDataFormat
(
output
_path
,
input_
path
,
column
)
elif
name
==
'csv'
:
elif
format
==
'csv'
:
return
CsvPipelineDataFormat
(
output
,
path
,
column
)
return
CsvPipelineDataFormat
(
output
_path
,
input_
path
,
column
)
elif
name
==
'pipe'
:
elif
format
==
'pipe'
:
return
PipedPipelineDataFormat
(
output
,
path
,
column
)
return
PipedPipelineDataFormat
(
output
_path
,
input_
path
,
column
)
else
:
else
:
raise
KeyError
(
'Unknown reader {} (Available reader are json/csv/pipe)'
.
format
(
name
))
raise
KeyError
(
'Unknown reader {} (Available reader are json/csv/pipe)'
.
format
(
format
))
class
CsvPipelineDataFormat
(
PipelineDataFormat
):
class
CsvPipelineDataFormat
(
PipelineDataFormat
):
def
__init__
(
self
,
output
:
Optional
[
str
],
input
:
Optional
[
str
],
column
:
Optional
[
str
]):
def
__init__
(
self
,
output
_path
:
Optional
[
str
],
input
_path
:
Optional
[
str
],
column
:
Optional
[
str
]):
super
().
__init__
(
output
,
input
,
column
)
super
().
__init__
(
output
_path
,
input
_path
,
column
)
def
__iter__
(
self
):
def
__iter__
(
self
):
with
open
(
self
.
path
,
'r'
)
as
f
:
with
open
(
self
.
input_
path
,
'r'
)
as
f
:
reader
=
csv
.
DictReader
(
f
)
reader
=
csv
.
DictReader
(
f
)
for
row
in
reader
:
for
row
in
reader
:
if
self
.
is_multi_columns
:
if
self
.
is_multi_columns
:
...
@@ -174,7 +177,7 @@ class CsvPipelineDataFormat(PipelineDataFormat):
...
@@ -174,7 +177,7 @@ class CsvPipelineDataFormat(PipelineDataFormat):
yield
row
[
self
.
column
[
0
]]
yield
row
[
self
.
column
[
0
]]
def
save
(
self
,
data
:
List
[
dict
]):
def
save
(
self
,
data
:
List
[
dict
]):
with
open
(
self
.
output
,
'w'
)
as
f
:
with
open
(
self
.
output
_path
,
'w'
)
as
f
:
if
len
(
data
)
>
0
:
if
len
(
data
)
>
0
:
writer
=
csv
.
DictWriter
(
f
,
list
(
data
[
0
].
keys
()))
writer
=
csv
.
DictWriter
(
f
,
list
(
data
[
0
].
keys
()))
writer
.
writeheader
()
writer
.
writeheader
()
...
@@ -182,10 +185,10 @@ class CsvPipelineDataFormat(PipelineDataFormat):
...
@@ -182,10 +185,10 @@ class CsvPipelineDataFormat(PipelineDataFormat):
class
JsonPipelineDataFormat
(
PipelineDataFormat
):
class
JsonPipelineDataFormat
(
PipelineDataFormat
):
def
__init__
(
self
,
output
:
Optional
[
str
],
input
:
Optional
[
str
],
column
:
Optional
[
str
]):
def
__init__
(
self
,
output
_path
:
Optional
[
str
],
input
_path
:
Optional
[
str
],
column
:
Optional
[
str
]):
super
().
__init__
(
output
,
input
,
column
)
super
().
__init__
(
output
_path
,
input
_path
,
column
)
with
open
(
input
,
'r'
)
as
f
:
with
open
(
input
_path
,
'r'
)
as
f
:
self
.
_entries
=
json
.
load
(
f
)
self
.
_entries
=
json
.
load
(
f
)
def
__iter__
(
self
):
def
__iter__
(
self
):
...
@@ -196,7 +199,7 @@ class JsonPipelineDataFormat(PipelineDataFormat):
...
@@ -196,7 +199,7 @@ class JsonPipelineDataFormat(PipelineDataFormat):
yield
entry
[
self
.
column
[
0
]]
yield
entry
[
self
.
column
[
0
]]
def
save
(
self
,
data
:
dict
):
def
save
(
self
,
data
:
dict
):
with
open
(
self
.
output
,
'w'
)
as
f
:
with
open
(
self
.
output
_path
,
'w'
)
as
f
:
json
.
dump
(
data
,
f
)
json
.
dump
(
data
,
f
)
...
@@ -208,9 +211,7 @@ class PipedPipelineDataFormat(PipelineDataFormat):
...
@@ -208,9 +211,7 @@ class PipedPipelineDataFormat(PipelineDataFormat):
If columns are provided, then the output will be a dictionary with {column_x: value_x}
If columns are provided, then the output will be a dictionary with {column_x: value_x}
"""
"""
def
__iter__
(
self
):
def
__iter__
(
self
):
import
sys
for
line
in
sys
.
stdin
:
for
line
in
sys
.
stdin
:
# Split for multi-columns
# Split for multi-columns
if
'
\t
'
in
line
:
if
'
\t
'
in
line
:
...
@@ -229,7 +230,7 @@ class PipedPipelineDataFormat(PipelineDataFormat):
...
@@ -229,7 +230,7 @@ class PipedPipelineDataFormat(PipelineDataFormat):
print
(
data
)
print
(
data
)
def
save_binary
(
self
,
data
:
Union
[
dict
,
List
[
dict
]])
->
str
:
def
save_binary
(
self
,
data
:
Union
[
dict
,
List
[
dict
]])
->
str
:
if
self
.
output
is
None
:
if
self
.
output
_path
is
None
:
raise
KeyError
(
raise
KeyError
(
'When using piped input on pipeline outputting large object requires an output file path. '
'When using piped input on pipeline outputting large object requires an output file path. '
'Please provide such output path through --output argument.'
'Please provide such output path through --output argument.'
...
@@ -294,6 +295,9 @@ class Pipeline(_ScikitCompat):
...
@@ -294,6 +295,9 @@ class Pipeline(_ScikitCompat):
nlp = NerPipeline(model='...', config='...', tokenizer='...')
nlp = NerPipeline(model='...', config='...', tokenizer='...')
nlp = QuestionAnsweringPipeline(model=AutoModel.from_pretrained('...'), tokenizer='...')
nlp = QuestionAnsweringPipeline(model=AutoModel.from_pretrained('...'), tokenizer='...')
"""
"""
default_input_names
=
None
def
__init__
(
self
,
model
,
tokenizer
:
PreTrainedTokenizer
=
None
,
def
__init__
(
self
,
model
,
tokenizer
:
PreTrainedTokenizer
=
None
,
modelcard
:
ModelCard
=
None
,
framework
:
Optional
[
str
]
=
None
,
modelcard
:
ModelCard
=
None
,
framework
:
Optional
[
str
]
=
None
,
args_parser
:
ArgumentHandler
=
None
,
device
:
int
=
-
1
,
args_parser
:
ArgumentHandler
=
None
,
device
:
int
=
-
1
,
...
@@ -582,6 +586,8 @@ class QuestionAnsweringPipeline(Pipeline):
...
@@ -582,6 +586,8 @@ class QuestionAnsweringPipeline(Pipeline):
Question Answering pipeline using ModelForQuestionAnswering head.
Question Answering pipeline using ModelForQuestionAnswering head.
"""
"""
default_input_names
=
'question,context'
def
__init__
(
self
,
model
,
def
__init__
(
self
,
model
,
tokenizer
:
Optional
[
PreTrainedTokenizer
],
tokenizer
:
Optional
[
PreTrainedTokenizer
],
modelcard
:
Optional
[
ModelCard
],
modelcard
:
Optional
[
ModelCard
],
...
@@ -684,7 +690,6 @@ class QuestionAnsweringPipeline(Pipeline):
...
@@ -684,7 +690,6 @@ class QuestionAnsweringPipeline(Pipeline):
}
}
for
s
,
e
,
score
in
zip
(
starts
,
ends
,
scores
)
for
s
,
e
,
score
in
zip
(
starts
,
ends
,
scores
)
]
]
if
len
(
answers
)
==
1
:
if
len
(
answers
)
==
1
:
return
answers
[
0
]
return
answers
[
0
]
return
answers
return
answers
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment