Unverified Commit e7fe901e authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

save intermediate (#87)

* save intermediate

* up

* up
parent c3d78cd3
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import tempfile
import unittest
import numpy as np
import torch
from diffusers import (
AutoencoderKL,
DDIMPipeline,
DDIMScheduler,
DDPMPipeline,
DDPMScheduler,
GlidePipeline,
GlideSuperResUNetModel,
GlideTextToImageUNetModel,
LatentDiffusionPipeline,
LatentDiffusionUncondPipeline,
NCSNpp,
PNDMPipeline,
PNDMScheduler,
ScoreSdeVePipeline,
ScoreSdeVeScheduler,
ScoreSdeVpPipeline,
ScoreSdeVpScheduler,
UNetLDMModel,
UNetModel,
UNetUnconditionalModel,
VQModel,
)
from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.testing_utils import floats_tensor, slow, torch_device
from diffusers.training_utils import EMAModel
def test_output_pretrained_ldm_dummy():
model = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy", ldm=True)
model.eval()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
time_step = torch.tensor([10] * noise.shape[0])
with torch.no_grad():
output = model(noise, time_step)
print(model)
import ipdb; ipdb.set_trace()
def test_output_pretrained_ldm():
model = UNetUnconditionalModel.from_pretrained("fusing/latent-diffusion-celeba-256", subfolder="unet", ldm=True)
model.eval()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
time_step = torch.tensor([10] * noise.shape[0])
with torch.no_grad():
output = model(noise, time_step)
print(model)
import ipdb; ipdb.set_trace()
# To see the how the final model should look like
test_output_pretrained_ldm_dummy()
test_output_pretrained_ldm()
# => this is the architecture in which the model should be saved in the new format
# -> verify new repo with the following tests (in `test_modeling_utils.py`)
# - test_ldm_uncond (in PipelineTesterMixin)
# - test_output_pretrained ( in UNetLDMModelTests)
...@@ -111,7 +111,7 @@ prompt = "A painting of a squirrel eating a burger" ...@@ -111,7 +111,7 @@ prompt = "A painting of a squirrel eating a burger"
image = ldm([prompt], generator=generator, eta=0.3, guidance_scale=6.0, num_inference_steps=50) image = ldm([prompt], generator=generator, eta=0.3, guidance_scale=6.0, num_inference_steps=50)
image_processed = image.cpu().permute(0, 2, 3, 1) image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = image_processed * 255. image_processed = image_processed * 255.0
image_processed = image_processed.numpy().astype(np.uint8) image_processed = image_processed.numpy().astype(np.uint8)
image_pil = PIL.Image.fromarray(image_processed[0]) image_pil = PIL.Image.fromarray(image_processed[0])
...@@ -143,6 +143,7 @@ audio = bddm(mel_spec, generator, torch_device=torch_device) ...@@ -143,6 +143,7 @@ audio = bddm(mel_spec, generator, torch_device=torch_device)
# save generated audio # save generated audio
from scipy.io.wavfile import write as wavwrite from scipy.io.wavfile import write as wavwrite
sampling_rate = 22050 sampling_rate = 22050
wavwrite("generated_audio.wav", sampling_rate, audio.squeeze().cpu().numpy()) wavwrite("generated_audio.wav", sampling_rate, audio.squeeze().cpu().numpy())
``` ```
......
...@@ -116,6 +116,7 @@ class ConfigMixin: ...@@ -116,6 +116,7 @@ class ConfigMixin:
use_auth_token = kwargs.pop("use_auth_token", None) use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
user_agent = {"file_type": "config"} user_agent = {"file_type": "config"}
...@@ -150,6 +151,7 @@ class ConfigMixin: ...@@ -150,6 +151,7 @@ class ConfigMixin:
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
user_agent=user_agent, user_agent=user_agent,
subfolder=subfolder,
) )
except RepositoryNotFoundError: except RepositoryNotFoundError:
......
...@@ -321,6 +321,7 @@ class ModelMixin(torch.nn.Module): ...@@ -321,6 +321,7 @@ class ModelMixin(torch.nn.Module):
use_auth_token = kwargs.pop("use_auth_token", None) use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
from_auto_class = kwargs.pop("_from_auto", False) from_auto_class = kwargs.pop("_from_auto", False)
subfolder = kwargs.pop("subfolder", None)
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
...@@ -336,6 +337,7 @@ class ModelMixin(torch.nn.Module): ...@@ -336,6 +337,7 @@ class ModelMixin(torch.nn.Module):
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
revision=revision, revision=revision,
subfolder=subfolder,
**kwargs, **kwargs,
) )
model.register_to_config(name_or_path=pretrained_model_name_or_path) model.register_to_config(name_or_path=pretrained_model_name_or_path)
...@@ -363,6 +365,7 @@ class ModelMixin(torch.nn.Module): ...@@ -363,6 +365,7 @@ class ModelMixin(torch.nn.Module):
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
user_agent=user_agent, user_agent=user_agent,
subfolder=subfolder,
) )
except RepositoryNotFoundError: except RepositoryNotFoundError:
......
...@@ -51,6 +51,7 @@ class AttentionBlock(nn.Module): ...@@ -51,6 +51,7 @@ class AttentionBlock(nn.Module):
overwrite_qkv=False, overwrite_qkv=False,
overwrite_linear=False, overwrite_linear=False,
rescale_output_factor=1.0, rescale_output_factor=1.0,
eps=1e-5,
): ):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
...@@ -62,7 +63,7 @@ class AttentionBlock(nn.Module): ...@@ -62,7 +63,7 @@ class AttentionBlock(nn.Module):
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels self.num_heads = channels // num_head_channels
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-5, affine=True) self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
self.qkv = nn.Conv1d(channels, channels * 3, 1) self.qkv = nn.Conv1d(channels, channels * 3, 1)
self.n_heads = self.num_heads self.n_heads = self.num_heads
self.rescale_output_factor = rescale_output_factor self.rescale_output_factor = rescale_output_factor
...@@ -165,7 +166,7 @@ class AttentionBlock(nn.Module): ...@@ -165,7 +166,7 @@ class AttentionBlock(nn.Module):
return result return result
class AttentionBlockNew(nn.Module): class AttentionBlockNew_2(nn.Module):
""" """
An attention block that allows spatial positions to attend to each other. An attention block that allows spatial positions to attend to each other.
...@@ -180,11 +181,14 @@ class AttentionBlockNew(nn.Module): ...@@ -180,11 +181,14 @@ class AttentionBlockNew(nn.Module):
num_groups=32, num_groups=32,
encoder_channels=None, encoder_channels=None,
rescale_output_factor=1.0, rescale_output_factor=1.0,
eps=1e-5,
): ):
super().__init__() super().__init__()
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-5, affine=True) self.channels = channels
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
self.qkv = nn.Conv1d(channels, channels * 3, 1) self.qkv = nn.Conv1d(channels, channels * 3, 1)
self.n_heads = channels // num_head_channels self.n_heads = channels // num_head_channels
self.num_head_size = num_head_channels
self.rescale_output_factor = rescale_output_factor self.rescale_output_factor = rescale_output_factor
if encoder_channels is not None: if encoder_channels is not None:
...@@ -192,6 +196,28 @@ class AttentionBlockNew(nn.Module): ...@@ -192,6 +196,28 @@ class AttentionBlockNew(nn.Module):
self.proj = zero_module(nn.Conv1d(channels, channels, 1)) self.proj = zero_module(nn.Conv1d(channels, channels, 1))
# ------------------------- new -----------------------
num_heads = self.n_heads
self.channels = channels
if num_head_channels is None:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
# define q,k,v as linear layers
self.query = nn.Linear(channels, channels)
self.key = nn.Linear(channels, channels)
self.value = nn.Linear(channels, channels)
self.rescale_output_factor = rescale_output_factor
self.proj_attn = zero_module(nn.Linear(channels, channels, 1))
# ------------------------- new -----------------------
def set_weight(self, attn_layer): def set_weight(self, attn_layer):
self.norm.weight.data = attn_layer.norm.weight.data self.norm.weight.data = attn_layer.norm.weight.data
self.norm.bias.data = attn_layer.norm.bias.data self.norm.bias.data = attn_layer.norm.bias.data
...@@ -202,6 +228,89 @@ class AttentionBlockNew(nn.Module): ...@@ -202,6 +228,89 @@ class AttentionBlockNew(nn.Module):
self.proj.weight.data = attn_layer.proj.weight.data self.proj.weight.data = attn_layer.proj.weight.data
self.proj.bias.data = attn_layer.proj.bias.data self.proj.bias.data = attn_layer.proj.bias.data
if hasattr(attn_layer, "q"):
module = attn_layer
qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[
:, :, :, 0
]
qkv_bias = torch.cat([module.q.bias.data, module.k.bias.data, module.v.bias.data], dim=0)
self.qkv.weight.data = qkv_weight
self.qkv.bias.data = qkv_bias
proj_out = zero_module(nn.Conv1d(self.channels, self.channels, 1))
proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0]
proj_out.bias.data = module.proj_out.bias.data
self.proj = proj_out
self.set_weights_2(attn_layer)
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
new_projection_shape = projection.size()[:-1] + (self.n_heads, self.num_head_size)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
return new_projection
def set_weights_2(self, attn_layer):
self.group_norm.weight.data = attn_layer.norm.weight.data
self.group_norm.bias.data = attn_layer.norm.bias.data
qkv_weight = attn_layer.qkv.weight.data.reshape(self.n_heads, 3 * self.channels // self.n_heads, self.channels)
qkv_bias = attn_layer.qkv.bias.data.reshape(self.n_heads, 3 * self.channels // self.n_heads)
q_w, k_w, v_w = qkv_weight.split(self.channels // self.n_heads, dim=1)
q_b, k_b, v_b = qkv_bias.split(self.channels // self.n_heads, dim=1)
self.query.weight.data = q_w.reshape(-1, self.channels)
self.key.weight.data = k_w.reshape(-1, self.channels)
self.value.weight.data = v_w.reshape(-1, self.channels)
self.query.bias.data = q_b.reshape(-1)
self.key.bias.data = k_b.reshape(-1)
self.value.bias.data = v_b.reshape(-1)
self.proj_attn.weight.data = attn_layer.proj.weight.data[:, :, 0]
self.proj_attn.bias.data = attn_layer.proj.bias.data
def forward_2(self, hidden_states):
residual = hidden_states
batch, channel, height, width = hidden_states.shape
# norm
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
# proj to q, k, v
query_proj = self.query(hidden_states)
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)
# transpose
query_states = self.transpose_for_scores(query_proj)
key_states = self.transpose_for_scores(key_proj)
value_states = self.transpose_for_scores(value_proj)
# get scores
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.channels // self.n_heads)
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# compute attention output
context_states = torch.matmul(attention_probs, value_states)
context_states = context_states.permute(0, 2, 1, 3).contiguous()
new_context_states_shape = context_states.size()[:-2] + (self.channels,)
context_states = context_states.view(new_context_states_shape)
# compute next hidden_states
hidden_states = self.proj_attn(context_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
def forward(self, x, encoder_out=None): def forward(self, x, encoder_out=None):
b, c, *spatial = x.shape b, c, *spatial = x.shape
hid_states = self.norm(x).view(b, c, -1) hid_states = self.norm(x).view(b, c, -1)
...@@ -230,10 +339,119 @@ class AttentionBlockNew(nn.Module): ...@@ -230,10 +339,119 @@ class AttentionBlockNew(nn.Module):
h = h.reshape(b, c, *spatial) h = h.reshape(b, c, *spatial)
result = x + h result = x + h
result = result / self.rescale_output_factor result = result / self.rescale_output_factor
return result result_2 = self.forward_2(x)
print((result - result_2).abs().sum())
return result_2
class AttentionBlockNew(nn.Module):
"""
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
Uses three q, k, v linear layers to compute attention
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=None,
num_groups=32,
rescale_output_factor=1.0,
eps=1e-5,
):
super().__init__()
self.channels = channels
if num_head_channels is None:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.num_head_size = num_head_channels
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
# define q,k,v as linear layers
self.query = nn.Linear(channels, channels)
self.key = nn.Linear(channels, channels)
self.value = nn.Linear(channels, channels)
self.rescale_output_factor = rescale_output_factor
self.proj_attn = zero_module(nn.Linear(channels, channels, 1))
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
new_projection_shape = projection.size()[:-1] + (self.num_heads, self.num_head_size)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
return new_projection
def forward(self, hidden_states):
residual = hidden_states
batch, channel, height, width = hidden_states.shape
# norm
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
# proj to q, k, v
query_proj = self.query(hidden_states)
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)
# transpose
query_states = self.transpose_for_scores(query_proj)
key_states = self.transpose_for_scores(key_proj)
value_states = self.transpose_for_scores(value_proj)
# get scores
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.channels // self.num_heads)
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# compute attention output
context_states = torch.matmul(attention_probs, value_states)
context_states = context_states.permute(0, 2, 1, 3).contiguous()
new_context_states_shape = context_states.size()[:-2] + (self.channels,)
context_states = context_states.view(new_context_states_shape)
# compute next hidden_states
hidden_states = self.proj_attn(context_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
def set_weight(self, attn_layer):
self.group_norm.weight.data = attn_layer.norm.weight.data
self.group_norm.bias.data = attn_layer.norm.bias.data
qkv_weight = attn_layer.qkv.weight.data.reshape(
self.num_heads, 3 * self.channels // self.num_heads, self.channels
)
qkv_bias = attn_layer.qkv.bias.data.reshape(self.num_heads, 3 * self.channels // self.num_heads)
q_w, k_w, v_w = qkv_weight.split(self.channels // self.num_heads, dim=1)
q_b, k_b, v_b = qkv_bias.split(self.channels // self.num_heads, dim=1)
self.query.weight.data = q_w.reshape(-1, self.channels)
self.key.weight.data = k_w.reshape(-1, self.channels)
self.value.weight.data = v_w.reshape(-1, self.channels)
self.query.bias.data = q_b.reshape(-1)
self.key.bias.data = k_b.reshape(-1)
self.value.bias.data = v_b.reshape(-1)
self.proj_attn.weight.data = attn_layer.proj.weight.data[:, :, 0]
self.proj_attn.bias.data = attn_layer.proj.bias.data
class SpatialTransformer(nn.Module): class SpatialTransformer(nn.Module):
......
...@@ -81,8 +81,10 @@ class Downsample2D(nn.Module): ...@@ -81,8 +81,10 @@ class Downsample2D(nn.Module):
self.conv = conv self.conv = conv
elif name == "Conv2d_0": elif name == "Conv2d_0":
self.Conv2d_0 = conv self.Conv2d_0 = conv
self.conv = conv
else: else:
self.op = conv self.op = conv
self.conv = conv
def forward(self, x): def forward(self, x):
assert x.shape[1] == self.channels assert x.shape[1] == self.channels
...@@ -90,13 +92,16 @@ class Downsample2D(nn.Module): ...@@ -90,13 +92,16 @@ class Downsample2D(nn.Module):
pad = (0, 1, 0, 1) pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0) x = F.pad(x, pad, mode="constant", value=0)
return self.conv(x)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if self.name == "conv":
return self.conv(x)
elif self.name == "Conv2d_0": # if self.name == "conv":
return self.Conv2d_0(x) # return self.conv(x)
else: # elif self.name == "Conv2d_0":
return self.op(x) # return self.Conv2d_0(x)
# else:
# return self.op(x)
class Upsample1D(nn.Module): class Upsample1D(nn.Module):
...@@ -656,9 +661,9 @@ class ResnetBlock(nn.Module): ...@@ -656,9 +661,9 @@ class ResnetBlock(nn.Module):
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if time_embedding_norm == "default" and temb_channels > 0: if time_embedding_norm == "default" and temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels) self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
elif time_embedding_norm == "scale_shift" and temb_channels > 0: elif time_embedding_norm == "scale_shift" and temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels) self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(dropout) self.dropout = torch.nn.Dropout(dropout)
...@@ -691,9 +696,9 @@ class ResnetBlock(nn.Module): ...@@ -691,9 +696,9 @@ class ResnetBlock(nn.Module):
self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
self.nin_shortcut = None self.conv_shortcut = None
if self.use_nin_shortcut: if self.use_nin_shortcut:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, temb): def forward(self, x, temb):
h = x h = x
...@@ -715,7 +720,7 @@ class ResnetBlock(nn.Module): ...@@ -715,7 +720,7 @@ class ResnetBlock(nn.Module):
h = self.nonlinearity(h) h = self.nonlinearity(h)
if temb is not None: if temb is not None:
temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None] temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
else: else:
temb = 0 temb = 0
...@@ -738,8 +743,8 @@ class ResnetBlock(nn.Module): ...@@ -738,8 +743,8 @@ class ResnetBlock(nn.Module):
h = self.norm2(h) h = self.norm2(h)
h = self.nonlinearity(h) h = self.nonlinearity(h)
if self.nin_shortcut is not None: if self.conv_shortcut is not None:
x = self.nin_shortcut(x) x = self.conv_shortcut(x)
return (x + h) / self.output_scale_factor return (x + h) / self.output_scale_factor
...@@ -750,8 +755,8 @@ class ResnetBlock(nn.Module): ...@@ -750,8 +755,8 @@ class ResnetBlock(nn.Module):
self.conv1.weight.data = resnet.conv1.weight.data self.conv1.weight.data = resnet.conv1.weight.data
self.conv1.bias.data = resnet.conv1.bias.data self.conv1.bias.data = resnet.conv1.bias.data
self.temb_proj.weight.data = resnet.temb_proj.weight.data self.time_emb_proj.weight.data = resnet.temb_proj.weight.data
self.temb_proj.bias.data = resnet.temb_proj.bias.data self.time_emb_proj.bias.data = resnet.temb_proj.bias.data
self.norm2.weight.data = resnet.norm2.weight.data self.norm2.weight.data = resnet.norm2.weight.data
self.norm2.bias.data = resnet.norm2.bias.data self.norm2.bias.data = resnet.norm2.bias.data
...@@ -760,8 +765,8 @@ class ResnetBlock(nn.Module): ...@@ -760,8 +765,8 @@ class ResnetBlock(nn.Module):
self.conv2.bias.data = resnet.conv2.bias.data self.conv2.bias.data = resnet.conv2.bias.data
if self.use_nin_shortcut: if self.use_nin_shortcut:
self.nin_shortcut.weight.data = resnet.nin_shortcut.weight.data self.conv_shortcut.weight.data = resnet.nin_shortcut.weight.data
self.nin_shortcut.bias.data = resnet.nin_shortcut.bias.data self.conv_shortcut.bias.data = resnet.nin_shortcut.bias.data
# TODO(Patrick) - just there to convert the weights; can delete afterward # TODO(Patrick) - just there to convert the weights; can delete afterward
......
...@@ -177,7 +177,9 @@ class UNetModel(ModelMixin, ConfigMixin): ...@@ -177,7 +177,9 @@ class UNetModel(ModelMixin, ConfigMixin):
hs.append(self.down[i_level].downsample(hs[-1])) hs.append(self.down[i_level].downsample(hs[-1]))
# middle # middle
print("hs", hs[-1].abs().sum())
h = self.mid_new(hs[-1], temb) h = self.mid_new(hs[-1], temb)
print("h", h.abs().sum())
# upsampling # upsampling
for i_level in reversed(range(self.num_resolutions)): for i_level in reversed(range(self.num_resolutions)):
......
...@@ -29,9 +29,10 @@ def get_down_block( ...@@ -29,9 +29,10 @@ def get_down_block(
resnet_eps, resnet_eps,
resnet_act_fn, resnet_act_fn,
attn_num_head_channels, attn_num_head_channels,
downsample_padding=None,
): ):
if down_block_type == "UNetResDownBlock2D": if down_block_type == "UNetResDownBlock2D":
return UNetResAttnDownBlock2D( return UNetResDownBlock2D(
num_layers=num_layers, num_layers=num_layers,
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
...@@ -39,6 +40,7 @@ def get_down_block( ...@@ -39,6 +40,7 @@ def get_down_block(
add_downsample=add_downsample, add_downsample=add_downsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
downsample_padding=downsample_padding,
) )
elif down_block_type == "UNetResAttnDownBlock2D": elif down_block_type == "UNetResAttnDownBlock2D":
return UNetResAttnDownBlock2D( return UNetResAttnDownBlock2D(
...@@ -57,7 +59,8 @@ def get_up_block( ...@@ -57,7 +59,8 @@ def get_up_block(
up_block_type, up_block_type,
num_layers, num_layers,
in_channels, in_channels,
next_channels, out_channels,
prev_output_channel,
temb_channels, temb_channels,
add_upsample, add_upsample,
resnet_eps, resnet_eps,
...@@ -68,7 +71,8 @@ def get_up_block( ...@@ -68,7 +71,8 @@ def get_up_block(
return UNetResUpBlock2D( return UNetResUpBlock2D(
num_layers=num_layers, num_layers=num_layers,
in_channels=in_channels, in_channels=in_channels,
next_channels=next_channels, out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels, temb_channels=temb_channels,
add_upsample=add_upsample, add_upsample=add_upsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
...@@ -78,7 +82,8 @@ def get_up_block( ...@@ -78,7 +82,8 @@ def get_up_block(
return UNetResAttnUpBlock2D( return UNetResAttnUpBlock2D(
num_layers=num_layers, num_layers=num_layers,
in_channels=in_channels, in_channels=in_channels,
next_channels=next_channels, out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels, temb_channels=temb_channels,
add_upsample=add_upsample, add_upsample=add_upsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
...@@ -100,11 +105,14 @@ class UNetMidBlock2D(nn.Module): ...@@ -100,11 +105,14 @@ class UNetMidBlock2D(nn.Module):
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
attn_num_head_channels=1, attn_num_head_channels=1,
attention_type="default",
output_scale_factor=1.0, output_scale_factor=1.0,
**kwargs, **kwargs,
): ):
super().__init__() super().__init__()
self.attention_type = attention_type
# there is always at least one resnet # there is always at least one resnet
resnets = [ resnets = [
ResnetBlock( ResnetBlock(
...@@ -128,6 +136,7 @@ class UNetMidBlock2D(nn.Module): ...@@ -128,6 +136,7 @@ class UNetMidBlock2D(nn.Module):
in_channels, in_channels,
num_head_channels=attn_num_head_channels, num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor, rescale_output_factor=output_scale_factor,
eps=resnet_eps,
) )
) )
resnets.append( resnets.append(
...@@ -148,18 +157,15 @@ class UNetMidBlock2D(nn.Module): ...@@ -148,18 +157,15 @@ class UNetMidBlock2D(nn.Module):
self.attentions = nn.ModuleList(attentions) self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None, encoder_states=None, mask=None): def forward(self, hidden_states, temb=None, encoder_states=None):
if mask is not None: hidden_states = self.resnets[0](hidden_states, temb)
hidden_states = self.resnets[0](hidden_states, temb, mask=mask)
else:
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(hidden_states, encoder_states) if self.attention_type == "default":
if mask is not None: hidden_states = attn(hidden_states)
hidden_states = resnet(hidden_states, temb, mask=mask)
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = attn(hidden_states, encoder_states)
hidden_states = resnet(hidden_states, temb)
return hidden_states return hidden_states
...@@ -178,6 +184,7 @@ class UNetResAttnDownBlock2D(nn.Module): ...@@ -178,6 +184,7 @@ class UNetResAttnDownBlock2D(nn.Module):
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
attn_num_head_channels=1, attn_num_head_channels=1,
attention_type="default",
output_scale_factor=1.0, output_scale_factor=1.0,
add_downsample=True, add_downsample=True,
): ):
...@@ -185,6 +192,8 @@ class UNetResAttnDownBlock2D(nn.Module): ...@@ -185,6 +192,8 @@ class UNetResAttnDownBlock2D(nn.Module):
resnets = [] resnets = []
attentions = [] attentions = []
self.attention_type = attention_type
for i in range(num_layers): for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels in_channels = in_channels if i == 0 else out_channels
resnets.append( resnets.append(
...@@ -206,6 +215,7 @@ class UNetResAttnDownBlock2D(nn.Module): ...@@ -206,6 +215,7 @@ class UNetResAttnDownBlock2D(nn.Module):
out_channels, out_channels,
num_head_channels=attn_num_head_channels, num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor, rescale_output_factor=output_scale_factor,
eps=resnet_eps,
) )
) )
...@@ -251,6 +261,7 @@ class UNetResDownBlock2D(nn.Module): ...@@ -251,6 +261,7 @@ class UNetResDownBlock2D(nn.Module):
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
output_scale_factor=1.0, output_scale_factor=1.0,
add_downsample=True, add_downsample=True,
downsample_padding=1,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -276,7 +287,11 @@ class UNetResDownBlock2D(nn.Module): ...@@ -276,7 +287,11 @@ class UNetResDownBlock2D(nn.Module):
if add_downsample: if add_downsample:
self.downsamplers = nn.ModuleList( self.downsamplers = nn.ModuleList(
[Downsample2D(in_channels, use_conv=True, out_channels=out_channels, padding=1, name="op")] [
Downsample2D(
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
) )
else: else:
self.downsamplers = None self.downsamplers = None
...@@ -301,7 +316,8 @@ class UNetResAttnUpBlock2D(nn.Module): ...@@ -301,7 +316,8 @@ class UNetResAttnUpBlock2D(nn.Module):
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
next_channels: int, prev_output_channel: int,
out_channels: int,
temb_channels: int, temb_channels: int,
dropout: float = 0.0, dropout: float = 0.0,
num_layers: int = 1, num_layers: int = 1,
...@@ -310,7 +326,7 @@ class UNetResAttnUpBlock2D(nn.Module): ...@@ -310,7 +326,7 @@ class UNetResAttnUpBlock2D(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
attention_layer_type: str = "self", attention_type="default",
attn_num_head_channels=1, attn_num_head_channels=1,
output_scale_factor=1.0, output_scale_factor=1.0,
add_upsample=True, add_upsample=True,
...@@ -319,12 +335,16 @@ class UNetResAttnUpBlock2D(nn.Module): ...@@ -319,12 +335,16 @@ class UNetResAttnUpBlock2D(nn.Module):
resnets = [] resnets = []
attentions = [] attentions = []
self.attention_type = attention_type
for i in range(num_layers): for i in range(num_layers):
resnet_channels = in_channels if i < num_layers - 1 else next_channels res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append( resnets.append(
ResnetBlock( ResnetBlock(
in_channels=in_channels + resnet_channels, in_channels=resnet_in_channels + res_skip_channels,
out_channels=in_channels, out_channels=out_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
eps=resnet_eps, eps=resnet_eps,
groups=resnet_groups, groups=resnet_groups,
...@@ -337,9 +357,10 @@ class UNetResAttnUpBlock2D(nn.Module): ...@@ -337,9 +357,10 @@ class UNetResAttnUpBlock2D(nn.Module):
) )
attentions.append( attentions.append(
AttentionBlockNew( AttentionBlockNew(
in_channels, out_channels,
num_head_channels=attn_num_head_channels, num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor, rescale_output_factor=output_scale_factor,
eps=resnet_eps,
) )
) )
...@@ -347,7 +368,7 @@ class UNetResAttnUpBlock2D(nn.Module): ...@@ -347,7 +368,7 @@ class UNetResAttnUpBlock2D(nn.Module):
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
if add_upsample: if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(in_channels, use_conv=True, out_channels=in_channels)]) self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else: else:
self.upsamplers = None self.upsamplers = None
...@@ -373,7 +394,8 @@ class UNetResUpBlock2D(nn.Module): ...@@ -373,7 +394,8 @@ class UNetResUpBlock2D(nn.Module):
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
next_channels: int, prev_output_channel: int,
out_channels: int,
temb_channels: int, temb_channels: int,
dropout: float = 0.0, dropout: float = 0.0,
num_layers: int = 1, num_layers: int = 1,
...@@ -382,7 +404,6 @@ class UNetResUpBlock2D(nn.Module): ...@@ -382,7 +404,6 @@ class UNetResUpBlock2D(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
attention_layer_type: str = "self",
output_scale_factor=1.0, output_scale_factor=1.0,
add_upsample=True, add_upsample=True,
): ):
...@@ -390,11 +411,13 @@ class UNetResUpBlock2D(nn.Module): ...@@ -390,11 +411,13 @@ class UNetResUpBlock2D(nn.Module):
resnets = [] resnets = []
for i in range(num_layers): for i in range(num_layers):
resnet_channels = in_channels if i < num_layers - 1 else next_channels res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append( resnets.append(
ResnetBlock( ResnetBlock(
in_channels=in_channels + resnet_channels, in_channels=resnet_in_channels + res_skip_channels,
out_channels=in_channels, out_channels=out_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
eps=resnet_eps, eps=resnet_eps,
groups=resnet_groups, groups=resnet_groups,
...@@ -409,7 +432,7 @@ class UNetResUpBlock2D(nn.Module): ...@@ -409,7 +432,7 @@ class UNetResUpBlock2D(nn.Module):
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
if add_upsample: if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(in_channels, use_conv=True, out_channels=in_channels)]) self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else: else:
self.upsamplers = None self.upsamplers = None
......
This diff is collapsed.
...@@ -271,6 +271,27 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -271,6 +271,27 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
# fmt: on # fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
print("Original success!!!")
model = UNetUnconditionalModel.from_pretrained("fusing/ddpm_dummy", ddpm=True)
model.eval()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
time_step = torch.tensor([10])
with torch.no_grad():
output = model(noise, time_step)
output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off
expected_output_slice = torch.tensor([0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053])
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase): class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):
model_class = GlideSuperResUNetModel model_class = GlideSuperResUNetModel
...@@ -486,18 +507,20 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -486,18 +507,20 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
"out_channels": 4, "out_channels": 4,
"num_res_blocks": 2, "num_res_blocks": 2,
"attention_resolutions": (16,), "attention_resolutions": (16,),
"block_input_channels": [32, 32], "block_channels": (32, 64),
"block_output_channels": [32, 64],
"num_head_channels": 32, "num_head_channels": 32,
"conv_resample": True, "conv_resample": True,
"down_blocks": ("UNetResDownBlock2D", "UNetResDownBlock2D"), "down_blocks": ("UNetResDownBlock2D", "UNetResDownBlock2D"),
"up_blocks": ("UNetResUpBlock2D", "UNetResUpBlock2D"), "up_blocks": ("UNetResUpBlock2D", "UNetResUpBlock2D"),
"ldm": True,
} }
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
def test_from_pretrained_hub(self): def test_from_pretrained_hub(self):
model, loading_info = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy", output_loading_info=True) model, loading_info = UNetUnconditionalModel.from_pretrained(
"fusing/unet-ldm-dummy", output_loading_info=True, ldm=True
)
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0) self.assertEqual(len(loading_info["missing_keys"]), 0)
...@@ -507,7 +530,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -507,7 +530,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
assert image is not None, "Make sure output is not None" assert image is not None, "Make sure output is not None"
def test_output_pretrained(self): def test_output_pretrained(self):
model = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy") model = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy", ldm=True)
model.eval() model.eval()
torch.manual_seed(0) torch.manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment