"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "c50d997591d14dfa2030b015d2a5934add658b1d"
Unverified Commit fa4c0e5e authored by C's avatar C Committed by GitHub
Browse files

optimize QwenImagePipeline to reduce unnecessary CUDA synchronization (#12072)

parent b793debd
...@@ -636,6 +636,11 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): ...@@ -636,6 +636,11 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
if self.attention_kwargs is None: if self.attention_kwargs is None:
self._attention_kwargs = {} self._attention_kwargs = {}
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
negative_txt_seq_lens = (
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
)
# 6. Denoising loop # 6. Denoising loop
self.scheduler.set_begin_index(0) self.scheduler.set_begin_index(0)
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
...@@ -654,7 +659,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): ...@@ -654,7 +659,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes, img_shapes=img_shapes,
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), txt_seq_lens=txt_seq_lens,
attention_kwargs=self.attention_kwargs, attention_kwargs=self.attention_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
...@@ -668,7 +673,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): ...@@ -668,7 +673,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds, encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes, img_shapes=img_shapes,
txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(), txt_seq_lens=negative_txt_seq_lens,
attention_kwargs=self.attention_kwargs, attention_kwargs=self.attention_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment