Unverified Commit cd77a036 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[CLIPGuidedStableDiffusion] support DDIM scheduler (#1190)

add ddim in clip guided
parent 663f0c19
...@@ -5,7 +5,14 @@ import torch ...@@ -5,7 +5,14 @@ import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from diffusers import AutoencoderKL, DiffusionPipeline, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel from diffusers import (
AutoencoderKL,
DDIMScheduler,
DiffusionPipeline,
LMSDiscreteScheduler,
PNDMScheduler,
UNet2DConditionModel,
)
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
from torchvision import transforms from torchvision import transforms
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer
...@@ -56,7 +63,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): ...@@ -56,7 +63,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
clip_model: CLIPModel, clip_model: CLIPModel,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: Union[PNDMScheduler, LMSDiscreteScheduler], scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler],
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
): ):
super().__init__() super().__init__()
...@@ -123,7 +130,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): ...@@ -123,7 +130,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample
if isinstance(self.scheduler, PNDMScheduler): if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler)):
alpha_prod_t = self.scheduler.alphas_cumprod[timestep] alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
beta_prod_t = 1 - alpha_prod_t beta_prod_t = 1 - alpha_prod_t
# compute predicted original sample from predicted noise also called # compute predicted original sample from predicted noise also called
...@@ -176,6 +183,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): ...@@ -176,6 +183,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
num_inference_steps: Optional[int] = 50, num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5, guidance_scale: Optional[float] = 7.5,
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
clip_guidance_scale: Optional[float] = 100, clip_guidance_scale: Optional[float] = 100,
clip_prompt: Optional[Union[str, List[str]]] = None, clip_prompt: Optional[Union[str, List[str]]] = None,
num_cutouts: Optional[int] = 4, num_cutouts: Optional[int] = 4,
...@@ -275,6 +283,20 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): ...@@ -275,6 +283,20 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
# scale the initial noise by the standard deviation required by the scheduler # scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma latents = latents * self.scheduler.init_noise_sigma
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
for i, t in enumerate(self.progress_bar(timesteps_tensor)): for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
...@@ -306,7 +328,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): ...@@ -306,7 +328,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
) )
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents).prev_sample latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# scale and decode the image latents with vae # scale and decode the image latents with vae
latents = 1 / 0.18215 * latents latents = 1 / 0.18215 * latents
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment