Unverified Commit ee2a80ec authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

add return_tensors parameter for feature_extraction 2 (#19707)

* add return_tensors parameter for feature_extraction  w/ test

add return_tensor parameter for feature extraction

Revert "Merge branch 'feature-extraction-return-tensor' of https://github.com/ajsanjoaquin/transformers

 into feature-extraction-return-tensor"

This reverts commit d559da743b87914e111a84a98ba6dbb70d08ad88, reversing
changes made to bbef89278650c04c090beb65637a8e9572dba222.

call parameter directly
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>

Fixup.

Update src/transformers/pipelines/feature_extraction.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Fix the imports.

* Fixing the test by not overflowing the model capacity.
Co-authored-by: default avatarAJ San Joaquin <ajsanjoaquin@gmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 02b63702
......@@ -31,6 +31,8 @@ class FeatureExtractionPipeline(Pipeline):
If no framework is specified, will default to the one currently installed. If no framework is specified and
both frameworks are installed, will default to the framework of the `model`, or to PyTorch if no model is
provided.
return_tensor (`bool`, *optional*):
If `True`, returns a tensor according to the specified framework, otherwise returns a list.
task (`str`, defaults to `""`):
A task-identifier for the pipeline.
args_parser ([`~pipelines.ArgumentHandler`], *optional*):
......@@ -40,7 +42,7 @@ class FeatureExtractionPipeline(Pipeline):
the associated CUDA device id.
"""
def _sanitize_parameters(self, truncation=None, tokenize_kwargs=None, **kwargs):
def _sanitize_parameters(self, truncation=None, tokenize_kwargs=None, return_tensors=None, **kwargs):
if tokenize_kwargs is None:
tokenize_kwargs = {}
......@@ -53,7 +55,11 @@ class FeatureExtractionPipeline(Pipeline):
preprocess_params = tokenize_kwargs
return preprocess_params, {}, {}
postprocess_params = {}
if return_tensors is not None:
postprocess_params["return_tensors"] = return_tensors
return preprocess_params, {}, postprocess_params
def preprocess(self, inputs, **tokenize_kwargs) -> Dict[str, GenericTensor]:
return_tensors = self.framework
......@@ -64,8 +70,10 @@ class FeatureExtractionPipeline(Pipeline):
model_outputs = self.model(**model_inputs)
return model_outputs
def postprocess(self, model_outputs):
def postprocess(self, model_outputs, return_tensors=False):
# [0] is the first available tensor, logits or last_hidden_state.
if return_tensors:
return model_outputs[0]
if self.framework == "pt":
return model_outputs[0].tolist()
elif self.framework == "tf":
......
......@@ -22,6 +22,8 @@ from transformers import (
TF_MODEL_MAPPING,
FeatureExtractionPipeline,
LxmertConfig,
is_tf_available,
is_torch_available,
pipeline,
)
from transformers.testing_utils import nested_simplify, require_tf, require_torch
......@@ -29,6 +31,13 @@ from transformers.testing_utils import nested_simplify, require_tf, require_torc
from .test_pipelines_common import PipelineTestCaseMeta
if is_torch_available():
import torch
if is_tf_available():
import tensorflow as tf
class FeatureExtractionPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
model_mapping = MODEL_MAPPING
tf_model_mapping = TF_MODEL_MAPPING
......@@ -133,6 +142,22 @@ class FeatureExtractionPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
tokenize_kwargs=tokenize_kwargs,
)
@require_torch
def test_return_tensors_pt(self):
feature_extractor = pipeline(
task="feature-extraction", model="hf-internal-testing/tiny-random-distilbert", framework="pt"
)
outputs = feature_extractor("This is a test", return_tensors=True)
self.assertTrue(torch.is_tensor(outputs))
@require_tf
def test_return_tensors_tf(self):
feature_extractor = pipeline(
task="feature-extraction", model="hf-internal-testing/tiny-random-distilbert", framework="tf"
)
outputs = feature_extractor("This is a test", return_tensors=True)
self.assertTrue(tf.is_tensor(outputs))
def get_shape(self, input_, shape=None):
if shape is None:
shape = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment