Unverified Commit 80ff4ba6 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Fix issue with prompt embeds and latents in SD Cascade Decoder with multiple...


Fix issue with prompt embeds and latents in SD Cascade Decoder with multiple image embeddings for a single prompt.  (#7381)

* fix

* update

* update

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent b09a2aa3
...@@ -100,8 +100,10 @@ class StableCascadeDecoderPipeline(DiffusionPipeline): ...@@ -100,8 +100,10 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
) )
self.register_to_config(latent_dim_scale=latent_dim_scale) self.register_to_config(latent_dim_scale=latent_dim_scale)
def prepare_latents(self, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, scheduler): def prepare_latents(
batch_size, channels, height, width = image_embeddings.shape self, batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, scheduler
):
_, channels, height, width = image_embeddings.shape
latents_shape = ( latents_shape = (
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
4, 4,
...@@ -383,7 +385,19 @@ class StableCascadeDecoderPipeline(DiffusionPipeline): ...@@ -383,7 +385,19 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
) )
if isinstance(image_embeddings, list): if isinstance(image_embeddings, list):
image_embeddings = torch.cat(image_embeddings, dim=0) image_embeddings = torch.cat(image_embeddings, dim=0)
batch_size = image_embeddings.shape[0]
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# Compute the effective number of images per prompt
# We must account for the fact that the image embeddings from the prior can be generated with num_images_per_prompt > 1
# This results in a case where a single prompt is associated with multiple image embeddings
# Divide the number of image embeddings by the batch size to determine if this is the case.
num_images_per_prompt = num_images_per_prompt * (image_embeddings.shape[0] // batch_size)
# 2. Encode caption # 2. Encode caption
if prompt_embeds is None and negative_prompt_embeds is None: if prompt_embeds is None and negative_prompt_embeds is None:
...@@ -417,7 +431,7 @@ class StableCascadeDecoderPipeline(DiffusionPipeline): ...@@ -417,7 +431,7 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
# 5. Prepare latents # 5. Prepare latents
latents = self.prepare_latents( latents = self.prepare_latents(
image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler
) )
# 6. Run denoising loop # 6. Run denoising loop
......
...@@ -33,6 +33,7 @@ from diffusers.utils.testing_utils import ( ...@@ -33,6 +33,7 @@ from diffusers.utils.testing_utils import (
slow, slow,
torch_device, torch_device,
) )
from diffusers.utils.torch_utils import randn_tensor
from ..test_pipelines_common import PipelineTesterMixin from ..test_pipelines_common import PipelineTesterMixin
...@@ -246,6 +247,66 @@ class StableCascadeDecoderPipelineFastTests(PipelineTesterMixin, unittest.TestCa ...@@ -246,6 +247,66 @@ class StableCascadeDecoderPipelineFastTests(PipelineTesterMixin, unittest.TestCa
assert np.abs(decoder_output_prompt.images - decoder_output_prompt_embeds.images).max() < 1e-5 assert np.abs(decoder_output_prompt.images - decoder_output_prompt_embeds.images).max() < 1e-5
def test_stable_cascade_decoder_single_prompt_multiple_image_embeddings(self):
device = "cpu"
components = self.get_dummy_components()
pipe = StableCascadeDecoderPipeline(**components)
pipe.set_progress_bar_config(disable=None)
prior_num_images_per_prompt = 2
decoder_num_images_per_prompt = 2
prompt = ["a cat"]
batch_size = len(prompt)
generator = torch.Generator(device)
image_embeddings = randn_tensor(
(batch_size * prior_num_images_per_prompt, 4, 4, 4), generator=generator.manual_seed(0)
)
decoder_output = pipe(
image_embeddings=image_embeddings,
prompt=prompt,
num_inference_steps=1,
output_type="np",
guidance_scale=0.0,
generator=generator.manual_seed(0),
num_images_per_prompt=decoder_num_images_per_prompt,
)
assert decoder_output.images.shape[0] == (
batch_size * prior_num_images_per_prompt * decoder_num_images_per_prompt
)
def test_stable_cascade_decoder_single_prompt_multiple_image_embeddings_with_guidance(self):
device = "cpu"
components = self.get_dummy_components()
pipe = StableCascadeDecoderPipeline(**components)
pipe.set_progress_bar_config(disable=None)
prior_num_images_per_prompt = 2
decoder_num_images_per_prompt = 2
prompt = ["a cat"]
batch_size = len(prompt)
generator = torch.Generator(device)
image_embeddings = randn_tensor(
(batch_size * prior_num_images_per_prompt, 4, 4, 4), generator=generator.manual_seed(0)
)
decoder_output = pipe(
image_embeddings=image_embeddings,
prompt=prompt,
num_inference_steps=1,
output_type="np",
guidance_scale=2.0,
generator=generator.manual_seed(0),
num_images_per_prompt=decoder_num_images_per_prompt,
)
assert decoder_output.images.shape[0] == (
batch_size * prior_num_images_per_prompt * decoder_num_images_per_prompt
)
@slow @slow
@require_torch_gpu @require_torch_gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment