Unverified Commit 0df4ad54 authored by Takuma Mori's avatar Takuma Mori Committed by GitHub
Browse files

Add support `Karras sigmas` for StableDiffusionKDiffusionPipeline (#2874)

* add use_karras_sigmas option

thanks @Stax124

* fix sigma_min/max from scheduler.sigmas

* add docstring

* revert to use k_diffusion_model.sigma, to(device)

* add integration test

* make style
parent 51d970d6
......@@ -17,6 +17,7 @@ from typing import Callable, List, Optional, Union
import torch
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
from k_diffusion.sampling import get_sigmas_karras
from ...loaders import TextualInversionLoaderMixin
from ...pipelines import DiffusionPipeline
......@@ -409,6 +410,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
use_karras_sigmas: Optional[bool] = False,
):
r"""
Function invoked when calling the pipeline for generation.
......@@ -465,7 +467,10 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Use karras sigmas. For example, specifying `sample_dpmpp_2m` to `set_scheduler` will be equivalent to
`DPM++2M` in stable-diffusion-webui. On top of that, setting this option to True will make it `DPM++2M
Karras`.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
......@@ -503,10 +508,18 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device)
sigmas = self.scheduler.sigmas
# 5. Prepare sigmas
if use_karras_sigmas:
sigma_min: float = self.k_diffusion_model.sigmas[0].item()
sigma_max: float = self.k_diffusion_model.sigmas[-1].item()
sigmas = get_sigmas_karras(n=num_inference_steps, sigma_min=sigma_min, sigma_max=sigma_max)
sigmas = sigmas.to(device)
else:
sigmas = self.scheduler.sigmas
sigmas = sigmas.to(prompt_embeds.dtype)
# 5. Prepare latent variables
# 6. Prepare latent variables
num_channels_latents = self.unet.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
......@@ -522,7 +535,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device)
# 6. Define model function
# 7. Define model function
def model_fn(x, t):
latent_model_input = torch.cat([x] * 2)
t = torch.cat([t] * 2)
......@@ -533,16 +546,16 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
return noise_pred
# 7. Run k-diffusion solver
# 8. Run k-diffusion solver
latents = self.sampler(model_fn, latents, sigmas)
# 8. Post-processing
# 9. Post-processing
image = self.decode_latents(latents)
# 9. Run safety checker
# 10. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
# 10. Convert to PIL
# 11. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
......
......@@ -75,3 +75,32 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
expected_slice = np.array([0.1237, 0.1320, 0.1438, 0.1359, 0.1390, 0.1132, 0.1277, 0.1175, 0.1112])
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-1
def test_stable_diffusion_karras_sigmas(self):
sd_pipe = StableDiffusionKDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
sd_pipe.set_scheduler("sample_dpmpp_2m")
prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
output = sd_pipe(
[prompt],
generator=generator,
guidance_scale=7.5,
num_inference_steps=15,
output_type="np",
use_karras_sigmas=True,
)
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array(
[0.11381689, 0.12112921, 0.1389457, 0.12549606, 0.1244964, 0.10831517, 0.11562866, 0.10867816, 0.10499048]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment