"vscode:/vscode.git/clone" did not exist on "480510ada99a8fd7cae8de47bb202382250d6873"
Unverified Commit 2c25b98c authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[AuraFlow] fix long prompt handling (#8937)

fix
parent 93983b67
......@@ -260,7 +260,6 @@ class AuraFlowPipeline(DiffusionPipeline):
padding="max_length",
return_tensors="pt",
)
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
text_input_ids = text_inputs["input_ids"]
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
......@@ -273,6 +272,7 @@ class AuraFlowPipeline(DiffusionPipeline):
f" {max_length} tokens: {removed_text}"
)
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
prompt_embeds = self.text_encoder(**text_inputs)[0]
prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape)
prompt_embeds = prompt_embeds * prompt_attention_mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment