Unverified Commit 6d1a6486 authored by alirezafarashah's avatar alirezafarashah Committed by GitHub
Browse files

Fix small inconsistency in output dimension of "_get_t5_prompt_embeds"...


Fix small inconsistency in output dimension of "_get_t5_prompt_embeds" function in sd3 pipeline (#12531)

* Fix small inconsistency in output dimension of t5 embeds when text_encoder_3 is None

* first commit

---------
Co-authored-by: default avatarAlireza Farashah <alireza.farashah@cn-g017.server.mila.quebec>
Co-authored-by: default avatarAlireza Farashah <alireza.farashah@login-2.server.mila.quebec>
parent 250f5cb5
......@@ -266,7 +266,7 @@ class StableDiffusion3ControlNetPipeline(
return torch.zeros(
(
batch_size * num_images_per_prompt,
self.tokenizer_max_length,
max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,
......
......@@ -284,7 +284,7 @@ class StableDiffusion3ControlNetInpaintingPipeline(
return torch.zeros(
(
batch_size * num_images_per_prompt,
self.tokenizer_max_length,
max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,
......
......@@ -237,7 +237,7 @@ class StableDiffusion3PAGPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSin
return torch.zeros(
(
batch_size * num_images_per_prompt,
self.tokenizer_max_length,
max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,
......
......@@ -253,7 +253,7 @@ class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
return torch.zeros(
(
batch_size * num_images_per_prompt,
self.tokenizer_max_length,
max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,
......
......@@ -248,7 +248,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
return torch.zeros(
(
batch_size * num_images_per_prompt,
self.tokenizer_max_length,
max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,
......
......@@ -272,7 +272,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
return torch.zeros(
(
batch_size * num_images_per_prompt,
self.tokenizer_max_length,
max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,
......
......@@ -278,7 +278,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
return torch.zeros(
(
batch_size * num_images_per_prompt,
self.tokenizer_max_length,
max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment