Commit f9cdb4dd authored by anton-l's avatar anton-l
Browse files

Convert glide upsampling weights

parent 43e728d3
import torch
from torch import nn
from diffusers import ClassifierFreeGuidanceScheduler, CLIPTextModel, UNetGLIDEModel
from diffusers import ClassifierFreeGuidanceScheduler, CLIPTextModel, GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from modeling_glide import GLIDE
from transformers import CLIPTextConfig, GPT2Tokenizer
......@@ -51,9 +51,9 @@ for layer_idx in range(config.num_hidden_layers):
hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"]
hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"]
### Convert the UNet
### Convert the Text-to-Image UNet
unet_model = UNetGLIDEModel(
text2im_model = GLIDETextToImageUNetModel(
in_channels=3,
model_channels=192,
out_channels=6,
......@@ -69,10 +69,38 @@ unet_model = UNetGLIDEModel(
transformer_dim=512,
)
unet_model.load_state_dict(state_dict, strict=False)
text2im_model.load_state_dict(state_dict, strict=False)
scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squaredcos_cap_v2")
text_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squaredcos_cap_v2")
glide = GLIDE(unet=unet_model, noise_scheduler=scheduler, text_encoder=model, tokenizer=tokenizer)
### Convert the Super-Resolution UNet
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt
state_dict = torch.load("upsample.pt", map_location="cpu")
superres_model = GLIDESuperResUNetModel(
in_channels=6,
model_channels=192,
out_channels=6,
num_res_blocks=2,
attention_resolutions=(8, 16, 32),
dropout=0.1,
channel_mult=(1, 1, 2, 2, 4, 4),
num_heads=1,
num_head_channels=64,
num_heads_upsample=1,
use_scale_shift_norm=True,
resblock_updown=True,
)
superres_model.load_state_dict(state_dict)
upscale_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squaredcos_cap_v2")
glide = GLIDE(text_unet=text2im_model, text_noise_scheduler=text_scheduler, text_encoder=model, tokenizer=tokenizer,
upscale_unet=superres_model, upscale_noise_scheduler=scheduler)
glide.save_pretrained("./glide-base")
......@@ -18,7 +18,7 @@ import numpy as np
import torch
import tqdm
from diffusers import ClassifierFreeGuidanceScheduler, CLIPTextModel, DiffusionPipeline, UNetGLIDEModel
from diffusers import ClassifierFreeGuidanceScheduler, CLIPTextModel, DiffusionPipeline, GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from transformers import GPT2Tokenizer
......@@ -41,7 +41,7 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
class GLIDE(DiffusionPipeline):
def __init__(
self,
unet: UNetGLIDEModel,
unet: GLIDETextToImageUNetModel,
noise_scheduler: ClassifierFreeGuidanceScheduler,
text_encoder: CLIPTextModel,
tokenizer: GPT2Tokenizer,
......
......@@ -12,7 +12,7 @@ generator = generator.manual_seed(0)
# 1. Load models
pipeline = GLIDE.from_pretrained("fusing/glide-base")
img = pipeline("an oil painting of a corgi", generator)
img = pipeline("a pencil sketch of a corgi", generator)
img = ((img + 1)*127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
plt.figure(figsize=(8, 8))
......
......@@ -7,7 +7,7 @@ __version__ = "0.0.1"
from .modeling_utils import ModelMixin
from .models.clip_text_transformer import CLIPTextModel
from .models.unet import UNetModel
from .models.unet_glide import UNetGLIDEModel
from .models.unet_glide import GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from .models.unet_ldm import UNetLDMModel
from .pipeline_utils import DiffusionPipeline
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
......
......@@ -18,5 +18,5 @@
from .clip_text_transformer import CLIPTextModel
from .unet import UNetModel
from .unet_glide import UNetGLIDEModel
from .unet_glide import GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from .unet_ldm import UNetLDMModel
......@@ -388,7 +388,7 @@ class QKVAttention(nn.Module):
return a.reshape(bs, -1, length)
class UNetGLIDEModel(ModelMixin, ConfigMixin):
class GLIDEUNetModel(ModelMixin, ConfigMixin):
"""
The full UNet model with attention and timestep embedding.
......@@ -435,7 +435,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
transformer_dim=512,
transformer_dim=None,
):
super().__init__()
self.register(
......@@ -455,7 +455,6 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
transformer_dim=transformer_dim,
)
if num_heads_upsample == -1:
......@@ -482,8 +481,6 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
linear(time_embed_dim, time_embed_dim),
)
self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4)
ch = input_ch = int(channel_mult[0] * model_channels)
self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))])
self._feature_size = ch
......@@ -635,7 +632,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps, transformer_out):
def forward(self, x, timesteps, y=None):
"""
Apply the model to an input batch.
......@@ -644,6 +641,42 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
if self.num_classes is not None:
assert y.shape == (x.shape[0],)
emb = emb + self.label_emb(y)
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb)
hs.append(h)
h = self.middle_block(h, emb)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb)
h = h.type(x.dtype)
return self.out(h)
class GLIDETextToImageUNetModel(GLIDEUNetModel):
"""
A UNetModel that performs super-resolution.
Expects an extra kwarg `low_res` to condition on a low-resolution image.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.transformer_proj = nn.Linear(kwargs["transformer_dim"], self.model_channels * 4)
def forward(self, x, timesteps, transformer_out=None):
hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
......@@ -663,3 +696,20 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
h = torch.cat([h, other], dim=1)
h = module(h, emb, transformer_out)
return self.out(h)
class GLIDESuperResUNetModel(GLIDEUNetModel):
"""
A UNetModel that performs super-resolution.
Expects an extra kwarg `low_res` to condition on a low-resolution image.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x, timesteps, low_res=None, **kwargs):
_, _, new_height, new_width = x.shape
upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
x = torch.cat([x, upsampled], dim=1)
return super().forward(x, timesteps, **kwargs)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment