"vscode:/vscode.git/clone" did not exist on "4279d8ca44eca22994406345d91a31fb89b142a7"
Unverified Commit 5782e039 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

Stable diffusion pipeline (#168)



* add stable diffusion pipeline

* get rid of multiple if/else

* batch_size is unused

* add type hints

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

* fix some bugs
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 92b6dbba
...@@ -31,6 +31,7 @@ from .training_utils import EMAModel ...@@ -31,6 +31,7 @@ from .training_utils import EMAModel
if is_transformers_available(): if is_transformers_available():
from .pipelines import LDMTextToImagePipeline from .pipelines import LDMTextToImagePipeline, StableDiffusionPipeline
else: else:
from .utils.dummy_transformers_objects import * from .utils.dummy_transformers_objects import *
...@@ -9,3 +9,4 @@ from .stochatic_karras_ve import KarrasVePipeline ...@@ -9,3 +9,4 @@ from .stochatic_karras_ve import KarrasVePipeline
if is_transformers_available(): if is_transformers_available():
from .latent_diffusion import LDMTextToImagePipeline from .latent_diffusion import LDMTextToImagePipeline
from .stable_diffusion import StableDiffusionPipeline
...@@ -62,9 +62,10 @@ class LDMTextToImagePipeline(DiffusionPipeline): ...@@ -62,9 +62,10 @@ class LDMTextToImagePipeline(DiffusionPipeline):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_kwrags = {}
extra_kwargs = {}
if accepts_eta: if accepts_eta:
extra_kwrags["eta"] = eta extra_kwargs["eta"] = eta
for t in tqdm(self.scheduler.timesteps): for t in tqdm(self.scheduler.timesteps):
if guidance_scale == 1.0: if guidance_scale == 1.0:
...@@ -86,7 +87,7 @@ class LDMTextToImagePipeline(DiffusionPipeline): ...@@ -86,7 +87,7 @@ class LDMTextToImagePipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwrags)["prev_sample"] latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs)["prev_sample"]
# scale and decode the image latents with vae # scale and decode the image latents with vae
latents = 1 / 0.18215 * latents latents = 1 / 0.18215 * latents
......
...@@ -35,15 +35,16 @@ class LDMPipeline(DiffusionPipeline): ...@@ -35,15 +35,16 @@ class LDMPipeline(DiffusionPipeline):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_kwrags = {}
extra_kwargs = {}
if accepts_eta: if accepts_eta:
extra_kwrags["eta"] = eta extra_kwargs["eta"] = eta
for t in tqdm(self.scheduler.timesteps): for t in tqdm(self.scheduler.timesteps):
# predict the noise residual # predict the noise residual
noise_prediction = self.unet(latents, t)["sample"] noise_prediction = self.unet(latents, t)["sample"]
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwrags)["prev_sample"] latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs)["prev_sample"]
# decode the image latents with the VAE # decode the image latents with the VAE
image = self.vqvae.decode(latents) image = self.vqvae.decode(latents)
......
from ...utils import is_transformers_available
if is_transformers_available():
from .pipeline_stable_diffusion import StableDiffusionPipeline
import inspect
from typing import List, Optional, Union
import torch
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, PNDMScheduler
class StableDiffusionPipeline(DiffusionPipeline):
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler],
):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 1.0,
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
torch_device: Optional[Union[str, torch.device]] = None,
output_type: Optional[str] = "pil",
):
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
self.unet.to(torch_device)
self.vae.to(torch_device)
self.text_encoder.to(torch_device)
# get prompt text embeddings
text_input = self.tokenizer(prompt, padding=True, truncation=True, return_tensors="pt")
text_embeddings = self.text_encoder(text_input.input_ids.to(torch_device))[0]
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
max_length = text_input.input_ids.shape[-1]
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(torch_device))[0]
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# get the intial random noise
latents = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
)
latents = latents.to(torch_device)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_kwargs = {}
if accepts_eta:
extra_kwargs["eta"] = eta
self.scheduler.set_timesteps(num_inference_steps)
for t in tqdm(self.scheduler.timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs)["prev_sample"]
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents)
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image}
...@@ -10,11 +10,12 @@ from ...schedulers import KarrasVeScheduler ...@@ -10,11 +10,12 @@ from ...schedulers import KarrasVeScheduler
class KarrasVePipeline(DiffusionPipeline): class KarrasVePipeline(DiffusionPipeline):
""" """
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
Use Algorithm 2 and the VE column of Table 1 from [1] for reference. the VE column of Table 1 from [1] for reference.
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364 [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
[2] Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." https://arxiv.org/abs/2011.13456 https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
differential equations." https://arxiv.org/abs/2011.13456
""" """
unet: UNet2DModel unet: UNet2DModel
......
...@@ -24,11 +24,12 @@ from .scheduling_utils import SchedulerMixin ...@@ -24,11 +24,12 @@ from .scheduling_utils import SchedulerMixin
class KarrasVeScheduler(SchedulerMixin, ConfigMixin): class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
""" """
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
Use Algorithm 2 and the VE column of Table 1 from [1] for reference. the VE column of Table 1 from [1] for reference.
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364 [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
[2] Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." https://arxiv.org/abs/2011.13456 https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
differential equations." https://arxiv.org/abs/2011.13456
""" """
@register_to_config @register_to_config
...@@ -43,10 +44,9 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -43,10 +44,9 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
tensor_format="pt", tensor_format="pt",
): ):
""" """
For more details on the parameters, see the original paper's Appendix E.: For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
"Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
The grid search values used to find the optimal {s_noise, s_churn, s_min, s_max} for a specific model optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
are described in Table 5 of the paper.
Args: Args:
sigma_min (`float`): minimum noise magnitude sigma_min (`float`): minimum noise magnitude
...@@ -81,8 +81,8 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -81,8 +81,8 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
def add_noise_to_input(self, sample, sigma, generator=None): def add_noise_to_input(self, sample, sigma, generator=None):
""" """
Explicit Langevin-like "churn" step of adding noise to the sample according to Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
a factor gamma_i ≥ 0 to reach a higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
""" """
if self.s_min <= sigma <= self.s_max: if self.s_min <= sigma <= self.s_max:
gamma = min(self.s_churn / self.num_inference_steps, 2**0.5 - 1) gamma = min(self.s_churn / self.num_inference_steps, 2**0.5 - 1)
......
...@@ -8,3 +8,10 @@ class LDMTextToImagePipeline(metaclass=DummyObject): ...@@ -8,3 +8,10 @@ class LDMTextToImagePipeline(metaclass=DummyObject):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"]) requires_backends(self, ["transformers"])
class StableDiffusionPipeline(metaclass=DummyObject):
_backends = ["transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
...@@ -45,6 +45,8 @@ from diffusers.pipeline_utils import DiffusionPipeline ...@@ -45,6 +45,8 @@ from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.testing_utils import floats_tensor, slow, torch_device from diffusers.testing_utils import floats_tensor, slow, torch_device
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
from ..src.diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
...@@ -839,6 +841,38 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -839,6 +841,38 @@ class PipelineTesterMixin(unittest.TestCase):
expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344]) expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
def test_stable_diffusion(self):
ldm = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers")
prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy")[
"sample"
]
image_slice = image[0, -3:, -3:, -1]
# TODO: update the expected_slice
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
def test_stable_diffusion_fast(self):
ldm = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers")
prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
image = ldm([prompt], generator=generator, num_inference_steps=1, output_type="numpy")["sample"]
image_slice = image[0, -3:, -3:, -1]
# TODO: update the expected_slice
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow @slow
def test_score_sde_ve_pipeline(self): def test_score_sde_ve_pipeline(self):
model_id = "google/ncsnpp-church-256" model_id = "google/ncsnpp-church-256"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment