Unverified Commit a3f9221a authored by regisss's avatar regisss Committed by GitHub
Browse files

Add generate kwargs to VQA pipeline (#29134)

parent 871ba71d
......@@ -123,9 +123,9 @@ class VisualQuestionAnsweringPipeline(Pipeline):
model_inputs.update(image_features)
return model_inputs
def _forward(self, model_inputs):
def _forward(self, model_inputs, **generate_kwargs):
if self.model.can_generate():
model_outputs = self.model.generate(**model_inputs)
model_outputs = self.model.generate(**model_inputs, **generate_kwargs)
else:
model_outputs = self.model(**model_inputs)
return model_outputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment