Unverified Commit 8252e24a authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`Generate`] Add conditional generation for multimodal models (#22424)

* add conditional generation

* add comments
parent 33f4cb10
...@@ -1288,6 +1288,10 @@ class GenerationMixin: ...@@ -1288,6 +1288,10 @@ class GenerationMixin:
model_kwargs=model_kwargs, model_kwargs=model_kwargs,
device=inputs_tensor.device, device=inputs_tensor.device,
) )
# conditional generation for multi-modal models.
if "input_ids" in model_kwargs and model_input_name == "pixel_values":
input_ids = torch.cat([input_ids, model_kwargs.pop("input_ids")], dim=-1)
else: else:
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment