Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f516cf39
Commit
f516cf39
authored
Dec 19, 2019
by
Morgan Funtowicz
Browse files
Allow pipeline to write output in binary format
parent
d72fa2a0
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
3 deletions
+39
-3
transformers/commands/run.py
transformers/commands/run.py
+9
-1
transformers/pipelines.py
transformers/pipelines.py
+30
-2
No files found.
transformers/commands/run.py
View file @
f516cf39
import
logging
from
argparse
import
ArgumentParser
from
argparse
import
ArgumentParser
from
transformers.commands
import
BaseTransformersCLICommand
from
transformers.commands
import
BaseTransformersCLICommand
from
transformers.pipelines
import
pipeline
,
Pipeline
,
PipelineDataFormat
,
SUPPORTED_TASKS
from
transformers.pipelines
import
pipeline
,
Pipeline
,
PipelineDataFormat
,
SUPPORTED_TASKS
logger
=
logging
.
getLogger
(
__name__
)
# pylint: disable=invalid-name
def
try_infer_format_from_ext
(
path
:
str
):
def
try_infer_format_from_ext
(
path
:
str
):
for
ext
in
PipelineDataFormat
.
SUPPORTED_FORMATS
:
for
ext
in
PipelineDataFormat
.
SUPPORTED_FORMATS
:
if
path
.
endswith
(
ext
):
if
path
.
endswith
(
ext
):
...
@@ -51,6 +55,10 @@ class RunCommand(BaseTransformersCLICommand):
...
@@ -51,6 +55,10 @@ class RunCommand(BaseTransformersCLICommand):
output
+=
[
nlp
(
entry
)]
output
+=
[
nlp
(
entry
)]
# Saving data
# Saving data
if
self
.
_nlp
.
binary_output
:
binary_path
=
self
.
_reader
.
save_binary
(
output
)
logger
.
warning
(
'Current pipeline requires output to be in binary format, saving at {}'
.
format
(
binary_path
))
else
:
self
.
_reader
.
save
(
output
)
self
.
_reader
.
save
(
output
)
...
...
transformers/pipelines.py
View file @
f516cf39
...
@@ -17,6 +17,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
...
@@ -17,6 +17,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import
csv
import
csv
import
json
import
json
import
os
import
os
import
pickle
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
itertools
import
groupby
from
itertools
import
groupby
...
@@ -91,6 +92,7 @@ class PipelineDataFormat:
...
@@ -91,6 +92,7 @@ class PipelineDataFormat:
if
exists
(
abspath
(
self
.
output
)):
if
exists
(
abspath
(
self
.
output
)):
raise
OSError
(
'{} already exists on disk'
.
format
(
self
.
output
))
raise
OSError
(
'{} already exists on disk'
.
format
(
self
.
output
))
if
path
is
not
None
:
if
not
exists
(
abspath
(
self
.
path
)):
if
not
exists
(
abspath
(
self
.
path
)):
raise
OSError
(
'{} doesnt exist on disk'
.
format
(
self
.
path
))
raise
OSError
(
'{} doesnt exist on disk'
.
format
(
self
.
path
))
...
@@ -102,6 +104,15 @@ class PipelineDataFormat:
...
@@ -102,6 +104,15 @@ class PipelineDataFormat:
def
save
(
self
,
data
:
dict
):
def
save
(
self
,
data
:
dict
):
raise
NotImplementedError
()
raise
NotImplementedError
()
def
save_binary
(
self
,
data
:
Union
[
dict
,
List
[
dict
]])
->
str
:
path
,
_
=
os
.
path
.
splitext
(
self
.
output
)
binary_path
=
os
.
path
.
extsep
.
join
((
path
,
'pickle'
))
with
open
(
binary_path
,
'wb+'
)
as
f_output
:
pickle
.
dump
(
data
,
f_output
)
return
binary_path
@
staticmethod
@
staticmethod
def
from_str
(
name
:
str
,
output
:
Optional
[
str
],
path
:
Optional
[
str
],
column
:
Optional
[
str
]):
def
from_str
(
name
:
str
,
output
:
Optional
[
str
],
path
:
Optional
[
str
],
column
:
Optional
[
str
]):
if
name
==
'json'
:
if
name
==
'json'
:
...
@@ -177,12 +188,20 @@ class PipedPipelineDataFormat(PipelineDataFormat):
...
@@ -177,12 +188,20 @@ class PipedPipelineDataFormat(PipelineDataFormat):
# No dictionary to map arguments
# No dictionary to map arguments
else
:
else
:
print
(
line
)
yield
line
yield
line
def
save
(
self
,
data
:
dict
):
def
save
(
self
,
data
:
dict
):
print
(
data
)
print
(
data
)
def
save_binary
(
self
,
data
:
Union
[
dict
,
List
[
dict
]])
->
str
:
if
self
.
output
is
None
:
raise
KeyError
(
'When using piped input on pipeline outputting large object requires an output file path. '
'Please provide such output path through --output argument.'
)
return
super
().
save_binary
(
data
)
class
_ScikitCompat
(
ABC
):
class
_ScikitCompat
(
ABC
):
"""
"""
...
@@ -205,11 +224,13 @@ class Pipeline(_ScikitCompat):
...
@@ -205,11 +224,13 @@ class Pipeline(_ScikitCompat):
Input -> Tokenization -> Model Inference -> Post-Processing (Task dependent) -> Output
Input -> Tokenization -> Model Inference -> Post-Processing (Task dependent) -> Output
"""
"""
def
__init__
(
self
,
model
,
tokenizer
:
PreTrainedTokenizer
=
None
,
def
__init__
(
self
,
model
,
tokenizer
:
PreTrainedTokenizer
=
None
,
args_parser
:
ArgumentHandler
=
None
,
device
:
int
=
-
1
,
**
kwargs
):
args_parser
:
ArgumentHandler
=
None
,
device
:
int
=
-
1
,
binary_output
:
bool
=
False
):
self
.
model
=
model
self
.
model
=
model
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
self
.
device
=
device
self
.
device
=
device
self
.
binary_output
=
binary_output
self
.
_args_parser
=
args_parser
or
DefaultArgumentHandler
()
self
.
_args_parser
=
args_parser
or
DefaultArgumentHandler
()
# Special handling
# Special handling
...
@@ -325,6 +346,13 @@ class FeatureExtractionPipeline(Pipeline):
...
@@ -325,6 +346,13 @@ class FeatureExtractionPipeline(Pipeline):
"""
"""
Feature extraction pipeline using Model head.
Feature extraction pipeline using Model head.
"""
"""
def
__init__
(
self
,
model
,
tokenizer
:
PreTrainedTokenizer
=
None
,
args_parser
:
ArgumentHandler
=
None
,
device
:
int
=
-
1
):
super
().
__init__
(
model
,
tokenizer
,
args_parser
,
device
,
binary_output
=
True
)
def
__call__
(
self
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
super
().
__call__
(
*
args
,
**
kwargs
).
tolist
()
return
super
().
__call__
(
*
args
,
**
kwargs
).
tolist
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment