Commit 8e3b1c86 authored by Morgan Funtowicz's avatar Morgan Funtowicz
Browse files

Added FeatureExtraction pipeline.

parent f1971bf3
...@@ -143,7 +143,23 @@ class JsonPipelineDataFormat(PipelineDataFormat): ...@@ -143,7 +143,23 @@ class JsonPipelineDataFormat(PipelineDataFormat):
class FeatureExtractionPipeline(Pipeline): class FeatureExtractionPipeline(Pipeline):
def __call__(self, *texts, **kwargs): def __call__(self, *texts, **kwargs):
pass # Generic compatibility with sklearn and Keras
if 'X' in kwargs and not texts:
texts = kwargs.pop('X')
inputs = self.tokenizer.batch_encode_plus(
texts, add_special_tokens=True, return_tensors='tf' if is_tf_available() else 'pt'
)
if is_tf_available():
# TODO trace model
predictions = self.model(inputs)[0]
else:
import torch
with torch.no_grad():
predictions = self.model(**inputs)[0]
return predictions.numpy().tolist()
class TextClassificationPipeline(Pipeline): class TextClassificationPipeline(Pipeline):
...@@ -424,6 +440,11 @@ class QuestionAnsweringPipeline(Pipeline): ...@@ -424,6 +440,11 @@ class QuestionAnsweringPipeline(Pipeline):
# Register all the supported task here # Register all the supported task here
SUPPORTED_TASKS = { SUPPORTED_TASKS = {
'feature-extraction': {
'impl': FeatureExtractionPipeline,
'tf': TFAutoModel if is_tf_available() else None,
'pt': AutoModel if is_torch_available() else None,
},
'text-classification': { 'text-classification': {
'impl': TextClassificationPipeline, 'impl': TextClassificationPipeline,
'tf': TFAutoModelForSequenceClassification if is_tf_available() else None, 'tf': TFAutoModelForSequenceClassification if is_tf_available() else None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment