Commit 542c7868 authored by patil-suraj's avatar patil-suraj
Browse files

Merge branch 'main' of https://github.com/huggingface/diffusers into main

parents 147d8e07 da1f920e
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!) # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
export PYTHONPATH = src export PYTHONPATH = src
check_dirs := tests src utils check_dirs := examples tests src utils
modified_only_fixup: modified_only_fixup:
$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs))) $(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))
......
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import tqdm
from ..pipeline_utils import DiffusionPipeline
class PNDM(DiffusionPipeline):
def __init__(self, unet, noise_scheduler):
super().__init__()
noise_scheduler = noise_scheduler.set_format("pt")
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50):
# eta corresponds to η in paper and should be between [0, 1]
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
num_trained_timesteps = self.noise_scheduler.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
self.unet.to(torch_device)
# Sample gaussian noise to begin loop
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
generator=generator,
)
image = image.to(torch_device)
seq = list(inference_step_times)
seq_next = [-1] + list(seq[:-1])
model = self.unet
warmup_steps = [len(seq) - (i // 4 + 1) for i in range(3 * 4)]
ets = []
prev_image = image
for i, step_idx in enumerate(warmup_steps):
i = seq[step_idx]
j = seq_next[step_idx]
t = (torch.ones(image.shape[0]) * i)
t_next = (torch.ones(image.shape[0]) * j)
residual = model(image.to("cuda"), t.to("cuda"))
residual = residual.to("cpu")
image = image.to("cpu")
image = self.noise_scheduler.transfer(prev_image.to("cpu"), t_list[0], t_list[1], residual)
if i % 4 == 0:
ets.append(residual)
prev_image = image
for
ets = []
step_idx = len(seq) - 1
while step_idx >= 0:
i = seq[step_idx]
j = seq_next[step_idx]
t = (torch.ones(image.shape[0]) * i)
t_next = (torch.ones(image.shape[0]) * j)
residual = model(image.to("cuda"), t.to("cuda"))
residual = residual.to("cpu")
t_list = [t, (t+t_next)/2, t_next]
ets.append(residual)
if len(ets) <= 3:
image = image.to("cpu")
x_2 = self.noise_scheduler.transfer(image.to("cpu"), t_list[0], t_list[1], residual)
e_2 = model(x_2.to("cuda"), t_list[1].to("cuda")).to("cpu")
x_3 = self.noise_scheduler.transfer(image, t_list[0], t_list[1], e_2)
e_3 = model(x_3.to("cuda"), t_list[1].to("cuda")).to("cpu")
x_4 = self.noise_scheduler.transfer(image, t_list[0], t_list[2], e_3)
e_4 = model(x_4.to("cuda"), t_list[2].to("cuda")).to("cpu")
residual = (1 / 6) * (residual + 2 * e_2 + 2 * e_3 + e_4)
else:
residual = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4])
img_next = self.noise_scheduler.transfer(image.to("cpu"), t, t_next, residual)
image = img_next
step_idx = step_idx - 1
# if len(prev_noises) in [1, 2]:
# t = (t + t_next) / 2
# elif len(prev_noises) == 3:
# t = t_next / 2
# if len(prev_noises) == 0:
# ets.append(residual)
#
# if len(ets) > 3:
# residual = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4])
# step_idx = step_idx - 1
# elif len(ets) <= 3 and len(prev_noises) == 3:
# residual = (1 / 6) * (prev_noises[-3] + 2 * prev_noises[-2] + 2 * prev_noises[-1] + residual)
# prev_noises = []
# step_idx = step_idx - 1
# elif len(ets) <= 3 and len(prev_noises) < 3:
# prev_noises.append(residual)
# if len(prev_noises) < 2:
# t_next = (t + t_next) / 2
#
# image = self.noise_scheduler.transfer(image.to("cpu"), t, t_next, residual)
return image
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
# for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
# 1. predict noise residual
# with torch.no_grad():
# residual = self.unet(image, inference_step_times[t])
#
# 2. predict previous mean of image x_t-1
# pred_prev_image = self.noise_scheduler.step(residual, image, t, num_inference_steps, eta)
#
# 3. optionally sample variance
# variance = 0
# if eta > 0:
# noise = torch.randn(image.shape, generator=generator).to(image.device)
# variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
#
# 4. set current image to prev_image: x_t -> x_t-1
# image = pred_prev_image + variance
...@@ -8,14 +8,23 @@ import PIL.Image ...@@ -8,14 +8,23 @@ import PIL.Image
from accelerate import Accelerator from accelerate import Accelerator
from datasets import load_dataset from datasets import load_dataset
from diffusers import DDPM, DDPMScheduler, UNetModel from diffusers import DDPM, DDPMScheduler, UNetModel
from torchvision.transforms import CenterCrop, Compose, Lambda, RandomHorizontalFlip, Resize, ToTensor from torchvision.transforms import (
Compose,
InterpolationMode,
Lambda,
RandomCrop,
RandomHorizontalFlip,
RandomVerticalFlip,
Resize,
ToTensor,
)
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import get_linear_schedule_with_warmup from transformers import get_linear_schedule_with_warmup
def set_seed(seed): def set_seed(seed):
torch.backends.cudnn.deterministic = True # torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False # torch.backends.cudnn.benchmark = False
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
np.random.seed(seed) np.random.seed(seed)
...@@ -30,13 +39,13 @@ model = UNetModel( ...@@ -30,13 +39,13 @@ model = UNetModel(
attn_resolutions=(16,), attn_resolutions=(16,),
ch=128, ch=128,
ch_mult=(1, 2, 2, 2), ch_mult=(1, 2, 2, 2),
dropout=0.1, dropout=0.0,
num_res_blocks=2, num_res_blocks=2,
resamp_with_conv=True, resamp_with_conv=True,
resolution=32 resolution=32,
) )
noise_scheduler = DDPMScheduler(timesteps=1000) noise_scheduler = DDPMScheduler(timesteps=1000)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0002) optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
num_epochs = 100 num_epochs = 100
batch_size = 64 batch_size = 64
...@@ -44,14 +53,15 @@ gradient_accumulation_steps = 2 ...@@ -44,14 +53,15 @@ gradient_accumulation_steps = 2
augmentations = Compose( augmentations = Compose(
[ [
Resize(32), Resize(32, interpolation=InterpolationMode.BILINEAR),
CenterCrop(32),
RandomHorizontalFlip(), RandomHorizontalFlip(),
RandomVerticalFlip(),
RandomCrop(32),
ToTensor(), ToTensor(),
Lambda(lambda x: x * 2 - 1), Lambda(lambda x: x * 2 - 1),
] ]
) )
dataset = load_dataset("huggan/pokemon", split="train") dataset = load_dataset("huggan/flowers-102-categories", split="train")
def transforms(examples): def transforms(examples):
...@@ -59,24 +69,24 @@ def transforms(examples): ...@@ -59,24 +69,24 @@ def transforms(examples):
return {"input": images} return {"input": images}
dataset = dataset.shuffle(seed=0)
dataset.set_transform(transforms) dataset.set_transform(transforms)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
#lr_scheduler = get_linear_schedule_with_warmup( lr_scheduler = get_linear_schedule_with_warmup(
# optimizer=optimizer, optimizer=optimizer,
# num_warmup_steps=1000, num_warmup_steps=500,
# num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps, num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,
#) )
model, optimizer, train_dataloader = accelerator.prepare( model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader model, optimizer, train_dataloader, lr_scheduler
) )
for epoch in range(num_epochs): for epoch in range(num_epochs):
model.train() model.train()
pbar = tqdm(total=len(train_dataloader), unit="ba") pbar = tqdm(total=len(train_dataloader), unit="ba")
pbar.set_description(f"Epoch {epoch}") pbar.set_description(f"Epoch {epoch}")
losses = []
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
clean_images = batch["input"] clean_images = batch["input"]
noisy_images = torch.empty_like(clean_images) noisy_images = torch.empty_like(clean_images)
...@@ -101,10 +111,12 @@ for epoch in range(num_epochs): ...@@ -101,10 +111,12 @@ for epoch in range(num_epochs):
accelerator.backward(loss) accelerator.backward(loss)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step() optimizer.step()
# lr_scheduler.step() lr_scheduler.step()
optimizer.zero_grad() optimizer.zero_grad()
loss = loss.detach().item()
losses.append(loss)
pbar.update(1) pbar.update(1)
pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"]) pbar.set_postfix(loss=loss, avg_loss=np.mean(losses), lr=optimizer.param_groups[0]["lr"])
optimizer.step() optimizer.step()
...@@ -124,5 +136,5 @@ for epoch in range(num_epochs): ...@@ -124,5 +136,5 @@ for epoch in range(num_epochs):
image_pil = PIL.Image.fromarray(image_processed[0]) image_pil = PIL.Image.fromarray(image_processed[0])
# save image # save image
pipeline.save_pretrained("./poke-ddpm") pipeline.save_pretrained("./flowers-ddpm")
image_pil.save(f"./poke-ddpm/test_{epoch}.png") image_pil.save(f"./flowers-ddpm/test_{epoch}.png")
...@@ -225,11 +225,8 @@ class ConfigMixin: ...@@ -225,11 +225,8 @@ class ConfigMixin:
text = reader.read() text = reader.read()
return json.loads(text) return json.loads(text)
# def __eq__(self, other): def __repr__(self):
# return self.__dict__ == other.__dict__ return f"{self.__class__.__name__} {self.to_json_string()}"
# def __repr__(self):
# return f"{self.__class__.__name__} {self.to_json_string()}"
@property @property
def config(self) -> Dict[str, Any]: def config(self) -> Dict[str, Any]:
......
...@@ -832,12 +832,12 @@ class GLIDE(DiffusionPipeline): ...@@ -832,12 +832,12 @@ class GLIDE(DiffusionPipeline):
# 1. Sample gaussian noise # 1. Sample gaussian noise
batch_size = 2 # second image is empty for classifier-free guidance batch_size = 2 # second image is empty for classifier-free guidance
image = self.text_noise_scheduler.sample_noise( image = torch.randn(
(batch_size, self.text_unet.in_channels, 64, 64), device=torch_device, generator=generator (batch_size, self.text_unet.in_channels, 64, 64), generator=generator
) ).to(torch_device)
# 2. Encode tokens # 2. Encode tokens
# an empty input is needed to guide the model away from ( # an empty input is needed to guide the model away from it
inputs = self.tokenizer([prompt, ""], padding="max_length", max_length=128, return_tensors="pt") inputs = self.tokenizer([prompt, ""], padding="max_length", max_length=128, return_tensors="pt")
input_ids = inputs["input_ids"].to(torch_device) input_ids = inputs["input_ids"].to(torch_device)
attention_mask = inputs["attention_mask"].to(torch_device) attention_mask = inputs["attention_mask"].to(torch_device)
...@@ -850,7 +850,7 @@ class GLIDE(DiffusionPipeline): ...@@ -850,7 +850,7 @@ class GLIDE(DiffusionPipeline):
mean, variance, log_variance, pred_xstart = self.p_mean_variance( mean, variance, log_variance, pred_xstart = self.p_mean_variance(
text_model_fn, self.text_noise_scheduler, image, t, transformer_out=transformer_out text_model_fn, self.text_noise_scheduler, image, t, transformer_out=transformer_out
) )
noise = self.text_noise_scheduler.sample_noise(image.shape, device=torch_device, generator=generator) noise = torch.randn(image.shape, generator=generator).to(torch_device)
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0 nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0
image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise
...@@ -873,8 +873,8 @@ class GLIDE(DiffusionPipeline): ...@@ -873,8 +873,8 @@ class GLIDE(DiffusionPipeline):
self.upscale_unet.resolution, self.upscale_unet.resolution,
), ),
generator=generator, generator=generator,
) ).to(torch_device)
image = image.to(torch_device) * upsample_temp image = image * upsample_temp
num_trained_timesteps = self.upscale_noise_scheduler.timesteps num_trained_timesteps = self.upscale_noise_scheduler.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps_upscale) inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps_upscale)
...@@ -896,7 +896,7 @@ class GLIDE(DiffusionPipeline): ...@@ -896,7 +896,7 @@ class GLIDE(DiffusionPipeline):
# 3. optionally sample variance # 3. optionally sample variance
variance = 0 variance = 0
if eta > 0: if eta > 0:
noise = torch.randn(image.shape, generator=generator).to(image.device) noise = torch.randn(image.shape, generator=generator).to(torch_device)
variance = ( variance = (
self.upscale_noise_scheduler.get_variance(t, num_inference_steps_upscale).sqrt() * eta * noise self.upscale_noise_scheduler.get_variance(t, num_inference_steps_upscale).sqrt() * eta * noise
) )
......
...@@ -28,13 +28,11 @@ class PNDM(DiffusionPipeline): ...@@ -28,13 +28,11 @@ class PNDM(DiffusionPipeline):
self.register_modules(unet=unet, noise_scheduler=noise_scheduler) self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50): def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50):
# eta corresponds to η in paper and should be between [0, 1] # For more information on the sampling method you can take a look at Algorithm 2 of
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
if torch_device is None: if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
num_trained_timesteps = self.noise_scheduler.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
self.unet.to(torch_device) self.unet.to(torch_device)
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
...@@ -44,91 +42,18 @@ class PNDM(DiffusionPipeline): ...@@ -44,91 +42,18 @@ class PNDM(DiffusionPipeline):
) )
image = image.to(torch_device) image = image.to(torch_device)
seq = list(inference_step_times) warmup_time_steps = self.noise_scheduler.get_warmup_time_steps(num_inference_steps)
seq_next = [-1] + list(seq[:-1]) for t in tqdm.tqdm(range(len(warmup_time_steps))):
model = self.unet t_orig = warmup_time_steps[t]
residual = self.unet(image, t_orig)
ets = []
prev_noises = []
step_idx = len(seq) - 1
while step_idx >= 0:
i = seq[step_idx]
j = seq_next[step_idx]
t = (torch.ones(image.shape[0]) * i)
t_next = (torch.ones(image.shape[0]) * j)
residual = model(image.to("cuda"), t.to("cuda"))
residual = residual.to("cpu")
t_list = [t, (t+t_next)/2, t_next] image = self.noise_scheduler.step_prk(residual, image, t, num_inference_steps)
ets.append(residual) timesteps = self.noise_scheduler.get_time_steps(num_inference_steps)
if len(ets) <= 3: for t in tqdm.tqdm(range(len(timesteps))):
image = image.to("cpu") t_orig = timesteps[t]
x_2 = self.noise_scheduler.transfer(image.to("cpu"), t_list[0], t_list[1], residual) residual = self.unet(image, t_orig)
e_2 = model(x_2.to("cuda"), t_list[1].to("cuda")).to("cpu") image = self.noise_scheduler.step_plms(residual, image, t, num_inference_steps)
x_3 = self.noise_scheduler.transfer(image, t_list[0], t_list[1], e_2)
e_3 = model(x_3.to("cuda"), t_list[1].to("cuda")).to("cpu")
x_4 = self.noise_scheduler.transfer(image, t_list[0], t_list[2], e_3)
e_4 = model(x_4.to("cuda"), t_list[2].to("cuda")).to("cpu")
residual = (1 / 6) * (residual + 2 * e_2 + 2 * e_3 + e_4)
else:
residual = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4])
img_next = self.noise_scheduler.transfer(image.to("cpu"), t, t_next, residual)
image = img_next
step_idx = step_idx - 1
# if len(prev_noises) in [1, 2]:
# t = (t + t_next) / 2
# elif len(prev_noises) == 3:
# t = t_next / 2
# if len(prev_noises) == 0:
# ets.append(residual)
#
# if len(ets) > 3:
# residual = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4])
# step_idx = step_idx - 1
# elif len(ets) <= 3 and len(prev_noises) == 3:
# residual = (1 / 6) * (prev_noises[-3] + 2 * prev_noises[-2] + 2 * prev_noises[-1] + residual)
# prev_noises = []
# step_idx = step_idx - 1
# elif len(ets) <= 3 and len(prev_noises) < 3:
# prev_noises.append(residual)
# if len(prev_noises) < 2:
# t_next = (t + t_next) / 2
#
# image = self.noise_scheduler.transfer(image.to("cpu"), t, t_next, residual)
return image return image
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
# for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
# 1. predict noise residual
# with torch.no_grad():
# residual = self.unet(image, inference_step_times[t])
#
# 2. predict previous mean of image x_t-1
# pred_prev_image = self.noise_scheduler.step(residual, image, t, num_inference_steps, eta)
#
# 3. optionally sample variance
# variance = 0
# if eta > 0:
# noise = torch.randn(image.shape, generator=generator).to(image.device)
# variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
#
# 4. set current image to prev_image: x_t -> x_t-1
# image = pred_prev_image + variance
...@@ -55,22 +55,17 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -55,22 +55,17 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.set_format(tensor_format=tensor_format) self.set_format(tensor_format=tensor_format)
# self.register_buffer("betas", betas.to(torch.float32)) # For now we only support F-PNDM, i.e. the runge-kutta method
# self.register_buffer("alphas", alphas.to(torch.float32)) # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
# self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32)) # mainly at equations (12) and (13) and the Algorithm 2.
self.pndm_order = 4
# alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
# TODO(PVP) - check how much of these is actually necessary! # running values
# LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ... self.cur_residual = 0
# https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246 self.cur_image = None
# variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) self.ets = []
# if variance_type == "fixed_small": self.warmup_time_steps = {}
# log_variance = torch.log(variance.clamp(min=1e-20)) self.time_steps = {}
# elif variance_type == "fixed_large":
# log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
#
#
# self.register_buffer("log_variance", log_variance.to(torch.float32))
def get_alpha(self, time_step): def get_alpha(self, time_step):
return self.alphas[time_step] return self.alphas[time_step]
...@@ -83,51 +78,64 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -83,51 +78,64 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return self.one return self.one
return self.alphas_cumprod[time_step] return self.alphas_cumprod[time_step]
def step(self, img, t_start, t_end, model, ets): def get_warmup_time_steps(self, num_inference_steps):
# img_next = self.method(img_n, t_start, t_end, model, self.alphas_cump, self.ets) if num_inference_steps in self.warmup_time_steps:
#def gen_order_4(img, t, t_next, model, alphas_cump, ets): return self.warmup_time_steps[num_inference_steps]
t_next, t = t_start, t_end
noise_ = model(img.to("cuda"), t.to("cuda")) inference_step_times = list(range(0, self.timesteps, self.timesteps // num_inference_steps))
noise_ = noise_.to("cpu")
t_list = [t, (t+t_next)/2, t_next] warmup_time_steps = np.array(inference_step_times[-self.pndm_order:]).repeat(2) + np.tile(np.array([0, self.timesteps // num_inference_steps // 2]), self.pndm_order)
if len(ets) > 2: self.warmup_time_steps[num_inference_steps] = list(reversed(warmup_time_steps[:-1].repeat(2)[1:-1]))
ets.append(noise_)
noise = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4]) return self.warmup_time_steps[num_inference_steps]
else:
noise = self.runge_kutta(img, t_list, model, ets, noise_) def get_time_steps(self, num_inference_steps):
if num_inference_steps in self.time_steps:
return self.time_steps[num_inference_steps]
img_next = self.transfer(img.to("cpu"), t, t_next, noise) inference_step_times = list(range(0, self.timesteps, self.timesteps // num_inference_steps))
return img_next, ets self.time_steps[num_inference_steps] = list(reversed(inference_step_times[:-3]))
def runge_kutta(self, x, t_list, model, ets, noise_): return self.time_steps[num_inference_steps]
model = model.to("cuda")
x = x.to("cpu")
e_1 = noise_ def step_prk(self, residual, image, t, num_inference_steps):
ets.append(e_1) # TODO(Patrick) - need to rethink whether the "warmup" way is the correct API design here
x_2 = self.transfer(x, t_list[0], t_list[1], e_1) warmup_time_steps = self.get_warmup_time_steps(num_inference_steps)
e_2 = model(x_2.to("cuda"), t_list[1].to("cuda")) t_prev = warmup_time_steps[t // 4 * 4]
e_2 = e_2.to("cpu") t_next = warmup_time_steps[min(t + 1, len(warmup_time_steps) - 1)]
x_3 = self.transfer(x, t_list[0], t_list[1], e_2)
e_3 = model(x_3.to("cuda"), t_list[1].to("cuda")) if t % 4 == 0:
e_3 = e_3.to("cpu") self.cur_residual += 1 / 6 * residual
x_4 = self.transfer(x, t_list[0], t_list[2], e_3) self.ets.append(residual)
self.cur_image = image
elif (t - 1) % 4 == 0:
self.cur_residual += 1 / 3 * residual
elif (t - 2) % 4 == 0:
self.cur_residual += 1 / 3 * residual
elif (t - 3) % 4 == 0:
residual = self.cur_residual + 1 / 6 * residual
self.cur_residual = 0
e_4 = model(x_4.to("cuda"), t_list[2].to("cuda")) return self.transfer(self.cur_image, t_prev, t_next, residual)
e_4 = e_4.to("cpu")
et = (1 / 6) * (e_1 + 2 * e_2 + 2 * e_3 + e_4) def step_plms(self, residual, image, t, num_inference_steps):
timesteps = self.get_time_steps(num_inference_steps)
return et t_prev = timesteps[t]
t_next = timesteps[min(t + 1, len(timesteps) - 1)]
self.ets.append(residual)
residual = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
return self.transfer(image, t_prev, t_next, residual)
def transfer(self, x, t, t_next, et): def transfer(self, x, t, t_next, et):
alphas_cump = self.alphas_cumprod # TODO(Patrick): clean up to be compatible with numpy and give better names
at = alphas_cump[t.long() + 1].view(-1, 1, 1, 1)
at_next = alphas_cump[t_next.long() + 1].view(-1, 1, 1, 1) alphas_cump = self.alphas_cumprod.to(x.device)
at = alphas_cump[t + 1].view(-1, 1, 1, 1)
at_next = alphas_cump[t_next + 1].view(-1, 1, 1, 1)
x_delta = (at_next - at) * ((1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * x - 1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et) x_delta = (at_next - at) * ((1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * x - 1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et)
......
...@@ -19,7 +19,7 @@ import unittest ...@@ -19,7 +19,7 @@ import unittest
import torch import torch
from diffusers import DDIM, DDPM, BDDM, DDIMScheduler, DDPMScheduler, LatentDiffusion, UNetModel, PNDM, PNDMScheduler from diffusers import DDIM, DDPM, PNDM, GLIDE, BDDM, DDIMScheduler, DDPMScheduler, LatentDiffusion, PNDMScheduler, UNetModel
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.pipeline_bddm import DiffWave from diffusers.pipelines.pipeline_bddm import DiffWave
...@@ -229,3 +229,17 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -229,3 +229,17 @@ class PipelineTesterMixin(unittest.TestCase):
_ = BDDM.from_pretrained(tmpdirname) _ = BDDM.from_pretrained(tmpdirname)
# check if the same works using the DifusionPipeline class # check if the same works using the DifusionPipeline class
_ = DiffusionPipeline.from_pretrained(tmpdirname) _ = DiffusionPipeline.from_pretrained(tmpdirname)
@slow
def test_glide_text2img(self):
model_id = "fusing/glide-base"
glide = GLIDE.from_pretrained(model_id)
prompt = "a pencil sketch of a corgi"
generator = torch.manual_seed(0)
image = glide(prompt, generator=generator, num_inference_steps_upscale=20)
image_slice = image[0, :3, :3, -1].cpu()
assert image.shape == (1, 256, 256, 3)
expected_slice = torch.tensor([0.7119, 0.7073, 0.6460, 0.7780, 0.7423, 0.6926, 0.7378, 0.7189, 0.7784])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment