Commit f516cf39 authored by Morgan Funtowicz's avatar Morgan Funtowicz
Browse files

Allow pipeline to write output in binary format

parent d72fa2a0
import logging
from argparse import ArgumentParser
from transformers.commands import BaseTransformersCLICommand
from transformers.pipelines import pipeline, Pipeline, PipelineDataFormat, SUPPORTED_TASKS
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def try_infer_format_from_ext(path: str):
for ext in PipelineDataFormat.SUPPORTED_FORMATS:
if path.endswith(ext):
......@@ -51,7 +55,11 @@ class RunCommand(BaseTransformersCLICommand):
output += [nlp(entry)]
# Saving data
self._reader.save(output)
if self._nlp.binary_output:
binary_path = self._reader.save_binary(output)
logger.warning('Current pipeline requires output to be in binary format, saving at {}'.format(binary_path))
else:
self._reader.save(output)
......@@ -17,6 +17,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import csv
import json
import os
import pickle
from abc import ABC, abstractmethod
from contextlib import contextmanager
from itertools import groupby
......@@ -91,6 +92,7 @@ class PipelineDataFormat:
if exists(abspath(self.output)):
raise OSError('{} already exists on disk'.format(self.output))
if path is not None:
if not exists(abspath(self.path)):
raise OSError('{} doesnt exist on disk'.format(self.path))
......@@ -102,6 +104,15 @@ class PipelineDataFormat:
def save(self, data: dict):
raise NotImplementedError()
def save_binary(self, data: Union[dict, List[dict]]) -> str:
path, _ = os.path.splitext(self.output)
binary_path = os.path.extsep.join((path, 'pickle'))
with open(binary_path, 'wb+') as f_output:
pickle.dump(data, f_output)
return binary_path
@staticmethod
def from_str(name: str, output: Optional[str], path: Optional[str], column: Optional[str]):
if name == 'json':
......@@ -177,12 +188,20 @@ class PipedPipelineDataFormat(PipelineDataFormat):
# No dictionary to map arguments
else:
print(line)
yield line
def save(self, data: dict):
print(data)
def save_binary(self, data: Union[dict, List[dict]]) -> str:
if self.output is None:
raise KeyError(
'When using piped input on pipeline outputting large object requires an output file path. '
'Please provide such output path through --output argument.'
)
return super().save_binary(data)
class _ScikitCompat(ABC):
"""
......@@ -205,11 +224,13 @@ class Pipeline(_ScikitCompat):
Input -> Tokenization -> Model Inference -> Post-Processing (Task dependent) -> Output
"""
def __init__(self, model, tokenizer: PreTrainedTokenizer = None,
args_parser: ArgumentHandler = None, device: int = -1, **kwargs):
args_parser: ArgumentHandler = None, device: int = -1,
binary_output: bool = False):
self.model = model
self.tokenizer = tokenizer
self.device = device
self.binary_output = binary_output
self._args_parser = args_parser or DefaultArgumentHandler()
# Special handling
......@@ -325,6 +346,13 @@ class FeatureExtractionPipeline(Pipeline):
"""
Feature extraction pipeline using Model head.
"""
def __init__(self, model,
tokenizer: PreTrainedTokenizer = None,
args_parser: ArgumentHandler = None,
device: int = -1):
super().__init__(model, tokenizer, args_parser, device, binary_output=True)
def __call__(self, *args, **kwargs):
return super().__call__(*args, **kwargs).tolist()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment