Unverified Commit fca66ec4 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fixing a hard to trigger bug for `text-generation` pipeline. (#18131)

* Fixing a bug where attention mask was not passed to generate.

* Fixing zero-size prompts.

* Comment on top.
parent 8581a798
...@@ -205,14 +205,17 @@ class TextGenerationPipeline(Pipeline): ...@@ -205,14 +205,17 @@ class TextGenerationPipeline(Pipeline):
def _forward(self, model_inputs, **generate_kwargs): def _forward(self, model_inputs, **generate_kwargs):
input_ids = model_inputs["input_ids"] input_ids = model_inputs["input_ids"]
attention_mask = model_inputs.get("attention_mask", None)
# Allow empty prompts # Allow empty prompts
if input_ids.shape[1] == 0: if input_ids.shape[1] == 0:
input_ids = None input_ids = None
attention_mask = None
in_b = 1 in_b = 1
else: else:
in_b = input_ids.shape[0] in_b = input_ids.shape[0]
prompt_text = model_inputs.pop("prompt_text") prompt_text = model_inputs.pop("prompt_text")
generated_sequence = self.model.generate(input_ids=input_ids, **generate_kwargs) # BS x SL # BS x SL
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
out_b = generated_sequence.shape[0] out_b = generated_sequence.shape[0]
if self.framework == "pt": if self.framework == "pt":
generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:]) generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment