"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "20e92586c1fda968ea3343ba0f44f2b21f3c09d2"
Unverified Commit 2921a201 authored by Nan's avatar Nan Committed by GitHub
Browse files

[SD3] Fix mis-matched shape when num_images_per_prompt > 1 using without T5...


[SD3] Fix mis-matched shape when num_images_per_prompt > 1 using without T5 (text_encoder_3=None) (#8558)

* fix shape mismatch when num_images_per_prompt > 1 and text_encoder_3=None

* style

* fix copies

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
parent 3376252d
...@@ -217,7 +217,11 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle ...@@ -217,7 +217,11 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
if self.text_encoder_3 is None: if self.text_encoder_3 is None:
return torch.zeros( return torch.zeros(
(batch_size, self.tokenizer_max_length, self.transformer.config.joint_attention_dim), (
batch_size * num_images_per_prompt,
self.tokenizer_max_length,
self.transformer.config.joint_attention_dim,
),
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
......
...@@ -232,7 +232,11 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline): ...@@ -232,7 +232,11 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
if self.text_encoder_3 is None: if self.text_encoder_3 is None:
return torch.zeros( return torch.zeros(
(batch_size, self.tokenizer_max_length, self.transformer.config.joint_attention_dim), (
batch_size * num_images_per_prompt,
self.tokenizer_max_length,
self.transformer.config.joint_attention_dim,
),
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment