"ppocr/git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "7a7059456a6f8faa377e3cef62fa650fe89380e0"
Commit f516cf39 authored by Morgan Funtowicz's avatar Morgan Funtowicz
Browse files

Allow pipeline to write output in binary format

parent d72fa2a0
import logging
from argparse import ArgumentParser from argparse import ArgumentParser
from transformers.commands import BaseTransformersCLICommand from transformers.commands import BaseTransformersCLICommand
from transformers.pipelines import pipeline, Pipeline, PipelineDataFormat, SUPPORTED_TASKS from transformers.pipelines import pipeline, Pipeline, PipelineDataFormat, SUPPORTED_TASKS
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def try_infer_format_from_ext(path: str): def try_infer_format_from_ext(path: str):
for ext in PipelineDataFormat.SUPPORTED_FORMATS: for ext in PipelineDataFormat.SUPPORTED_FORMATS:
if path.endswith(ext): if path.endswith(ext):
...@@ -51,7 +55,11 @@ class RunCommand(BaseTransformersCLICommand): ...@@ -51,7 +55,11 @@ class RunCommand(BaseTransformersCLICommand):
output += [nlp(entry)] output += [nlp(entry)]
# Saving data # Saving data
self._reader.save(output) if self._nlp.binary_output:
binary_path = self._reader.save_binary(output)
logger.warning('Current pipeline requires output to be in binary format, saving at {}'.format(binary_path))
else:
self._reader.save(output)
...@@ -17,6 +17,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -17,6 +17,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import csv import csv
import json import json
import os import os
import pickle
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from itertools import groupby from itertools import groupby
...@@ -91,6 +92,7 @@ class PipelineDataFormat: ...@@ -91,6 +92,7 @@ class PipelineDataFormat:
if exists(abspath(self.output)): if exists(abspath(self.output)):
raise OSError('{} already exists on disk'.format(self.output)) raise OSError('{} already exists on disk'.format(self.output))
if path is not None:
if not exists(abspath(self.path)): if not exists(abspath(self.path)):
raise OSError('{} doesnt exist on disk'.format(self.path)) raise OSError('{} doesnt exist on disk'.format(self.path))
...@@ -102,6 +104,15 @@ class PipelineDataFormat: ...@@ -102,6 +104,15 @@ class PipelineDataFormat:
def save(self, data: dict): def save(self, data: dict):
raise NotImplementedError() raise NotImplementedError()
def save_binary(self, data: Union[dict, List[dict]]) -> str:
path, _ = os.path.splitext(self.output)
binary_path = os.path.extsep.join((path, 'pickle'))
with open(binary_path, 'wb+') as f_output:
pickle.dump(data, f_output)
return binary_path
@staticmethod @staticmethod
def from_str(name: str, output: Optional[str], path: Optional[str], column: Optional[str]): def from_str(name: str, output: Optional[str], path: Optional[str], column: Optional[str]):
if name == 'json': if name == 'json':
...@@ -177,12 +188,20 @@ class PipedPipelineDataFormat(PipelineDataFormat): ...@@ -177,12 +188,20 @@ class PipedPipelineDataFormat(PipelineDataFormat):
# No dictionary to map arguments # No dictionary to map arguments
else: else:
print(line)
yield line yield line
def save(self, data: dict): def save(self, data: dict):
print(data) print(data)
def save_binary(self, data: Union[dict, List[dict]]) -> str:
if self.output is None:
raise KeyError(
'When using piped input on pipeline outputting large object requires an output file path. '
'Please provide such output path through --output argument.'
)
return super().save_binary(data)
class _ScikitCompat(ABC): class _ScikitCompat(ABC):
""" """
...@@ -205,11 +224,13 @@ class Pipeline(_ScikitCompat): ...@@ -205,11 +224,13 @@ class Pipeline(_ScikitCompat):
Input -> Tokenization -> Model Inference -> Post-Processing (Task dependent) -> Output Input -> Tokenization -> Model Inference -> Post-Processing (Task dependent) -> Output
""" """
def __init__(self, model, tokenizer: PreTrainedTokenizer = None, def __init__(self, model, tokenizer: PreTrainedTokenizer = None,
args_parser: ArgumentHandler = None, device: int = -1, **kwargs): args_parser: ArgumentHandler = None, device: int = -1,
binary_output: bool = False):
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.device = device self.device = device
self.binary_output = binary_output
self._args_parser = args_parser or DefaultArgumentHandler() self._args_parser = args_parser or DefaultArgumentHandler()
# Special handling # Special handling
...@@ -325,6 +346,13 @@ class FeatureExtractionPipeline(Pipeline): ...@@ -325,6 +346,13 @@ class FeatureExtractionPipeline(Pipeline):
""" """
Feature extraction pipeline using Model head. Feature extraction pipeline using Model head.
""" """
def __init__(self, model,
tokenizer: PreTrainedTokenizer = None,
args_parser: ArgumentHandler = None,
device: int = -1):
super().__init__(model, tokenizer, args_parser, device, binary_output=True)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return super().__call__(*args, **kwargs).tolist() return super().__call__(*args, **kwargs).tolist()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment