Unverified Commit 9b704f76 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Img2Img2] Re-add K LMS scheduler (#340)

parent e49dd03d
...@@ -5,12 +5,11 @@ import numpy as np ...@@ -5,12 +5,11 @@ import numpy as np
import torch import torch
import PIL import PIL
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
...@@ -31,7 +30,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -31,7 +30,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler], scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
): ):
...@@ -93,12 +92,17 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -93,12 +92,17 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
# get the original timestep using init_timestep # get the original timestep using init_timestep
init_timestep = int(num_inference_steps * strength) + offset init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps) init_timestep = min(init_timestep, num_inference_steps)
timesteps = self.scheduler.timesteps[-init_timestep] if isinstance(self.scheduler, LMSDiscreteScheduler):
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) timesteps = torch.tensor(
[num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
)
else:
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
# add noise to latents using the timesteps # add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device) noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) init_latents = self.scheduler.add_noise(init_latents, noise, timesteps).to(self.device)
# get prompt text embeddings # get prompt text embeddings
text_input = self.tokenizer( text_input = self.tokenizer(
...@@ -137,11 +141,22 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -137,11 +141,22 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
extra_step_kwargs["eta"] = eta extra_step_kwargs["eta"] = eta
latents = init_latents latents = init_latents
t_start = max(num_inference_steps - init_timestep + offset, 0) t_start = max(num_inference_steps - init_timestep + offset, 0)
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])): for i, t in enumerate(self.progress_bar(self.scheduler.timesteps[t_start:])):
t_index = t_start + i
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[t_index]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
latent_model_input = latent_model_input.to(self.unet.dtype)
t = t.to(self.unet.dtype)
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
...@@ -151,11 +166,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -151,11 +166,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"] if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs)["prev_sample"]
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
# scale and decode the image latents with vae # scale and decode the image latents with vae
latents = 1 / 0.18215 * latents latents = 1 / 0.18215 * latents
image = self.vae.decode(latents) image = self.vae.decode(latents.to(self.vae.dtype))
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
......
...@@ -445,6 +445,49 @@ class PipelineFastTests(unittest.TestCase): ...@@ -445,6 +445,49 @@ class PipelineFastTests(unittest.TestCase):
expected_slice = np.array([0.4492, 0.3865, 0.4222, 0.5854, 0.5139, 0.4379, 0.4193, 0.48, 0.4218]) expected_slice = np.array([0.4492, 0.3865, 0.4222, 0.5854, 0.5139, 0.4379, 0.4193, 0.48, 0.4218])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_img2img_k_lms(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
init_image = self.dummy_image.to(device)
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionImg2ImgPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
init_image=init_image,
)
image = output["sample"]
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.4367, 0.4986, 0.4372, 0.6706, 0.5665, 0.444, 0.5864, 0.6019, 0.5203])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_inpaint(self): def test_stable_diffusion_inpaint(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet unet = self.dummy_cond_unet
...@@ -892,7 +935,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -892,7 +935,7 @@ class PipelineTesterMixin(unittest.TestCase):
def test_stable_diffusion_img2img_pipeline(self): def test_stable_diffusion_img2img_pipeline(self):
ds = load_dataset("hf-internal-testing/diffusers-images", split="train") ds = load_dataset("hf-internal-testing/diffusers-images", split="train")
init_image = ds[1]["image"].resize((768, 512)) init_image = ds[2]["image"].resize((768, 512))
output_image = ds[0]["image"].resize((768, 512)) output_image = ds[0]["image"].resize((768, 512))
model_id = "CompVis/stable-diffusion-v1-4" model_id = "CompVis/stable-diffusion-v1-4"
...@@ -915,12 +958,40 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -915,12 +958,40 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_in_paint_pipeline(self): def test_stable_diffusion_img2img_pipeline_k_lms(self):
ds = load_dataset("hf-internal-testing/diffusers-images", split="train") ds = load_dataset("hf-internal-testing/diffusers-images", split="train")
init_image = ds[2]["image"].resize((768, 512)) init_image = ds[2]["image"].resize((768, 512))
mask_image = ds[3]["image"].resize((768, 512)) output_image = ds[1]["image"].resize((768, 512))
output_image = ds[4]["image"].resize((768, 512))
lms = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id, scheduler=lms, use_auth_token=True)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
prompt = "A fantasy landscape, trending on artstation"
generator = torch.Generator(device=torch_device).manual_seed(0)
image = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5, generator=generator)[
"sample"
][0]
expected_array = np.array(output_image)
sampled_array = np.array(image)
assert sampled_array.shape == (512, 768, 3)
assert np.max(np.abs(sampled_array - expected_array)) < 1e-4
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_in_paint_pipeline(self):
ds = load_dataset("hf-internal-testing/diffusers-images", split="train")
init_image = ds[3]["image"].resize((768, 512))
mask_image = ds[4]["image"].resize((768, 512))
output_image = ds[5]["image"].resize((768, 512))
model_id = "CompVis/stable-diffusion-v1-4" model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, use_auth_token=True) pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, use_auth_token=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment