Unverified Commit 055c90f5 authored by Nipun Jindal's avatar Nipun Jindal Committed by GitHub
Browse files

[2737]: Add DPMSolverMultistepScheduler to CLIP guided community pipeline (#2779)



[2737]: Add DPMSolverMultistepScheduler to CLIP guided community pipelines
Co-authored-by: default avatarnjindal <njindal@adobe.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 2ef9bdd7
...@@ -11,6 +11,7 @@ from diffusers import ( ...@@ -11,6 +11,7 @@ from diffusers import (
AutoencoderKL, AutoencoderKL,
DDIMScheduler, DDIMScheduler,
DiffusionPipeline, DiffusionPipeline,
DPMSolverMultistepScheduler,
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
UNet2DConditionModel, UNet2DConditionModel,
...@@ -63,7 +64,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): ...@@ -63,7 +64,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
clip_model: CLIPModel, clip_model: CLIPModel,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler], scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler],
feature_extractor: CLIPImageProcessor, feature_extractor: CLIPImageProcessor,
): ):
super().__init__() super().__init__()
...@@ -125,17 +126,12 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): ...@@ -125,17 +126,12 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
): ):
latents = latents.detach().requires_grad_() latents = latents.detach().requires_grad_()
if isinstance(self.scheduler, LMSDiscreteScheduler): latent_model_input = self.scheduler.scale_model_input(latents, timestep)
sigma = self.scheduler.sigmas[index]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latents / ((sigma**2 + 1) ** 0.5)
else:
latent_model_input = latents
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample
if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler)): if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler, DPMSolverMultistepScheduler)):
alpha_prod_t = self.scheduler.alphas_cumprod[timestep] alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
beta_prod_t = 1 - alpha_prod_t beta_prod_t = 1 - alpha_prod_t
# compute predicted original sample from predicted noise also called # compute predicted original sample from predicted noise also called
......
...@@ -13,6 +13,7 @@ from diffusers import ( ...@@ -13,6 +13,7 @@ from diffusers import (
AutoencoderKL, AutoencoderKL,
DDIMScheduler, DDIMScheduler,
DiffusionPipeline, DiffusionPipeline,
DPMSolverMultistepScheduler,
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
UNet2DConditionModel, UNet2DConditionModel,
...@@ -140,7 +141,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): ...@@ -140,7 +141,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
clip_model: CLIPModel, clip_model: CLIPModel,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler], scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler],
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
): ):
super().__init__() super().__init__()
...@@ -263,17 +264,12 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): ...@@ -263,17 +264,12 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
): ):
latents = latents.detach().requires_grad_() latents = latents.detach().requires_grad_()
if isinstance(self.scheduler, LMSDiscreteScheduler): latent_model_input = self.scheduler.scale_model_input(latents, timestep)
sigma = self.scheduler.sigmas[index]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latents / ((sigma**2 + 1) ** 0.5)
else:
latent_model_input = latents
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample
if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler)): if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler, DPMSolverMultistepScheduler)):
alpha_prod_t = self.scheduler.alphas_cumprod[timestep] alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
beta_prod_t = 1 - alpha_prod_t beta_prod_t = 1 - alpha_prod_t
# compute predicted original sample from predicted noise also called # compute predicted original sample from predicted noise also called
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment