"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "00d8502b7ade5aa3da43b13f23bb447faa6d459e"
Unverified Commit ee2a80ec authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

add return_tensors parameter for feature_extraction 2 (#19707)

* add return_tensors parameter for feature_extraction  w/ test

add return_tensor parameter for feature extraction

Revert "Merge branch 'feature-extraction-return-tensor' of https://github.com/ajsanjoaquin/transformers

 into feature-extraction-return-tensor"

This reverts commit d559da743b87914e111a84a98ba6dbb70d08ad88, reversing
changes made to bbef89278650c04c090beb65637a8e9572dba222.

call parameter directly
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>

Fixup.

Update src/transformers/pipelines/feature_extraction.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Fix the imports.

* Fixing the test by not overflowing the model capacity.
Co-authored-by: default avatarAJ San Joaquin <ajsanjoaquin@gmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 02b63702
...@@ -31,6 +31,8 @@ class FeatureExtractionPipeline(Pipeline): ...@@ -31,6 +31,8 @@ class FeatureExtractionPipeline(Pipeline):
If no framework is specified, will default to the one currently installed. If no framework is specified and If no framework is specified, will default to the one currently installed. If no framework is specified and
both frameworks are installed, will default to the framework of the `model`, or to PyTorch if no model is both frameworks are installed, will default to the framework of the `model`, or to PyTorch if no model is
provided. provided.
return_tensor (`bool`, *optional*):
If `True`, returns a tensor according to the specified framework, otherwise returns a list.
task (`str`, defaults to `""`): task (`str`, defaults to `""`):
A task-identifier for the pipeline. A task-identifier for the pipeline.
args_parser ([`~pipelines.ArgumentHandler`], *optional*): args_parser ([`~pipelines.ArgumentHandler`], *optional*):
...@@ -40,7 +42,7 @@ class FeatureExtractionPipeline(Pipeline): ...@@ -40,7 +42,7 @@ class FeatureExtractionPipeline(Pipeline):
the associated CUDA device id. the associated CUDA device id.
""" """
def _sanitize_parameters(self, truncation=None, tokenize_kwargs=None, **kwargs): def _sanitize_parameters(self, truncation=None, tokenize_kwargs=None, return_tensors=None, **kwargs):
if tokenize_kwargs is None: if tokenize_kwargs is None:
tokenize_kwargs = {} tokenize_kwargs = {}
...@@ -53,7 +55,11 @@ class FeatureExtractionPipeline(Pipeline): ...@@ -53,7 +55,11 @@ class FeatureExtractionPipeline(Pipeline):
preprocess_params = tokenize_kwargs preprocess_params = tokenize_kwargs
return preprocess_params, {}, {} postprocess_params = {}
if return_tensors is not None:
postprocess_params["return_tensors"] = return_tensors
return preprocess_params, {}, postprocess_params
def preprocess(self, inputs, **tokenize_kwargs) -> Dict[str, GenericTensor]: def preprocess(self, inputs, **tokenize_kwargs) -> Dict[str, GenericTensor]:
return_tensors = self.framework return_tensors = self.framework
...@@ -64,8 +70,10 @@ class FeatureExtractionPipeline(Pipeline): ...@@ -64,8 +70,10 @@ class FeatureExtractionPipeline(Pipeline):
model_outputs = self.model(**model_inputs) model_outputs = self.model(**model_inputs)
return model_outputs return model_outputs
def postprocess(self, model_outputs): def postprocess(self, model_outputs, return_tensors=False):
# [0] is the first available tensor, logits or last_hidden_state. # [0] is the first available tensor, logits or last_hidden_state.
if return_tensors:
return model_outputs[0]
if self.framework == "pt": if self.framework == "pt":
return model_outputs[0].tolist() return model_outputs[0].tolist()
elif self.framework == "tf": elif self.framework == "tf":
......
...@@ -22,6 +22,8 @@ from transformers import ( ...@@ -22,6 +22,8 @@ from transformers import (
TF_MODEL_MAPPING, TF_MODEL_MAPPING,
FeatureExtractionPipeline, FeatureExtractionPipeline,
LxmertConfig, LxmertConfig,
is_tf_available,
is_torch_available,
pipeline, pipeline,
) )
from transformers.testing_utils import nested_simplify, require_tf, require_torch from transformers.testing_utils import nested_simplify, require_tf, require_torch
...@@ -29,6 +31,13 @@ from transformers.testing_utils import nested_simplify, require_tf, require_torc ...@@ -29,6 +31,13 @@ from transformers.testing_utils import nested_simplify, require_tf, require_torc
from .test_pipelines_common import PipelineTestCaseMeta from .test_pipelines_common import PipelineTestCaseMeta
if is_torch_available():
import torch
if is_tf_available():
import tensorflow as tf
class FeatureExtractionPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta): class FeatureExtractionPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
model_mapping = MODEL_MAPPING model_mapping = MODEL_MAPPING
tf_model_mapping = TF_MODEL_MAPPING tf_model_mapping = TF_MODEL_MAPPING
...@@ -133,6 +142,22 @@ class FeatureExtractionPipelineTests(unittest.TestCase, metaclass=PipelineTestCa ...@@ -133,6 +142,22 @@ class FeatureExtractionPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
tokenize_kwargs=tokenize_kwargs, tokenize_kwargs=tokenize_kwargs,
) )
@require_torch
def test_return_tensors_pt(self):
feature_extractor = pipeline(
task="feature-extraction", model="hf-internal-testing/tiny-random-distilbert", framework="pt"
)
outputs = feature_extractor("This is a test", return_tensors=True)
self.assertTrue(torch.is_tensor(outputs))
@require_tf
def test_return_tensors_tf(self):
feature_extractor = pipeline(
task="feature-extraction", model="hf-internal-testing/tiny-random-distilbert", framework="tf"
)
outputs = feature_extractor("This is a test", return_tensors=True)
self.assertTrue(tf.is_tensor(outputs))
def get_shape(self, input_, shape=None): def get_shape(self, input_, shape=None):
if shape is None: if shape is None:
shape = [] shape = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment