Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f516cf39
Commit
f516cf39
authored
Dec 19, 2019
by
Morgan Funtowicz
Browse files
Allow pipeline to write output in binary format
parent
d72fa2a0
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
3 deletions
+39
-3
transformers/commands/run.py
transformers/commands/run.py
+9
-1
transformers/pipelines.py
transformers/pipelines.py
+30
-2
No files found.
transformers/commands/run.py
View file @
f516cf39
import
logging
from
argparse
import
ArgumentParser
from
transformers.commands
import
BaseTransformersCLICommand
from
transformers.pipelines
import
pipeline
,
Pipeline
,
PipelineDataFormat
,
SUPPORTED_TASKS
logger
=
logging
.
getLogger
(
__name__
)
# pylint: disable=invalid-name
def
try_infer_format_from_ext
(
path
:
str
):
for
ext
in
PipelineDataFormat
.
SUPPORTED_FORMATS
:
if
path
.
endswith
(
ext
):
...
...
@@ -51,6 +55,10 @@ class RunCommand(BaseTransformersCLICommand):
output
+=
[
nlp
(
entry
)]
# Saving data
if
self
.
_nlp
.
binary_output
:
binary_path
=
self
.
_reader
.
save_binary
(
output
)
logger
.
warning
(
'Current pipeline requires output to be in binary format, saving at {}'
.
format
(
binary_path
))
else
:
self
.
_reader
.
save
(
output
)
...
...
transformers/pipelines.py
View file @
f516cf39
...
...
@@ -17,6 +17,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import
csv
import
json
import
os
import
pickle
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
contextmanager
from
itertools
import
groupby
...
...
@@ -91,6 +92,7 @@ class PipelineDataFormat:
if
exists
(
abspath
(
self
.
output
)):
raise
OSError
(
'{} already exists on disk'
.
format
(
self
.
output
))
if
path
is
not
None
:
if
not
exists
(
abspath
(
self
.
path
)):
raise
OSError
(
'{} doesnt exist on disk'
.
format
(
self
.
path
))
...
...
@@ -102,6 +104,15 @@ class PipelineDataFormat:
def
save
(
self
,
data
:
dict
):
raise
NotImplementedError
()
def
save_binary
(
self
,
data
:
Union
[
dict
,
List
[
dict
]])
->
str
:
path
,
_
=
os
.
path
.
splitext
(
self
.
output
)
binary_path
=
os
.
path
.
extsep
.
join
((
path
,
'pickle'
))
with
open
(
binary_path
,
'wb+'
)
as
f_output
:
pickle
.
dump
(
data
,
f_output
)
return
binary_path
@
staticmethod
def
from_str
(
name
:
str
,
output
:
Optional
[
str
],
path
:
Optional
[
str
],
column
:
Optional
[
str
]):
if
name
==
'json'
:
...
...
@@ -177,12 +188,20 @@ class PipedPipelineDataFormat(PipelineDataFormat):
# No dictionary to map arguments
else
:
print
(
line
)
yield
line
def
save
(
self
,
data
:
dict
):
print
(
data
)
def
save_binary
(
self
,
data
:
Union
[
dict
,
List
[
dict
]])
->
str
:
if
self
.
output
is
None
:
raise
KeyError
(
'When using piped input on pipeline outputting large object requires an output file path. '
'Please provide such output path through --output argument.'
)
return
super
().
save_binary
(
data
)
class
_ScikitCompat
(
ABC
):
"""
...
...
@@ -205,11 +224,13 @@ class Pipeline(_ScikitCompat):
Input -> Tokenization -> Model Inference -> Post-Processing (Task dependent) -> Output
"""
def
__init__
(
self
,
model
,
tokenizer
:
PreTrainedTokenizer
=
None
,
args_parser
:
ArgumentHandler
=
None
,
device
:
int
=
-
1
,
**
kwargs
):
args_parser
:
ArgumentHandler
=
None
,
device
:
int
=
-
1
,
binary_output
:
bool
=
False
):
self
.
model
=
model
self
.
tokenizer
=
tokenizer
self
.
device
=
device
self
.
binary_output
=
binary_output
self
.
_args_parser
=
args_parser
or
DefaultArgumentHandler
()
# Special handling
...
...
@@ -325,6 +346,13 @@ class FeatureExtractionPipeline(Pipeline):
"""
Feature extraction pipeline using Model head.
"""
def
__init__
(
self
,
model
,
tokenizer
:
PreTrainedTokenizer
=
None
,
args_parser
:
ArgumentHandler
=
None
,
device
:
int
=
-
1
):
super
().
__init__
(
model
,
tokenizer
,
args_parser
,
device
,
binary_output
=
True
)
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
super
().
__call__
(
*
args
,
**
kwargs
).
tolist
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment