Commit bbaaec04 authored by thomwolf's avatar thomwolf
Browse files

fixing CLI pipeline

parent 1c12ee0e
...@@ -9,6 +9,9 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name ...@@ -9,6 +9,9 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def try_infer_format_from_ext(path: str): def try_infer_format_from_ext(path: str):
if not path:
return 'pipe'
for ext in PipelineDataFormat.SUPPORTED_FORMATS: for ext in PipelineDataFormat.SUPPORTED_FORMATS:
if path.endswith(ext): if path.endswith(ext):
return ext return ext
...@@ -20,9 +23,16 @@ def try_infer_format_from_ext(path: str): ...@@ -20,9 +23,16 @@ def try_infer_format_from_ext(path: str):
def run_command_factory(args): def run_command_factory(args):
nlp = pipeline(task=args.task, model=args.model, config=args.config, tokenizer=args.tokenizer, device=args.device) nlp = pipeline(task=args.task,
model=args.model if args.model else None,
config=args.config,
tokenizer=args.tokenizer,
device=args.device)
format = try_infer_format_from_ext(args.input) if args.format == 'infer' else args.format format = try_infer_format_from_ext(args.input) if args.format == 'infer' else args.format
reader = PipelineDataFormat.from_str(format, args.output, args.input, args.column) reader = PipelineDataFormat.from_str(format=format,
output_path=args.output,
input_path=args.input,
column=args.column if args.column else nlp.default_input_names)
return RunCommand(nlp, reader) return RunCommand(nlp, reader)
...@@ -35,24 +45,26 @@ class RunCommand(BaseTransformersCLICommand): ...@@ -35,24 +45,26 @@ class RunCommand(BaseTransformersCLICommand):
@staticmethod @staticmethod
def register_subcommand(parser: ArgumentParser): def register_subcommand(parser: ArgumentParser):
run_parser = parser.add_parser('run', help="Run a pipeline through the CLI") run_parser = parser.add_parser('run', help="Run a pipeline through the CLI")
run_parser.add_argument('--device', type=int, default=-1, help='Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)')
run_parser.add_argument('--task', choices=SUPPORTED_TASKS.keys(), help='Task to run') run_parser.add_argument('--task', choices=SUPPORTED_TASKS.keys(), help='Task to run')
run_parser.add_argument('--model', type=str, required=True, help='Name or path to the model to instantiate.') run_parser.add_argument('--input', type=str, help='Path to the file to use for inference')
run_parser.add_argument('--output', type=str, help='Path to the file that will be used post to write results.')
run_parser.add_argument('--model', type=str, help='Name or path to the model to instantiate.')
run_parser.add_argument('--config', type=str, help='Name or path to the model\'s config to instantiate.') run_parser.add_argument('--config', type=str, help='Name or path to the model\'s config to instantiate.')
run_parser.add_argument('--tokenizer', type=str, help='Name of the tokenizer to use. (default: same as the model name)') run_parser.add_argument('--tokenizer', type=str, help='Name of the tokenizer to use. (default: same as the model name)')
run_parser.add_argument('--column', type=str, help='Name of the column to use as input. (For multi columns input as QA use column1,columns2)') run_parser.add_argument('--column', type=str, help='Name of the column to use as input. (For multi columns input as QA use column1,columns2)')
run_parser.add_argument('--format', type=str, default='infer', choices=PipelineDataFormat.SUPPORTED_FORMATS, help='Input format to read from') run_parser.add_argument('--format', type=str, default='infer', choices=PipelineDataFormat.SUPPORTED_FORMATS, help='Input format to read from')
run_parser.add_argument('--input', type=str, help='Path to the file to use for inference') run_parser.add_argument('--device', type=int, default=-1, help='Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)')
run_parser.add_argument('--output', type=str, help='Path to the file that will be used post to write results.')
run_parser.set_defaults(func=run_command_factory) run_parser.set_defaults(func=run_command_factory)
def run(self): def run(self):
nlp, output = self._nlp, [] nlp, outputs = self._nlp, []
for entry in self._reader: for entry in self._reader:
if self._reader.is_multi_columns: output = nlp(**entry) if self._reader.is_multi_columns else nlp(entry)
output += nlp(**entry) if isinstance(output, dict):
outputs.append(output)
else: else:
output += nlp(entry) outputs += output
# Saving data # Saving data
if self._nlp.binary_output: if self._nlp.binary_output:
......
...@@ -14,12 +14,14 @@ ...@@ -14,12 +14,14 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import sys
import csv import csv
import json import json
import os import os
import pickle import pickle
import logging import logging
import six import six
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from itertools import groupby from itertools import groupby
...@@ -98,28 +100,29 @@ class PipelineDataFormat: ...@@ -98,28 +100,29 @@ class PipelineDataFormat:
Supported data formats currently includes: Supported data formats currently includes:
- JSON - JSON
- CSV - CSV
- stdin/stdout (pipe)
PipelineDataFormat also includes some utilities to work with multi-columns like mapping from datasets columns PipelineDataFormat also includes some utilities to work with multi-columns like mapping from datasets columns
to pipelines keyword arguments through the `dataset_kwarg_1=dataset_column_1` format. to pipelines keyword arguments through the `dataset_kwarg_1=dataset_column_1` format.
""" """
SUPPORTED_FORMATS = ['json', 'csv', 'pipe'] SUPPORTED_FORMATS = ['json', 'csv', 'pipe']
def __init__(self, output: Optional[str], input: Optional[str], column: Optional[str]): def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str]):
self.output = output self.output_path = output_path
self.path = input self.input_path = input_path
self.column = column.split(',') if column else [''] self.column = column.split(',') if column is not None else ['']
self.is_multi_columns = len(self.column) > 1 self.is_multi_columns = len(self.column) > 1
if self.is_multi_columns: if self.is_multi_columns:
self.column = [tuple(c.split('=')) if '=' in c else (c, c) for c in self.column] self.column = [tuple(c.split('=')) if '=' in c else (c, c) for c in self.column]
if output is not None: if output_path is not None:
if exists(abspath(self.output)): if exists(abspath(self.output_path)):
raise OSError('{} already exists on disk'.format(self.output)) raise OSError('{} already exists on disk'.format(self.output_path))
if input is not None: if input_path is not None:
if not exists(abspath(self.path)): if not exists(abspath(self.input_path)):
raise OSError('{} doesnt exist on disk'.format(self.path)) raise OSError('{} doesnt exist on disk'.format(self.input_path))
@abstractmethod @abstractmethod
def __iter__(self): def __iter__(self):
...@@ -140,7 +143,7 @@ class PipelineDataFormat: ...@@ -140,7 +143,7 @@ class PipelineDataFormat:
:param data: data to store :param data: data to store
:return: (str) Path where the data has been saved :return: (str) Path where the data has been saved
""" """
path, _ = os.path.splitext(self.output) path, _ = os.path.splitext(self.output_path)
binary_path = os.path.extsep.join((path, 'pickle')) binary_path = os.path.extsep.join((path, 'pickle'))
with open(binary_path, 'wb+') as f_output: with open(binary_path, 'wb+') as f_output:
...@@ -149,23 +152,23 @@ class PipelineDataFormat: ...@@ -149,23 +152,23 @@ class PipelineDataFormat:
return binary_path return binary_path
@staticmethod @staticmethod
def from_str(name: str, output: Optional[str], path: Optional[str], column: Optional[str]): def from_str(format: str, output_path: Optional[str], input_path: Optional[str], column: Optional[str]):
if name == 'json': if format == 'json':
return JsonPipelineDataFormat(output, path, column) return JsonPipelineDataFormat(output_path, input_path, column)
elif name == 'csv': elif format == 'csv':
return CsvPipelineDataFormat(output, path, column) return CsvPipelineDataFormat(output_path, input_path, column)
elif name == 'pipe': elif format == 'pipe':
return PipedPipelineDataFormat(output, path, column) return PipedPipelineDataFormat(output_path, input_path, column)
else: else:
raise KeyError('Unknown reader {} (Available reader are json/csv/pipe)'.format(name)) raise KeyError('Unknown reader {} (Available reader are json/csv/pipe)'.format(format))
class CsvPipelineDataFormat(PipelineDataFormat): class CsvPipelineDataFormat(PipelineDataFormat):
def __init__(self, output: Optional[str], input: Optional[str], column: Optional[str]): def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str]):
super().__init__(output, input, column) super().__init__(output_path, input_path, column)
def __iter__(self): def __iter__(self):
with open(self.path, 'r') as f: with open(self.input_path, 'r') as f:
reader = csv.DictReader(f) reader = csv.DictReader(f)
for row in reader: for row in reader:
if self.is_multi_columns: if self.is_multi_columns:
...@@ -174,7 +177,7 @@ class CsvPipelineDataFormat(PipelineDataFormat): ...@@ -174,7 +177,7 @@ class CsvPipelineDataFormat(PipelineDataFormat):
yield row[self.column[0]] yield row[self.column[0]]
def save(self, data: List[dict]): def save(self, data: List[dict]):
with open(self.output, 'w') as f: with open(self.output_path, 'w') as f:
if len(data) > 0: if len(data) > 0:
writer = csv.DictWriter(f, list(data[0].keys())) writer = csv.DictWriter(f, list(data[0].keys()))
writer.writeheader() writer.writeheader()
...@@ -182,10 +185,10 @@ class CsvPipelineDataFormat(PipelineDataFormat): ...@@ -182,10 +185,10 @@ class CsvPipelineDataFormat(PipelineDataFormat):
class JsonPipelineDataFormat(PipelineDataFormat): class JsonPipelineDataFormat(PipelineDataFormat):
def __init__(self, output: Optional[str], input: Optional[str], column: Optional[str]): def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str]):
super().__init__(output, input, column) super().__init__(output_path, input_path, column)
with open(input, 'r') as f: with open(input_path, 'r') as f:
self._entries = json.load(f) self._entries = json.load(f)
def __iter__(self): def __iter__(self):
...@@ -196,7 +199,7 @@ class JsonPipelineDataFormat(PipelineDataFormat): ...@@ -196,7 +199,7 @@ class JsonPipelineDataFormat(PipelineDataFormat):
yield entry[self.column[0]] yield entry[self.column[0]]
def save(self, data: dict): def save(self, data: dict):
with open(self.output, 'w') as f: with open(self.output_path, 'w') as f:
json.dump(data, f) json.dump(data, f)
...@@ -208,9 +211,7 @@ class PipedPipelineDataFormat(PipelineDataFormat): ...@@ -208,9 +211,7 @@ class PipedPipelineDataFormat(PipelineDataFormat):
If columns are provided, then the output will be a dictionary with {column_x: value_x} If columns are provided, then the output will be a dictionary with {column_x: value_x}
""" """
def __iter__(self): def __iter__(self):
import sys
for line in sys.stdin: for line in sys.stdin:
# Split for multi-columns # Split for multi-columns
if '\t' in line: if '\t' in line:
...@@ -229,7 +230,7 @@ class PipedPipelineDataFormat(PipelineDataFormat): ...@@ -229,7 +230,7 @@ class PipedPipelineDataFormat(PipelineDataFormat):
print(data) print(data)
def save_binary(self, data: Union[dict, List[dict]]) -> str: def save_binary(self, data: Union[dict, List[dict]]) -> str:
if self.output is None: if self.output_path is None:
raise KeyError( raise KeyError(
'When using piped input on pipeline outputting large object requires an output file path. ' 'When using piped input on pipeline outputting large object requires an output file path. '
'Please provide such output path through --output argument.' 'Please provide such output path through --output argument.'
...@@ -294,6 +295,9 @@ class Pipeline(_ScikitCompat): ...@@ -294,6 +295,9 @@ class Pipeline(_ScikitCompat):
nlp = NerPipeline(model='...', config='...', tokenizer='...') nlp = NerPipeline(model='...', config='...', tokenizer='...')
nlp = QuestionAnsweringPipeline(model=AutoModel.from_pretrained('...'), tokenizer='...') nlp = QuestionAnsweringPipeline(model=AutoModel.from_pretrained('...'), tokenizer='...')
""" """
default_input_names = None
def __init__(self, model, tokenizer: PreTrainedTokenizer = None, def __init__(self, model, tokenizer: PreTrainedTokenizer = None,
modelcard: ModelCard = None, framework: Optional[str] = None, modelcard: ModelCard = None, framework: Optional[str] = None,
args_parser: ArgumentHandler = None, device: int = -1, args_parser: ArgumentHandler = None, device: int = -1,
...@@ -582,6 +586,8 @@ class QuestionAnsweringPipeline(Pipeline): ...@@ -582,6 +586,8 @@ class QuestionAnsweringPipeline(Pipeline):
Question Answering pipeline using ModelForQuestionAnswering head. Question Answering pipeline using ModelForQuestionAnswering head.
""" """
default_input_names = 'question,context'
def __init__(self, model, def __init__(self, model,
tokenizer: Optional[PreTrainedTokenizer], tokenizer: Optional[PreTrainedTokenizer],
modelcard: Optional[ModelCard], modelcard: Optional[ModelCard],
...@@ -684,7 +690,6 @@ class QuestionAnsweringPipeline(Pipeline): ...@@ -684,7 +690,6 @@ class QuestionAnsweringPipeline(Pipeline):
} }
for s, e, score in zip(starts, ends, scores) for s, e, score in zip(starts, ends, scores)
] ]
if len(answers) == 1: if len(answers) == 1:
return answers[0] return answers[0]
return answers return answers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment