Commit 955d7ecb authored by Morgan Funtowicz's avatar Morgan Funtowicz
Browse files

Refactored Pipeline with dedicated argument handler.

parent 8e3b1c86
...@@ -36,29 +36,40 @@ if is_torch_available(): ...@@ -36,29 +36,40 @@ if is_torch_available():
AutoModelForQuestionAnswering, AutoModelForTokenClassification AutoModelForQuestionAnswering, AutoModelForTokenClassification
class Pipeline(ABC): class ArgumentHandler(ABC):
def __init__(self, model, tokenizer: PreTrainedTokenizer = None, **kwargs): """
self.model = model Base interface for handling varargs for each Pipeline
self.tokenizer = tokenizer """
@abstractmethod
def __call__(self, *args, **kwargs):
raise NotImplementedError()
def save_pretrained(self, save_directory):
if not os.path.isdir(save_directory):
logger.error("Provided path ({}) should be a directory".format(save_directory))
return
self.model.save_pretrained(save_directory) class DefaultArgumentHandler(ArgumentHandler):
self.tokenizer.save_pretrained(save_directory) """
Default varargs argument parser handling parameters for each Pipeline
"""
def __call__(self, *args, **kwargs):
if 'X' in kwargs:
return kwargs['X']
elif 'data' in kwargs:
return kwargs['data']
elif len(args) > 0:
return list(args)
raise ValueError('Unable to infer the format of the provided data (X=, data=, ...)')
def transform(self, *texts, **kwargs):
# Generic compatibility with sklearn and Keras
return self(*texts, **kwargs)
def predict(self, *texts, **kwargs): class _ScikitCompat(ABC):
# Generic compatibility with sklearn and Keras """
return self(*texts, **kwargs) Interface layer for the Scikit and Keras compatibility.
"""
@abstractmethod @abstractmethod
def __call__(self, *texts, **kwargs): def transform(self, X):
raise NotImplementedError()
@abstractmethod
def predict(self, X):
raise NotImplementedError() raise NotImplementedError()
...@@ -133,24 +144,45 @@ class JsonPipelineDataFormat(PipelineDataFormat): ...@@ -133,24 +144,45 @@ class JsonPipelineDataFormat(PipelineDataFormat):
if self.is_multi_columns: if self.is_multi_columns:
yield {k: entry[c] for k, c in self.column} yield {k: entry[c] for k, c in self.column}
else: else:
yield entry[self.column] yield entry[self.column[0]]
def save(self, data: dict): def save(self, data: dict):
with open(self.output, 'w') as f: with open(self.output, 'w') as f:
json.dump(data, f) json.dump(data, f)
class FeatureExtractionPipeline(Pipeline): class Pipeline(_ScikitCompat):
def __init__(self, model, tokenizer: PreTrainedTokenizer = None, args_parser: ArgumentHandler = None, **kwargs):
self.model = model
self.tokenizer = tokenizer
self._args_parser = args_parser or DefaultArgumentHandler()
def save_pretrained(self, save_directory):
if not os.path.isdir(save_directory):
logger.error("Provided path ({}) should be a directory".format(save_directory))
return
self.model.save_pretrained(save_directory)
self.tokenizer.save_pretrained(save_directory)
def transform(self, X):
return self(X=X)
def predict(self, X):
return self(X=X)
def __call__(self, *texts, **kwargs): def __call__(self, *texts, **kwargs):
# Generic compatibility with sklearn and Keras # Parse arguments
if 'X' in kwargs and not texts: inputs = self._args_parser(*texts, **kwargs)
texts = kwargs.pop('X')
# Encode for forward
inputs = self.tokenizer.batch_encode_plus( inputs = self.tokenizer.batch_encode_plus(
texts, add_special_tokens=True, return_tensors='tf' if is_tf_available() else 'pt' inputs, add_special_tokens=True, return_tensors='tf' if is_tf_available() else 'pt'
) )
return self._forward(inputs)
def _forward(self, inputs):
if is_tf_available(): if is_tf_available():
# TODO trace model # TODO trace model
predictions = self.model(inputs)[0] predictions = self.model(inputs)[0]
...@@ -159,7 +191,12 @@ class FeatureExtractionPipeline(Pipeline): ...@@ -159,7 +191,12 @@ class FeatureExtractionPipeline(Pipeline):
with torch.no_grad(): with torch.no_grad():
predictions = self.model(**inputs)[0] predictions = self.model(**inputs)[0]
return predictions.numpy().tolist() return predictions.numpy()
class FeatureExtractionPipeline(Pipeline):
def __call__(self, *args, **kwargs):
return super().__call__(*args, **kwargs).tolist()
class TextClassificationPipeline(Pipeline): class TextClassificationPipeline(Pipeline):
...@@ -170,26 +207,8 @@ class TextClassificationPipeline(Pipeline): ...@@ -170,26 +207,8 @@ class TextClassificationPipeline(Pipeline):
raise Exception('Invalid parameter nb_classes. int >= 2 is required (got: {})'.format(nb_classes)) raise Exception('Invalid parameter nb_classes. int >= 2 is required (got: {})'.format(nb_classes))
self._nb_classes = nb_classes self._nb_classes = nb_classes
def __call__(self, *texts, **kwargs): def __call__(self, *args, **kwargs):
# Generic compatibility with sklearn and Keras return super().__call__(*args, **kwargs).tolist()
if 'X' in kwargs and not texts:
texts = kwargs.pop('X')
inputs = self.tokenizer.batch_encode_plus(
texts, add_special_tokens=True, return_tensors='tf' if is_tf_available() else 'pt'
)
special_tokens_mask = inputs.pop('special_tokens_mask')
if is_tf_available():
# TODO trace model
predictions = self.model(**inputs)[0]
else:
import torch
with torch.no_grad():
predictions = self.model(**inputs)[0]
return predictions.numpy().tolist()
class NerPipeline(Pipeline): class NerPipeline(Pipeline):
...@@ -198,8 +217,7 @@ class NerPipeline(Pipeline): ...@@ -198,8 +217,7 @@ class NerPipeline(Pipeline):
super().__init__(model, tokenizer) super().__init__(model, tokenizer)
def __call__(self, *texts, **kwargs): def __call__(self, *texts, **kwargs):
(texts, ), answers = texts, [] inputs, answers = self._args_parser(*texts, **kwargs), []
for sentence in texts: for sentence in texts:
# Ugly token to word idx mapping (for now) # Ugly token to word idx mapping (for now)
...@@ -241,9 +259,52 @@ class QuestionAnsweringPipeline(Pipeline): ...@@ -241,9 +259,52 @@ class QuestionAnsweringPipeline(Pipeline):
Question Answering pipeline involving Tokenization and Inference. Question Answering pipeline involving Tokenization and Inference.
""" """
@classmethod class QuestionAnsweringArgumentHandler(ArgumentHandler):
def from_config(cls, model, tokenizer: PreTrainedTokenizer, **kwargs):
pass def __call__(self, *args, **kwargs):
# Position args, handling is sensibly the same as X and data, so forwarding to avoid duplicating
if args is not None and len(args) > 1:
kwargs['X'] = args
# Generic compatibility with sklearn and Keras
# Batched data
if 'X' in kwargs or 'data' in kwargs:
data = kwargs['X'] if 'X' in kwargs else kwargs['data']
if not isinstance(data, list):
data = [data]
for i, item in enumerate(data):
if isinstance(item, dict):
if any(k not in item for k in ['question', 'context']):
raise KeyError('You need to provide a dictionary with keys {question:..., context:...}')
data[i] = QuestionAnsweringPipeline.create_sample(**item)
elif isinstance(item, SquadExample):
continue
else:
raise ValueError(
'{} argument needs to be of type (list[SquadExample | dict], SquadExample, dict)'
.format('X' if 'X' in kwargs else 'data')
)
inputs = data
# Tabular input
elif 'question' in kwargs and 'context' in kwargs:
if isinstance(kwargs['question'], str):
kwargs['question'] = [kwargs['question']]
if isinstance(kwargs['context'], str):
kwargs['context'] = [kwargs['context']]
inputs = [QuestionAnsweringPipeline.create_sample(q, c) for q, c in zip(kwargs['question'], kwargs['context'])]
else:
raise ValueError('Unknown arguments {}'.format(kwargs))
if not isinstance(inputs, list):
inputs = [inputs]
return inputs
@staticmethod @staticmethod
def create_sample(question: Union[str, List[str]], context: Union[str, List[str]]) -> Union[SquadExample, List[SquadExample]]: def create_sample(question: Union[str, List[str]], context: Union[str, List[str]]) -> Union[SquadExample, List[SquadExample]]:
...@@ -254,54 +315,8 @@ class QuestionAnsweringPipeline(Pipeline): ...@@ -254,54 +315,8 @@ class QuestionAnsweringPipeline(Pipeline):
else: else:
return SquadExample(None, question, context, None, None, None) return SquadExample(None, question, context, None, None, None)
@staticmethod
def handle_args(*inputs, **kwargs) -> List[SquadExample]:
# Position args, handling is sensibly the same as X and data, so forwarding to avoid duplicating
if inputs is not None and len(inputs) > 1:
kwargs['X'] = inputs
# Generic compatibility with sklearn and Keras
# Batched data
if 'X' in kwargs or 'data' in kwargs:
data = kwargs['X'] if 'X' in kwargs else kwargs['data']
if not isinstance(data, list):
data = [data]
for i, item in enumerate(data):
if isinstance(item, dict):
if any(k not in item for k in ['question', 'context']):
raise KeyError('You need to provide a dictionary with keys {question:..., context:...}')
data[i] = QuestionAnsweringPipeline.create_sample(**item)
elif isinstance(item, SquadExample):
continue
else:
raise ValueError(
'{} argument needs to be of type (list[SquadExample | dict], SquadExample, dict)'
.format('X' if 'X' in kwargs else 'data')
)
inputs = data
# Tabular input
elif 'question' in kwargs and 'context' in kwargs:
if isinstance(kwargs['question'], str):
kwargs['question'] = [kwargs['question']]
if isinstance(kwargs['context'], str):
kwargs['context'] = [kwargs['context']]
inputs = [QuestionAnsweringPipeline.create_sample(q, c) for q, c in zip(kwargs['question'], kwargs['context'])]
else:
raise ValueError('Unknown arguments {}'.format(kwargs))
if not isinstance(inputs, list):
inputs = [inputs]
return inputs
def __init__(self, model, tokenizer: Optional[PreTrainedTokenizer]): def __init__(self, model, tokenizer: Optional[PreTrainedTokenizer]):
super().__init__(model, tokenizer) super().__init__(model, tokenizer, args_parser=QuestionAnsweringPipeline.QuestionAnsweringArgumentHandler())
def inputs_for_model(self, features: Union[SquadExample, List[SquadExample]]) -> Dict: def inputs_for_model(self, features: Union[SquadExample, List[SquadExample]]) -> Dict:
args = ['input_ids', 'attention_mask'] args = ['input_ids', 'attention_mask']
...@@ -332,9 +347,8 @@ class QuestionAnsweringPipeline(Pipeline): ...@@ -332,9 +347,8 @@ class QuestionAnsweringPipeline(Pipeline):
if kwargs['max_answer_len'] < 1: if kwargs['max_answer_len'] < 1:
raise ValueError('max_answer_len parameter should be >= 1 (got {})'.format(kwargs['max_answer_len'])) raise ValueError('max_answer_len parameter should be >= 1 (got {})'.format(kwargs['max_answer_len']))
examples = QuestionAnsweringPipeline.handle_args(texts, **kwargs)
# Convert inputs to features # Convert inputs to features
examples = self._args_parser(*texts, **kwargs)
features = squad_convert_examples_to_features(examples, self.tokenizer, kwargs['max_seq_len'], kwargs['doc_stride'], kwargs['max_question_len'], False) features = squad_convert_examples_to_features(examples, self.tokenizer, kwargs['max_seq_len'], kwargs['doc_stride'], kwargs['max_question_len'], False)
fw_args = self.inputs_for_model(features) fw_args = self.inputs_for_model(features)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment