Unverified Commit 235d34cf authored by Nilesh's avatar Nilesh Committed by GitHub
Browse files

Check for latents, before calling prepare_latents - sdxlImg2Img (#7582)

* Check for latents, before calling prepare_latents - sdxlImg2Img

* Added latents check for all the img2img pipeline

* Fixed silly mistake while checking latents as None
parent 50296739
...@@ -359,9 +359,16 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline, StableDiffusionMixin): ...@@ -359,9 +359,16 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline, StableDiffusionMixin):
# Preprocess image # Preprocess image
image = preprocess(image, width, height) image = preprocess(image, width, height)
latents = self.prepare_latents( if latents is None:
image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, self.device, generator latents = self.prepare_latents(
) image,
latent_timestep,
batch_size,
num_images_per_prompt,
text_embeddings.dtype,
self.device,
generator,
)
if clip_guidance_scale > 0: if clip_guidance_scale > 0:
if clip_prompt is not None: if clip_prompt is not None:
......
...@@ -335,17 +335,18 @@ class LatentConsistencyModelImg2ImgPipeline(DiffusionPipeline): ...@@ -335,17 +335,18 @@ class LatentConsistencyModelImg2ImgPipeline(DiffusionPipeline):
# 5. Prepare latent variable # 5. Prepare latent variable
num_channels_latents = self.unet.config.in_channels num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents( if latents is None:
image, latents = self.prepare_latents(
latent_timestep, image,
batch_size * num_images_per_prompt, latent_timestep,
num_channels_latents, batch_size * num_images_per_prompt,
height, num_channels_latents,
width, height,
prompt_embeds.dtype, width,
device, prompt_embeds.dtype,
latents, device,
) latents,
)
bs = batch_size * num_images_per_prompt bs = batch_size * num_images_per_prompt
# 6. Get Guidance Scale Embedding # 6. Get Guidance Scale Embedding
......
...@@ -802,15 +802,16 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, StableDiffusio ...@@ -802,15 +802,16 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, StableDiffusio
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# 6. Prepare latent variables # 6. Prepare latent variables
latents = self.prepare_latents( if latents is None:
image, latents = self.prepare_latents(
latent_timestep, image,
batch_size, latent_timestep,
num_images_per_prompt, batch_size,
prompt_embeds.dtype, num_images_per_prompt,
device, prompt_embeds.dtype,
generator, device,
) generator,
)
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
......
...@@ -907,15 +907,16 @@ class StableDiffusionControlNetInpaintImg2ImgPipeline(DiffusionPipeline, StableD ...@@ -907,15 +907,16 @@ class StableDiffusionControlNetInpaintImg2ImgPipeline(DiffusionPipeline, StableD
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# 6. Prepare latent variables # 6. Prepare latent variables
latents = self.prepare_latents( if latents is None:
image, latents = self.prepare_latents(
latent_timestep, image,
batch_size, latent_timestep,
num_images_per_prompt, batch_size,
prompt_embeds.dtype, num_images_per_prompt,
device, prompt_embeds.dtype,
generator, device,
) generator,
)
mask_image_latents = self.prepare_mask_latents( mask_image_latents = self.prepare_mask_latents(
mask_image, mask_image,
......
...@@ -1169,15 +1169,16 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -1169,15 +1169,16 @@ class StableDiffusionControlNetImg2ImgPipeline(
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
# 6. Prepare latent variables # 6. Prepare latent variables
latents = self.prepare_latents( if latents is None:
image, latents = self.prepare_latents(
latent_timestep, image,
batch_size, latent_timestep,
num_images_per_prompt, batch_size,
prompt_embeds.dtype, num_images_per_prompt,
device, prompt_embeds.dtype,
generator, device,
) generator,
)
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
......
...@@ -1429,16 +1429,17 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -1429,16 +1429,17 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
# 6. Prepare latent variables # 6. Prepare latent variables
latents = self.prepare_latents( if latents is None:
image, latents = self.prepare_latents(
latent_timestep, image,
batch_size, latent_timestep,
num_images_per_prompt, batch_size,
prompt_embeds.dtype, num_images_per_prompt,
device, prompt_embeds.dtype,
generator, device,
True, generator,
) True,
)
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
......
...@@ -872,9 +872,10 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -872,9 +872,10 @@ class LatentConsistencyModelImg2ImgPipeline(
else self.scheduler.config.original_inference_steps else self.scheduler.config.original_inference_steps
) )
latent_timestep = timesteps[:1] latent_timestep = timesteps[:1]
latents = self.prepare_latents( if latents is None:
image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator latents = self.prepare_latents(
) image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
)
bs = batch_size * num_images_per_prompt bs = batch_size * num_images_per_prompt
# 6. Get Guidance Scale Embedding # 6. Get Guidance Scale Embedding
......
...@@ -239,15 +239,15 @@ class ShapEImg2ImgPipeline(DiffusionPipeline): ...@@ -239,15 +239,15 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
num_embeddings = self.prior.config.num_embeddings num_embeddings = self.prior.config.num_embeddings
embedding_dim = self.prior.config.embedding_dim embedding_dim = self.prior.config.embedding_dim
if latents is None:
latents = self.prepare_latents( latents = self.prepare_latents(
(batch_size, num_embeddings * embedding_dim), (batch_size, num_embeddings * embedding_dim),
image_embeds.dtype, image_embeds.dtype,
device, device,
generator, generator,
latents, latents,
self.scheduler, self.scheduler,
) )
# YiYi notes: for testing only to match ldm, we can directly create a latents with desired shape: batch_size, num_embeddings, embedding_dim # YiYi notes: for testing only to match ldm, we can directly create a latents with desired shape: batch_size, num_embeddings, embedding_dim
latents = latents.reshape(latents.shape[0], num_embeddings, embedding_dim) latents = latents.reshape(latents.shape[0], num_embeddings, embedding_dim)
......
...@@ -786,16 +786,17 @@ class StableUnCLIPImg2ImgPipeline( ...@@ -786,16 +786,17 @@ class StableUnCLIPImg2ImgPipeline(
# 6. Prepare latent variables # 6. Prepare latent variables
num_channels_latents = self.unet.config.in_channels num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents( if latents is None:
batch_size=batch_size, latents = self.prepare_latents(
num_channels_latents=num_channels_latents, batch_size=batch_size,
height=height, num_channels_latents=num_channels_latents,
width=width, height=height,
dtype=prompt_embeds.dtype, width=width,
device=device, dtype=prompt_embeds.dtype,
generator=generator, device=device,
latents=latents, generator=generator,
) latents=latents,
)
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
......
...@@ -1247,17 +1247,19 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -1247,17 +1247,19 @@ class StableDiffusionXLImg2ImgPipeline(
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
add_noise = True if self.denoising_start is None else False add_noise = True if self.denoising_start is None else False
# 6. Prepare latent variables # 6. Prepare latent variables
latents = self.prepare_latents( if latents is None:
image, latents = self.prepare_latents(
latent_timestep, image,
batch_size, latent_timestep,
num_images_per_prompt, batch_size,
prompt_embeds.dtype, num_images_per_prompt,
device, prompt_embeds.dtype,
generator, device,
add_noise, generator,
) add_noise,
)
# 7. Prepare extra step kwargs. # 7. Prepare extra step kwargs.
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment