Commit ca2635d9 authored by anton-l's avatar anton-l
Browse files

GlideDDIM -> DDIM

parent 5c21d962
......@@ -14,4 +14,3 @@ from .schedulers import SchedulerMixin
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .schedulers.ddim import DDIMScheduler
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
from .schedulers.glide_ddim import GlideDDIMScheduler
......@@ -420,6 +420,7 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin):
def __init__(
self,
in_channels=3,
resolution=64,
model_channels=192,
out_channels=6,
num_res_blocks=3,
......@@ -443,6 +444,7 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin):
num_heads_upsample = num_heads
self.in_channels = in_channels
self.resolution = resolution
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
......@@ -649,6 +651,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
def __init__(
self,
in_channels=3,
resolution=64,
model_channels=192,
out_channels=6,
num_res_blocks=3,
......@@ -668,6 +671,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
):
super().__init__(
in_channels=in_channels,
resolution=resolution,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
......@@ -687,6 +691,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
)
self.register(
in_channels=in_channels,
resolution=resolution,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
......@@ -739,6 +744,7 @@ class GLIDESuperResUNetModel(GLIDEUNetModel):
def __init__(
self,
in_channels=3,
resolution=256,
model_channels=192,
out_channels=6,
num_res_blocks=3,
......@@ -757,6 +763,7 @@ class GLIDESuperResUNetModel(GLIDEUNetModel):
):
super().__init__(
in_channels=in_channels,
resolution=resolution,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
......@@ -775,6 +782,7 @@ class GLIDESuperResUNetModel(GLIDEUNetModel):
)
self.register(
in_channels=in_channels,
resolution=resolution,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
......
......@@ -37,7 +37,6 @@ LOADABLE_CLASSES = {
"SchedulerMixin": ["save_config", "from_config"],
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
"ClassifierFreeGuidanceScheduler": ["save_config", "from_config"],
"GlideDDIMScheduler": ["save_config", "from_config"],
},
"transformers": {
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
......
import torch
from torch import nn
from diffusers import (
ClassifierFreeGuidanceScheduler,
GlideDDIMScheduler,
GLIDESuperResUNetModel,
GLIDETextToImageUNetModel,
)
from modeling_glide import GLIDE, CLIPTextModel
from diffusers import ClassifierFreeGuidanceScheduler, DDIMScheduler, GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from diffusers.pipelines.pipeline_glide import GLIDE, CLIPTextModel
from transformers import CLIPTextConfig, GPT2Tokenizer
......@@ -102,7 +97,7 @@ superres_model = GLIDESuperResUNetModel(
superres_model.load_state_dict(ups_state_dict, strict=False)
upscale_scheduler = GlideDDIMScheduler(timesteps=1000, beta_schedule="linear")
upscale_scheduler = DDIMScheduler(timesteps=1000, beta_schedule="linear", beta_start=0.0001, beta_end=0.02)
glide = GLIDE(
text_unet=text2im_model,
......
import torch
from torch import nn
from diffusers import (
ClassifierFreeGuidanceScheduler,
GlideDDIMScheduler,
GLIDESuperResUNetModel,
GLIDETextToImageUNetModel,
)
from diffusers import ClassifierFreeGuidanceScheduler, GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from modeling_glide import GLIDE, CLIPTextModel
from transformers import CLIPTextConfig, GPT2Tokenizer
......@@ -102,7 +97,7 @@ superres_model = GLIDESuperResUNetModel(
superres_model.load_state_dict(ups_state_dict, strict=False)
upscale_scheduler = GlideDDIMScheduler(timesteps=1000, beta_schedule="linear")
upscale_scheduler = DDIMScheduler(timesteps=1000, beta_schedule="linear")
glide = GLIDE(
text_unet=text2im_model,
......
......@@ -26,8 +26,8 @@ from torch import nn
import tqdm
from diffusers import (
ClassifierFreeGuidanceScheduler,
DDIMScheduler,
DiffusionPipeline,
GlideDDIMScheduler,
GLIDESuperResUNetModel,
GLIDETextToImageUNetModel,
)
......@@ -727,7 +727,7 @@ class GLIDE(DiffusionPipeline):
text_encoder: CLIPTextModel,
tokenizer: GPT2Tokenizer,
upscale_unet: GLIDESuperResUNetModel,
upscale_noise_scheduler: GlideDDIMScheduler,
upscale_noise_scheduler: DDIMScheduler,
):
super().__init__()
self.register_modules(
......
......@@ -38,7 +38,7 @@ from transformers.utils import (
from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from ..pipeline_utils import DiffusionPipeline
from ..schedulers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler
from ..schedulers import ClassifierFreeGuidanceScheduler, DDIMScheduler
#####################
......@@ -724,7 +724,7 @@ class GLIDE(DiffusionPipeline):
text_encoder: CLIPTextModel,
tokenizer: GPT2Tokenizer,
upscale_unet: GLIDESuperResUNetModel,
upscale_noise_scheduler: GlideDDIMScheduler,
upscale_noise_scheduler: DDIMScheduler,
):
super().__init__()
self.register_modules(
......@@ -816,7 +816,7 @@ class GLIDE(DiffusionPipeline):
) / _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
@torch.no_grad()
def __call__(self, prompt, generator=None, torch_device=None):
def __call__(self, prompt, generator=None, torch_device=None, num_inference_steps_upscale=50):
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.text_unet.to(torch_device)
......@@ -870,50 +870,45 @@ class GLIDE(DiffusionPipeline):
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
upsample_temp = 0.997
image = (
self.upscale_noise_scheduler.sample_noise(
(batch_size, 3, 256, 256), device=torch_device, generator=generator
)
* upsample_temp
# Sample gaussian noise to begin loop
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
generator=generator,
)
num_timesteps = len(self.upscale_noise_scheduler)
for t in tqdm.tqdm(
reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler)
):
# i) define coefficients for time step t
clipped_image_coeff = 1 / torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t))
clipped_noise_coeff = torch.sqrt(1 / self.upscale_noise_scheduler.get_alpha_prod(t) - 1)
image_coeff = (
(1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1))
* torch.sqrt(self.upscale_noise_scheduler.get_alpha(t))
/ (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
image = image.to(torch_device)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
for t in tqdm.tqdm(reversed(range(num_inference_steps_upscale)), total=num_inference_steps_upscale):
# 1. predict noise residual
with torch.no_grad():
time_input = torch.tensor([t] * image.shape[0], device=torch_device)
model_output = self.upscale_unet(image, time_input, low_res)
noise_residual, pred_variance = torch.split(model_output, 3, dim=1)
# 2. predict previous mean of image x_t-1
pred_prev_image = self.upscale_noise_scheduler.step(
noise_residual, image, t, num_inference_steps_upscale, eta
)
clipped_coeff = (
torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1))
* self.upscale_noise_scheduler.get_beta(t)
/ (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
)
# ii) predict noise residual
time_input = torch.tensor([t] * image.shape[0], device=torch_device)
model_output = self.upscale_unet(image, time_input, low_res)
noise_residual, pred_variance = torch.split(model_output, 3, dim=1)
# iii) compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
pred_mean = clipped_image_coeff * image - clipped_noise_coeff * noise_residual
pred_mean = torch.clamp(pred_mean, -1, 1)
prev_image = clipped_coeff * pred_mean + image_coeff * image
# iv) sample variance
prev_variance = self.upscale_noise_scheduler.sample_variance(
t, prev_image.shape, device=torch_device, generator=generator
)
# 3. optionally sample variance
variance = 0
if eta > 0:
noise = torch.randn(image.shape, generator=generator).to(image.device)
variance = (
self.upscale_noise_scheduler.get_variance(t, num_inference_steps_upscale).sqrt() * eta * noise
)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image = prev_image + prev_variance
image = sampled_prev_image
# 4. set current image to prev_image: x_t -> x_t-1
image = pred_prev_image + variance
image = image.permute(0, 2, 3, 1)
......
......@@ -19,5 +19,4 @@
from .classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .ddim import DDIMScheduler
from .gaussian_ddpm import GaussianDDPMScheduler
from .glide_ddim import GlideDDIMScheduler
from .schedulers_utils import SchedulerMixin
......@@ -69,6 +69,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
#
# self.register_buffer("log_variance", log_variance.to(torch.float32))
def rescale_betas(self, num_timesteps):
if self.beta_schedule == "linear":
scale = self.timesteps / num_timesteps
self.betas = linear_beta_schedule(
num_timesteps, beta_start=self.beta_start * scale, beta_end=self.beta_end * scale
)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
def get_alpha(self, time_step):
return self.alphas[time_step]
......
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from torch import nn
from ..configuration_utils import ConfigMixin
from .schedulers_utils import betas_for_alpha_bar, linear_beta_schedule
SAMPLING_CONFIG_NAME = "scheduler_config.json"
class GlideDDIMScheduler(nn.Module, ConfigMixin):
config_name = SAMPLING_CONFIG_NAME
def __init__(self, timesteps=1000, beta_schedule="linear", variance_type="fixed_large"):
super().__init__()
self.register(
timesteps=timesteps,
beta_schedule=beta_schedule,
)
self.timesteps = int(timesteps)
if beta_schedule == "linear":
# Linear schedule from Ho et al, extended to work for any number of
# diffusion steps.
scale = 1000 / self.timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
if variance_type == "fixed_small":
log_variance = torch.log(variance.clamp(min=1e-20))
elif variance_type == "fixed_large":
log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
self.register_buffer("betas", betas.to(torch.float32))
self.register_buffer("alphas", alphas.to(torch.float32))
self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32))
self.register_buffer("log_variance", log_variance.to(torch.float32))
def get_alpha(self, time_step):
return self.alphas[time_step]
def get_beta(self, time_step):
return self.betas[time_step]
def get_alpha_prod(self, time_step):
if time_step < 0:
return torch.tensor(1.0)
return self.alphas_cumprod[time_step]
def sample_variance(self, time_step, shape, device, generator=None):
variance = self.log_variance[time_step]
nonzero_mask = torch.tensor([1 - (time_step == 0)], device=device).float()[None, :]
noise = self.sample_noise(shape, device=device, generator=generator)
sampled_variance = nonzero_mask * (0.5 * variance).exp()
sampled_variance = sampled_variance * noise
return sampled_variance
def sample_noise(self, shape, device, generator=None):
# always sample on CPU to be deterministic
return torch.randn(shape, generator=generator).to(device)
def __len__(self):
return self.timesteps
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment