"vscode:/vscode.git/clone" did not exist on "e344e45f2d8fa5d8b1024bc2e8b740b873652dd4"
Commit f1971bf3 authored by Morgan Funtowicz's avatar Morgan Funtowicz
Browse files

Binding pipelines to the cli.

parent 0b51532c
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from argparse import ArgumentParser from argparse import ArgumentParser
from transformers.commands.download import DownloadCommand from transformers.commands.download import DownloadCommand
from transformers.commands.run import RunCommand
from transformers.commands.serving import ServeCommand from transformers.commands.serving import ServeCommand
from transformers.commands.user import UserCommands from transformers.commands.user import UserCommands
from transformers.commands.train import TrainCommand from transformers.commands.train import TrainCommand
...@@ -14,9 +15,10 @@ if __name__ == '__main__': ...@@ -14,9 +15,10 @@ if __name__ == '__main__':
# Register commands # Register commands
ConvertCommand.register_subcommand(commands_parser) ConvertCommand.register_subcommand(commands_parser)
DownloadCommand.register_subcommand(commands_parser) DownloadCommand.register_subcommand(commands_parser)
RunCommand.register_subcommand(commands_parser)
ServeCommand.register_subcommand(commands_parser) ServeCommand.register_subcommand(commands_parser)
UserCommands.register_subcommand(commands_parser)
TrainCommand.register_subcommand(commands_parser) TrainCommand.register_subcommand(commands_parser)
UserCommands.register_subcommand(commands_parser)
# Let's go # Let's go
args = parser.parse_args() args = parser.parse_args()
......
from argparse import ArgumentParser
from transformers.commands import BaseTransformersCLICommand
from transformers.pipelines import pipeline, Pipeline, PipelineDataFormat, SUPPORTED_TASKS
def try_infer_format_from_ext(path: str):
for ext in PipelineDataFormat.SUPPORTED_FORMATS:
if path.endswith(ext):
return ext
raise Exception(
'Unable to determine file format from file extension {}. '
'Please provide the format through --format {}'.format(path, PipelineDataFormat.SUPPORTED_FORMATS)
)
def run_command_factory(args):
nlp = pipeline(task=args.task, model=args.model, tokenizer=args.tokenizer)
format = try_infer_format_from_ext(args.input) if args.format == 'infer' else args.format
reader = PipelineDataFormat.from_str(format, args.output, args.input, args.column)
return RunCommand(nlp, reader)
class RunCommand(BaseTransformersCLICommand):
def __init__(self, nlp: Pipeline, reader: PipelineDataFormat):
self._nlp = nlp
self._reader = reader
@staticmethod
def register_subcommand(parser: ArgumentParser):
run_parser = parser.add_parser('run', help="Run a pipeline through the CLI")
run_parser.add_argument('--task', choices=SUPPORTED_TASKS.keys(), help='Task to run')
run_parser.add_argument('--model', type=str, required=True, help='Name or path to the model to instantiate.')
run_parser.add_argument('--tokenizer', type=str, help='Name of the tokenizer to use. (default: same as the model name)')
run_parser.add_argument('--column', type=str, required=True, help='Name of the column to use as input. (For multi columns input as QA use column1,columns2)')
run_parser.add_argument('--format', type=str, default='infer', choices=PipelineDataFormat.SUPPORTED_FORMATS, help='Input format to read from')
run_parser.add_argument('--input', type=str, required=True, help='Path to the file to use for inference')
run_parser.add_argument('--output', type=str, required=True, help='Path to the file that will be used post to write results.')
run_parser.add_argument('kwargs', nargs='*', help='Arguments to forward to the file format reader')
run_parser.set_defaults(func=run_command_factory)
def run(self):
nlp, output = self._nlp, []
for entry in self._reader:
if self._reader.is_multi_columns:
output += [nlp(**entry)]
else:
output += [nlp(entry)]
# Saving data
self._reader.save(output)
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import csv
import json
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from itertools import groupby from itertools import groupby
...@@ -25,11 +27,13 @@ from transformers import AutoTokenizer, PreTrainedTokenizer, PretrainedConfig, \ ...@@ -25,11 +27,13 @@ from transformers import AutoTokenizer, PreTrainedTokenizer, PretrainedConfig, \
SquadExample, squad_convert_examples_to_features, is_tf_available, is_torch_available, logger SquadExample, squad_convert_examples_to_features, is_tf_available, is_torch_available, logger
if is_tf_available(): if is_tf_available():
from transformers import TFAutoModelForSequenceClassification, TFAutoModelForQuestionAnswering, TFAutoModelForTokenClassification from transformers import TFAutoModel, TFAutoModelForSequenceClassification, \
TFAutoModelForQuestionAnswering, TFAutoModelForTokenClassification
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import AutoModelForSequenceClassification, AutoModelForQuestionAnswering, AutoModelForTokenClassification from transformers import AutoModel, AutoModelForSequenceClassification, \
AutoModelForQuestionAnswering, AutoModelForTokenClassification
class Pipeline(ABC): class Pipeline(ABC):
...@@ -58,6 +62,84 @@ class Pipeline(ABC): ...@@ -58,6 +62,84 @@ class Pipeline(ABC):
raise NotImplementedError() raise NotImplementedError()
class PipelineDataFormat:
SUPPORTED_FORMATS = ['json', 'csv']
def __init__(self, output: str, path: str, column: str):
self.output = output
self.path = path
self.column = column.split(',')
self.is_multi_columns = len(self.column) > 1
if self.is_multi_columns:
self.column = [tuple(c.split('=')) if '=' in c else (c, c) for c in self.column]
from os.path import abspath, exists
if exists(abspath(self.output)):
raise OSError('{} already exists on disk'.format(self.output))
if not exists(abspath(self.path)):
raise OSError('{} doesnt exist on disk'.format(self.path))
@abstractmethod
def __iter__(self):
raise NotImplementedError()
@abstractmethod
def save(self, data: dict):
raise NotImplementedError()
@staticmethod
def from_str(name: str, output: str, path: str, column: str):
if name == 'json':
return JsonPipelineDataFormat(output, path, column)
elif name == 'csv':
return CsvPipelineDataFormat(output, path, column)
else:
raise KeyError('Unknown reader {} (Available reader are json/csv)'.format(name))
class CsvPipelineDataFormat(PipelineDataFormat):
def __init__(self, output: str, path: str, column: str):
super().__init__(output, path, column)
def __iter__(self):
with open(self.path, 'r') as f:
reader = csv.DictReader(f)
for row in reader:
if self.is_multi_columns:
yield {k: row[c] for k, c in self.column}
else:
yield row[self.column]
def save(self, data: List[dict]):
with open(self.output, 'w') as f:
if len(data) > 0:
writer = csv.DictWriter(f, list(data[0].keys()))
writer.writeheader()
writer.writerows(data)
class JsonPipelineDataFormat(PipelineDataFormat):
def __init__(self, output: str, path: str, column: str):
super().__init__(output, path, column)
with open(path, 'r') as f:
self._entries = json.load(f)
def __iter__(self):
for entry in self._entries:
if self.is_multi_columns:
yield {k: entry[c] for k, c in self.column}
else:
yield entry[self.column]
def save(self, data: dict):
with open(self.output, 'w') as f:
json.dump(data, f)
class FeatureExtractionPipeline(Pipeline): class FeatureExtractionPipeline(Pipeline):
def __call__(self, *texts, **kwargs): def __call__(self, *texts, **kwargs):
...@@ -127,7 +209,7 @@ class NerPipeline(Pipeline): ...@@ -127,7 +209,7 @@ class NerPipeline(Pipeline):
label_idx = score.argmax() label_idx = score.argmax()
answer += [{ answer += [{
'word': words[idx - 1], 'score': score[label_idx], 'entity': self.model.config.id2label[label_idx] 'word': words[idx - 1], 'score': score[label_idx].item(), 'entity': self.model.config.id2label[label_idx]
}] }]
# Update token start # Update token start
...@@ -270,16 +352,18 @@ class QuestionAnsweringPipeline(Pipeline): ...@@ -270,16 +352,18 @@ class QuestionAnsweringPipeline(Pipeline):
char_to_word = np.array(example.char_to_word_offset) char_to_word = np.array(example.char_to_word_offset)
# Convert the answer (tokens) back to the original text # Convert the answer (tokens) back to the original text
answers += [[ answers += [
{ {
'score': score, 'score': score.item(),
'start': np.where(char_to_word == feature.token_to_orig_map[s])[0][0], 'start': np.where(char_to_word == feature.token_to_orig_map[s])[0][0].item(),
'end': np.where(char_to_word == feature.token_to_orig_map[e])[0][-1], 'end': np.where(char_to_word == feature.token_to_orig_map[e])[0][-1].item(),
'answer': ' '.join(example.doc_tokens[feature.token_to_orig_map[s]: feature.token_to_orig_map[e] + 1]) 'answer': ' '.join(example.doc_tokens[feature.token_to_orig_map[s]: feature.token_to_orig_map[e] + 1])
} }
for s, e, score in zip(starts, ends, scores) for s, e, score in zip(starts, ends, scores)
]] ]
if len(answers) == 1:
return answers[0]
return answers return answers
def decode(self, start: np.ndarray, end: np.ndarray, topk: int, max_answer_len: int) -> Tuple: def decode(self, start: np.ndarray, end: np.ndarray, topk: int, max_answer_len: int) -> Tuple:
...@@ -363,7 +447,7 @@ def pipeline(task: str, model, config: Optional[PretrainedConfig] = None, tokeni ...@@ -363,7 +447,7 @@ def pipeline(task: str, model, config: Optional[PretrainedConfig] = None, tokeni
Utility factory method to build pipeline. Utility factory method to build pipeline.
""" """
# Try to infer tokenizer from model name (if provided as str) # Try to infer tokenizer from model name (if provided as str)
if not isinstance(tokenizer, PreTrainedTokenizer): if tokenizer is None:
if not isinstance(model, str): if not isinstance(model, str):
# Impossible to guest what is the right tokenizer here # Impossible to guest what is the right tokenizer here
raise Exception('Tokenizer cannot be None if provided model is a PreTrainedModel instance') raise Exception('Tokenizer cannot be None if provided model is a PreTrainedModel instance')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment