Unverified Commit fa4c0e5e authored by C's avatar C Committed by GitHub
Browse files

optimize QwenImagePipeline to reduce unnecessary CUDA synchronization (#12072)

parent b793debd
......@@ -636,6 +636,11 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
if self.attention_kwargs is None:
self._attention_kwargs = {}
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
negative_txt_seq_lens = (
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
)
# 6. Denoising loop
self.scheduler.set_begin_index(0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
......@@ -654,7 +659,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
txt_seq_lens=txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
......@@ -668,7 +673,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(),
txt_seq_lens=negative_txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment