Commit f9cdb4dd authored by anton-l's avatar anton-l
Browse files

Convert glide upsampling weights

parent 43e728d3
import torch import torch
from torch import nn from torch import nn
from diffusers import ClassifierFreeGuidanceScheduler, CLIPTextModel, UNetGLIDEModel from diffusers import ClassifierFreeGuidanceScheduler, CLIPTextModel, GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from modeling_glide import GLIDE from modeling_glide import GLIDE
from transformers import CLIPTextConfig, GPT2Tokenizer from transformers import CLIPTextConfig, GPT2Tokenizer
...@@ -51,9 +51,9 @@ for layer_idx in range(config.num_hidden_layers): ...@@ -51,9 +51,9 @@ for layer_idx in range(config.num_hidden_layers):
hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"] hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"]
hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"] hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"]
### Convert the UNet ### Convert the Text-to-Image UNet
unet_model = UNetGLIDEModel( text2im_model = GLIDETextToImageUNetModel(
in_channels=3, in_channels=3,
model_channels=192, model_channels=192,
out_channels=6, out_channels=6,
...@@ -69,10 +69,38 @@ unet_model = UNetGLIDEModel( ...@@ -69,10 +69,38 @@ unet_model = UNetGLIDEModel(
transformer_dim=512, transformer_dim=512,
) )
unet_model.load_state_dict(state_dict, strict=False) text2im_model.load_state_dict(state_dict, strict=False)
scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squaredcos_cap_v2") text_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squaredcos_cap_v2")
glide = GLIDE(unet=unet_model, noise_scheduler=scheduler, text_encoder=model, tokenizer=tokenizer) ### Convert the Super-Resolution UNet
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt
state_dict = torch.load("upsample.pt", map_location="cpu")
superres_model = GLIDESuperResUNetModel(
in_channels=6,
model_channels=192,
out_channels=6,
num_res_blocks=2,
attention_resolutions=(8, 16, 32),
dropout=0.1,
channel_mult=(1, 1, 2, 2, 4, 4),
num_heads=1,
num_head_channels=64,
num_heads_upsample=1,
use_scale_shift_norm=True,
resblock_updown=True,
)
superres_model.load_state_dict(state_dict)
upscale_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squaredcos_cap_v2")
glide = GLIDE(text_unet=text2im_model, text_noise_scheduler=text_scheduler, text_encoder=model, tokenizer=tokenizer,
upscale_unet=superres_model, upscale_noise_scheduler=scheduler)
glide.save_pretrained("./glide-base") glide.save_pretrained("./glide-base")
...@@ -18,7 +18,7 @@ import numpy as np ...@@ -18,7 +18,7 @@ import numpy as np
import torch import torch
import tqdm import tqdm
from diffusers import ClassifierFreeGuidanceScheduler, CLIPTextModel, DiffusionPipeline, UNetGLIDEModel from diffusers import ClassifierFreeGuidanceScheduler, CLIPTextModel, DiffusionPipeline, GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from transformers import GPT2Tokenizer from transformers import GPT2Tokenizer
...@@ -41,7 +41,7 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape): ...@@ -41,7 +41,7 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
class GLIDE(DiffusionPipeline): class GLIDE(DiffusionPipeline):
def __init__( def __init__(
self, self,
unet: UNetGLIDEModel, unet: GLIDETextToImageUNetModel,
noise_scheduler: ClassifierFreeGuidanceScheduler, noise_scheduler: ClassifierFreeGuidanceScheduler,
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
tokenizer: GPT2Tokenizer, tokenizer: GPT2Tokenizer,
......
...@@ -12,7 +12,7 @@ generator = generator.manual_seed(0) ...@@ -12,7 +12,7 @@ generator = generator.manual_seed(0)
# 1. Load models # 1. Load models
pipeline = GLIDE.from_pretrained("fusing/glide-base") pipeline = GLIDE.from_pretrained("fusing/glide-base")
img = pipeline("an oil painting of a corgi", generator) img = pipeline("a pencil sketch of a corgi", generator)
img = ((img + 1)*127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy() img = ((img + 1)*127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
plt.figure(figsize=(8, 8)) plt.figure(figsize=(8, 8))
......
...@@ -7,7 +7,7 @@ __version__ = "0.0.1" ...@@ -7,7 +7,7 @@ __version__ = "0.0.1"
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .models.clip_text_transformer import CLIPTextModel from .models.clip_text_transformer import CLIPTextModel
from .models.unet import UNetModel from .models.unet import UNetModel
from .models.unet_glide import UNetGLIDEModel from .models.unet_glide import GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from .models.unet_ldm import UNetLDMModel from .models.unet_ldm import UNetLDMModel
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
......
...@@ -18,5 +18,5 @@ ...@@ -18,5 +18,5 @@
from .clip_text_transformer import CLIPTextModel from .clip_text_transformer import CLIPTextModel
from .unet import UNetModel from .unet import UNetModel
from .unet_glide import UNetGLIDEModel from .unet_glide import GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from .unet_ldm import UNetLDMModel from .unet_ldm import UNetLDMModel
...@@ -388,7 +388,7 @@ class QKVAttention(nn.Module): ...@@ -388,7 +388,7 @@ class QKVAttention(nn.Module):
return a.reshape(bs, -1, length) return a.reshape(bs, -1, length)
class UNetGLIDEModel(ModelMixin, ConfigMixin): class GLIDEUNetModel(ModelMixin, ConfigMixin):
""" """
The full UNet model with attention and timestep embedding. The full UNet model with attention and timestep embedding.
...@@ -435,7 +435,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): ...@@ -435,7 +435,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
num_heads_upsample=-1, num_heads_upsample=-1,
use_scale_shift_norm=False, use_scale_shift_norm=False,
resblock_updown=False, resblock_updown=False,
transformer_dim=512, transformer_dim=None,
): ):
super().__init__() super().__init__()
self.register( self.register(
...@@ -455,7 +455,6 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): ...@@ -455,7 +455,6 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
num_heads_upsample=num_heads_upsample, num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown, resblock_updown=resblock_updown,
transformer_dim=transformer_dim,
) )
if num_heads_upsample == -1: if num_heads_upsample == -1:
...@@ -482,8 +481,6 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): ...@@ -482,8 +481,6 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
linear(time_embed_dim, time_embed_dim), linear(time_embed_dim, time_embed_dim),
) )
self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4)
ch = input_ch = int(channel_mult[0] * model_channels) ch = input_ch = int(channel_mult[0] * model_channels)
self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]) self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))])
self._feature_size = ch self._feature_size = ch
...@@ -635,7 +632,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): ...@@ -635,7 +632,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
self.middle_block.apply(convert_module_to_f32) self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps, transformer_out): def forward(self, x, timesteps, y=None):
""" """
Apply the model to an input batch. Apply the model to an input batch.
...@@ -644,6 +641,42 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): ...@@ -644,6 +641,42 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
:param y: an [N] Tensor of labels, if class-conditional. :param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs. :return: an [N x C x ...] Tensor of outputs.
""" """
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
if self.num_classes is not None:
assert y.shape == (x.shape[0],)
emb = emb + self.label_emb(y)
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb)
hs.append(h)
h = self.middle_block(h, emb)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb)
h = h.type(x.dtype)
return self.out(h)
class GLIDETextToImageUNetModel(GLIDEUNetModel):
"""
A UNetModel that performs super-resolution.
Expects an extra kwarg `low_res` to condition on a low-resolution image.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.transformer_proj = nn.Linear(kwargs["transformer_dim"], self.model_channels * 4)
def forward(self, x, timesteps, transformer_out=None):
hs = [] hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
...@@ -663,3 +696,20 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): ...@@ -663,3 +696,20 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
h = torch.cat([h, other], dim=1) h = torch.cat([h, other], dim=1)
h = module(h, emb, transformer_out) h = module(h, emb, transformer_out)
return self.out(h) return self.out(h)
class GLIDESuperResUNetModel(GLIDEUNetModel):
"""
A UNetModel that performs super-resolution.
Expects an extra kwarg `low_res` to condition on a low-resolution image.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x, timesteps, low_res=None, **kwargs):
_, _, new_height, new_width = x.shape
upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
x = torch.cat([x, upsampled], dim=1)
return super().forward(x, timesteps, **kwargs)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment