Commit 13e37cab authored by Patrick von Platen's avatar Patrick von Platen
Browse files

Merge branch 'main' of https://github.com/huggingface/diffusers into main

parents 760dcb1f 889aa600
...@@ -5,18 +5,17 @@ ...@@ -5,18 +5,17 @@
The command to train a DDPM UNet model on the Oxford Flowers dataset: The command to train a DDPM UNet model on the Oxford Flowers dataset:
```bash ```bash
python -m torch.distributed.launch \ accelerate launch train_unconditional.py \
--nproc_per_node 4 \
train_unconditional.py \
--dataset="huggan/flowers-102-categories" \ --dataset="huggan/flowers-102-categories" \
--resolution=64 \ --resolution=64 \
--output_dir="flowers-ddpm" \ --output_dir="ddpm-ema-flowers-64" \
--batch_size=16 \ --train_batch_size=16 \
--num_epochs=100 \ --num_epochs=100 \
--gradient_accumulation_steps=1 \ --gradient_accumulation_steps=1 \
--lr=1e-4 \ --learning_rate=1e-4 \
--warmup_steps=500 \ --lr_warmup_steps=500 \
--mixed_precision=no --mixed_precision=no \
--push_to_hub
``` ```
A full training run takes 2 hours on 4xV100 GPUs. A full training run takes 2 hours on 4xV100 GPUs.
...@@ -29,18 +28,17 @@ A full training run takes 2 hours on 4xV100 GPUs. ...@@ -29,18 +28,17 @@ A full training run takes 2 hours on 4xV100 GPUs.
The command to train a DDPM UNet model on the Pokemon dataset: The command to train a DDPM UNet model on the Pokemon dataset:
```bash ```bash
python -m torch.distributed.launch \ accelerate launch train_unconditional.py \
--nproc_per_node 4 \
train_unconditional.py \
--dataset="huggan/pokemon" \ --dataset="huggan/pokemon" \
--resolution=64 \ --resolution=64 \
--output_dir="pokemon-ddpm" \ --output_dir="ddpm-ema-pokemon-64" \
--batch_size=16 \ --train_batch_size=16 \
--num_epochs=100 \ --num_epochs=100 \
--gradient_accumulation_steps=1 \ --gradient_accumulation_steps=1 \
--lr=1e-4 \ --learning_rate=1e-4 \
--warmup_steps=500 \ --lr_warmup_steps=500 \
--mixed_precision=no --mixed_precision=no \
--push_to_hub
``` ```
A full training run takes 2 hours on 4xV100 GPUs. A full training run takes 2 hours on 4xV100 GPUs.
......
...@@ -4,10 +4,10 @@ import os ...@@ -4,10 +4,10 @@ import os
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from accelerate import Accelerator, DistributedDataParallelKwargs from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from datasets import load_dataset from datasets import load_dataset
from diffusers import DDIMPipeline, DDIMScheduler, UNetModel from diffusers import DDPMPipeline, DDPMScheduler, UNetUnconditionalModel
from diffusers.hub_utils import init_git_repo, push_to_hub from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
...@@ -27,25 +27,37 @@ logger = get_logger(__name__) ...@@ -27,25 +27,37 @@ logger = get_logger(__name__)
def main(args): def main(args):
ddp_unused_params = DistributedDataParallelKwargs(find_unused_parameters=True)
logging_dir = os.path.join(args.output_dir, args.logging_dir) logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator = Accelerator( accelerator = Accelerator(
mixed_precision=args.mixed_precision, mixed_precision=args.mixed_precision,
log_with="tensorboard", log_with="tensorboard",
logging_dir=logging_dir, logging_dir=logging_dir,
kwargs_handlers=[ddp_unused_params],
) )
model = UNetModel( model = UNetUnconditionalModel(
attn_resolutions=(16,), image_size=args.resolution,
ch=128, in_channels=3,
ch_mult=(1, 2, 4, 8), out_channels=3,
dropout=0.0,
num_res_blocks=2, num_res_blocks=2,
resamp_with_conv=True, block_channels=(128, 128, 256, 256, 512, 512),
resolution=args.resolution, down_blocks=(
"UNetResDownBlock2D",
"UNetResDownBlock2D",
"UNetResDownBlock2D",
"UNetResDownBlock2D",
"UNetResAttnDownBlock2D",
"UNetResDownBlock2D",
),
up_blocks=(
"UNetResUpBlock2D",
"UNetResAttnUpBlock2D",
"UNetResUpBlock2D",
"UNetResUpBlock2D",
"UNetResUpBlock2D",
"UNetResUpBlock2D",
),
) )
noise_scheduler = DDIMScheduler(timesteps=1000, tensor_format="pt") noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt")
optimizer = torch.optim.AdamW( optimizer = torch.optim.AdamW(
model.parameters(), model.parameters(),
lr=args.learning_rate, lr=args.learning_rate,
...@@ -92,19 +104,6 @@ def main(args): ...@@ -92,19 +104,6 @@ def main(args):
run = os.path.split(__file__)[-1].split(".")[0] run = os.path.split(__file__)[-1].split(".")[0]
accelerator.init_trackers(run) accelerator.init_trackers(run)
# Train!
is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size() if is_distributed else 1
total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * world_size
max_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataloader.dataset)}")
logger.info(f" Num Epochs = {args.num_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}")
global_step = 0 global_step = 0
for epoch in range(args.num_epochs): for epoch in range(args.num_epochs):
model.train() model.train()
...@@ -112,45 +111,37 @@ def main(args): ...@@ -112,45 +111,37 @@ def main(args):
progress_bar.set_description(f"Epoch {epoch}") progress_bar.set_description(f"Epoch {epoch}")
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
clean_images = batch["input"] clean_images = batch["input"]
noise_samples = torch.randn(clean_images.shape).to(clean_images.device) # Sample noise that we'll add to the images
noise = torch.randn(clean_images.shape).to(clean_images.device)
bsz = clean_images.shape[0] bsz = clean_images.shape[0]
timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long() # Sample a random timestep for each image
timesteps = torch.randint(
0, noise_scheduler.num_train_timesteps, (bsz,), device=clean_images.device
).long()
# add noise onto the clean images according to the noise magnitude at each timestep # Add noise to the clean images according to the noise magnitude at each timestep
# (this is the forward diffusion process) # (this is the forward diffusion process)
noisy_images = noise_scheduler.add_noise(clean_images, noise_samples, timesteps) noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
if step % args.gradient_accumulation_steps != 0: with accelerator.accumulate(model):
with accelerator.no_sync(model): # Predict the noise residual
output = model(noisy_images, timesteps) noise_pred = model(noisy_images, timesteps)["sample"]
# predict the noise residual loss = F.mse_loss(noise_pred, noise)
loss = F.mse_loss(output, noise_samples)
loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss)
else:
output = model(noisy_images, timesteps)
# predict the noise residual
loss = F.mse_loss(output, noise_samples)
loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss) accelerator.backward(loss)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step() optimizer.step()
lr_scheduler.step() lr_scheduler.step()
if args.use_ema:
ema_model.step(model) ema_model.step(model)
optimizer.zero_grad() optimizer.zero_grad()
progress_bar.update(1) progress_bar.update(1)
progress_bar.set_postfix( logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"], ema_decay=ema_model.decay if args.use_ema:
) logs["ema_decay"] = ema_model.decay
accelerator.log( progress_bar.set_postfix(**logs)
{ accelerator.log(logs, step=global_step)
"train_loss": loss.detach().item(),
"epoch": epoch,
"ema_decay": ema_model.decay,
"step": global_step,
},
step=global_step,
)
global_step += 1 global_step += 1
progress_bar.close() progress_bar.close()
...@@ -159,14 +150,14 @@ def main(args): ...@@ -159,14 +150,14 @@ def main(args):
# Generate a sample image for visual inspection # Generate a sample image for visual inspection
if accelerator.is_main_process: if accelerator.is_main_process:
with torch.no_grad(): with torch.no_grad():
pipeline = DDIMPipeline( pipeline = DDPMPipeline(
unet=accelerator.unwrap_model(ema_model.averaged_model), unet=accelerator.unwrap_model(ema_model.averaged_model if args.use_ema else model),
noise_scheduler=noise_scheduler, scheduler=noise_scheduler,
) )
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
# run pipeline in inference (sample random noise and denoise) # run pipeline in inference (sample random noise and denoise)
images = pipeline(generator=generator, batch_size=args.eval_batch_size, num_inference_steps=50) images = pipeline(generator=generator, batch_size=args.eval_batch_size)
# denormalize the images and save to tensorboard # denormalize the images and save to tensorboard
images_processed = (images.cpu() + 1.0) * 127.5 images_processed = (images.cpu() + 1.0) * 127.5
...@@ -174,6 +165,7 @@ def main(args): ...@@ -174,6 +165,7 @@ def main(args):
accelerator.trackers[0].writer.add_images("test_samples", images_processed, epoch) accelerator.trackers[0].writer.add_images("test_samples", images_processed, epoch)
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
# save the model # save the model
if args.push_to_hub: if args.push_to_hub:
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False) push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
...@@ -188,12 +180,13 @@ if __name__ == "__main__": ...@@ -188,12 +180,13 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Simple example of a training script.") parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument("--local_rank", type=int, default=-1) parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--dataset", type=str, default="huggan/flowers-102-categories") parser.add_argument("--dataset", type=str, default="huggan/flowers-102-categories")
parser.add_argument("--output_dir", type=str, default="ddpm-model") parser.add_argument("--output_dir", type=str, default="ddpm-flowers-64")
parser.add_argument("--overwrite_output_dir", action="store_true") parser.add_argument("--overwrite_output_dir", action="store_true")
parser.add_argument("--resolution", type=int, default=64) parser.add_argument("--resolution", type=int, default=64)
parser.add_argument("--train_batch_size", type=int, default=16) parser.add_argument("--train_batch_size", type=int, default=16)
parser.add_argument("--eval_batch_size", type=int, default=16) parser.add_argument("--eval_batch_size", type=int, default=16)
parser.add_argument("--num_epochs", type=int, default=100) parser.add_argument("--num_epochs", type=int, default=100)
parser.add_argument("--save_model_epochs", type=int, default=5)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1) parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--learning_rate", type=float, default=1e-4) parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--lr_scheduler", type=str, default="cosine") parser.add_argument("--lr_scheduler", type=str, default="cosine")
...@@ -202,6 +195,7 @@ if __name__ == "__main__": ...@@ -202,6 +195,7 @@ if __name__ == "__main__":
parser.add_argument("--adam_beta2", type=float, default=0.999) parser.add_argument("--adam_beta2", type=float, default=0.999)
parser.add_argument("--adam_weight_decay", type=float, default=1e-6) parser.add_argument("--adam_weight_decay", type=float, default=1e-6)
parser.add_argument("--adam_epsilon", type=float, default=1e-3) parser.add_argument("--adam_epsilon", type=float, default=1e-3)
parser.add_argument("--use_ema", action="store_true", default=True)
parser.add_argument("--ema_inv_gamma", type=float, default=1.0) parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
parser.add_argument("--ema_power", type=float, default=3 / 4) parser.add_argument("--ema_power", type=float, default=3 / 4)
parser.add_argument("--ema_max_decay", type=float, default=0.9999) parser.add_argument("--ema_max_decay", type=float, default=0.9999)
......
...@@ -145,7 +145,7 @@ class Decoder(nn.Module): ...@@ -145,7 +145,7 @@ class Decoder(nn.Module):
block_in = ch * ch_mult[self.num_resolutions - 1] block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1) curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res) self.z_shape = (1, z_channels, curr_res, curr_res)
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) # print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
# z to block_in # z to block_in
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
......
...@@ -19,6 +19,7 @@ import os ...@@ -19,6 +19,7 @@ import os
from typing import Optional, Union from typing import Optional, Union
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from PIL import Image
from .configuration_utils import ConfigMixin from .configuration_utils import ConfigMixin
from .utils import DIFFUSERS_CACHE, logging from .utils import DIFFUSERS_CACHE, logging
...@@ -120,6 +121,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -120,6 +121,7 @@ class DiffusionPipeline(ConfigMixin):
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None) use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
# 1. Download the checkpoints and configs # 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained # use snapshot download here to get it working from from_pretrained
...@@ -131,6 +133,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -131,6 +133,7 @@ class DiffusionPipeline(ConfigMixin):
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
revision=revision,
) )
else: else:
cached_folder = pretrained_model_name_or_path cached_folder = pretrained_model_name_or_path
...@@ -187,3 +190,15 @@ class DiffusionPipeline(ConfigMixin): ...@@ -187,3 +190,15 @@ class DiffusionPipeline(ConfigMixin):
# 5. Instantiate the pipeline # 5. Instantiate the pipeline
model = pipeline_class(**init_kwargs) model = pipeline_class(**init_kwargs)
return model return model
@staticmethod
def numpy_to_pil(images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images
...@@ -28,7 +28,7 @@ class DDIMPipeline(DiffusionPipeline): ...@@ -28,7 +28,7 @@ class DDIMPipeline(DiffusionPipeline):
self.register_modules(unet=unet, scheduler=scheduler) self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad() @torch.no_grad()
def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50): def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil"):
# eta corresponds to η in paper and should be between [0, 1] # eta corresponds to η in paper and should be between [0, 1]
if torch_device is None: if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
...@@ -53,4 +53,9 @@ class DDIMPipeline(DiffusionPipeline): ...@@ -53,4 +53,9 @@ class DDIMPipeline(DiffusionPipeline):
# do x_t -> x_t-1 # do x_t -> x_t-1
image = self.scheduler.step(model_output, t, image, eta)["prev_sample"] image = self.scheduler.step(model_output, t, image, eta)["prev_sample"]
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image} return {"sample": image}
...@@ -28,7 +28,7 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -28,7 +28,7 @@ class DDPMPipeline(DiffusionPipeline):
self.register_modules(unet=unet, scheduler=scheduler) self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad() @torch.no_grad()
def __call__(self, batch_size=1, generator=None, torch_device=None): def __call__(self, batch_size=1, generator=None, torch_device=None, output_type="pil"):
if torch_device is None: if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
...@@ -54,4 +54,9 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -54,4 +54,9 @@ class DDPMPipeline(DiffusionPipeline):
# 3. set current image to prev_image: x_t -> x_t-1 # 3. set current image to prev_image: x_t -> x_t-1
image = pred_prev_image image = pred_prev_image
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image} return {"sample": image}
...@@ -4,7 +4,7 @@ import torch ...@@ -4,7 +4,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint import torch.utils.checkpoint
import tqdm from tqdm.auto import tqdm
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import BaseModelOutput from transformers.modeling_outputs import BaseModelOutput
...@@ -30,51 +30,67 @@ class LatentDiffusionPipeline(DiffusionPipeline): ...@@ -30,51 +30,67 @@ class LatentDiffusionPipeline(DiffusionPipeline):
eta=0.0, eta=0.0,
guidance_scale=1.0, guidance_scale=1.0,
num_inference_steps=50, num_inference_steps=50,
output_type="pil",
): ):
# eta corresponds to η in paper and should be between [0, 1] # eta corresponds to η in paper and should be between [0, 1]
if torch_device is None: if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = len(prompt)
self.unet.to(torch_device) self.unet.to(torch_device)
self.vqvae.to(torch_device) self.vqvae.to(torch_device)
self.bert.to(torch_device) self.bert.to(torch_device)
# get unconditional embeddings for classifier free guidence # get unconditional embeddings for classifier free guidance
if guidance_scale != 1.0: if guidance_scale != 1.0:
uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors="pt").to( uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
torch_device uncond_embeddings = self.bert(uncond_input.input_ids.to(torch_device))
)
uncond_embeddings = self.bert(uncond_input.input_ids)
# get text embedding # get prompt text embeddings
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device) text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
text_embedding = self.bert(text_input.input_ids) text_embeddings = self.bert(text_input.input_ids.to(torch_device))
image = torch.randn( latents = torch.randn(
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size), (batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
generator=generator, generator=generator,
).to(torch_device) )
latents = latents.to(torch_device)
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
for t in tqdm.tqdm(self.scheduler.timesteps): for t in tqdm(self.scheduler.timesteps):
# 1. predict noise residual if guidance_scale == 1.0:
pred_noise_t = self.unet(image, t, encoder_hidden_states=text_embedding) # guidance_scale of 1 means no guidance
latents_input = latents
context = text_embeddings
else:
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
latents_input = torch.cat([latents] * 2)
context = torch.cat([uncond_embeddings, text_embeddings])
# predict the noise residual
noise_pred = self.unet(latents_input, t, encoder_hidden_states=context)["sample"]
# perform guidance
if guidance_scale != 1.0:
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
if isinstance(pred_noise_t, dict): # compute the previous noisy sample x_t -> x_t-1
pred_noise_t = pred_noise_t["sample"] latents = self.scheduler.step(noise_pred, t, latents, eta)["prev_sample"]
# 2. predict previous mean of image x_t-1 and add variance depending on eta # scale and decode the image latents with vae
# do x_t -> x_t-1 latents = 1 / 0.18215 * latents
image = self.scheduler.step(pred_noise_t, t, image, eta)["prev_sample"] image = self.vqvae.decode(latents)
# scale and decode image with vae image = (image / 2 + 0.5).clamp(0, 1)
image = 1 / 0.18215 * image image = image.cpu().permute(0, 2, 3, 1).numpy()
image = self.vqvae.decode(image) if output_type == "pil":
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) image = self.numpy_to_pil(image)
return image return {"sample": image}
################################################################################ ################################################################################
......
...@@ -13,12 +13,7 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline): ...@@ -13,12 +13,7 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil"
batch_size=1,
generator=None,
torch_device=None,
eta=0.0,
num_inference_steps=50,
): ):
# eta corresponds to η in paper and should be between [0, 1] # eta corresponds to η in paper and should be between [0, 1]
...@@ -28,25 +23,26 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline): ...@@ -28,25 +23,26 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
self.unet.to(torch_device) self.unet.to(torch_device)
self.vqvae.to(torch_device) self.vqvae.to(torch_device)
image = torch.randn( latents = torch.randn(
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size), (batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
generator=generator, generator=generator,
).to(torch_device) )
latents = latents.to(torch_device)
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
for t in tqdm(self.scheduler.timesteps): for t in tqdm(self.scheduler.timesteps):
with torch.no_grad(): # predict the noise residual
model_output = self.unet(image, t) noise_prediction = self.unet(latents, t)["sample"]
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_prediction, t, latents, eta)["prev_sample"]
if isinstance(model_output, dict): # decode the image latents with the VAE
model_output = model_output["sample"] image = self.vqvae.decode(latents)
# 2. predict previous mean of image x_t-1 and add variance depending on eta image = (image / 2 + 0.5).clamp(0, 1)
# do x_t -> x_t-1 image = image.cpu().permute(0, 2, 3, 1).numpy()
image = self.scheduler.step(model_output, t, image, eta)["prev_sample"] if output_type == "pil":
image = self.numpy_to_pil(image)
# decode image with vae
with torch.no_grad():
image = self.vqvae.decode(image)
return {"sample": image} return {"sample": image}
...@@ -28,7 +28,7 @@ class PNDMPipeline(DiffusionPipeline): ...@@ -28,7 +28,7 @@ class PNDMPipeline(DiffusionPipeline):
self.register_modules(unet=unet, scheduler=scheduler) self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad() @torch.no_grad()
def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50): def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50, output_type="pil"):
# For more information on the sampling method you can take a look at Algorithm 2 of # For more information on the sampling method you can take a look at Algorithm 2 of
# the official paper: https://arxiv.org/pdf/2202.09778.pdf # the official paper: https://arxiv.org/pdf/2202.09778.pdf
if torch_device is None: if torch_device is None:
...@@ -43,18 +43,20 @@ class PNDMPipeline(DiffusionPipeline): ...@@ -43,18 +43,20 @@ class PNDMPipeline(DiffusionPipeline):
) )
image = image.to(torch_device) image = image.to(torch_device)
prk_time_steps = self.scheduler.get_prk_time_steps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
for t in tqdm(range(len(prk_time_steps))): for i, t in enumerate(tqdm(self.scheduler.prk_timesteps)):
t_orig = prk_time_steps[t] model_output = self.unet(image, t)["sample"]
model_output = self.unet(image, t_orig)["sample"]
image = self.scheduler.step_prk(model_output, t, image, num_inference_steps)["prev_sample"] image = self.scheduler.step_prk(model_output, i, image, num_inference_steps)["prev_sample"]
timesteps = self.scheduler.get_time_steps(num_inference_steps) for i, t in enumerate(tqdm(self.scheduler.plms_timesteps)):
for t in tqdm(range(len(timesteps))): model_output = self.unet(image, t)["sample"]
t_orig = timesteps[t]
model_output = self.unet(image, t_orig)["sample"]
image = self.scheduler.step_plms(model_output, t, image, num_inference_steps)["prev_sample"] image = self.scheduler.step_plms(model_output, i, image, num_inference_steps)["prev_sample"]
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image} return {"sample": image}
...@@ -2,15 +2,16 @@ ...@@ -2,15 +2,16 @@
import torch import torch
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from tqdm.auto import tqdm
# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names
class ScoreSdeVePipeline(DiffusionPipeline): class ScoreSdeVePipeline(DiffusionPipeline):
def __init__(self, model, scheduler): def __init__(self, model, scheduler):
super().__init__() super().__init__()
self.register_modules(model=model, scheduler=scheduler) self.register_modules(model=model, scheduler=scheduler)
def __call__(self, num_inference_steps=2000, generator=None): @torch.no_grad()
def __call__(self, num_inference_steps=2000, generator=None, output_type="pil"):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
img_size = self.model.config.image_size img_size = self.model.config.image_size
...@@ -24,11 +25,10 @@ class ScoreSdeVePipeline(DiffusionPipeline): ...@@ -24,11 +25,10 @@ class ScoreSdeVePipeline(DiffusionPipeline):
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.set_sigmas(num_inference_steps) self.scheduler.set_sigmas(num_inference_steps)
for i, t in enumerate(self.scheduler.timesteps): for i, t in tqdm(enumerate(self.scheduler.timesteps)):
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device) sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device)
for _ in range(self.scheduler.correct_steps): for _ in range(self.scheduler.correct_steps):
with torch.no_grad():
model_output = self.model(sample, sigma_t) model_output = self.model(sample, sigma_t)
if isinstance(model_output, dict): if isinstance(model_output, dict):
...@@ -45,4 +45,9 @@ class ScoreSdeVePipeline(DiffusionPipeline): ...@@ -45,4 +45,9 @@ class ScoreSdeVePipeline(DiffusionPipeline):
output = self.scheduler.step_pred(model_output, t, sample) output = self.scheduler.step_pred(model_output, t, sample)
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"] sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
return sample_mean sample = sample.clamp(0, 1)
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
sample = self.numpy_to_pil(sample)
return {"sample": sample}
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math import math
import pdb
from typing import Union from typing import Union
import numpy as np import numpy as np
...@@ -71,8 +72,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -71,8 +72,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.one = np.array(1.0) self.one = np.array(1.0)
self.set_format(tensor_format=tensor_format)
# For now we only support F-PNDM, i.e. the runge-kutta method # For now we only support F-PNDM, i.e. the runge-kutta method
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
# mainly at formula (9), (12), (13) and the Algorithm 2. # mainly at formula (9), (12), (13) and the Algorithm 2.
...@@ -82,49 +81,29 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -82,49 +81,29 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.cur_model_output = 0 self.cur_model_output = 0
self.cur_sample = None self.cur_sample = None
self.ets = [] self.ets = []
self.prk_time_steps = {}
self.time_steps = {}
self.set_prk_mode()
def get_prk_time_steps(self, num_inference_steps): # setable values
if num_inference_steps in self.prk_time_steps: self.num_inference_steps = None
return self.prk_time_steps[num_inference_steps] self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
self.prk_timesteps = None
self.plms_timesteps = None
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
inference_step_times = list( def set_timesteps(self, num_inference_steps):
self.num_inference_steps = num_inference_steps
self.timesteps = list(
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps) range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
) )
prk_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile( prk_time_steps = np.array(self.timesteps[-self.pndm_order :]).repeat(2) + np.tile(
np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
) )
self.prk_time_steps[num_inference_steps] = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1])) self.prk_timesteps = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1]))
self.plms_timesteps = list(reversed(self.timesteps[:-3]))
return self.prk_time_steps[num_inference_steps]
def get_time_steps(self, num_inference_steps):
if num_inference_steps in self.time_steps:
return self.time_steps[num_inference_steps]
inference_step_times = list(
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
)
self.time_steps[num_inference_steps] = list(reversed(inference_step_times[:-3]))
return self.time_steps[num_inference_steps]
def set_prk_mode(self):
self.mode = "prk"
def set_plms_mode(self):
self.mode = "plms"
def step(self, *args, **kwargs):
if self.mode == "prk":
return self.step_prk(*args, **kwargs)
if self.mode == "plms":
return self.step_plms(*args, **kwargs)
raise ValueError(f"mode {self.mode} does not exist.") self.set_format(tensor_format=self.tensor_format)
def step_prk( def step_prk(
self, self,
...@@ -138,7 +117,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -138,7 +117,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
solution to the differential equation. solution to the differential equation.
""" """
t = timestep t = timestep
prk_time_steps = self.get_prk_time_steps(num_inference_steps) prk_time_steps = self.prk_timesteps
t_orig = prk_time_steps[t // 4 * 4] t_orig = prk_time_steps[t // 4 * 4]
t_orig_prev = prk_time_steps[min(t + 1, len(prk_time_steps) - 1)] t_orig_prev = prk_time_steps[min(t + 1, len(prk_time_steps) - 1)]
...@@ -180,7 +159,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -180,7 +159,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
"for more information." "for more information."
) )
timesteps = self.get_time_steps(num_inference_steps) timesteps = self.plms_timesteps
t_orig = timesteps[t] t_orig = timesteps[t]
t_orig_prev = timesteps[min(t + 1, len(timesteps) - 1)] t_orig_prev = timesteps[min(t + 1, len(timesteps) - 1)]
......
...@@ -18,11 +18,11 @@ import inspect ...@@ -18,11 +18,11 @@ import inspect
import math import math
import tempfile import tempfile
import unittest import unittest
from atexit import register
import numpy as np import numpy as np
import torch import torch
import PIL
from diffusers import UNetConditionalModel # noqa: F401 TODO(Patrick) - need to write tests with it from diffusers import UNetConditionalModel # noqa: F401 TODO(Patrick) - need to write tests with it
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
...@@ -704,11 +704,11 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -704,11 +704,11 @@ class PipelineTesterMixin(unittest.TestCase):
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ddpm(generator=generator)["sample"] image = ddpm(generator=generator, output_type="numpy")["sample"]
generator = generator.manual_seed(0) generator = generator.manual_seed(0)
new_image = new_ddpm(generator=generator)["sample"] new_image = new_ddpm(generator=generator, output_type="numpy")["sample"]
assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass" assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
@slow @slow
def test_from_pretrained_hub(self): def test_from_pretrained_hub(self):
...@@ -722,11 +722,32 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -722,11 +722,32 @@ class PipelineTesterMixin(unittest.TestCase):
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ddpm(generator=generator)["sample"] image = ddpm(generator=generator, output_type="numpy")["sample"]
generator = generator.manual_seed(0) generator = generator.manual_seed(0)
new_image = ddpm_from_hub(generator=generator)["sample"] new_image = ddpm_from_hub(generator=generator, output_type="numpy")["sample"]
assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass" assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
@slow
def test_output_format(self):
model_path = "google/ddpm-cifar10-32"
pipe = DDIMPipeline.from_pretrained(model_path)
generator = torch.manual_seed(0)
images = pipe(generator=generator, output_type="numpy")["sample"]
assert images.shape == (1, 32, 32, 3)
assert isinstance(images, np.ndarray)
images = pipe(generator=generator, output_type="pil")["sample"]
assert isinstance(images, list)
assert len(images) == 1
assert isinstance(images[0], PIL.Image.Image)
# use PIL by default
images = pipe(generator=generator)["sample"]
assert isinstance(images, list)
assert isinstance(images[0], PIL.Image.Image)
@slow @slow
def test_ddpm_cifar10(self): def test_ddpm_cifar10(self):
...@@ -739,15 +760,13 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -739,15 +760,13 @@ class PipelineTesterMixin(unittest.TestCase):
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ddpm(generator=generator)["sample"] image = ddpm(generator=generator, output_type="numpy")["sample"]
image_slice = image[0, -1, -3:, -3:].cpu() image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 3, 32, 32) assert image.shape == (1, 32, 32, 3)
expected_slice = torch.tensor( expected_slice = np.array([0.41995, 0.35885, 0.19385, 0.38475, 0.3382, 0.2647, 0.41545, 0.3582, 0.33845])
[-0.1601, -0.2823, -0.6123, -0.2305, -0.3236, -0.4706, -0.1691, -0.2836, -0.3231] assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
@slow @slow
def test_ddim_lsun(self): def test_ddim_lsun(self):
...@@ -759,15 +778,13 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -759,15 +778,13 @@ class PipelineTesterMixin(unittest.TestCase):
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ddpm(generator=generator)["sample"] image = ddpm(generator=generator, output_type="numpy")["sample"]
image_slice = image[0, -1, -3:, -3:].cpu() image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 3, 256, 256) assert image.shape == (1, 256, 256, 3)
expected_slice = torch.tensor( expected_slice = np.array([0.00605, 0.0201, 0.0344, 0.00235, 0.00185, 0.00025, 0.00215, 0.0, 0.00685])
[-0.9879, -0.9598, -0.9312, -0.9953, -0.9963, -0.9995, -0.9957, -1.0000, -0.9863] assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
@slow @slow
def test_ddim_cifar10(self): def test_ddim_cifar10(self):
...@@ -779,15 +796,13 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -779,15 +796,13 @@ class PipelineTesterMixin(unittest.TestCase):
ddim = DDIMPipeline(unet=unet, scheduler=scheduler) ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ddim(generator=generator, eta=0.0)["sample"] image = ddim(generator=generator, eta=0.0, output_type="numpy")["sample"]
image_slice = image[0, -1, -3:, -3:].cpu() image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 3, 32, 32) assert image.shape == (1, 32, 32, 3)
expected_slice = torch.tensor( expected_slice = np.array([0.17235, 0.16175, 0.16005, 0.16255, 0.1497, 0.1513, 0.15045, 0.1442, 0.1453])
[-0.6553, -0.6765, -0.6799, -0.6749, -0.7006, -0.6974, -0.6991, -0.7116, -0.7094] assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
@slow @slow
def test_pndm_cifar10(self): def test_pndm_cifar10(self):
...@@ -798,15 +813,13 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -798,15 +813,13 @@ class PipelineTesterMixin(unittest.TestCase):
pndm = PNDMPipeline(unet=unet, scheduler=scheduler) pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = pndm(generator=generator)["sample"] image = pndm(generator=generator, output_type="numpy")["sample"]
image_slice = image[0, -1, -3:, -3:].cpu() image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 3, 32, 32) assert image.shape == (1, 32, 32, 3)
expected_slice = torch.tensor( expected_slice = np.array([0.1564, 0.14645, 0.1406, 0.14715, 0.12425, 0.14045, 0.13115, 0.12175, 0.125])
[-0.6872, -0.7071, -0.7188, -0.7057, -0.7515, -0.7191, -0.7377, -0.7565, -0.7500] assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
@slow @slow
def test_ldm_text2img(self): def test_ldm_text2img(self):
...@@ -814,13 +827,15 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -814,13 +827,15 @@ class PipelineTesterMixin(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ldm([prompt], generator=generator, num_inference_steps=20) image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy")[
"sample"
]
image_slice = image[0, -1, -3:, -3:].cpu() image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 3, 256, 256) assert image.shape == (1, 256, 256, 3)
expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458]) expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow @slow
def test_ldm_text2img_fast(self): def test_ldm_text2img_fast(self):
...@@ -828,55 +843,44 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -828,55 +843,44 @@ class PipelineTesterMixin(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ldm([prompt], generator=generator, num_inference_steps=1) image = ldm([prompt], generator=generator, num_inference_steps=1, output_type="numpy")["sample"]
image_slice = image[0, -1, -3:, -3:].cpu() image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 3, 256, 256) assert image.shape == (1, 256, 256, 3)
expected_slice = torch.tensor([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344]) expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow @slow
def test_score_sde_ve_pipeline(self): def test_score_sde_ve_pipeline(self):
model = UNetUnconditionalModel.from_pretrained("google/ncsnpp-ffhq-1024") model = UNetUnconditionalModel.from_pretrained("google/ncsnpp-church-256")
torch.manual_seed(0) torch.manual_seed(0)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed_all(0) torch.cuda.manual_seed_all(0)
scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-ffhq-1024") scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-church-256")
sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler) sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler)
torch.manual_seed(0) torch.manual_seed(0)
image = sde_ve(num_inference_steps=2) image = sde_ve(num_inference_steps=300, output_type="numpy")["sample"]
if model.device.type == "cpu":
# patrick's cpu
expected_image_sum = 3384805888.0
expected_image_mean = 1076.00085
# m1 mbp image_slice = image[0, -3:, -3:, -1]
# expected_image_sum = 3384805376.0
# expected_image_mean = 1076.000610351562
else:
expected_image_sum = 3382849024.0
expected_image_mean = 1075.3788
assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2 assert image.shape == (1, 256, 256, 3)
assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4 expected_slice = np.array([0.64363, 0.5868, 0.3031, 0.2284, 0.7409, 0.3216, 0.25643, 0.6557, 0.2633])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow @slow
def test_ldm_uncond(self): def test_ldm_uncond(self):
ldm = LatentDiffusionUncondPipeline.from_pretrained("CompVis/ldm-celebahq-256") ldm = LatentDiffusionUncondPipeline.from_pretrained("CompVis/ldm-celebahq-256")
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ldm(generator=generator, num_inference_steps=5)["sample"] image = ldm(generator=generator, num_inference_steps=5, output_type="numpy")["sample"]
image_slice = image[0, -1, -3:, -3:].cpu() image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 3, 256, 256) assert image.shape == (1, 256, 256, 3)
expected_slice = torch.tensor( expected_slice = np.array([0.4399, 0.44975, 0.46825, 0.474, 0.4359, 0.4581, 0.45095, 0.4341, 0.4447])
[-0.1202, -0.1005, -0.0635, -0.0520, -0.1282, -0.0838, -0.0981, -0.1318, -0.1106] assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
...@@ -70,7 +70,6 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -70,7 +70,6 @@ class SchedulerCommonTest(unittest.TestCase):
num_inference_steps = kwargs.pop("num_inference_steps", None) num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
scheduler_class = self.scheduler_classes[0]
sample = self.dummy_sample sample = self.dummy_sample
residual = 0.1 * sample residual = 0.1 * sample
...@@ -102,7 +101,6 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -102,7 +101,6 @@ class SchedulerCommonTest(unittest.TestCase):
sample = self.dummy_sample sample = self.dummy_sample
residual = 0.1 * sample residual = 0.1 * sample
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
...@@ -375,33 +373,40 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -375,33 +373,40 @@ class PNDMSchedulerTest(SchedulerCommonTest):
config.update(**kwargs) config.update(**kwargs)
return config return config
def check_over_configs_pmls(self, time_step=0, **config): def check_over_configs(self, time_step=0, **config):
kwargs = dict(self.forward_default_kwargs) kwargs = dict(self.forward_default_kwargs)
sample = self.dummy_sample sample = self.dummy_sample
residual = 0.1 * sample residual = 0.1 * sample
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05] dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config) scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(kwargs["num_inference_steps"])
# copy over dummy past residuals # copy over dummy past residuals
scheduler.ets = dummy_past_residuals[:] scheduler.ets = dummy_past_residuals[:]
scheduler.set_plms_mode()
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname)
new_scheduler.set_timesteps(kwargs["num_inference_steps"])
# copy over dummy past residuals # copy over dummy past residuals
new_scheduler.ets = dummy_past_residuals[:] new_scheduler.ets = dummy_past_residuals[:]
new_scheduler.set_plms_mode()
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"] output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"] new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def check_over_forward_pmls(self, time_step=0, **forward_kwargs): output = scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_from_pretrained_save_pretrained(self):
pass
def check_over_forward(self, time_step=0, **forward_kwargs):
kwargs = dict(self.forward_default_kwargs) kwargs = dict(self.forward_default_kwargs)
kwargs.update(forward_kwargs) kwargs.update(forward_kwargs)
sample = self.dummy_sample sample = self.dummy_sample
...@@ -409,74 +414,127 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -409,74 +414,127 @@ class PNDMSchedulerTest(SchedulerCommonTest):
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05] dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(kwargs["num_inference_steps"])
# copy over dummy past residuals # copy over dummy past residuals
scheduler.ets = dummy_past_residuals[:] scheduler.ets = dummy_past_residuals[:]
scheduler.set_plms_mode()
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname)
# copy over dummy past residuals # copy over dummy past residuals
new_scheduler.ets = dummy_past_residuals[:] new_scheduler.ets = dummy_past_residuals[:]
new_scheduler.set_plms_mode() new_scheduler.set_timesteps(kwargs["num_inference_steps"])
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"] output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"] new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
output = scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_pytorch_equal_numpy(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
sample = self.dummy_sample
residual = 0.1 * sample
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
sample_pt = torch.tensor(sample)
residual_pt = 0.1 * sample_pt
dummy_past_residuals_pt = [residual_pt + 0.2, residual_pt + 0.15, residual_pt + 0.1, residual_pt + 0.05]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
# copy over dummy past residuals
scheduler.ets = dummy_past_residuals[:]
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
# copy over dummy past residuals
scheduler_pt.ets = dummy_past_residuals_pt[:]
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
scheduler_pt.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step_prk(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, num_inference_steps, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
output = scheduler.step_plms(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, num_inference_steps, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
sample = self.dummy_sample
residual = 0.1 * sample
# copy over dummy past residuals
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
scheduler.ets = dummy_past_residuals[:]
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output_0 = scheduler.step_prk(residual, 0, sample, num_inference_steps, **kwargs)["prev_sample"]
output_1 = scheduler.step_prk(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
output_0 = scheduler.step_plms(residual, 0, sample, num_inference_steps, **kwargs)["prev_sample"]
output_1 = scheduler.step_plms(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
def test_timesteps(self): def test_timesteps(self):
for timesteps in [100, 1000]: for timesteps in [100, 1000]:
self.check_over_configs(num_train_timesteps=timesteps) self.check_over_configs(num_train_timesteps=timesteps)
def test_timesteps_pmls(self):
for timesteps in [100, 1000]:
self.check_over_configs_pmls(num_train_timesteps=timesteps)
def test_betas(self): def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]): for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]):
self.check_over_configs(beta_start=beta_start, beta_end=beta_end) self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
def test_betas_pmls(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]):
self.check_over_configs_pmls(beta_start=beta_start, beta_end=beta_end)
def test_schedules(self): def test_schedules(self):
for schedule in ["linear", "squaredcos_cap_v2"]: for schedule in ["linear", "squaredcos_cap_v2"]:
self.check_over_configs(beta_schedule=schedule) self.check_over_configs(beta_schedule=schedule)
def test_schedules_pmls(self):
for schedule in ["linear", "squaredcos_cap_v2"]:
self.check_over_configs(beta_schedule=schedule)
def test_time_indices(self): def test_time_indices(self):
for t in [1, 5, 10]: for t in [1, 5, 10]:
self.check_over_forward(time_step=t) self.check_over_forward(time_step=t)
def test_time_indices_pmls(self):
for t in [1, 5, 10]:
self.check_over_forward_pmls(time_step=t)
def test_inference_steps(self): def test_inference_steps(self):
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]): for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps) self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
def test_inference_steps_pmls(self): def test_inference_plms_no_past_residuals(self):
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
self.check_over_forward_pmls(time_step=t, num_inference_steps=num_inference_steps)
def test_inference_pmls_no_past_residuals(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
scheduler.set_plms_mode() scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample, 50)["prev_sample"]
scheduler.step(self.dummy_sample, 1, self.dummy_sample, 50)["prev_sample"]
def test_full_loop_no_noise(self): def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
...@@ -486,20 +544,15 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -486,20 +544,15 @@ class PNDMSchedulerTest(SchedulerCommonTest):
num_inference_steps = 10 num_inference_steps = 10
model = self.dummy_model() model = self.dummy_model()
sample = self.dummy_sample_deter sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)
prk_time_steps = scheduler.get_prk_time_steps(num_inference_steps) for i, t in enumerate(scheduler.prk_timesteps):
for t in range(len(prk_time_steps)): residual = model(sample, t)
t_orig = prk_time_steps[t] sample = scheduler.step_prk(residual, i, sample, num_inference_steps)["prev_sample"]
residual = model(sample, t_orig)
sample = scheduler.step_prk(residual, t, sample, num_inference_steps)["prev_sample"]
timesteps = scheduler.get_time_steps(num_inference_steps)
for t in range(len(timesteps)):
t_orig = timesteps[t]
residual = model(sample, t_orig)
sample = scheduler.step_plms(residual, t, sample, num_inference_steps)["prev_sample"] for i, t in enumerate(scheduler.plms_timesteps):
residual = model(sample, t)
sample = scheduler.step_plms(residual, i, sample, num_inference_steps)["prev_sample"]
result_sum = np.sum(np.abs(sample)) result_sum = np.sum(np.abs(sample))
result_mean = np.mean(np.abs(sample)) result_mean = np.mean(np.abs(sample))
...@@ -562,7 +615,6 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase): ...@@ -562,7 +615,6 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
kwargs = dict(self.forward_default_kwargs) kwargs = dict(self.forward_default_kwargs)
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
scheduler_class = self.scheduler_classes[0]
sample = self.dummy_sample sample = self.dummy_sample
residual = 0.1 * sample residual = 0.1 * sample
...@@ -591,7 +643,6 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase): ...@@ -591,7 +643,6 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
sample = self.dummy_sample sample = self.dummy_sample
residual = 0.1 * sample residual = 0.1 * sample
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment