"vscode:/vscode.git/clone" did not exist on "791650ddef9eb11e011506dbd5d22ed6bfcb6a10"
Unverified Commit 2921a201 authored by Nan's avatar Nan Committed by GitHub
Browse files

[SD3] Fix mis-matched shape when num_images_per_prompt > 1 using without T5...


[SD3] Fix mis-matched shape when num_images_per_prompt > 1 using without T5 (text_encoder_3=None) (#8558)

* fix shape mismatch when num_images_per_prompt > 1 and text_encoder_3=None

* style

* fix copies

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
parent 3376252d
......@@ -217,7 +217,11 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
if self.text_encoder_3 is None:
return torch.zeros(
(batch_size, self.tokenizer_max_length, self.transformer.config.joint_attention_dim),
(
batch_size * num_images_per_prompt,
self.tokenizer_max_length,
self.transformer.config.joint_attention_dim,
),
device=device,
dtype=dtype,
)
......
......@@ -232,7 +232,11 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
if self.text_encoder_3 is None:
return torch.zeros(
(batch_size, self.tokenizer_max_length, self.transformer.config.joint_attention_dim),
(
batch_size * num_images_per_prompt,
self.tokenizer_max_length,
self.transformer.config.joint_attention_dim,
),
device=device,
dtype=dtype,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment