Unverified Commit efa773af authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Support K-LMS in img2img (#270)

* Support K-LMS in img2img

* Apply review suggestions
parent da7d4cf2
...@@ -5,7 +5,14 @@ import numpy as np ...@@ -5,7 +5,14 @@ import numpy as np
import torch import torch
import PIL import PIL
from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, PNDMScheduler, UNet2DConditionModel from diffusers import (
AutoencoderKL,
DDIMScheduler,
DiffusionPipeline,
LMSDiscreteScheduler,
PNDMScheduler,
UNet2DConditionModel,
)
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
...@@ -87,12 +94,17 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -87,12 +94,17 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
# get the original timestep using init_timestep # get the original timestep using init_timestep
init_timestep = int(num_inference_steps * strength) + offset init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps) init_timestep = min(init_timestep, num_inference_steps)
timesteps = self.scheduler.timesteps[-init_timestep] if isinstance(self.scheduler, LMSDiscreteScheduler):
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) timesteps = torch.tensor(
[num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
)
else:
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
# add noise to latents using the timesteps # add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device) noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) init_latents = self.scheduler.add_noise(init_latents, noise, timesteps).to(self.device)
# get prompt text embeddings # get prompt text embeddings
text_input = self.tokenizer( text_input = self.tokenizer(
...@@ -133,8 +145,15 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -133,8 +145,15 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
latents = init_latents latents = init_latents
t_start = max(num_inference_steps - init_timestep + offset, 0) t_start = max(num_inference_steps - init_timestep + offset, 0)
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])): for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
t_index = t_start + i
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[t_index]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
latent_model_input = latent_model_input.to(self.unet.dtype)
t = t.to(self.unet.dtype)
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
...@@ -145,11 +164,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -145,11 +164,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"] if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs)["prev_sample"]
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
# scale and decode the image latents with vae # scale and decode the image latents with vae
latents = 1 / 0.18215 * latents latents = 1 / 0.18215 * latents
image = self.vae.decode(latents) image = self.vae.decode(latents.to(self.vae.dtype))
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
......
...@@ -138,6 +138,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -138,6 +138,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler): if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[i] sigma = self.scheduler.sigmas[i]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
# predict the noise residual # predict the noise residual
......
...@@ -124,8 +124,9 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -124,8 +124,9 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
return {"prev_sample": prev_sample} return {"prev_sample": prev_sample}
def add_noise(self, original_samples, noise, timesteps): def add_noise(self, original_samples, noise, timesteps):
sigmas = self.match_shape(self.sigmas, noise) sigmas = self.match_shape(self.sigmas[timesteps], noise)
noisy_samples = original_samples + noise * sigmas[timesteps] noisy_samples = original_samples + noise * sigmas
return noisy_samples return noisy_samples
def __len__(self): def __len__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment