Commit dc6324d4 authored by anton-l's avatar anton-l
Browse files

end-to-end glide pipeline with DDIM scheduler for upscaling

parent ff89f808
import torch import torch
from torch import nn from torch import nn
from diffusers import ClassifierFreeGuidanceScheduler, CLIPTextModel, GLIDETextToImageUNetModel, GLIDESuperResUNetModel from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, CLIPTextModel, GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from modeling_glide import GLIDE from modeling_glide import GLIDE
from transformers import CLIPTextConfig, GPT2Tokenizer from transformers import CLIPTextConfig, GPT2Tokenizer
...@@ -76,7 +76,7 @@ text_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule=" ...@@ -76,7 +76,7 @@ text_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="
### Convert the Super-Resolution UNet ### Convert the Super-Resolution UNet
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt # wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt
state_dict = torch.load("upsample.pt", map_location="cpu") ups_state_dict = torch.load("upsample.pt", map_location="cpu")
superres_model = GLIDESuperResUNetModel( superres_model = GLIDESuperResUNetModel(
in_channels=6, in_channels=6,
...@@ -93,12 +93,12 @@ superres_model = GLIDESuperResUNetModel( ...@@ -93,12 +93,12 @@ superres_model = GLIDESuperResUNetModel(
resblock_updown=True, resblock_updown=True,
) )
superres_model.load_state_dict(state_dict) superres_model.load_state_dict(ups_state_dict, strict=False)
upscale_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squaredcos_cap_v2") upscale_scheduler = GlideDDIMScheduler(timesteps=1000, beta_schedule="linear")
glide = GLIDE(text_unet=text2im_model, text_noise_scheduler=text_scheduler, text_encoder=model, tokenizer=tokenizer, glide = GLIDE(text_unet=text2im_model, text_noise_scheduler=text_scheduler, text_encoder=model, tokenizer=tokenizer,
upscale_unet=superres_model, upscale_noise_scheduler=scheduler) upscale_unet=superres_model, upscale_noise_scheduler=upscale_scheduler)
glide.save_pretrained("./glide-base") glide.save_pretrained("./glide-base")
......
...@@ -18,7 +18,7 @@ import numpy as np ...@@ -18,7 +18,7 @@ import numpy as np
import torch import torch
import tqdm import tqdm
from diffusers import ClassifierFreeGuidanceScheduler, CLIPTextModel, DiffusionPipeline, GLIDETextToImageUNetModel, GLIDESuperResUNetModel from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, CLIPTextModel, DiffusionPipeline, GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from transformers import GPT2Tokenizer from transformers import GPT2Tokenizer
...@@ -41,17 +41,20 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape): ...@@ -41,17 +41,20 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
class GLIDE(DiffusionPipeline): class GLIDE(DiffusionPipeline):
def __init__( def __init__(
self, self,
unet: GLIDETextToImageUNetModel, text_unet: GLIDETextToImageUNetModel,
noise_scheduler: ClassifierFreeGuidanceScheduler, text_noise_scheduler: ClassifierFreeGuidanceScheduler,
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
tokenizer: GPT2Tokenizer, tokenizer: GPT2Tokenizer,
upscale_unet: GLIDESuperResUNetModel,
upscale_noise_scheduler: GlideDDIMScheduler
): ):
super().__init__() super().__init__()
self.register_modules( self.register_modules(
unet=unet, noise_scheduler=noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer text_unet=text_unet, text_noise_scheduler=text_noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer,
upscale_unet=upscale_unet, upscale_noise_scheduler=upscale_noise_scheduler
) )
def q_posterior_mean_variance(self, x_start, x_t, t): def q_posterior_mean_variance(self, scheduler, x_start, x_t, t):
""" """
Compute the mean and variance of the diffusion posterior: Compute the mean and variance of the diffusion posterior:
...@@ -60,12 +63,12 @@ class GLIDE(DiffusionPipeline): ...@@ -60,12 +63,12 @@ class GLIDE(DiffusionPipeline):
""" """
assert x_start.shape == x_t.shape assert x_start.shape == x_t.shape
posterior_mean = ( posterior_mean = (
_extract_into_tensor(self.noise_scheduler.posterior_mean_coef1, t, x_t.shape) * x_start _extract_into_tensor(scheduler.posterior_mean_coef1, t, x_t.shape) * x_start
+ _extract_into_tensor(self.noise_scheduler.posterior_mean_coef2, t, x_t.shape) * x_t + _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
) )
posterior_variance = _extract_into_tensor(self.noise_scheduler.posterior_variance, t, x_t.shape) posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = _extract_into_tensor( posterior_log_variance_clipped = _extract_into_tensor(
self.noise_scheduler.posterior_log_variance_clipped, t, x_t.shape scheduler.posterior_log_variance_clipped, t, x_t.shape
) )
assert ( assert (
posterior_mean.shape[0] posterior_mean.shape[0]
...@@ -75,7 +78,7 @@ class GLIDE(DiffusionPipeline): ...@@ -75,7 +78,7 @@ class GLIDE(DiffusionPipeline):
) )
return posterior_mean, posterior_variance, posterior_log_variance_clipped return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, model, x, t, transformer_out, clip_denoised=True, model_kwargs=None): def p_mean_variance(self, model, scheduler, x, t, transformer_out=None, low_res=None, clip_denoised=True):
""" """
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
the initial x, x_0. the initial x, x_0.
...@@ -93,51 +96,60 @@ class GLIDE(DiffusionPipeline): ...@@ -93,51 +96,60 @@ class GLIDE(DiffusionPipeline):
- 'log_variance': the log of 'variance'. - 'log_variance': the log of 'variance'.
- 'pred_xstart': the prediction for x_0. - 'pred_xstart': the prediction for x_0.
""" """
if model_kwargs is None:
model_kwargs = {}
B, C = x.shape[:2] B, C = x.shape[:2]
assert t.shape == (B,) assert t.shape == (B,)
if transformer_out is None:
# super-res model
model_output = model(x, t, low_res)
else:
# text2image model
model_output = model(x, t, transformer_out) model_output = model(x, t, transformer_out)
assert model_output.shape == (B, C * 2, *x.shape[2:]) assert model_output.shape == (B, C * 2, *x.shape[2:])
model_output, model_var_values = torch.split(model_output, C, dim=1) model_output, model_var_values = torch.split(model_output, C, dim=1)
min_log = _extract_into_tensor(self.noise_scheduler.posterior_log_variance_clipped, t, x.shape) min_log = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x.shape)
max_log = _extract_into_tensor(np.log(self.noise_scheduler.betas), t, x.shape) max_log = _extract_into_tensor(np.log(scheduler.betas), t, x.shape)
# The model_var_values is [-1, 1] for [min_var, max_var]. # The model_var_values is [-1, 1] for [min_var, max_var].
frac = (model_var_values + 1) / 2 frac = (model_var_values + 1) / 2
model_log_variance = frac * max_log + (1 - frac) * min_log model_log_variance = frac * max_log + (1 - frac) * min_log
model_variance = torch.exp(model_log_variance) model_variance = torch.exp(model_log_variance)
pred_xstart = self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) pred_xstart = self._predict_xstart_from_eps(scheduler, x_t=x, t=t, eps=model_output)
if clip_denoised: if clip_denoised:
pred_xstart = pred_xstart.clamp(-1, 1) pred_xstart = pred_xstart.clamp(-1, 1)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) model_mean, _, _ = self.q_posterior_mean_variance(scheduler, x_start=pred_xstart, x_t=x, t=t)
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
return model_mean, model_variance, model_log_variance, pred_xstart return model_mean, model_variance, model_log_variance, pred_xstart
def _predict_xstart_from_eps(self, x_t, t, eps): def _predict_xstart_from_eps(self, scheduler, x_t, t, eps):
assert x_t.shape == eps.shape assert x_t.shape == eps.shape
return ( return (
_extract_into_tensor(self.noise_scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t _extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- _extract_into_tensor(self.noise_scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps - _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
) )
def _predict_eps_from_xstart(self, scheduler, x_t, t, pred_xstart):
return (
_extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
) / _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
@torch.no_grad() @torch.no_grad()
def __call__(self, prompt, generator=None, torch_device=None): def __call__(self, prompt, generator=None, torch_device=None):
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.unet.to(torch_device) self.text_unet.to(torch_device)
self.text_encoder.to(torch_device) self.text_encoder.to(torch_device)
self.upscale_unet.to(torch_device)
# Create a classifier-free guidance sampling function # Create a classifier-free guidance sampling function
guidance_scale = 3.0 guidance_scale = 3.0
def model_fn(x_t, ts, transformer_out, **kwargs): def text_model_fn(x_t, ts, transformer_out, **kwargs):
half = x_t[: len(x_t) // 2] half = x_t[: len(x_t) // 2]
combined = torch.cat([half, half], dim=0) combined = torch.cat([half, half], dim=0)
model_out = self.unet(combined, ts, transformer_out, **kwargs) model_out = self.text_unet(combined, ts, transformer_out, **kwargs)
eps, rest = model_out[:, :3], model_out[:, 3:] eps, rest = model_out[:, :3], model_out[:, 3:]
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps) half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
...@@ -146,8 +158,8 @@ class GLIDE(DiffusionPipeline): ...@@ -146,8 +158,8 @@ class GLIDE(DiffusionPipeline):
# 1. Sample gaussian noise # 1. Sample gaussian noise
batch_size = 2 # second image is empty for classifier-free guidance batch_size = 2 # second image is empty for classifier-free guidance
image = self.noise_scheduler.sample_noise( image = self.text_noise_scheduler.sample_noise(
(batch_size, self.unet.in_channels, 64, 64), device=torch_device, generator=generator (batch_size, self.text_unet.in_channels, 64, 64), device=torch_device, generator=generator
) )
# 2. Encode tokens # 2. Encode tokens
...@@ -157,14 +169,60 @@ class GLIDE(DiffusionPipeline): ...@@ -157,14 +169,60 @@ class GLIDE(DiffusionPipeline):
attention_mask = inputs["attention_mask"].to(torch_device) attention_mask = inputs["attention_mask"].to(torch_device)
transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state
num_timesteps = len(self.noise_scheduler) # 3. Run the text2image generation step
num_timesteps = len(self.text_noise_scheduler)
for i in tqdm.tqdm(reversed(range(num_timesteps)), total=num_timesteps): for i in tqdm.tqdm(reversed(range(num_timesteps)), total=num_timesteps):
t = torch.tensor([i] * image.shape[0], device=torch_device) t = torch.tensor([i] * image.shape[0], device=torch_device)
mean, variance, log_variance, pred_xstart = self.p_mean_variance(model_fn, image, t, transformer_out) mean, variance, log_variance, pred_xstart = self.p_mean_variance(
noise = self.noise_scheduler.sample_noise(image.shape, device=torch_device, generator=generator) text_model_fn, self.text_noise_scheduler, image, t, transformer_out=transformer_out
)
noise = self.text_noise_scheduler.sample_noise(image.shape, device=torch_device, generator=generator)
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0 nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0
image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise
# 4. Run the upscaling step
batch_size = 1
image = image[:1]
low_res = ((image + 1) * 127.5).round() / 127.5 - 1
eta = 0.0
# Tune this parameter to control the sharpness of 256x256 images.
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
upsample_temp = 0.997
image = self.upscale_noise_scheduler.sample_noise(
(batch_size, 3, 256, 256), device=torch_device, generator=generator
) * upsample_temp
num_timesteps = len(self.upscale_noise_scheduler)
for t in tqdm.tqdm(reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler)):
# i) define coefficients for time step t
clipped_image_coeff = 1 / torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t))
clipped_noise_coeff = torch.sqrt(1 / self.upscale_noise_scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(
self.upscale_noise_scheduler.get_alpha(t)) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
clipped_coeff = torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * self.upscale_noise_scheduler.get_beta(
t) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
# ii) predict noise residual
time_input = torch.tensor([t] * image.shape[0], device=torch_device)
model_output = self.upscale_unet(image, time_input, low_res)
noise_residual, pred_variance = torch.split(model_output, 3, dim=1)
# iii) compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
pred_mean = clipped_image_coeff * image - clipped_noise_coeff * noise_residual
pred_mean = torch.clamp(pred_mean, -1, 1)
prev_image = clipped_coeff * pred_mean + image_coeff * image
# iv) sample variance
prev_variance = self.upscale_noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device,
generator=generator)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image = prev_image + prev_variance
image = sampled_prev_image
image = image[0].permute(1, 2, 0) image = image[0].permute(1, 2, 0)
return image return image
...@@ -9,7 +9,6 @@ matplotlib.rcParams['interactive'] = True ...@@ -9,7 +9,6 @@ matplotlib.rcParams['interactive'] = True
generator = torch.Generator() generator = torch.Generator()
generator = generator.manual_seed(0) generator = generator.manual_seed(0)
# 1. Load models
pipeline = GLIDE.from_pretrained("fusing/glide-base") pipeline = GLIDE.from_pretrained("fusing/glide-base")
img = pipeline("a pencil sketch of a corgi", generator) img = pipeline("a pencil sketch of a corgi", generator)
......
...@@ -13,3 +13,4 @@ from .models.vqvae import VQModel ...@@ -13,3 +13,4 @@ from .models.vqvae import VQModel
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
from .schedulers.glide_ddim import GlideDDIMScheduler
...@@ -419,11 +419,11 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin): ...@@ -419,11 +419,11 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin):
def __init__( def __init__(
self, self,
in_channels, in_channels=3,
model_channels, model_channels=192,
out_channels, out_channels=6,
num_res_blocks, num_res_blocks=3,
attention_resolutions, attention_resolutions=(2, 4, 8),
dropout=0, dropout=0,
channel_mult=(1, 2, 4, 8), channel_mult=(1, 2, 4, 8),
conv_resample=True, conv_resample=True,
...@@ -438,24 +438,6 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin): ...@@ -438,24 +438,6 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin):
transformer_dim=None, transformer_dim=None,
): ):
super().__init__() super().__init__()
self.register(
in_channels=in_channels,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
use_checkpoint=use_checkpoint,
use_fp16=use_fp16,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
)
if num_heads_upsample == -1: if num_heads_upsample == -1:
num_heads_upsample = num_heads num_heads_upsample = num_heads
...@@ -632,7 +614,7 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin): ...@@ -632,7 +614,7 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin):
self.middle_block.apply(convert_module_to_f32) self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps, y=None): def forward(self, x, timesteps):
""" """
Apply the model to an input batch. Apply the model to an input batch.
...@@ -641,17 +623,10 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin): ...@@ -641,17 +623,10 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin):
:param y: an [N] Tensor of labels, if class-conditional. :param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs. :return: an [N x C x ...] Tensor of outputs.
""" """
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = [] hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
if self.num_classes is not None:
assert y.shape == (x.shape[0],)
emb = emb + self.label_emb(y)
h = x.type(self.dtype) h = x.type(self.dtype)
for module in self.input_blocks: for module in self.input_blocks:
h = module(h, emb) h = module(h, emb)
...@@ -671,10 +646,66 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel): ...@@ -671,10 +646,66 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
Expects an extra kwarg `low_res` to condition on a low-resolution image. Expects an extra kwarg `low_res` to condition on a low-resolution image.
""" """
def __init__(self, *args, **kwargs): def __init__(
super().__init__(*args, **kwargs) self,
in_channels=3,
model_channels=192,
out_channels=6,
num_res_blocks=3,
attention_resolutions=(2, 4, 8),
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
use_checkpoint=False,
use_fp16=False,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
transformer_dim=512
):
super().__init__(
in_channels=in_channels,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
use_checkpoint=use_checkpoint,
use_fp16=use_fp16,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
transformer_dim=transformer_dim
)
self.register(
in_channels=in_channels,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
use_checkpoint=use_checkpoint,
use_fp16=use_fp16,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
transformer_dim=transformer_dim
)
self.transformer_proj = nn.Linear(kwargs["transformer_dim"], self.model_channels * 4) self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4)
def forward(self, x, timesteps, transformer_out=None): def forward(self, x, timesteps, transformer_out=None):
hs = [] hs = []
...@@ -705,11 +736,77 @@ class GLIDESuperResUNetModel(GLIDEUNetModel): ...@@ -705,11 +736,77 @@ class GLIDESuperResUNetModel(GLIDEUNetModel):
Expects an extra kwarg `low_res` to condition on a low-resolution image. Expects an extra kwarg `low_res` to condition on a low-resolution image.
""" """
def __init__(self, *args, **kwargs): def __init__(
super().__init__(*args, **kwargs) self,
in_channels=3,
model_channels=192,
out_channels=6,
num_res_blocks=3,
attention_resolutions=(2, 4, 8),
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
use_checkpoint=False,
use_fp16=False,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
):
super().__init__(
in_channels=in_channels,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
use_checkpoint=use_checkpoint,
use_fp16=use_fp16,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
)
self.register(
in_channels=in_channels,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
use_checkpoint=use_checkpoint,
use_fp16=use_fp16,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
)
def forward(self, x, timesteps, low_res=None, **kwargs): def forward(self, x, timesteps, low_res=None):
_, _, new_height, new_width = x.shape _, _, new_height, new_width = x.shape
upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
x = torch.cat([x, upsampled], dim=1) x = torch.cat([x, upsampled], dim=1)
return super().forward(x, timesteps, **kwargs)
\ No newline at end of file hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
h = x
for module in self.input_blocks:
h = module(h, emb)
hs.append(h)
h = self.middle_block(h, emb)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb)
return self.out(h)
\ No newline at end of file
...@@ -39,6 +39,7 @@ LOADABLE_CLASSES = { ...@@ -39,6 +39,7 @@ LOADABLE_CLASSES = {
"CLIPTextModel": ["save_pretrained", "from_pretrained"], # TODO (Anton): move to transformers "CLIPTextModel": ["save_pretrained", "from_pretrained"], # TODO (Anton): move to transformers
"GaussianDDPMScheduler": ["save_config", "from_config"], "GaussianDDPMScheduler": ["save_config", "from_config"],
"ClassifierFreeGuidanceScheduler": ["save_config", "from_config"], "ClassifierFreeGuidanceScheduler": ["save_config", "from_config"],
"GlideDDIMScheduler": ["save_config", "from_config"],
}, },
"transformers": { "transformers": {
"GPT2Tokenizer": ["save_pretrained", "from_pretrained"], "GPT2Tokenizer": ["save_pretrained", "from_pretrained"],
......
...@@ -18,3 +18,4 @@ ...@@ -18,3 +18,4 @@
from .classifier_free_guidance import ClassifierFreeGuidanceScheduler from .classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .gaussian_ddpm import GaussianDDPMScheduler from .gaussian_ddpm import GaussianDDPMScheduler
from .glide_ddim import GlideDDIMScheduler
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch import torch
import math import numpy as np
from torch import nn from torch import nn
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
...@@ -22,36 +22,30 @@ from .schedulers_utils import linear_beta_schedule, betas_for_alpha_bar ...@@ -22,36 +22,30 @@ from .schedulers_utils import linear_beta_schedule, betas_for_alpha_bar
SAMPLING_CONFIG_NAME = "scheduler_config.json" SAMPLING_CONFIG_NAME = "scheduler_config.json"
class GaussianDDPMScheduler(nn.Module, ConfigMixin): class GlideDDIMScheduler(nn.Module, ConfigMixin):
config_name = SAMPLING_CONFIG_NAME config_name = SAMPLING_CONFIG_NAME
def __init__( def __init__(
self, self,
timesteps=1000, timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear", beta_schedule="linear",
variance_type="fixed_small", variance_type="fixed_large"
): ):
super().__init__() super().__init__()
self.register( self.register(
timesteps=timesteps, timesteps=timesteps,
beta_start=beta_start,
beta_end=beta_end,
beta_schedule=beta_schedule, beta_schedule=beta_schedule,
variance_type=variance_type,
) )
self.num_timesteps = int(timesteps) self.num_timesteps = int(timesteps)
if beta_schedule == "linear": if beta_schedule == "linear":
# Linear schedule from Ho et al, extended to work for any number of
# diffusion steps.
scale = 1000 / self.num_timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
elif beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule
betas = betas_for_alpha_bar(
timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
else: else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment