Unverified Commit dec759e7 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding support for `truncation` parameter on `feature-extraction` pipeline. (#14193)

* Adding support for `truncation` parameter on `feature-extraction`
pipeline.

Fixes #14183

* Fixing tests on ibert, longformer, and roberta.

* Rebase fix.
parent 27b1516d
......@@ -41,12 +41,19 @@ class FeatureExtractionPipeline(Pipeline):
the associated CUDA device id.
"""
def _sanitize_parameters(self, **kwargs):
return {}, {}, {}
def _sanitize_parameters(self, truncation=None, **kwargs):
preprocess_params = {}
if truncation is not None:
preprocess_params["truncation"] = truncation
return preprocess_params, {}, {}
def preprocess(self, inputs) -> Dict[str, GenericTensor]:
def preprocess(self, inputs, truncation=None) -> Dict[str, GenericTensor]:
return_tensors = self.framework
model_inputs = self.tokenizer(inputs, return_tensors=return_tensors)
if truncation is None:
kwargs = {}
else:
kwargs = {"truncation": truncation}
model_inputs = self.tokenizer(inputs, return_tensors=return_tensors, **kwargs)
return model_inputs
def _forward(self, model_inputs):
......
......@@ -22,7 +22,15 @@ from abc import abstractmethod
from functools import lru_cache
from unittest import skipIf
from transformers import FEATURE_EXTRACTOR_MAPPING, TOKENIZER_MAPPING, AutoFeatureExtractor, AutoTokenizer, pipeline
from transformers import (
FEATURE_EXTRACTOR_MAPPING,
TOKENIZER_MAPPING,
AutoFeatureExtractor,
AutoTokenizer,
IBertConfig,
RobertaConfig,
pipeline,
)
from transformers.pipelines.base import _pad
from transformers.testing_utils import is_pipeline_test, require_torch
......@@ -143,7 +151,7 @@ class PipelineTestCaseMeta(type):
try:
tokenizer = get_tiny_tokenizer_from_checkpoint(checkpoint)
# XLNet actually defines it as -1.
if model.config.__class__.__name__ == "RobertaConfig":
if isinstance(model.config, (RobertaConfig, IBertConfig)):
tokenizer.model_max_length = model.config.max_position_embeddings - 2
elif (
hasattr(model.config, "max_position_embeddings")
......
......@@ -105,3 +105,7 @@ class FeatureExtractionPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
outputs = feature_extractor(["This is a test", "Another longer test"])
shape = self.get_shape(outputs)
self.assertEqual(shape[0], 2)
outputs = feature_extractor("This is a test" * 100, truncation=True)
shape = self.get_shape(outputs)
self.assertEqual(shape[0], 1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment