"docs/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "e0b56d2b189330afed74e984a3309e3877450d42"
Unverified Commit 6d5ef87e authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[DDPM] Make DDPM work (#88)

* up

* finish

* uP
parent e7fe901e
...@@ -50,6 +50,8 @@ from diffusers.testing_utils import floats_tensor, slow, torch_device ...@@ -50,6 +50,8 @@ from diffusers.testing_utils import floats_tensor, slow, torch_device
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
# 1. LDM
def test_output_pretrained_ldm_dummy(): def test_output_pretrained_ldm_dummy():
model = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy", ldm=True) model = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy", ldm=True)
model.eval() model.eval()
...@@ -86,9 +88,38 @@ def test_output_pretrained_ldm(): ...@@ -86,9 +88,38 @@ def test_output_pretrained_ldm():
import ipdb; ipdb.set_trace() import ipdb; ipdb.set_trace()
# To see the how the final model should look like # To see the how the final model should look like
test_output_pretrained_ldm_dummy()
test_output_pretrained_ldm()
# => this is the architecture in which the model should be saved in the new format # => this is the architecture in which the model should be saved in the new format
# -> verify new repo with the following tests (in `test_modeling_utils.py`) # -> verify new repo with the following tests (in `test_modeling_utils.py`)
# - test_ldm_uncond (in PipelineTesterMixin) # - test_ldm_uncond (in PipelineTesterMixin)
# - test_output_pretrained ( in UNetLDMModelTests) # - test_output_pretrained ( in UNetLDMModelTests)
#test_output_pretrained_ldm_dummy()
#test_output_pretrained_ldm()
# 2. DDPM
def get_model(model_id):
model = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy", ldm=True)
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
time_step = torch.tensor([10] * noise.shape[0])
with torch.no_grad():
output = model(noise, time_step)
print(model)
# Repos to convert and port to google (part of https://github.com/hojonathanho/diffusion)
# - fusing/ddpm_dummy
# - fusing/ddpm-cifar10
# - https://huggingface.co/fusing/ddpm-lsun-church-ema
# - https://huggingface.co/fusing/ddpm-lsun-bedroom-ema
# - https://huggingface.co/fusing/ddpm-celeba-hq
# tests to make sure to pass
# - test_ddim_cifar10, test_ddim_lsun, test_ddpm_cifar10, test_ddim_cifar10 (in PipelineTesterMixin)
# - test_output_pretrained ( in UNetModelTests)
# e.g.
get_model("fusing/ddpm-cifar10")
...@@ -492,44 +492,46 @@ class ModelMixin(torch.nn.Module): ...@@ -492,44 +492,46 @@ class ModelMixin(torch.nn.Module):
) )
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
if len(unexpected_keys) > 0: if False:
logger.warning( if len(unexpected_keys) > 0:
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" logger.warning(
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
" with another architecture (e.g. initializing a BertForSequenceClassification model from a" f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
" BertForPreTraining model).\n- This IS NOT expected if you are initializing" " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical" " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
" (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
) " identical (initializing a BertForSequenceClassification model from a"
else: " BertForSequenceClassification model)."
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") )
if len(missing_keys) > 0: else:
logger.warning( logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" if len(missing_keys) > 0:
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" logger.warning(
" TRAIN this model on a down-stream task to be able to use it for predictions and inference." f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
) f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
elif len(mismatched_keys) == 0: " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
logger.info( )
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" elif len(mismatched_keys) == 0:
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" logger.info(
f" was trained on, you can already use {model.__class__.__name__} for predictions without further" f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
" training." f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
) f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
if len(mismatched_keys) > 0: " without further training."
mismatched_warning = "\n".join( )
[ if len(mismatched_keys) > 0:
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" mismatched_warning = "\n".join(
for key, shape1, shape2 in mismatched_keys [
] f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
) for key, shape1, shape2 in mismatched_keys
logger.warning( ]
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" )
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" logger.warning(
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
" to use it for predictions and inference." f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
) f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
" able to use it for predictions and inference."
)
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
......
...@@ -166,188 +166,6 @@ class AttentionBlock(nn.Module): ...@@ -166,188 +166,6 @@ class AttentionBlock(nn.Module):
return result return result
class AttentionBlockNew_2(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_head_channels=1,
num_groups=32,
encoder_channels=None,
rescale_output_factor=1.0,
eps=1e-5,
):
super().__init__()
self.channels = channels
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
self.qkv = nn.Conv1d(channels, channels * 3, 1)
self.n_heads = channels // num_head_channels
self.num_head_size = num_head_channels
self.rescale_output_factor = rescale_output_factor
if encoder_channels is not None:
self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1)
self.proj = zero_module(nn.Conv1d(channels, channels, 1))
# ------------------------- new -----------------------
num_heads = self.n_heads
self.channels = channels
if num_head_channels is None:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
# define q,k,v as linear layers
self.query = nn.Linear(channels, channels)
self.key = nn.Linear(channels, channels)
self.value = nn.Linear(channels, channels)
self.rescale_output_factor = rescale_output_factor
self.proj_attn = zero_module(nn.Linear(channels, channels, 1))
# ------------------------- new -----------------------
def set_weight(self, attn_layer):
self.norm.weight.data = attn_layer.norm.weight.data
self.norm.bias.data = attn_layer.norm.bias.data
self.qkv.weight.data = attn_layer.qkv.weight.data
self.qkv.bias.data = attn_layer.qkv.bias.data
self.proj.weight.data = attn_layer.proj.weight.data
self.proj.bias.data = attn_layer.proj.bias.data
if hasattr(attn_layer, "q"):
module = attn_layer
qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[
:, :, :, 0
]
qkv_bias = torch.cat([module.q.bias.data, module.k.bias.data, module.v.bias.data], dim=0)
self.qkv.weight.data = qkv_weight
self.qkv.bias.data = qkv_bias
proj_out = zero_module(nn.Conv1d(self.channels, self.channels, 1))
proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0]
proj_out.bias.data = module.proj_out.bias.data
self.proj = proj_out
self.set_weights_2(attn_layer)
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
new_projection_shape = projection.size()[:-1] + (self.n_heads, self.num_head_size)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
return new_projection
def set_weights_2(self, attn_layer):
self.group_norm.weight.data = attn_layer.norm.weight.data
self.group_norm.bias.data = attn_layer.norm.bias.data
qkv_weight = attn_layer.qkv.weight.data.reshape(self.n_heads, 3 * self.channels // self.n_heads, self.channels)
qkv_bias = attn_layer.qkv.bias.data.reshape(self.n_heads, 3 * self.channels // self.n_heads)
q_w, k_w, v_w = qkv_weight.split(self.channels // self.n_heads, dim=1)
q_b, k_b, v_b = qkv_bias.split(self.channels // self.n_heads, dim=1)
self.query.weight.data = q_w.reshape(-1, self.channels)
self.key.weight.data = k_w.reshape(-1, self.channels)
self.value.weight.data = v_w.reshape(-1, self.channels)
self.query.bias.data = q_b.reshape(-1)
self.key.bias.data = k_b.reshape(-1)
self.value.bias.data = v_b.reshape(-1)
self.proj_attn.weight.data = attn_layer.proj.weight.data[:, :, 0]
self.proj_attn.bias.data = attn_layer.proj.bias.data
def forward_2(self, hidden_states):
residual = hidden_states
batch, channel, height, width = hidden_states.shape
# norm
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
# proj to q, k, v
query_proj = self.query(hidden_states)
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)
# transpose
query_states = self.transpose_for_scores(query_proj)
key_states = self.transpose_for_scores(key_proj)
value_states = self.transpose_for_scores(value_proj)
# get scores
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.channels // self.n_heads)
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# compute attention output
context_states = torch.matmul(attention_probs, value_states)
context_states = context_states.permute(0, 2, 1, 3).contiguous()
new_context_states_shape = context_states.size()[:-2] + (self.channels,)
context_states = context_states.view(new_context_states_shape)
# compute next hidden_states
hidden_states = self.proj_attn(context_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
def forward(self, x, encoder_out=None):
b, c, *spatial = x.shape
hid_states = self.norm(x).view(b, c, -1)
qkv = self.qkv(hid_states)
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
if encoder_out is not None:
encoder_kv = self.encoder_kv(encoder_out)
assert encoder_kv.shape[1] == self.n_heads * ch * 2
ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
k = torch.cat([ek, k], dim=-1)
v = torch.cat([ev, v], dim=-1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v)
h = a.reshape(bs, -1, length)
h = self.proj(h)
h = h.reshape(b, c, *spatial)
result = x + h
result = result / self.rescale_output_factor
result_2 = self.forward_2(x)
print((result - result_2).abs().sum())
return result_2
class AttentionBlockNew(nn.Module): class AttentionBlockNew(nn.Module):
""" """
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
...@@ -387,7 +205,7 @@ class AttentionBlockNew(nn.Module): ...@@ -387,7 +205,7 @@ class AttentionBlockNew(nn.Module):
self.proj_attn = zero_module(nn.Linear(channels, channels, 1)) self.proj_attn = zero_module(nn.Linear(channels, channels, 1))
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
new_projection_shape = projection.size()[:-1] + (self.num_heads, self.num_head_size) new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
return new_projection return new_projection
...@@ -434,24 +252,36 @@ class AttentionBlockNew(nn.Module): ...@@ -434,24 +252,36 @@ class AttentionBlockNew(nn.Module):
self.group_norm.weight.data = attn_layer.norm.weight.data self.group_norm.weight.data = attn_layer.norm.weight.data
self.group_norm.bias.data = attn_layer.norm.bias.data self.group_norm.bias.data = attn_layer.norm.bias.data
qkv_weight = attn_layer.qkv.weight.data.reshape( if hasattr(attn_layer, "q"):
self.num_heads, 3 * self.channels // self.num_heads, self.channels self.query.weight.data = attn_layer.q.weight.data[:, :, 0, 0]
) self.key.weight.data = attn_layer.k.weight.data[:, :, 0, 0]
qkv_bias = attn_layer.qkv.bias.data.reshape(self.num_heads, 3 * self.channels // self.num_heads) self.value.weight.data = attn_layer.v.weight.data[:, :, 0, 0]
self.query.bias.data = attn_layer.q.bias.data
self.key.bias.data = attn_layer.k.bias.data
self.value.bias.data = attn_layer.v.bias.data
self.proj_attn.weight.data = attn_layer.proj_out.weight.data[:, :, 0, 0]
self.proj_attn.bias.data = attn_layer.proj_out.bias.data
else:
qkv_weight = attn_layer.qkv.weight.data.reshape(
self.num_heads, 3 * self.channels // self.num_heads, self.channels
)
qkv_bias = attn_layer.qkv.bias.data.reshape(self.num_heads, 3 * self.channels // self.num_heads)
q_w, k_w, v_w = qkv_weight.split(self.channels // self.num_heads, dim=1) q_w, k_w, v_w = qkv_weight.split(self.channels // self.num_heads, dim=1)
q_b, k_b, v_b = qkv_bias.split(self.channels // self.num_heads, dim=1) q_b, k_b, v_b = qkv_bias.split(self.channels // self.num_heads, dim=1)
self.query.weight.data = q_w.reshape(-1, self.channels) self.query.weight.data = q_w.reshape(-1, self.channels)
self.key.weight.data = k_w.reshape(-1, self.channels) self.key.weight.data = k_w.reshape(-1, self.channels)
self.value.weight.data = v_w.reshape(-1, self.channels) self.value.weight.data = v_w.reshape(-1, self.channels)
self.query.bias.data = q_b.reshape(-1) self.query.bias.data = q_b.reshape(-1)
self.key.bias.data = k_b.reshape(-1) self.key.bias.data = k_b.reshape(-1)
self.value.bias.data = v_b.reshape(-1) self.value.bias.data = v_b.reshape(-1)
self.proj_attn.weight.data = attn_layer.proj.weight.data[:, :, 0] self.proj_attn.weight.data = attn_layer.proj.weight.data[:, :, 0]
self.proj_attn.bias.data = attn_layer.proj.bias.data self.proj_attn.bias.data = attn_layer.proj.bias.data
class SpatialTransformer(nn.Module): class SpatialTransformer(nn.Module):
......
...@@ -87,12 +87,21 @@ class Downsample2D(nn.Module): ...@@ -87,12 +87,21 @@ class Downsample2D(nn.Module):
self.conv = conv self.conv = conv
def forward(self, x): def forward(self, x):
# print("use_conv", self.use_conv)
# print("padding", self.padding)
assert x.shape[1] == self.channels assert x.shape[1] == self.channels
if self.use_conv and self.padding == 0: if self.use_conv and self.padding == 0:
pad = (0, 1, 0, 1) pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0) x = F.pad(x, pad, mode="constant", value=0)
return self.conv(x) # print("x", x.abs().sum())
self.hey = x
assert x.shape[1] == self.channels
x = self.conv(x)
self.yas = x
# print("x", x.abs().sum())
return x
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
......
...@@ -177,9 +177,7 @@ class UNetModel(ModelMixin, ConfigMixin): ...@@ -177,9 +177,7 @@ class UNetModel(ModelMixin, ConfigMixin):
hs.append(self.down[i_level].downsample(hs[-1])) hs.append(self.down[i_level].downsample(hs[-1]))
# middle # middle
print("hs", hs[-1].abs().sum())
h = self.mid_new(hs[-1], temb) h = self.mid_new(hs[-1], temb)
print("h", h.abs().sum())
# upsampling # upsampling
for i_level in reversed(range(self.num_resolutions)): for i_level in reversed(range(self.num_resolutions)):
......
...@@ -51,6 +51,7 @@ def get_down_block( ...@@ -51,6 +51,7 @@ def get_down_block(
add_downsample=add_downsample, add_downsample=add_downsample,
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels, attn_num_head_channels=attn_num_head_channels,
) )
...@@ -186,6 +187,7 @@ class UNetResAttnDownBlock2D(nn.Module): ...@@ -186,6 +187,7 @@ class UNetResAttnDownBlock2D(nn.Module):
attn_num_head_channels=1, attn_num_head_channels=1,
attention_type="default", attention_type="default",
output_scale_factor=1.0, output_scale_factor=1.0,
downsample_padding=1,
add_downsample=True, add_downsample=True,
): ):
super().__init__() super().__init__()
...@@ -224,7 +226,11 @@ class UNetResAttnDownBlock2D(nn.Module): ...@@ -224,7 +226,11 @@ class UNetResAttnDownBlock2D(nn.Module):
if add_downsample: if add_downsample:
self.downsamplers = nn.ModuleList( self.downsamplers = nn.ModuleList(
[Downsample2D(in_channels, use_conv=True, out_channels=out_channels, padding=1, name="op")] [
Downsample2D(
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
) )
else: else:
self.downsamplers = None self.downsamplers = None
......
...@@ -94,25 +94,6 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -94,25 +94,6 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
): ):
super().__init__() super().__init__()
# DELETE if statements if not necessary anymore
# DDPM
if ddpm:
out_channels = out_ch
image_size = resolution
block_channels = [x * ch for x in ch_mult]
conv_resample = resamp_with_conv
flip_sin_to_cos = False
downscale_freq_shift = 1
resnet_eps = 1e-6
block_channels = (32, 64)
down_blocks = (
"UNetResDownBlock2D",
"UNetResAttnDownBlock2D",
)
up_blocks = ("UNetResUpBlock2D", "UNetResAttnUpBlock2D")
downsample_padding = 0
num_head_channels = 64
# register all __init__ params with self.register # register all __init__ params with self.register
self.register_to_config( self.register_to_config(
image_size=image_size, image_size=image_size,
...@@ -250,6 +231,10 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -250,6 +231,10 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
out_channels, out_channels,
) )
if ddpm: if ddpm:
out_channels = out_ch
image_size = resolution
block_channels = [x * ch for x in ch_mult]
conv_resample = resamp_with_conv
self.init_for_ddpm( self.init_for_ddpm(
ch_mult, ch_mult,
ch, ch,
...@@ -290,13 +275,11 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -290,13 +275,11 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
# append to tuple # append to tuple
down_block_res_samples += res_samples down_block_res_samples += res_samples
print("sample", sample.abs().sum())
# 4. mid block # 4. mid block
if self.config.ddpm: if self.config.ddpm:
sample = self.mid_new_2(sample, emb) sample = self.mid_new_2(sample, emb)
else: else:
sample = self.mid(sample, emb) sample = self.mid(sample, emb)
print("sample", sample.abs().sum())
# 5. up blocks # 5. up blocks
for upsample_block in self.upsample_blocks: for upsample_block in self.upsample_blocks:
...@@ -373,8 +356,10 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -373,8 +356,10 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
elif self.config.ddpm: elif self.config.ddpm:
# =============== SET WEIGHTS =============== # =============== SET WEIGHTS ===============
# =============== TIME ====================== # =============== TIME ======================
self.time_embed[0] = self.temb.dense[0] self.time_embedding.linear_1.weight.data = self.temb.dense[0].weight.data
self.time_embed[2] = self.temb.dense[1] self.time_embedding.linear_1.bias.data = self.temb.dense[0].bias.data
self.time_embedding.linear_2.weight.data = self.temb.dense[1].weight.data
self.time_embedding.linear_2.bias.data = self.temb.dense[1].bias.data
for i, block in enumerate(self.down): for i, block in enumerate(self.down):
if hasattr(block, "downsample"): if hasattr(block, "downsample"):
...@@ -391,6 +376,23 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -391,6 +376,23 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self.mid_new_2.resnets[1].set_weight(self.mid.block_2) self.mid_new_2.resnets[1].set_weight(self.mid.block_2)
self.mid_new_2.attentions[0].set_weight(self.mid.attn_1) self.mid_new_2.attentions[0].set_weight(self.mid.attn_1)
for i, block in enumerate(self.up):
k = len(self.up) - 1 - i
if hasattr(block, "upsample"):
self.upsample_blocks[k].upsamplers[0].conv.weight.data = block.upsample.conv.weight.data
self.upsample_blocks[k].upsamplers[0].conv.bias.data = block.upsample.conv.bias.data
if hasattr(block, "block") and len(block.block) > 0:
for j in range(self.num_res_blocks + 1):
self.upsample_blocks[k].resnets[j].set_weight(block.block[j])
if hasattr(block, "attn") and len(block.attn) > 0:
for j in range(self.num_res_blocks + 1):
self.upsample_blocks[k].attentions[j].set_weight(block.attn[j])
self.conv_norm_out.weight.data = self.norm_out.weight.data
self.conv_norm_out.bias.data = self.norm_out.bias.data
self.remove_ddpm()
def init_for_ddpm( def init_for_ddpm(
self, self,
ch_mult, ch_mult,
...@@ -685,3 +687,10 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -685,3 +687,10 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
del self.middle_block del self.middle_block
del self.output_blocks del self.output_blocks
del self.out del self.out
def remove_ddpm(self):
del self.temb
del self.down
del self.mid_new
del self.up
del self.norm_out
...@@ -40,7 +40,6 @@ from diffusers import ( ...@@ -40,7 +40,6 @@ from diffusers import (
ScoreSdeVpPipeline, ScoreSdeVpPipeline,
ScoreSdeVpScheduler, ScoreSdeVpScheduler,
UNetLDMModel, UNetLDMModel,
UNetModel,
UNetUnconditionalModel, UNetUnconditionalModel,
VQModel, VQModel,
) )
...@@ -209,7 +208,7 @@ class ModelTesterMixin: ...@@ -209,7 +208,7 @@ class ModelTesterMixin:
class UnetModelTests(ModelTesterMixin, unittest.TestCase): class UnetModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNetModel model_class = UNetUnconditionalModel
@property @property
def dummy_input(self): def dummy_input(self):
...@@ -234,15 +233,24 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -234,15 +233,24 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
init_dict = { init_dict = {
"ch": 32, "ch": 32,
"ch_mult": (1, 2), "ch_mult": (1, 2),
"block_channels": (32, 64),
"down_blocks": ("UNetResDownBlock2D", "UNetResAttnDownBlock2D"),
"up_blocks": ("UNetResAttnUpBlock2D", "UNetResUpBlock2D"),
"num_head_channels": None,
"out_channels": 3,
"in_channels": 3,
"num_res_blocks": 2, "num_res_blocks": 2,
"attn_resolutions": (16,), "attn_resolutions": (16,),
"resolution": 32, "resolution": 32,
"image_size": 32,
} }
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
def test_from_pretrained_hub(self): def test_from_pretrained_hub(self):
model, loading_info = UNetModel.from_pretrained("fusing/ddpm_dummy", output_loading_info=True) model, loading_info = UNetUnconditionalModel.from_pretrained(
"fusing/ddpm_dummy", output_loading_info=True, ddpm=True
)
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0) self.assertEqual(len(loading_info["missing_keys"]), 0)
...@@ -252,27 +260,6 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -252,27 +260,6 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
assert image is not None, "Make sure output is not None" assert image is not None, "Make sure output is not None"
def test_output_pretrained(self): def test_output_pretrained(self):
model = UNetModel.from_pretrained("fusing/ddpm_dummy")
model.eval()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
noise = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution)
time_step = torch.tensor([10])
with torch.no_grad():
output = model(noise, time_step)
output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off
expected_output_slice = torch.tensor([0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053])
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
print("Original success!!!")
model = UNetUnconditionalModel.from_pretrained("fusing/ddpm_dummy", ddpm=True) model = UNetUnconditionalModel.from_pretrained("fusing/ddpm_dummy", ddpm=True)
model.eval() model.eval()
...@@ -849,7 +836,9 @@ class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase): ...@@ -849,7 +836,9 @@ class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase):
class PipelineTesterMixin(unittest.TestCase): class PipelineTesterMixin(unittest.TestCase):
def test_from_pretrained_save_pretrained(self): def test_from_pretrained_save_pretrained(self):
# 1. Load models # 1. Load models
model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32) model = UNetUnconditionalModel(
ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32, ddpm=True
)
schedular = DDPMScheduler(timesteps=10) schedular = DDPMScheduler(timesteps=10)
ddpm = DDPMPipeline(model, schedular) ddpm = DDPMPipeline(model, schedular)
...@@ -888,7 +877,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -888,7 +877,7 @@ class PipelineTesterMixin(unittest.TestCase):
def test_ddpm_cifar10(self): def test_ddpm_cifar10(self):
model_id = "fusing/ddpm-cifar10" model_id = "fusing/ddpm-cifar10"
unet = UNetModel.from_pretrained(model_id) unet = UNetUnconditionalModel.from_pretrained(model_id, ddpm=True)
noise_scheduler = DDPMScheduler.from_config(model_id) noise_scheduler = DDPMScheduler.from_config(model_id)
noise_scheduler = noise_scheduler.set_format("pt") noise_scheduler = noise_scheduler.set_format("pt")
...@@ -901,7 +890,28 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -901,7 +890,28 @@ class PipelineTesterMixin(unittest.TestCase):
assert image.shape == (1, 3, 32, 32) assert image.shape == (1, 3, 32, 32)
expected_slice = torch.tensor( expected_slice = torch.tensor(
[-0.5712, -0.6215, -0.5953, -0.5438, -0.4775, -0.4539, -0.5172, -0.4872, -0.5105] [-0.1601, -0.2823, -0.6123, -0.2305, -0.3236, -0.4706, -0.1691, -0.2836, -0.3231]
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
@slow
def test_ddim_lsun(self):
model_id = "fusing/ddpm-lsun-bedroom-ema"
unet = UNetUnconditionalModel.from_pretrained(model_id, ddpm=True)
noise_scheduler = DDIMScheduler.from_config(model_id)
noise_scheduler = noise_scheduler.set_format("pt")
ddpm = DDIMPipeline(unet=unet, noise_scheduler=noise_scheduler)
generator = torch.manual_seed(0)
image = ddpm(generator=generator)
image_slice = image[0, -1, -3:, -3:].cpu()
assert image.shape == (1, 3, 256, 256)
expected_slice = torch.tensor(
[-0.9879, -0.9598, -0.9312, -0.9953, -0.9963, -0.9995, -0.9957, -1.0000, -0.9863]
) )
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
...@@ -909,7 +919,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -909,7 +919,7 @@ class PipelineTesterMixin(unittest.TestCase):
def test_ddim_cifar10(self): def test_ddim_cifar10(self):
model_id = "fusing/ddpm-cifar10" model_id = "fusing/ddpm-cifar10"
unet = UNetModel.from_pretrained(model_id) unet = UNetUnconditionalModel.from_pretrained(model_id, ddpm=True)
noise_scheduler = DDIMScheduler(tensor_format="pt") noise_scheduler = DDIMScheduler(tensor_format="pt")
ddim = DDIMPipeline(unet=unet, noise_scheduler=noise_scheduler) ddim = DDIMPipeline(unet=unet, noise_scheduler=noise_scheduler)
...@@ -929,7 +939,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -929,7 +939,7 @@ class PipelineTesterMixin(unittest.TestCase):
def test_pndm_cifar10(self): def test_pndm_cifar10(self):
model_id = "fusing/ddpm-cifar10" model_id = "fusing/ddpm-cifar10"
unet = UNetModel.from_pretrained(model_id) unet = UNetUnconditionalModel.from_pretrained(model_id, ddpm=True)
noise_scheduler = PNDMScheduler(tensor_format="pt") noise_scheduler = PNDMScheduler(tensor_format="pt")
pndm = PNDMPipeline(unet=unet, noise_scheduler=noise_scheduler) pndm = PNDMPipeline(unet=unet, noise_scheduler=noise_scheduler)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment