"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "f2df39fa0e6246d13aea03364366b2d53a4ab5f9"
Commit 74d2da99 authored by patil-suraj's avatar patil-suraj
Browse files

Merge branch 'main' of https://github.com/huggingface/diffusers into main

parents 397b31c8 c6a33e3d
import torch import torch
from torch import nn from torch import nn
from diffusers import ClassifierFreeGuidanceScheduler, CLIPTextModel, UNetGLIDEModel from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, CLIPTextModel, GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from modeling_glide import GLIDE from modeling_glide import GLIDE
from transformers import CLIPTextConfig, GPT2Tokenizer from transformers import CLIPTextConfig, GPT2Tokenizer
...@@ -51,9 +51,9 @@ for layer_idx in range(config.num_hidden_layers): ...@@ -51,9 +51,9 @@ for layer_idx in range(config.num_hidden_layers):
hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"] hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"]
hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"] hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"]
### Convert the UNet ### Convert the Text-to-Image UNet
unet_model = UNetGLIDEModel( text2im_model = GLIDETextToImageUNetModel(
in_channels=3, in_channels=3,
model_channels=192, model_channels=192,
out_channels=6, out_channels=6,
...@@ -69,10 +69,38 @@ unet_model = UNetGLIDEModel( ...@@ -69,10 +69,38 @@ unet_model = UNetGLIDEModel(
transformer_dim=512, transformer_dim=512,
) )
unet_model.load_state_dict(state_dict, strict=False) text2im_model.load_state_dict(state_dict, strict=False)
scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squaredcos_cap_v2") text_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squaredcos_cap_v2")
glide = GLIDE(unet=unet_model, noise_scheduler=scheduler, text_encoder=model, tokenizer=tokenizer) ### Convert the Super-Resolution UNet
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt
ups_state_dict = torch.load("upsample.pt", map_location="cpu")
superres_model = GLIDESuperResUNetModel(
in_channels=6,
model_channels=192,
out_channels=6,
num_res_blocks=2,
attention_resolutions=(8, 16, 32),
dropout=0.1,
channel_mult=(1, 1, 2, 2, 4, 4),
num_heads=1,
num_head_channels=64,
num_heads_upsample=1,
use_scale_shift_norm=True,
resblock_updown=True,
)
superres_model.load_state_dict(ups_state_dict, strict=False)
upscale_scheduler = GlideDDIMScheduler(timesteps=1000, beta_schedule="linear")
glide = GLIDE(text_unet=text2im_model, text_noise_scheduler=text_scheduler, text_encoder=model, tokenizer=tokenizer,
upscale_unet=superres_model, upscale_noise_scheduler=upscale_scheduler)
glide.save_pretrained("./glide-base") glide.save_pretrained("./glide-base")
...@@ -18,7 +18,7 @@ import numpy as np ...@@ -18,7 +18,7 @@ import numpy as np
import torch import torch
import tqdm import tqdm
from diffusers import ClassifierFreeGuidanceScheduler, CLIPTextModel, DiffusionPipeline, UNetGLIDEModel from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, CLIPTextModel, DiffusionPipeline, GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from transformers import GPT2Tokenizer from transformers import GPT2Tokenizer
...@@ -41,17 +41,20 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape): ...@@ -41,17 +41,20 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
class GLIDE(DiffusionPipeline): class GLIDE(DiffusionPipeline):
def __init__( def __init__(
self, self,
unet: UNetGLIDEModel, text_unet: GLIDETextToImageUNetModel,
noise_scheduler: ClassifierFreeGuidanceScheduler, text_noise_scheduler: ClassifierFreeGuidanceScheduler,
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
tokenizer: GPT2Tokenizer, tokenizer: GPT2Tokenizer,
upscale_unet: GLIDESuperResUNetModel,
upscale_noise_scheduler: GlideDDIMScheduler
): ):
super().__init__() super().__init__()
self.register_modules( self.register_modules(
unet=unet, noise_scheduler=noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer text_unet=text_unet, text_noise_scheduler=text_noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer,
upscale_unet=upscale_unet, upscale_noise_scheduler=upscale_noise_scheduler
) )
def q_posterior_mean_variance(self, x_start, x_t, t): def q_posterior_mean_variance(self, scheduler, x_start, x_t, t):
""" """
Compute the mean and variance of the diffusion posterior: Compute the mean and variance of the diffusion posterior:
...@@ -60,12 +63,12 @@ class GLIDE(DiffusionPipeline): ...@@ -60,12 +63,12 @@ class GLIDE(DiffusionPipeline):
""" """
assert x_start.shape == x_t.shape assert x_start.shape == x_t.shape
posterior_mean = ( posterior_mean = (
_extract_into_tensor(self.noise_scheduler.posterior_mean_coef1, t, x_t.shape) * x_start _extract_into_tensor(scheduler.posterior_mean_coef1, t, x_t.shape) * x_start
+ _extract_into_tensor(self.noise_scheduler.posterior_mean_coef2, t, x_t.shape) * x_t + _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
) )
posterior_variance = _extract_into_tensor(self.noise_scheduler.posterior_variance, t, x_t.shape) posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = _extract_into_tensor( posterior_log_variance_clipped = _extract_into_tensor(
self.noise_scheduler.posterior_log_variance_clipped, t, x_t.shape scheduler.posterior_log_variance_clipped, t, x_t.shape
) )
assert ( assert (
posterior_mean.shape[0] posterior_mean.shape[0]
...@@ -75,7 +78,7 @@ class GLIDE(DiffusionPipeline): ...@@ -75,7 +78,7 @@ class GLIDE(DiffusionPipeline):
) )
return posterior_mean, posterior_variance, posterior_log_variance_clipped return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, model, x, t, transformer_out, clip_denoised=True, model_kwargs=None): def p_mean_variance(self, model, scheduler, x, t, transformer_out=None, low_res=None, clip_denoised=True):
""" """
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
the initial x, x_0. the initial x, x_0.
...@@ -93,51 +96,60 @@ class GLIDE(DiffusionPipeline): ...@@ -93,51 +96,60 @@ class GLIDE(DiffusionPipeline):
- 'log_variance': the log of 'variance'. - 'log_variance': the log of 'variance'.
- 'pred_xstart': the prediction for x_0. - 'pred_xstart': the prediction for x_0.
""" """
if model_kwargs is None:
model_kwargs = {}
B, C = x.shape[:2] B, C = x.shape[:2]
assert t.shape == (B,) assert t.shape == (B,)
model_output = model(x, t, transformer_out) if transformer_out is None:
# super-res model
model_output = model(x, t, low_res)
else:
# text2image model
model_output = model(x, t, transformer_out)
assert model_output.shape == (B, C * 2, *x.shape[2:]) assert model_output.shape == (B, C * 2, *x.shape[2:])
model_output, model_var_values = torch.split(model_output, C, dim=1) model_output, model_var_values = torch.split(model_output, C, dim=1)
min_log = _extract_into_tensor(self.noise_scheduler.posterior_log_variance_clipped, t, x.shape) min_log = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x.shape)
max_log = _extract_into_tensor(np.log(self.noise_scheduler.betas), t, x.shape) max_log = _extract_into_tensor(np.log(scheduler.betas), t, x.shape)
# The model_var_values is [-1, 1] for [min_var, max_var]. # The model_var_values is [-1, 1] for [min_var, max_var].
frac = (model_var_values + 1) / 2 frac = (model_var_values + 1) / 2
model_log_variance = frac * max_log + (1 - frac) * min_log model_log_variance = frac * max_log + (1 - frac) * min_log
model_variance = torch.exp(model_log_variance) model_variance = torch.exp(model_log_variance)
pred_xstart = self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) pred_xstart = self._predict_xstart_from_eps(scheduler, x_t=x, t=t, eps=model_output)
if clip_denoised: if clip_denoised:
pred_xstart = pred_xstart.clamp(-1, 1) pred_xstart = pred_xstart.clamp(-1, 1)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) model_mean, _, _ = self.q_posterior_mean_variance(scheduler, x_start=pred_xstart, x_t=x, t=t)
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
return model_mean, model_variance, model_log_variance, pred_xstart return model_mean, model_variance, model_log_variance, pred_xstart
def _predict_xstart_from_eps(self, x_t, t, eps): def _predict_xstart_from_eps(self, scheduler, x_t, t, eps):
assert x_t.shape == eps.shape assert x_t.shape == eps.shape
return ( return (
_extract_into_tensor(self.noise_scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t _extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- _extract_into_tensor(self.noise_scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps - _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
) )
def _predict_eps_from_xstart(self, scheduler, x_t, t, pred_xstart):
return (
_extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
) / _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
@torch.no_grad() @torch.no_grad()
def __call__(self, prompt, generator=None, torch_device=None): def __call__(self, prompt, generator=None, torch_device=None):
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.unet.to(torch_device) self.text_unet.to(torch_device)
self.text_encoder.to(torch_device) self.text_encoder.to(torch_device)
self.upscale_unet.to(torch_device)
# Create a classifier-free guidance sampling function # Create a classifier-free guidance sampling function
guidance_scale = 3.0 guidance_scale = 3.0
def model_fn(x_t, ts, transformer_out, **kwargs): def text_model_fn(x_t, ts, transformer_out, **kwargs):
half = x_t[: len(x_t) // 2] half = x_t[: len(x_t) // 2]
combined = torch.cat([half, half], dim=0) combined = torch.cat([half, half], dim=0)
model_out = self.unet(combined, ts, transformer_out, **kwargs) model_out = self.text_unet(combined, ts, transformer_out, **kwargs)
eps, rest = model_out[:, :3], model_out[:, 3:] eps, rest = model_out[:, :3], model_out[:, 3:]
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps) half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
...@@ -146,8 +158,8 @@ class GLIDE(DiffusionPipeline): ...@@ -146,8 +158,8 @@ class GLIDE(DiffusionPipeline):
# 1. Sample gaussian noise # 1. Sample gaussian noise
batch_size = 2 # second image is empty for classifier-free guidance batch_size = 2 # second image is empty for classifier-free guidance
image = self.noise_scheduler.sample_noise( image = self.text_noise_scheduler.sample_noise(
(batch_size, self.unet.in_channels, 64, 64), device=torch_device, generator=generator (batch_size, self.text_unet.in_channels, 64, 64), device=torch_device, generator=generator
) )
# 2. Encode tokens # 2. Encode tokens
...@@ -157,14 +169,60 @@ class GLIDE(DiffusionPipeline): ...@@ -157,14 +169,60 @@ class GLIDE(DiffusionPipeline):
attention_mask = inputs["attention_mask"].to(torch_device) attention_mask = inputs["attention_mask"].to(torch_device)
transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state
num_timesteps = len(self.noise_scheduler) # 3. Run the text2image generation step
num_timesteps = len(self.text_noise_scheduler)
for i in tqdm.tqdm(reversed(range(num_timesteps)), total=num_timesteps): for i in tqdm.tqdm(reversed(range(num_timesteps)), total=num_timesteps):
t = torch.tensor([i] * image.shape[0], device=torch_device) t = torch.tensor([i] * image.shape[0], device=torch_device)
mean, variance, log_variance, pred_xstart = self.p_mean_variance(model_fn, image, t, transformer_out) mean, variance, log_variance, pred_xstart = self.p_mean_variance(
noise = self.noise_scheduler.sample_noise(image.shape, device=torch_device, generator=generator) text_model_fn, self.text_noise_scheduler, image, t, transformer_out=transformer_out
)
noise = self.text_noise_scheduler.sample_noise(image.shape, device=torch_device, generator=generator)
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0 nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0
image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise
# 4. Run the upscaling step
batch_size = 1
image = image[:1]
low_res = ((image + 1) * 127.5).round() / 127.5 - 1
eta = 0.0
# Tune this parameter to control the sharpness of 256x256 images.
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
upsample_temp = 0.997
image = self.upscale_noise_scheduler.sample_noise(
(batch_size, 3, 256, 256), device=torch_device, generator=generator
) * upsample_temp
num_timesteps = len(self.upscale_noise_scheduler)
for t in tqdm.tqdm(reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler)):
# i) define coefficients for time step t
clipped_image_coeff = 1 / torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t))
clipped_noise_coeff = torch.sqrt(1 / self.upscale_noise_scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(
self.upscale_noise_scheduler.get_alpha(t)) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
clipped_coeff = torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * self.upscale_noise_scheduler.get_beta(
t) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
# ii) predict noise residual
time_input = torch.tensor([t] * image.shape[0], device=torch_device)
model_output = self.upscale_unet(image, time_input, low_res)
noise_residual, pred_variance = torch.split(model_output, 3, dim=1)
# iii) compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
pred_mean = clipped_image_coeff * image - clipped_noise_coeff * noise_residual
pred_mean = torch.clamp(pred_mean, -1, 1)
prev_image = clipped_coeff * pred_mean + image_coeff * image
# iv) sample variance
prev_variance = self.upscale_noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device,
generator=generator)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image = prev_image + prev_variance
image = sampled_prev_image
image = image[0].permute(1, 2, 0) image = image[0].permute(1, 2, 0)
return image return image
import torch import torch
from diffusers import DiffusionPipeline
from modeling_glide import GLIDE import PIL.Image
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['interactive'] = True
generator = torch.Generator() generator = torch.Generator()
generator = generator.manual_seed(0) generator = generator.manual_seed(0)
# 1. Load models model_id = "fusing/glide-base"
pipeline = GLIDE.from_pretrained("fusing/glide-base")
# load model and scheduler
pipeline = DiffusionPipeline.from_pretrained(model_id)
# run inference (text-conditioned denoising + upscaling)
img = pipeline("a clip art of a hugging face", generator)
img = pipeline("an oil painting of a corgi", generator) # process image to PIL
img = ((img + 1)*127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy() img = ((img + 1)*127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
image_pil = PIL.Image.fromarray(img)
plt.figure(figsize=(8, 8)) # save image
plt.imshow(img) image_pil.save("test.png")
plt.show() \ No newline at end of file
...@@ -7,9 +7,10 @@ __version__ = "0.0.1" ...@@ -7,9 +7,10 @@ __version__ = "0.0.1"
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .models.clip_text_transformer import CLIPTextModel from .models.clip_text_transformer import CLIPTextModel
from .models.unet import UNetModel from .models.unet import UNetModel
from .models.unet_glide import UNetGLIDEModel from .models.unet_glide import GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from .models.unet_ldm import UNetLDMModel from .models.unet_ldm import UNetLDMModel
from .models.vqvae import VQModel from .models.vqvae import VQModel
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
from .schedulers.glide_ddim import GlideDDIMScheduler
...@@ -18,6 +18,6 @@ ...@@ -18,6 +18,6 @@
from .clip_text_transformer import CLIPTextModel from .clip_text_transformer import CLIPTextModel
from .unet import UNetModel from .unet import UNetModel
from .unet_glide import UNetGLIDEModel from .unet_glide import GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from .unet_ldm import UNetLDMModel from .unet_ldm import UNetLDMModel
from .vqvae import VQModel from .vqvae import VQModel
\ No newline at end of file
...@@ -388,7 +388,7 @@ class QKVAttention(nn.Module): ...@@ -388,7 +388,7 @@ class QKVAttention(nn.Module):
return a.reshape(bs, -1, length) return a.reshape(bs, -1, length)
class UNetGLIDEModel(ModelMixin, ConfigMixin): class GLIDEUNetModel(ModelMixin, ConfigMixin):
""" """
The full UNet model with attention and timestep embedding. The full UNet model with attention and timestep embedding.
...@@ -419,11 +419,11 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): ...@@ -419,11 +419,11 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
def __init__( def __init__(
self, self,
in_channels, in_channels=3,
model_channels, model_channels=192,
out_channels, out_channels=6,
num_res_blocks, num_res_blocks=3,
attention_resolutions, attention_resolutions=(2, 4, 8),
dropout=0, dropout=0,
channel_mult=(1, 2, 4, 8), channel_mult=(1, 2, 4, 8),
conv_resample=True, conv_resample=True,
...@@ -435,28 +435,9 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): ...@@ -435,28 +435,9 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
num_heads_upsample=-1, num_heads_upsample=-1,
use_scale_shift_norm=False, use_scale_shift_norm=False,
resblock_updown=False, resblock_updown=False,
transformer_dim=512, transformer_dim=None,
): ):
super().__init__() super().__init__()
self.register(
in_channels=in_channels,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
use_checkpoint=use_checkpoint,
use_fp16=use_fp16,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
transformer_dim=transformer_dim,
)
if num_heads_upsample == -1: if num_heads_upsample == -1:
num_heads_upsample = num_heads num_heads_upsample = num_heads
...@@ -482,8 +463,6 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): ...@@ -482,8 +463,6 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
linear(time_embed_dim, time_embed_dim), linear(time_embed_dim, time_embed_dim),
) )
self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4)
ch = input_ch = int(channel_mult[0] * model_channels) ch = input_ch = int(channel_mult[0] * model_channels)
self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]) self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))])
self._feature_size = ch self._feature_size = ch
...@@ -635,7 +614,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): ...@@ -635,7 +614,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
self.middle_block.apply(convert_module_to_f32) self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps, transformer_out): def forward(self, x, timesteps):
""" """
Apply the model to an input batch. Apply the model to an input batch.
...@@ -644,6 +623,91 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): ...@@ -644,6 +623,91 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
:param y: an [N] Tensor of labels, if class-conditional. :param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs. :return: an [N x C x ...] Tensor of outputs.
""" """
hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb)
hs.append(h)
h = self.middle_block(h, emb)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb)
h = h.type(x.dtype)
return self.out(h)
class GLIDETextToImageUNetModel(GLIDEUNetModel):
"""
A UNetModel that performs super-resolution.
Expects an extra kwarg `low_res` to condition on a low-resolution image.
"""
def __init__(
self,
in_channels=3,
model_channels=192,
out_channels=6,
num_res_blocks=3,
attention_resolutions=(2, 4, 8),
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
use_checkpoint=False,
use_fp16=False,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
transformer_dim=512
):
super().__init__(
in_channels=in_channels,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
use_checkpoint=use_checkpoint,
use_fp16=use_fp16,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
transformer_dim=transformer_dim
)
self.register(
in_channels=in_channels,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
use_checkpoint=use_checkpoint,
use_fp16=use_fp16,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
transformer_dim=transformer_dim
)
self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4)
def forward(self, x, timesteps, transformer_out=None):
hs = [] hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
...@@ -663,3 +727,86 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): ...@@ -663,3 +727,86 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
h = torch.cat([h, other], dim=1) h = torch.cat([h, other], dim=1)
h = module(h, emb, transformer_out) h = module(h, emb, transformer_out)
return self.out(h) return self.out(h)
class GLIDESuperResUNetModel(GLIDEUNetModel):
"""
A UNetModel that performs super-resolution.
Expects an extra kwarg `low_res` to condition on a low-resolution image.
"""
def __init__(
self,
in_channels=3,
model_channels=192,
out_channels=6,
num_res_blocks=3,
attention_resolutions=(2, 4, 8),
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
use_checkpoint=False,
use_fp16=False,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
):
super().__init__(
in_channels=in_channels,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
use_checkpoint=use_checkpoint,
use_fp16=use_fp16,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
)
self.register(
in_channels=in_channels,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
use_checkpoint=use_checkpoint,
use_fp16=use_fp16,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
)
def forward(self, x, timesteps, low_res=None):
_, _, new_height, new_width = x.shape
upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
x = torch.cat([x, upsampled], dim=1)
hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
h = x
for module in self.input_blocks:
h = module(h, emb)
hs.append(h)
h = self.middle_block(h, emb)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb)
return self.out(h)
\ No newline at end of file
...@@ -39,9 +39,10 @@ LOADABLE_CLASSES = { ...@@ -39,9 +39,10 @@ LOADABLE_CLASSES = {
"CLIPTextModel": ["save_pretrained", "from_pretrained"], # TODO (Anton): move to transformers "CLIPTextModel": ["save_pretrained", "from_pretrained"], # TODO (Anton): move to transformers
"GaussianDDPMScheduler": ["save_config", "from_config"], "GaussianDDPMScheduler": ["save_config", "from_config"],
"ClassifierFreeGuidanceScheduler": ["save_config", "from_config"], "ClassifierFreeGuidanceScheduler": ["save_config", "from_config"],
"GlideDDIMScheduler": ["save_config", "from_config"],
}, },
"transformers": { "transformers": {
"GPT2Tokenizer": ["save_pretrained", "from_pretrained"], "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
}, },
} }
......
...@@ -18,3 +18,4 @@ ...@@ -18,3 +18,4 @@
from .classifier_free_guidance import ClassifierFreeGuidanceScheduler from .classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .gaussian_ddpm import GaussianDDPMScheduler from .gaussian_ddpm import GaussianDDPMScheduler
from .glide_ddim import GlideDDIMScheduler
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch import torch
import math import numpy as np
from torch import nn from torch import nn
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
...@@ -22,36 +22,30 @@ from .schedulers_utils import linear_beta_schedule, betas_for_alpha_bar ...@@ -22,36 +22,30 @@ from .schedulers_utils import linear_beta_schedule, betas_for_alpha_bar
SAMPLING_CONFIG_NAME = "scheduler_config.json" SAMPLING_CONFIG_NAME = "scheduler_config.json"
class GaussianDDPMScheduler(nn.Module, ConfigMixin): class GlideDDIMScheduler(nn.Module, ConfigMixin):
config_name = SAMPLING_CONFIG_NAME config_name = SAMPLING_CONFIG_NAME
def __init__( def __init__(
self, self,
timesteps=1000, timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear", beta_schedule="linear",
variance_type="fixed_small", variance_type="fixed_large"
): ):
super().__init__() super().__init__()
self.register( self.register(
timesteps=timesteps, timesteps=timesteps,
beta_start=beta_start,
beta_end=beta_end,
beta_schedule=beta_schedule, beta_schedule=beta_schedule,
variance_type=variance_type,
) )
self.num_timesteps = int(timesteps) self.num_timesteps = int(timesteps)
if beta_schedule == "linear": if beta_schedule == "linear":
# Linear schedule from Ho et al, extended to work for any number of
# diffusion steps.
scale = 1000 / self.num_timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
elif beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule
betas = betas_for_alpha_bar(
timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
else: else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
...@@ -99,4 +93,4 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin): ...@@ -99,4 +93,4 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
return torch.randn(shape, generator=generator).to(device) return torch.randn(shape, generator=generator).to(device)
def __len__(self): def __len__(self):
return self.num_timesteps return self.num_timesteps
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment