Unverified Commit d3f4cef7 authored by Rak Alexey's avatar Rak Alexey Committed by GitHub
Browse files

fix image2test args forwarding (#19648)



* fix image2test args forwarding

* fix issues

* Proposing the update to the PR.

* Fixup.
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>
parent 3b419cfc
...@@ -44,8 +44,20 @@ class ImageToTextPipeline(Pipeline): ...@@ -44,8 +44,20 @@ class ImageToTextPipeline(Pipeline):
TF_MODEL_FOR_VISION_2_SEQ_MAPPING if self.framework == "tf" else MODEL_FOR_VISION_2_SEQ_MAPPING TF_MODEL_FOR_VISION_2_SEQ_MAPPING if self.framework == "tf" else MODEL_FOR_VISION_2_SEQ_MAPPING
) )
def _sanitize_parameters(self, **kwargs): def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None):
return {}, {}, {} forward_kwargs = {}
if generate_kwargs is not None:
forward_kwargs["generate_kwargs"] = generate_kwargs
if max_new_tokens is not None:
if "generate_kwargs" not in forward_kwargs:
forward_kwargs["generate_kwargs"] = {}
if "max_new_tokens" in forward_kwargs["generate_kwargs"]:
raise ValueError(
"'max_new_tokens' is defined twice, once in 'generate_kwargs' and once as a direct parameter,"
" please use only one"
)
forward_kwargs["generate_kwargs"]["max_new_tokens"] = max_new_tokens
return {}, forward_kwargs, {}
def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs): def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs):
""" """
...@@ -61,6 +73,12 @@ class ImageToTextPipeline(Pipeline): ...@@ -61,6 +73,12 @@ class ImageToTextPipeline(Pipeline):
The pipeline accepts either a single image or a batch of images. The pipeline accepts either a single image or a batch of images.
max_new_tokens (`int`, *optional*):
The amount of maximum tokens to generate. By default it will use `generate` default.
generate_kwargs (`Dict`, *optional*):
Pass it to send all of these arguments directly to `generate` allowing full control of this function.
Return: Return:
A list or a list of list of `dict`: Each result comes as a dictionary with the following key: A list or a list of list of `dict`: Each result comes as a dictionary with the following key:
...@@ -73,13 +91,15 @@ class ImageToTextPipeline(Pipeline): ...@@ -73,13 +91,15 @@ class ImageToTextPipeline(Pipeline):
model_inputs = self.feature_extractor(images=image, return_tensors=self.framework) model_inputs = self.feature_extractor(images=image, return_tensors=self.framework)
return model_inputs return model_inputs
def _forward(self, model_inputs): def _forward(self, model_inputs, generate_kwargs=None):
if generate_kwargs is None:
generate_kwargs = {}
# FIXME: We need to pop here due to a difference in how `generation_utils.py` and `generation_tf_utils.py` # FIXME: We need to pop here due to a difference in how `generation_utils.py` and `generation_tf_utils.py`
# parse inputs. In the Tensorflow version, `generate` raises an error if we don't use `input_ids` whereas # parse inputs. In the Tensorflow version, `generate` raises an error if we don't use `input_ids` whereas
# the PyTorch version matches it with `self.model.main_input_name` or `self.model.encoder.main_input_name` # the PyTorch version matches it with `self.model.main_input_name` or `self.model.encoder.main_input_name`
# in the `_prepare_model_inputs` method. # in the `_prepare_model_inputs` method.
inputs = model_inputs.pop(self.model.main_input_name) inputs = model_inputs.pop(self.model.main_input_name)
model_outputs = self.model.generate(inputs, **model_inputs) model_outputs = self.model.generate(inputs, **model_inputs, **generate_kwargs)
return model_outputs return model_outputs
def postprocess(self, model_outputs): def postprocess(self, model_outputs):
......
...@@ -86,6 +86,12 @@ class ImageToTextPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta ...@@ -86,6 +86,12 @@ class ImageToTextPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta
], ],
) )
outputs = pipe(image, max_new_tokens=1)
self.assertEqual(
outputs,
[{"generated_text": "growth"}],
)
@require_torch @require_torch
def test_small_model_pt(self): def test_small_model_pt(self):
pipe = pipeline("image-to-text", model="hf-internal-testing/tiny-random-vit-gpt2") pipe = pipeline("image-to-text", model="hf-internal-testing/tiny-random-vit-gpt2")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment