Commit ca2635d9 authored by anton-l's avatar anton-l
Browse files

GlideDDIM -> DDIM

parent 5c21d962
...@@ -14,4 +14,3 @@ from .schedulers import SchedulerMixin ...@@ -14,4 +14,3 @@ from .schedulers import SchedulerMixin
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .schedulers.ddim import DDIMScheduler from .schedulers.ddim import DDIMScheduler
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
from .schedulers.glide_ddim import GlideDDIMScheduler
...@@ -420,6 +420,7 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin): ...@@ -420,6 +420,7 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin):
def __init__( def __init__(
self, self,
in_channels=3, in_channels=3,
resolution=64,
model_channels=192, model_channels=192,
out_channels=6, out_channels=6,
num_res_blocks=3, num_res_blocks=3,
...@@ -443,6 +444,7 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin): ...@@ -443,6 +444,7 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin):
num_heads_upsample = num_heads num_heads_upsample = num_heads
self.in_channels = in_channels self.in_channels = in_channels
self.resolution = resolution
self.model_channels = model_channels self.model_channels = model_channels
self.out_channels = out_channels self.out_channels = out_channels
self.num_res_blocks = num_res_blocks self.num_res_blocks = num_res_blocks
...@@ -649,6 +651,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel): ...@@ -649,6 +651,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
def __init__( def __init__(
self, self,
in_channels=3, in_channels=3,
resolution=64,
model_channels=192, model_channels=192,
out_channels=6, out_channels=6,
num_res_blocks=3, num_res_blocks=3,
...@@ -668,6 +671,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel): ...@@ -668,6 +671,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
): ):
super().__init__( super().__init__(
in_channels=in_channels, in_channels=in_channels,
resolution=resolution,
model_channels=model_channels, model_channels=model_channels,
out_channels=out_channels, out_channels=out_channels,
num_res_blocks=num_res_blocks, num_res_blocks=num_res_blocks,
...@@ -687,6 +691,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel): ...@@ -687,6 +691,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
) )
self.register( self.register(
in_channels=in_channels, in_channels=in_channels,
resolution=resolution,
model_channels=model_channels, model_channels=model_channels,
out_channels=out_channels, out_channels=out_channels,
num_res_blocks=num_res_blocks, num_res_blocks=num_res_blocks,
...@@ -739,6 +744,7 @@ class GLIDESuperResUNetModel(GLIDEUNetModel): ...@@ -739,6 +744,7 @@ class GLIDESuperResUNetModel(GLIDEUNetModel):
def __init__( def __init__(
self, self,
in_channels=3, in_channels=3,
resolution=256,
model_channels=192, model_channels=192,
out_channels=6, out_channels=6,
num_res_blocks=3, num_res_blocks=3,
...@@ -757,6 +763,7 @@ class GLIDESuperResUNetModel(GLIDEUNetModel): ...@@ -757,6 +763,7 @@ class GLIDESuperResUNetModel(GLIDEUNetModel):
): ):
super().__init__( super().__init__(
in_channels=in_channels, in_channels=in_channels,
resolution=resolution,
model_channels=model_channels, model_channels=model_channels,
out_channels=out_channels, out_channels=out_channels,
num_res_blocks=num_res_blocks, num_res_blocks=num_res_blocks,
...@@ -775,6 +782,7 @@ class GLIDESuperResUNetModel(GLIDEUNetModel): ...@@ -775,6 +782,7 @@ class GLIDESuperResUNetModel(GLIDEUNetModel):
) )
self.register( self.register(
in_channels=in_channels, in_channels=in_channels,
resolution=resolution,
model_channels=model_channels, model_channels=model_channels,
out_channels=out_channels, out_channels=out_channels,
num_res_blocks=num_res_blocks, num_res_blocks=num_res_blocks,
......
...@@ -37,7 +37,6 @@ LOADABLE_CLASSES = { ...@@ -37,7 +37,6 @@ LOADABLE_CLASSES = {
"SchedulerMixin": ["save_config", "from_config"], "SchedulerMixin": ["save_config", "from_config"],
"DiffusionPipeline": ["save_pretrained", "from_pretrained"], "DiffusionPipeline": ["save_pretrained", "from_pretrained"],
"ClassifierFreeGuidanceScheduler": ["save_config", "from_config"], "ClassifierFreeGuidanceScheduler": ["save_config", "from_config"],
"GlideDDIMScheduler": ["save_config", "from_config"],
}, },
"transformers": { "transformers": {
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
......
import torch import torch
from torch import nn from torch import nn
from diffusers import ( from diffusers import ClassifierFreeGuidanceScheduler, DDIMScheduler, GLIDESuperResUNetModel, GLIDETextToImageUNetModel
ClassifierFreeGuidanceScheduler, from diffusers.pipelines.pipeline_glide import GLIDE, CLIPTextModel
GlideDDIMScheduler,
GLIDESuperResUNetModel,
GLIDETextToImageUNetModel,
)
from modeling_glide import GLIDE, CLIPTextModel
from transformers import CLIPTextConfig, GPT2Tokenizer from transformers import CLIPTextConfig, GPT2Tokenizer
...@@ -102,7 +97,7 @@ superres_model = GLIDESuperResUNetModel( ...@@ -102,7 +97,7 @@ superres_model = GLIDESuperResUNetModel(
superres_model.load_state_dict(ups_state_dict, strict=False) superres_model.load_state_dict(ups_state_dict, strict=False)
upscale_scheduler = GlideDDIMScheduler(timesteps=1000, beta_schedule="linear") upscale_scheduler = DDIMScheduler(timesteps=1000, beta_schedule="linear", beta_start=0.0001, beta_end=0.02)
glide = GLIDE( glide = GLIDE(
text_unet=text2im_model, text_unet=text2im_model,
......
import torch import torch
from torch import nn from torch import nn
from diffusers import ( from diffusers import ClassifierFreeGuidanceScheduler, GLIDESuperResUNetModel, GLIDETextToImageUNetModel
ClassifierFreeGuidanceScheduler,
GlideDDIMScheduler,
GLIDESuperResUNetModel,
GLIDETextToImageUNetModel,
)
from modeling_glide import GLIDE, CLIPTextModel from modeling_glide import GLIDE, CLIPTextModel
from transformers import CLIPTextConfig, GPT2Tokenizer from transformers import CLIPTextConfig, GPT2Tokenizer
...@@ -102,7 +97,7 @@ superres_model = GLIDESuperResUNetModel( ...@@ -102,7 +97,7 @@ superres_model = GLIDESuperResUNetModel(
superres_model.load_state_dict(ups_state_dict, strict=False) superres_model.load_state_dict(ups_state_dict, strict=False)
upscale_scheduler = GlideDDIMScheduler(timesteps=1000, beta_schedule="linear") upscale_scheduler = DDIMScheduler(timesteps=1000, beta_schedule="linear")
glide = GLIDE( glide = GLIDE(
text_unet=text2im_model, text_unet=text2im_model,
......
...@@ -26,8 +26,8 @@ from torch import nn ...@@ -26,8 +26,8 @@ from torch import nn
import tqdm import tqdm
from diffusers import ( from diffusers import (
ClassifierFreeGuidanceScheduler, ClassifierFreeGuidanceScheduler,
DDIMScheduler,
DiffusionPipeline, DiffusionPipeline,
GlideDDIMScheduler,
GLIDESuperResUNetModel, GLIDESuperResUNetModel,
GLIDETextToImageUNetModel, GLIDETextToImageUNetModel,
) )
...@@ -727,7 +727,7 @@ class GLIDE(DiffusionPipeline): ...@@ -727,7 +727,7 @@ class GLIDE(DiffusionPipeline):
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
tokenizer: GPT2Tokenizer, tokenizer: GPT2Tokenizer,
upscale_unet: GLIDESuperResUNetModel, upscale_unet: GLIDESuperResUNetModel,
upscale_noise_scheduler: GlideDDIMScheduler, upscale_noise_scheduler: DDIMScheduler,
): ):
super().__init__() super().__init__()
self.register_modules( self.register_modules(
......
...@@ -38,7 +38,7 @@ from transformers.utils import ( ...@@ -38,7 +38,7 @@ from transformers.utils import (
from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..schedulers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler from ..schedulers import ClassifierFreeGuidanceScheduler, DDIMScheduler
##################### #####################
...@@ -724,7 +724,7 @@ class GLIDE(DiffusionPipeline): ...@@ -724,7 +724,7 @@ class GLIDE(DiffusionPipeline):
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
tokenizer: GPT2Tokenizer, tokenizer: GPT2Tokenizer,
upscale_unet: GLIDESuperResUNetModel, upscale_unet: GLIDESuperResUNetModel,
upscale_noise_scheduler: GlideDDIMScheduler, upscale_noise_scheduler: DDIMScheduler,
): ):
super().__init__() super().__init__()
self.register_modules( self.register_modules(
...@@ -816,7 +816,7 @@ class GLIDE(DiffusionPipeline): ...@@ -816,7 +816,7 @@ class GLIDE(DiffusionPipeline):
) / _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) ) / _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
@torch.no_grad() @torch.no_grad()
def __call__(self, prompt, generator=None, torch_device=None): def __call__(self, prompt, generator=None, torch_device=None, num_inference_steps_upscale=50):
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.text_unet.to(torch_device) self.text_unet.to(torch_device)
...@@ -870,50 +870,45 @@ class GLIDE(DiffusionPipeline): ...@@ -870,50 +870,45 @@ class GLIDE(DiffusionPipeline):
# A value of 1.0 is sharper, but sometimes results in grainy artifacts. # A value of 1.0 is sharper, but sometimes results in grainy artifacts.
upsample_temp = 0.997 upsample_temp = 0.997
image = ( # Sample gaussian noise to begin loop
self.upscale_noise_scheduler.sample_noise( image = torch.randn(
(batch_size, 3, 256, 256), device=torch_device, generator=generator (batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
generator=generator,
) )
* upsample_temp image = image.to(torch_device)
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
num_timesteps = len(self.upscale_noise_scheduler) # Ideally, read DDIM paper in-detail understanding
for t in tqdm.tqdm(
reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler) # Notation (<variable name> -> <name in paper>
): # - pred_noise_t -> e_theta(x_t, t)
# i) define coefficients for time step t # - pred_original_image -> f_theta(x_t, t) or x_0
clipped_image_coeff = 1 / torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t)) # - std_dev_t -> sigma_t
clipped_noise_coeff = torch.sqrt(1 / self.upscale_noise_scheduler.get_alpha_prod(t) - 1) # - eta -> η
image_coeff = ( # - pred_image_direction -> "direction pointingc to x_t"
(1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1)) # - pred_prev_image -> "x_t-1"
* torch.sqrt(self.upscale_noise_scheduler.get_alpha(t)) for t in tqdm.tqdm(reversed(range(num_inference_steps_upscale)), total=num_inference_steps_upscale):
/ (1 - self.upscale_noise_scheduler.get_alpha_prod(t)) # 1. predict noise residual
) with torch.no_grad():
clipped_coeff = (
torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1))
* self.upscale_noise_scheduler.get_beta(t)
/ (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
)
# ii) predict noise residual
time_input = torch.tensor([t] * image.shape[0], device=torch_device) time_input = torch.tensor([t] * image.shape[0], device=torch_device)
model_output = self.upscale_unet(image, time_input, low_res) model_output = self.upscale_unet(image, time_input, low_res)
noise_residual, pred_variance = torch.split(model_output, 3, dim=1) noise_residual, pred_variance = torch.split(model_output, 3, dim=1)
# iii) compute predicted image from residual # 2. predict previous mean of image x_t-1
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison pred_prev_image = self.upscale_noise_scheduler.step(
pred_mean = clipped_image_coeff * image - clipped_noise_coeff * noise_residual noise_residual, image, t, num_inference_steps_upscale, eta
pred_mean = torch.clamp(pred_mean, -1, 1) )
prev_image = clipped_coeff * pred_mean + image_coeff * image
# iv) sample variance # 3. optionally sample variance
prev_variance = self.upscale_noise_scheduler.sample_variance( variance = 0
t, prev_image.shape, device=torch_device, generator=generator if eta > 0:
noise = torch.randn(image.shape, generator=generator).to(image.device)
variance = (
self.upscale_noise_scheduler.get_variance(t, num_inference_steps_upscale).sqrt() * eta * noise
) )
# v) sample x_{t-1} ~ N(prev_image, prev_variance) # 4. set current image to prev_image: x_t -> x_t-1
sampled_prev_image = prev_image + prev_variance image = pred_prev_image + variance
image = sampled_prev_image
image = image.permute(0, 2, 3, 1) image = image.permute(0, 2, 3, 1)
......
...@@ -19,5 +19,4 @@ ...@@ -19,5 +19,4 @@
from .classifier_free_guidance import ClassifierFreeGuidanceScheduler from .classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .ddim import DDIMScheduler from .ddim import DDIMScheduler
from .gaussian_ddpm import GaussianDDPMScheduler from .gaussian_ddpm import GaussianDDPMScheduler
from .glide_ddim import GlideDDIMScheduler
from .schedulers_utils import SchedulerMixin from .schedulers_utils import SchedulerMixin
...@@ -69,6 +69,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -69,6 +69,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# #
# self.register_buffer("log_variance", log_variance.to(torch.float32)) # self.register_buffer("log_variance", log_variance.to(torch.float32))
def rescale_betas(self, num_timesteps):
if self.beta_schedule == "linear":
scale = self.timesteps / num_timesteps
self.betas = linear_beta_schedule(
num_timesteps, beta_start=self.beta_start * scale, beta_end=self.beta_end * scale
)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
def get_alpha(self, time_step): def get_alpha(self, time_step):
return self.alphas[time_step] return self.alphas[time_step]
......
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from torch import nn
from ..configuration_utils import ConfigMixin
from .schedulers_utils import betas_for_alpha_bar, linear_beta_schedule
SAMPLING_CONFIG_NAME = "scheduler_config.json"
class GlideDDIMScheduler(nn.Module, ConfigMixin):
config_name = SAMPLING_CONFIG_NAME
def __init__(self, timesteps=1000, beta_schedule="linear", variance_type="fixed_large"):
super().__init__()
self.register(
timesteps=timesteps,
beta_schedule=beta_schedule,
)
self.timesteps = int(timesteps)
if beta_schedule == "linear":
# Linear schedule from Ho et al, extended to work for any number of
# diffusion steps.
scale = 1000 / self.timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
if variance_type == "fixed_small":
log_variance = torch.log(variance.clamp(min=1e-20))
elif variance_type == "fixed_large":
log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
self.register_buffer("betas", betas.to(torch.float32))
self.register_buffer("alphas", alphas.to(torch.float32))
self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32))
self.register_buffer("log_variance", log_variance.to(torch.float32))
def get_alpha(self, time_step):
return self.alphas[time_step]
def get_beta(self, time_step):
return self.betas[time_step]
def get_alpha_prod(self, time_step):
if time_step < 0:
return torch.tensor(1.0)
return self.alphas_cumprod[time_step]
def sample_variance(self, time_step, shape, device, generator=None):
variance = self.log_variance[time_step]
nonzero_mask = torch.tensor([1 - (time_step == 0)], device=device).float()[None, :]
noise = self.sample_noise(shape, device=device, generator=generator)
sampled_variance = nonzero_mask * (0.5 * variance).exp()
sampled_variance = sampled_variance * noise
return sampled_variance
def sample_noise(self, shape, device, generator=None):
# always sample on CPU to be deterministic
return torch.randn(shape, generator=generator).to(device)
def __len__(self):
return self.timesteps
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment