Commit dc6324d4 authored by anton-l's avatar anton-l
Browse files

end-to-end glide pipeline with DDIM scheduler for upscaling

parent ff89f808
import torch
from torch import nn
from diffusers import ClassifierFreeGuidanceScheduler, CLIPTextModel, GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, CLIPTextModel, GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from modeling_glide import GLIDE
from transformers import CLIPTextConfig, GPT2Tokenizer
......@@ -76,7 +76,7 @@ text_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="
### Convert the Super-Resolution UNet
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt
state_dict = torch.load("upsample.pt", map_location="cpu")
ups_state_dict = torch.load("upsample.pt", map_location="cpu")
superres_model = GLIDESuperResUNetModel(
in_channels=6,
......@@ -93,12 +93,12 @@ superres_model = GLIDESuperResUNetModel(
resblock_updown=True,
)
superres_model.load_state_dict(state_dict)
superres_model.load_state_dict(ups_state_dict, strict=False)
upscale_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squaredcos_cap_v2")
upscale_scheduler = GlideDDIMScheduler(timesteps=1000, beta_schedule="linear")
glide = GLIDE(text_unet=text2im_model, text_noise_scheduler=text_scheduler, text_encoder=model, tokenizer=tokenizer,
upscale_unet=superres_model, upscale_noise_scheduler=scheduler)
upscale_unet=superres_model, upscale_noise_scheduler=upscale_scheduler)
glide.save_pretrained("./glide-base")
......
......@@ -18,7 +18,7 @@ import numpy as np
import torch
import tqdm
from diffusers import ClassifierFreeGuidanceScheduler, CLIPTextModel, DiffusionPipeline, GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, CLIPTextModel, DiffusionPipeline, GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from transformers import GPT2Tokenizer
......@@ -41,17 +41,20 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
class GLIDE(DiffusionPipeline):
def __init__(
self,
unet: GLIDETextToImageUNetModel,
noise_scheduler: ClassifierFreeGuidanceScheduler,
text_unet: GLIDETextToImageUNetModel,
text_noise_scheduler: ClassifierFreeGuidanceScheduler,
text_encoder: CLIPTextModel,
tokenizer: GPT2Tokenizer,
upscale_unet: GLIDESuperResUNetModel,
upscale_noise_scheduler: GlideDDIMScheduler
):
super().__init__()
self.register_modules(
unet=unet, noise_scheduler=noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer
text_unet=text_unet, text_noise_scheduler=text_noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer,
upscale_unet=upscale_unet, upscale_noise_scheduler=upscale_noise_scheduler
)
def q_posterior_mean_variance(self, x_start, x_t, t):
def q_posterior_mean_variance(self, scheduler, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior:
......@@ -60,12 +63,12 @@ class GLIDE(DiffusionPipeline):
"""
assert x_start.shape == x_t.shape
posterior_mean = (
_extract_into_tensor(self.noise_scheduler.posterior_mean_coef1, t, x_t.shape) * x_start
+ _extract_into_tensor(self.noise_scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
_extract_into_tensor(scheduler.posterior_mean_coef1, t, x_t.shape) * x_start
+ _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = _extract_into_tensor(self.noise_scheduler.posterior_variance, t, x_t.shape)
posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = _extract_into_tensor(
self.noise_scheduler.posterior_log_variance_clipped, t, x_t.shape
scheduler.posterior_log_variance_clipped, t, x_t.shape
)
assert (
posterior_mean.shape[0]
......@@ -75,7 +78,7 @@ class GLIDE(DiffusionPipeline):
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, model, x, t, transformer_out, clip_denoised=True, model_kwargs=None):
def p_mean_variance(self, model, scheduler, x, t, transformer_out=None, low_res=None, clip_denoised=True):
"""
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
the initial x, x_0.
......@@ -93,51 +96,60 @@ class GLIDE(DiffusionPipeline):
- 'log_variance': the log of 'variance'.
- 'pred_xstart': the prediction for x_0.
"""
if model_kwargs is None:
model_kwargs = {}
B, C = x.shape[:2]
assert t.shape == (B,)
model_output = model(x, t, transformer_out)
if transformer_out is None:
# super-res model
model_output = model(x, t, low_res)
else:
# text2image model
model_output = model(x, t, transformer_out)
assert model_output.shape == (B, C * 2, *x.shape[2:])
model_output, model_var_values = torch.split(model_output, C, dim=1)
min_log = _extract_into_tensor(self.noise_scheduler.posterior_log_variance_clipped, t, x.shape)
max_log = _extract_into_tensor(np.log(self.noise_scheduler.betas), t, x.shape)
min_log = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x.shape)
max_log = _extract_into_tensor(np.log(scheduler.betas), t, x.shape)
# The model_var_values is [-1, 1] for [min_var, max_var].
frac = (model_var_values + 1) / 2
model_log_variance = frac * max_log + (1 - frac) * min_log
model_variance = torch.exp(model_log_variance)
pred_xstart = self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
pred_xstart = self._predict_xstart_from_eps(scheduler, x_t=x, t=t, eps=model_output)
if clip_denoised:
pred_xstart = pred_xstart.clamp(-1, 1)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
model_mean, _, _ = self.q_posterior_mean_variance(scheduler, x_start=pred_xstart, x_t=x, t=t)
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
return model_mean, model_variance, model_log_variance, pred_xstart
def _predict_xstart_from_eps(self, x_t, t, eps):
def _predict_xstart_from_eps(self, scheduler, x_t, t, eps):
assert x_t.shape == eps.shape
return (
_extract_into_tensor(self.noise_scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- _extract_into_tensor(self.noise_scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
_extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
)
def _predict_eps_from_xstart(self, scheduler, x_t, t, pred_xstart):
return (
_extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
) / _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
@torch.no_grad()
def __call__(self, prompt, generator=None, torch_device=None):
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.unet.to(torch_device)
self.text_unet.to(torch_device)
self.text_encoder.to(torch_device)
self.upscale_unet.to(torch_device)
# Create a classifier-free guidance sampling function
guidance_scale = 3.0
def model_fn(x_t, ts, transformer_out, **kwargs):
def text_model_fn(x_t, ts, transformer_out, **kwargs):
half = x_t[: len(x_t) // 2]
combined = torch.cat([half, half], dim=0)
model_out = self.unet(combined, ts, transformer_out, **kwargs)
model_out = self.text_unet(combined, ts, transformer_out, **kwargs)
eps, rest = model_out[:, :3], model_out[:, 3:]
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
......@@ -146,8 +158,8 @@ class GLIDE(DiffusionPipeline):
# 1. Sample gaussian noise
batch_size = 2 # second image is empty for classifier-free guidance
image = self.noise_scheduler.sample_noise(
(batch_size, self.unet.in_channels, 64, 64), device=torch_device, generator=generator
image = self.text_noise_scheduler.sample_noise(
(batch_size, self.text_unet.in_channels, 64, 64), device=torch_device, generator=generator
)
# 2. Encode tokens
......@@ -157,14 +169,60 @@ class GLIDE(DiffusionPipeline):
attention_mask = inputs["attention_mask"].to(torch_device)
transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state
num_timesteps = len(self.noise_scheduler)
# 3. Run the text2image generation step
num_timesteps = len(self.text_noise_scheduler)
for i in tqdm.tqdm(reversed(range(num_timesteps)), total=num_timesteps):
t = torch.tensor([i] * image.shape[0], device=torch_device)
mean, variance, log_variance, pred_xstart = self.p_mean_variance(model_fn, image, t, transformer_out)
noise = self.noise_scheduler.sample_noise(image.shape, device=torch_device, generator=generator)
mean, variance, log_variance, pred_xstart = self.p_mean_variance(
text_model_fn, self.text_noise_scheduler, image, t, transformer_out=transformer_out
)
noise = self.text_noise_scheduler.sample_noise(image.shape, device=torch_device, generator=generator)
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0
image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise
# 4. Run the upscaling step
batch_size = 1
image = image[:1]
low_res = ((image + 1) * 127.5).round() / 127.5 - 1
eta = 0.0
# Tune this parameter to control the sharpness of 256x256 images.
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
upsample_temp = 0.997
image = self.upscale_noise_scheduler.sample_noise(
(batch_size, 3, 256, 256), device=torch_device, generator=generator
) * upsample_temp
num_timesteps = len(self.upscale_noise_scheduler)
for t in tqdm.tqdm(reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler)):
# i) define coefficients for time step t
clipped_image_coeff = 1 / torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t))
clipped_noise_coeff = torch.sqrt(1 / self.upscale_noise_scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(
self.upscale_noise_scheduler.get_alpha(t)) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
clipped_coeff = torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * self.upscale_noise_scheduler.get_beta(
t) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
# ii) predict noise residual
time_input = torch.tensor([t] * image.shape[0], device=torch_device)
model_output = self.upscale_unet(image, time_input, low_res)
noise_residual, pred_variance = torch.split(model_output, 3, dim=1)
# iii) compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
pred_mean = clipped_image_coeff * image - clipped_noise_coeff * noise_residual
pred_mean = torch.clamp(pred_mean, -1, 1)
prev_image = clipped_coeff * pred_mean + image_coeff * image
# iv) sample variance
prev_variance = self.upscale_noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device,
generator=generator)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image = prev_image + prev_variance
image = sampled_prev_image
image = image[0].permute(1, 2, 0)
return image
......@@ -9,7 +9,6 @@ matplotlib.rcParams['interactive'] = True
generator = torch.Generator()
generator = generator.manual_seed(0)
# 1. Load models
pipeline = GLIDE.from_pretrained("fusing/glide-base")
img = pipeline("a pencil sketch of a corgi", generator)
......
......@@ -13,3 +13,4 @@ from .models.vqvae import VQModel
from .pipeline_utils import DiffusionPipeline
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
from .schedulers.glide_ddim import GlideDDIMScheduler
......@@ -419,11 +419,11 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin):
def __init__(
self,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
in_channels=3,
model_channels=192,
out_channels=6,
num_res_blocks=3,
attention_resolutions=(2, 4, 8),
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
......@@ -438,24 +438,6 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin):
transformer_dim=None,
):
super().__init__()
self.register(
in_channels=in_channels,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
use_checkpoint=use_checkpoint,
use_fp16=use_fp16,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
)
if num_heads_upsample == -1:
num_heads_upsample = num_heads
......@@ -632,7 +614,7 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin):
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps, y=None):
def forward(self, x, timesteps):
"""
Apply the model to an input batch.
......@@ -641,17 +623,10 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin):
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
if self.num_classes is not None:
assert y.shape == (x.shape[0],)
emb = emb + self.label_emb(y)
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb)
......@@ -671,10 +646,66 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
Expects an extra kwarg `low_res` to condition on a low-resolution image.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(
self,
in_channels=3,
model_channels=192,
out_channels=6,
num_res_blocks=3,
attention_resolutions=(2, 4, 8),
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
use_checkpoint=False,
use_fp16=False,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
transformer_dim=512
):
super().__init__(
in_channels=in_channels,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
use_checkpoint=use_checkpoint,
use_fp16=use_fp16,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
transformer_dim=transformer_dim
)
self.register(
in_channels=in_channels,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
use_checkpoint=use_checkpoint,
use_fp16=use_fp16,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
transformer_dim=transformer_dim
)
self.transformer_proj = nn.Linear(kwargs["transformer_dim"], self.model_channels * 4)
self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4)
def forward(self, x, timesteps, transformer_out=None):
hs = []
......@@ -705,11 +736,77 @@ class GLIDESuperResUNetModel(GLIDEUNetModel):
Expects an extra kwarg `low_res` to condition on a low-resolution image.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(
self,
in_channels=3,
model_channels=192,
out_channels=6,
num_res_blocks=3,
attention_resolutions=(2, 4, 8),
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
use_checkpoint=False,
use_fp16=False,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
):
super().__init__(
in_channels=in_channels,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
use_checkpoint=use_checkpoint,
use_fp16=use_fp16,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
)
self.register(
in_channels=in_channels,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
use_checkpoint=use_checkpoint,
use_fp16=use_fp16,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
)
def forward(self, x, timesteps, low_res=None, **kwargs):
def forward(self, x, timesteps, low_res=None):
_, _, new_height, new_width = x.shape
upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
x = torch.cat([x, upsampled], dim=1)
return super().forward(x, timesteps, **kwargs)
\ No newline at end of file
hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
h = x
for module in self.input_blocks:
h = module(h, emb)
hs.append(h)
h = self.middle_block(h, emb)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb)
return self.out(h)
\ No newline at end of file
......@@ -39,6 +39,7 @@ LOADABLE_CLASSES = {
"CLIPTextModel": ["save_pretrained", "from_pretrained"], # TODO (Anton): move to transformers
"GaussianDDPMScheduler": ["save_config", "from_config"],
"ClassifierFreeGuidanceScheduler": ["save_config", "from_config"],
"GlideDDIMScheduler": ["save_config", "from_config"],
},
"transformers": {
"GPT2Tokenizer": ["save_pretrained", "from_pretrained"],
......
......@@ -18,3 +18,4 @@
from .classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .gaussian_ddpm import GaussianDDPMScheduler
from .glide_ddim import GlideDDIMScheduler
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import math
import numpy as np
from torch import nn
from ..configuration_utils import ConfigMixin
......@@ -22,36 +22,30 @@ from .schedulers_utils import linear_beta_schedule, betas_for_alpha_bar
SAMPLING_CONFIG_NAME = "scheduler_config.json"
class GaussianDDPMScheduler(nn.Module, ConfigMixin):
class GlideDDIMScheduler(nn.Module, ConfigMixin):
config_name = SAMPLING_CONFIG_NAME
def __init__(
self,
timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
variance_type="fixed_small",
variance_type="fixed_large"
):
super().__init__()
self.register(
timesteps=timesteps,
beta_start=beta_start,
beta_end=beta_end,
beta_schedule=beta_schedule,
variance_type=variance_type,
)
self.num_timesteps = int(timesteps)
if beta_schedule == "linear":
# Linear schedule from Ho et al, extended to work for any number of
# diffusion steps.
scale = 1000 / self.num_timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
elif beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule
betas = betas_for_alpha_bar(
timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
......@@ -99,4 +93,4 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
return torch.randn(shape, generator=generator).to(device)
def __len__(self):
return self.num_timesteps
return self.num_timesteps
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment