Commit 955d7ecb authored by Morgan Funtowicz's avatar Morgan Funtowicz
Browse files

Refactored Pipeline with dedicated argument handler.

parent 8e3b1c86
......@@ -36,29 +36,40 @@ if is_torch_available():
AutoModelForQuestionAnswering, AutoModelForTokenClassification
class Pipeline(ABC):
def __init__(self, model, tokenizer: PreTrainedTokenizer = None, **kwargs):
self.model = model
self.tokenizer = tokenizer
class ArgumentHandler(ABC):
"""
Base interface for handling varargs for each Pipeline
"""
@abstractmethod
def __call__(self, *args, **kwargs):
raise NotImplementedError()
def save_pretrained(self, save_directory):
if not os.path.isdir(save_directory):
logger.error("Provided path ({}) should be a directory".format(save_directory))
return
self.model.save_pretrained(save_directory)
self.tokenizer.save_pretrained(save_directory)
class DefaultArgumentHandler(ArgumentHandler):
"""
Default varargs argument parser handling parameters for each Pipeline
"""
def __call__(self, *args, **kwargs):
if 'X' in kwargs:
return kwargs['X']
elif 'data' in kwargs:
return kwargs['data']
elif len(args) > 0:
return list(args)
raise ValueError('Unable to infer the format of the provided data (X=, data=, ...)')
def transform(self, *texts, **kwargs):
# Generic compatibility with sklearn and Keras
return self(*texts, **kwargs)
def predict(self, *texts, **kwargs):
# Generic compatibility with sklearn and Keras
return self(*texts, **kwargs)
class _ScikitCompat(ABC):
"""
Interface layer for the Scikit and Keras compatibility.
"""
@abstractmethod
def __call__(self, *texts, **kwargs):
def transform(self, X):
raise NotImplementedError()
@abstractmethod
def predict(self, X):
raise NotImplementedError()
......@@ -133,24 +144,45 @@ class JsonPipelineDataFormat(PipelineDataFormat):
if self.is_multi_columns:
yield {k: entry[c] for k, c in self.column}
else:
yield entry[self.column]
yield entry[self.column[0]]
def save(self, data: dict):
with open(self.output, 'w') as f:
json.dump(data, f)
class FeatureExtractionPipeline(Pipeline):
class Pipeline(_ScikitCompat):
def __init__(self, model, tokenizer: PreTrainedTokenizer = None, args_parser: ArgumentHandler = None, **kwargs):
self.model = model
self.tokenizer = tokenizer
self._args_parser = args_parser or DefaultArgumentHandler()
def save_pretrained(self, save_directory):
if not os.path.isdir(save_directory):
logger.error("Provided path ({}) should be a directory".format(save_directory))
return
self.model.save_pretrained(save_directory)
self.tokenizer.save_pretrained(save_directory)
def transform(self, X):
return self(X=X)
def predict(self, X):
return self(X=X)
def __call__(self, *texts, **kwargs):
# Generic compatibility with sklearn and Keras
if 'X' in kwargs and not texts:
texts = kwargs.pop('X')
# Parse arguments
inputs = self._args_parser(*texts, **kwargs)
# Encode for forward
inputs = self.tokenizer.batch_encode_plus(
texts, add_special_tokens=True, return_tensors='tf' if is_tf_available() else 'pt'
inputs, add_special_tokens=True, return_tensors='tf' if is_tf_available() else 'pt'
)
return self._forward(inputs)
def _forward(self, inputs):
if is_tf_available():
# TODO trace model
predictions = self.model(inputs)[0]
......@@ -159,7 +191,12 @@ class FeatureExtractionPipeline(Pipeline):
with torch.no_grad():
predictions = self.model(**inputs)[0]
return predictions.numpy().tolist()
return predictions.numpy()
class FeatureExtractionPipeline(Pipeline):
def __call__(self, *args, **kwargs):
return super().__call__(*args, **kwargs).tolist()
class TextClassificationPipeline(Pipeline):
......@@ -170,26 +207,8 @@ class TextClassificationPipeline(Pipeline):
raise Exception('Invalid parameter nb_classes. int >= 2 is required (got: {})'.format(nb_classes))
self._nb_classes = nb_classes
def __call__(self, *texts, **kwargs):
# Generic compatibility with sklearn and Keras
if 'X' in kwargs and not texts:
texts = kwargs.pop('X')
inputs = self.tokenizer.batch_encode_plus(
texts, add_special_tokens=True, return_tensors='tf' if is_tf_available() else 'pt'
)
special_tokens_mask = inputs.pop('special_tokens_mask')
if is_tf_available():
# TODO trace model
predictions = self.model(**inputs)[0]
else:
import torch
with torch.no_grad():
predictions = self.model(**inputs)[0]
return predictions.numpy().tolist()
def __call__(self, *args, **kwargs):
return super().__call__(*args, **kwargs).tolist()
class NerPipeline(Pipeline):
......@@ -198,8 +217,7 @@ class NerPipeline(Pipeline):
super().__init__(model, tokenizer)
def __call__(self, *texts, **kwargs):
(texts, ), answers = texts, []
inputs, answers = self._args_parser(*texts, **kwargs), []
for sentence in texts:
# Ugly token to word idx mapping (for now)
......@@ -241,24 +259,12 @@ class QuestionAnsweringPipeline(Pipeline):
Question Answering pipeline involving Tokenization and Inference.
"""
@classmethod
def from_config(cls, model, tokenizer: PreTrainedTokenizer, **kwargs):
pass
@staticmethod
def create_sample(question: Union[str, List[str]], context: Union[str, List[str]]) -> Union[SquadExample, List[SquadExample]]:
is_list = isinstance(question, list)
if is_list:
return [SquadExample(None, q, c, None, None, None) for q, c in zip(question, context)]
else:
return SquadExample(None, question, context, None, None, None)
class QuestionAnsweringArgumentHandler(ArgumentHandler):
@staticmethod
def handle_args(*inputs, **kwargs) -> List[SquadExample]:
def __call__(self, *args, **kwargs):
# Position args, handling is sensibly the same as X and data, so forwarding to avoid duplicating
if inputs is not None and len(inputs) > 1:
kwargs['X'] = inputs
if args is not None and len(args) > 1:
kwargs['X'] = args
# Generic compatibility with sklearn and Keras
# Batched data
......@@ -300,8 +306,17 @@ class QuestionAnsweringPipeline(Pipeline):
return inputs
@staticmethod
def create_sample(question: Union[str, List[str]], context: Union[str, List[str]]) -> Union[SquadExample, List[SquadExample]]:
is_list = isinstance(question, list)
if is_list:
return [SquadExample(None, q, c, None, None, None) for q, c in zip(question, context)]
else:
return SquadExample(None, question, context, None, None, None)
def __init__(self, model, tokenizer: Optional[PreTrainedTokenizer]):
super().__init__(model, tokenizer)
super().__init__(model, tokenizer, args_parser=QuestionAnsweringPipeline.QuestionAnsweringArgumentHandler())
def inputs_for_model(self, features: Union[SquadExample, List[SquadExample]]) -> Dict:
args = ['input_ids', 'attention_mask']
......@@ -332,9 +347,8 @@ class QuestionAnsweringPipeline(Pipeline):
if kwargs['max_answer_len'] < 1:
raise ValueError('max_answer_len parameter should be >= 1 (got {})'.format(kwargs['max_answer_len']))
examples = QuestionAnsweringPipeline.handle_args(texts, **kwargs)
# Convert inputs to features
examples = self._args_parser(*texts, **kwargs)
features = squad_convert_examples_to_features(examples, self.tokenizer, kwargs['max_seq_len'], kwargs['doc_stride'], kwargs['max_question_len'], False)
fw_args = self.inputs_for_model(features)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment