Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f516cf39
"ppocr/git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "7a7059456a6f8faa377e3cef62fa650fe89380e0"
Commit
f516cf39
authored
Dec 19, 2019
by
Morgan Funtowicz
Browse files
Allow pipeline to write output in binary format
parent
d72fa2a0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
3 deletions
+39
-3
transformers/commands/run.py
transformers/commands/run.py
+9
-1
transformers/pipelines.py
transformers/pipelines.py
+30
-2
No files found.
transformers/commands/run.py
View file @
f516cf39
import
logging
from
argparse
import
ArgumentParser
from
argparse
import
ArgumentParser
from
transformers.commands
import
BaseTransformersCLICommand
from
transformers.commands
import
BaseTransformersCLICommand
from
transformers.pipelines
import
pipeline
,
Pipeline
,
PipelineDataFormat
,
SUPPORTED_TASKS
from
transformers.pipelines
import
pipeline
,
Pipeline
,
PipelineDataFormat
,
SUPPORTED_TASKS
logger
=
logging
.
getLogger
(
__name__
)
# pylint: disable=invalid-name
def
try_infer_format_from_ext
(
path
:
str
):
def
try_infer_format_from_ext
(
path
:
str
):
for
ext
in
PipelineDataFormat
.
SUPPORTED_FORMATS
:
for
ext
in
PipelineDataFormat
.
SUPPORTED_FORMATS
:
if
path
.
endswith
(
ext
):
if
path
.
endswith
(
ext
):
...
@@ -51,7 +55,11 @@ class RunCommand(BaseTransformersCLICommand):
...
@@ -51,7 +55,11 @@ class RunCommand(BaseTransformersCLICommand):
output
+=
[
nlp
(
entry
)]
output
+=
[
nlp
(
entry
)]
# Saving data
# Saving data
self
.
_reader
.
save
(
output
)
if
self
.
_nlp
.
binary_output
:
binary_path
=
self
.
_reader
.
save_binary
(
output
)
logger
.
warning
(
'Current pipeline requires output to be in binary format, saving at {}'
.
format
(
binary_path
))
else
:
self
.
_reader
.
save
(
output
)
transformers/pipelines.py
View file @
f516cf39
...
@@ -17,6 +17,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
...
@@ -17,6 +17,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import
csv
import
csv
import
json
import
json
import
os
import
os
import
pickle
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
itertools
import
groupby
from
itertools
import
groupby
...
@@ -91,6 +92,7 @@ class PipelineDataFormat:
...
@@ -91,6 +92,7 @@ class PipelineDataFormat:
if
exists
(
abspath
(
self
.
output
)):
if
exists
(
abspath
(
self
.
output
)):
raise
OSError
(
'{} already exists on disk'
.
format
(
self
.
output
))
raise
OSError
(
'{} already exists on disk'
.
format
(
self
.
output
))
if
path
is
not
None
:
if
not
exists
(
abspath
(
self
.
path
)):
if
not
exists
(
abspath
(
self
.
path
)):
raise
OSError
(
'{} doesnt exist on disk'
.
format
(
self
.
path
))
raise
OSError
(
'{} doesnt exist on disk'
.
format
(
self
.
path
))
...
@@ -102,6 +104,15 @@ class PipelineDataFormat:
...
@@ -102,6 +104,15 @@ class PipelineDataFormat:
def
save
(
self
,
data
:
dict
):
def
save
(
self
,
data
:
dict
):
raise
NotImplementedError
()
raise
NotImplementedError
()
def
save_binary
(
self
,
data
:
Union
[
dict
,
List
[
dict
]])
->
str
:
path
,
_
=
os
.
path
.
splitext
(
self
.
output
)
binary_path
=
os
.
path
.
extsep
.
join
((
path
,
'pickle'
))
with
open
(
binary_path
,
'wb+'
)
as
f_output
:
pickle
.
dump
(
data
,
f_output
)
return
binary_path
@
staticmethod
@
staticmethod
def
from_str
(
name
:
str
,
output
:
Optional
[
str
],
path
:
Optional
[
str
],
column
:
Optional
[
str
]):
def
from_str
(
name
:
str
,
output
:
Optional
[
str
],
path
:
Optional
[
str
],
column
:
Optional
[
str
]):
if
name
==
'json'
:
if
name
==
'json'
:
...
@@ -177,12 +188,20 @@ class PipedPipelineDataFormat(PipelineDataFormat):
...
@@ -177,12 +188,20 @@ class PipedPipelineDataFormat(PipelineDataFormat):
# No dictionary to map arguments
# No dictionary to map arguments
else
:
else
:
print
(
line
)
yield
line
yield
line
def
save
(
self
,
data
:
dict
):
def
save
(
self
,
data
:
dict
):
print
(
data
)
print
(
data
)
def
save_binary
(
self
,
data
:
Union
[
dict
,
List
[
dict
]])
->
str
:
if
self
.
output
is
None
:
raise
KeyError
(
'When using piped input on pipeline outputting large object requires an output file path. '
'Please provide such output path through --output argument.'
)
return
super
().
save_binary
(
data
)
class
_ScikitCompat
(
ABC
):
class
_ScikitCompat
(
ABC
):
"""
"""
...
@@ -205,11 +224,13 @@ class Pipeline(_ScikitCompat):
...
@@ -205,11 +224,13 @@ class Pipeline(_ScikitCompat):
Input -> Tokenization -> Model Inference -> Post-Processing (Task dependent) -> Output
Input -> Tokenization -> Model Inference -> Post-Processing (Task dependent) -> Output
"""
"""
def
__init__
(
self
,
model
,
tokenizer
:
PreTrainedTokenizer
=
None
,
def
__init__
(
self
,
model
,
tokenizer
:
PreTrainedTokenizer
=
None
,
args_parser
:
ArgumentHandler
=
None
,
device
:
int
=
-
1
,
**
kwargs
):
args_parser
:
ArgumentHandler
=
None
,
device
:
int
=
-
1
,
binary_output
:
bool
=
False
):
self
.
model
=
model
self
.
model
=
model
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
self
.
device
=
device
self
.
device
=
device
self
.
binary_output
=
binary_output
self
.
_args_parser
=
args_parser
or
DefaultArgumentHandler
()
self
.
_args_parser
=
args_parser
or
DefaultArgumentHandler
()
# Special handling
# Special handling
...
@@ -325,6 +346,13 @@ class FeatureExtractionPipeline(Pipeline):
...
@@ -325,6 +346,13 @@ class FeatureExtractionPipeline(Pipeline):
"""
"""
Feature extraction pipeline using Model head.
Feature extraction pipeline using Model head.
"""
"""
def
__init__
(
self
,
model
,
tokenizer
:
PreTrainedTokenizer
=
None
,
args_parser
:
ArgumentHandler
=
None
,
device
:
int
=
-
1
):
super
().
__init__
(
model
,
tokenizer
,
args_parser
,
device
,
binary_output
=
True
)
def
__call__
(
self
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
super
().
__call__
(
*
args
,
**
kwargs
).
tolist
()
return
super
().
__call__
(
*
args
,
**
kwargs
).
tolist
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment