Unverified Commit 235d34cf authored by Nilesh's avatar Nilesh Committed by GitHub
Browse files

Check for latents, before calling prepare_latents - sdxlImg2Img (#7582)

* Check for latents, before calling prepare_latents - sdxlImg2Img

* Added latents check for all the img2img pipeline

* Fixed silly mistake while checking latents as None
parent 50296739
...@@ -359,8 +359,15 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline, StableDiffusionMixin): ...@@ -359,8 +359,15 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline, StableDiffusionMixin):
# Preprocess image # Preprocess image
image = preprocess(image, width, height) image = preprocess(image, width, height)
if latents is None:
latents = self.prepare_latents( latents = self.prepare_latents(
image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, self.device, generator image,
latent_timestep,
batch_size,
num_images_per_prompt,
text_embeddings.dtype,
self.device,
generator,
) )
if clip_guidance_scale > 0: if clip_guidance_scale > 0:
......
...@@ -335,6 +335,7 @@ class LatentConsistencyModelImg2ImgPipeline(DiffusionPipeline): ...@@ -335,6 +335,7 @@ class LatentConsistencyModelImg2ImgPipeline(DiffusionPipeline):
# 5. Prepare latent variable # 5. Prepare latent variable
num_channels_latents = self.unet.config.in_channels num_channels_latents = self.unet.config.in_channels
if latents is None:
latents = self.prepare_latents( latents = self.prepare_latents(
image, image,
latent_timestep, latent_timestep,
......
...@@ -802,6 +802,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, StableDiffusio ...@@ -802,6 +802,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, StableDiffusio
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# 6. Prepare latent variables # 6. Prepare latent variables
if latents is None:
latents = self.prepare_latents( latents = self.prepare_latents(
image, image,
latent_timestep, latent_timestep,
......
...@@ -907,6 +907,7 @@ class StableDiffusionControlNetInpaintImg2ImgPipeline(DiffusionPipeline, StableD ...@@ -907,6 +907,7 @@ class StableDiffusionControlNetInpaintImg2ImgPipeline(DiffusionPipeline, StableD
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# 6. Prepare latent variables # 6. Prepare latent variables
if latents is None:
latents = self.prepare_latents( latents = self.prepare_latents(
image, image,
latent_timestep, latent_timestep,
......
...@@ -1169,6 +1169,7 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -1169,6 +1169,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
# 6. Prepare latent variables # 6. Prepare latent variables
if latents is None:
latents = self.prepare_latents( latents = self.prepare_latents(
image, image,
latent_timestep, latent_timestep,
......
...@@ -1429,6 +1429,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -1429,6 +1429,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
# 6. Prepare latent variables # 6. Prepare latent variables
if latents is None:
latents = self.prepare_latents( latents = self.prepare_latents(
image, image,
latent_timestep, latent_timestep,
......
...@@ -872,6 +872,7 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -872,6 +872,7 @@ class LatentConsistencyModelImg2ImgPipeline(
else self.scheduler.config.original_inference_steps else self.scheduler.config.original_inference_steps
) )
latent_timestep = timesteps[:1] latent_timestep = timesteps[:1]
if latents is None:
latents = self.prepare_latents( latents = self.prepare_latents(
image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
) )
......
...@@ -239,7 +239,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline): ...@@ -239,7 +239,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
num_embeddings = self.prior.config.num_embeddings num_embeddings = self.prior.config.num_embeddings
embedding_dim = self.prior.config.embedding_dim embedding_dim = self.prior.config.embedding_dim
if latents is None:
latents = self.prepare_latents( latents = self.prepare_latents(
(batch_size, num_embeddings * embedding_dim), (batch_size, num_embeddings * embedding_dim),
image_embeds.dtype, image_embeds.dtype,
......
...@@ -786,6 +786,7 @@ class StableUnCLIPImg2ImgPipeline( ...@@ -786,6 +786,7 @@ class StableUnCLIPImg2ImgPipeline(
# 6. Prepare latent variables # 6. Prepare latent variables
num_channels_latents = self.unet.config.in_channels num_channels_latents = self.unet.config.in_channels
if latents is None:
latents = self.prepare_latents( latents = self.prepare_latents(
batch_size=batch_size, batch_size=batch_size,
num_channels_latents=num_channels_latents, num_channels_latents=num_channels_latents,
......
...@@ -1247,7 +1247,9 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -1247,7 +1247,9 @@ class StableDiffusionXLImg2ImgPipeline(
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
add_noise = True if self.denoising_start is None else False add_noise = True if self.denoising_start is None else False
# 6. Prepare latent variables # 6. Prepare latent variables
if latents is None:
latents = self.prepare_latents( latents = self.prepare_latents(
image, image,
latent_timestep, latent_timestep,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment