Unverified Commit 5e12d5c6 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Clean uncond unet more (#85)

* up

* finished clean up

* remove @
parent 8aed37c1
...@@ -19,6 +19,74 @@ from .attention import AttentionBlockNew ...@@ -19,6 +19,74 @@ from .attention import AttentionBlockNew
from .resnet import Downsample2D, ResnetBlock, Upsample2D from .resnet import Downsample2D, ResnetBlock, Upsample2D
def get_down_block(
down_block_type,
num_layers,
in_channels,
out_channels,
temb_channels,
add_downsample,
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
):
if down_block_type == "UNetResDownBlock2D":
return UNetResAttnDownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
)
elif down_block_type == "UNetResAttnDownBlock2D":
return UNetResAttnDownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
attn_num_head_channels=attn_num_head_channels,
)
def get_up_block(
up_block_type,
num_layers,
in_channels,
next_channels,
temb_channels,
add_upsample,
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
):
if up_block_type == "UNetResUpBlock2D":
return UNetResUpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
next_channels=next_channels,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
)
elif up_block_type == "UNetResAttnUpBlock2D":
return UNetResAttnUpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
next_channels=next_channels,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
attn_num_head_channels=attn_num_head_channels,
)
class UNetMidBlock2D(nn.Module): class UNetMidBlock2D(nn.Module):
def __init__( def __init__(
self, self,
......
...@@ -6,13 +6,7 @@ from ..modeling_utils import ModelMixin ...@@ -6,13 +6,7 @@ from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
from .unet_new import ( from .unet_new import UNetMidBlock2D, get_down_block, get_up_block
UNetMidBlock2D,
UNetResAttnDownBlock2D,
UNetResAttnUpBlock2D,
UNetResDownBlock2D,
UNetResUpBlock2D,
)
class UNetUnconditionalModel(ModelMixin, ConfigMixin): class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...@@ -39,6 +33,188 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -39,6 +33,188 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
increased efficiency. increased efficiency.
""" """
def __init__(
self,
image_size,
in_channels,
out_channels,
num_res_blocks,
dropout=0,
block_input_channels=(224, 224, 448, 672),
block_output_channels=(224, 448, 672, 896),
down_blocks=(
"UNetResDownBlock2D",
"UNetResAttnDownBlock2D",
"UNetResAttnDownBlock2D",
"UNetResAttnDownBlock2D",
),
up_blocks=("UNetResAttnUpBlock2D", "UNetResAttnUpBlock2D", "UNetResAttnUpBlock2D", "UNetResUpBlock2D"),
resnet_act_fn="silu",
resnet_eps=1e-5,
conv_resample=True,
num_head_channels=32,
# To delete once weights are converted
attention_resolutions=(8, 4, 2),
):
super().__init__()
# register all __init__ params with self.register
self.register_to_config(
image_size=image_size,
in_channels=in_channels,
block_input_channels=block_input_channels,
block_output_channels=block_output_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
down_blocks=down_blocks,
up_blocks=up_blocks,
dropout=dropout,
conv_resample=conv_resample,
num_head_channels=num_head_channels,
# (TODO(PVP) - To delete once weights are converted
attention_resolutions=attention_resolutions,
)
# To delete - replace with config values
self.image_size = image_size
self.in_channels = in_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.dropout = dropout
time_embed_dim = block_input_channels[0] * 4
# ======================== Input ===================
self.conv_in = nn.Conv2d(in_channels, block_input_channels[0], kernel_size=3, padding=(1, 1))
# ======================== Time ====================
self.time_embed = nn.Sequential(
nn.Linear(block_input_channels[0], time_embed_dim),
nn.SiLU(),
nn.Linear(time_embed_dim, time_embed_dim),
)
# ======================== Down ====================
input_channels = list(block_input_channels)
output_channels = list(block_output_channels)
self.downsample_blocks = nn.ModuleList([])
for i, (input_channel, output_channel) in enumerate(zip(input_channels, output_channels)):
down_block_type = down_blocks[i]
is_final_block = i == len(input_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=num_res_blocks,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
attn_num_head_channels=num_head_channels,
)
self.downsample_blocks.append(down_block)
# ======================== Mid ====================
self.mid = UNetMidBlock2D(
in_channels=output_channels[-1],
dropout=dropout,
temb_channels=time_embed_dim,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_time_scale_shift="default",
attn_num_head_channels=num_head_channels,
)
self.upsample_blocks = nn.ModuleList([])
for i, (input_channel, output_channel) in enumerate(zip(reversed(input_channels), reversed(output_channels))):
up_block_type = up_blocks[i]
is_final_block = i == len(input_channels) - 1
up_block = get_up_block(
up_block_type,
num_layers=num_res_blocks + 1,
in_channels=output_channel,
next_channels=input_channel,
temb_channels=time_embed_dim,
add_upsample=not is_final_block,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
attn_num_head_channels=num_head_channels,
)
self.upsample_blocks.append(up_block)
# ======================== Out ====================
self.out = nn.Sequential(
nn.GroupNorm(num_channels=output_channels[0], num_groups=32, eps=1e-5),
nn.SiLU(),
nn.Conv2d(block_input_channels[0], out_channels, 3, padding=1),
)
# =========== TO DELETE AFTER CONVERSION ==========
transformer_depth = 1
context_dim = None
legacy = True
num_heads = -1
model_channels = block_input_channels[0]
channel_mult = tuple([x // model_channels for x in block_output_channels])
self.init_for_ldm(
in_channels,
model_channels,
channel_mult,
num_res_blocks,
dropout,
time_embed_dim,
attention_resolutions,
num_head_channels,
num_heads,
legacy,
False,
transformer_depth,
context_dim,
conv_resample,
out_channels,
)
def forward(self, sample, timesteps=None):
# 1. time step embeddings
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
t_emb = get_timestep_embedding(
timesteps, self.config.block_input_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0
)
emb = self.time_embed(t_emb)
# 2. pre-process sample
# sample = sample.type(self.dtype_)
sample = self.conv_in(sample)
# 3. down blocks
down_block_res_samples = (sample,)
for downsample_block in self.downsample_blocks:
sample, res_samples = downsample_block(sample, emb)
# append to tuple
down_block_res_samples += res_samples
# 4. mid block
sample = self.mid(sample, emb)
# 5. up blocks
for upsample_block in self.upsample_blocks:
# pop from tuple
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
sample = upsample_block(sample, res_samples, emb)
# 6. post-process sample
sample = self.out(sample)
return sample
def init_for_ldm( def init_for_ldm(
self, self,
in_channels, in_channels,
...@@ -252,200 +428,3 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -252,200 +428,3 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self.conv_in.weight.data = self.input_blocks[0][0].weight.data self.conv_in.weight.data = self.input_blocks[0][0].weight.data
self.conv_in.bias.data = self.input_blocks[0][0].bias.data self.conv_in.bias.data = self.input_blocks[0][0].bias.data
def __init__(
self,
image_size,
in_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
resnet_input_channels=(224, 224, 448, 672),
resnet_output_channels=(224, 448, 672, 896),
conv_resample=True,
num_head_channels=32,
):
super().__init__()
# register all __init__ params with self.register
self.register_to_config(
image_size=image_size,
in_channels=in_channels,
resnet_input_channels=resnet_input_channels,
resnet_output_channels=resnet_output_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
dropout=dropout,
conv_resample=conv_resample,
num_head_channels=num_head_channels,
)
# To delete - replace with config values
self.image_size = image_size
self.in_channels = in_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.dropout = dropout
time_embed_dim = resnet_input_channels[0] * 4
# ======================== Input ===================
self.conv_in = nn.Conv2d(in_channels, resnet_input_channels[0], kernel_size=3, padding=(1, 1))
# ======================== Time ====================
self.time_embed = nn.Sequential(
nn.Linear(resnet_input_channels[0], time_embed_dim),
nn.SiLU(),
nn.Linear(time_embed_dim, time_embed_dim),
)
# ======================== Down ====================
input_channels = list(resnet_input_channels)
output_channels = list(resnet_output_channels)
ds_new = 1
self.downsample_blocks = nn.ModuleList([])
for i, (input_channel, output_channel) in enumerate(zip(input_channels, output_channels)):
is_final_block = i == len(input_channels) - 1
if ds_new in attention_resolutions:
down_block = UNetResAttnDownBlock2D(
num_layers=num_res_blocks,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=1e-5,
resnet_act_fn="silu",
attn_num_head_channels=num_head_channels,
)
else:
down_block = UNetResDownBlock2D(
num_layers=num_res_blocks,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=1e-5,
resnet_act_fn="silu",
)
self.downsample_blocks.append(down_block)
ds_new *= 2
ds_new = ds_new / 2
# ======================== Mid ====================
self.mid = UNetMidBlock2D(
in_channels=output_channels[-1],
dropout=dropout,
temb_channels=time_embed_dim,
resnet_eps=1e-5,
resnet_act_fn="silu",
resnet_time_scale_shift="default",
attn_num_head_channels=num_head_channels,
)
self.upsample_blocks = nn.ModuleList([])
for i, (input_channel, output_channel) in enumerate(zip(reversed(input_channels), reversed(output_channels))):
is_final_block = i == len(input_channels) - 1
if ds_new in attention_resolutions:
up_block = UNetResAttnUpBlock2D(
num_layers=num_res_blocks + 1,
in_channels=output_channel,
next_channels=input_channel,
temb_channels=time_embed_dim,
add_upsample=not is_final_block,
resnet_eps=1e-5,
resnet_act_fn="silu",
attn_num_head_channels=num_head_channels,
)
else:
up_block = UNetResUpBlock2D(
num_layers=num_res_blocks + 1,
in_channels=output_channel,
next_channels=input_channel,
temb_channels=time_embed_dim,
add_upsample=not is_final_block,
resnet_eps=1e-5,
resnet_act_fn="silu",
)
self.upsample_blocks.append(up_block)
ds_new /= 2
# ======================== Out ====================
self.out = nn.Sequential(
nn.GroupNorm(num_channels=output_channels[0], num_groups=32, eps=1e-5),
nn.SiLU(),
nn.Conv2d(resnet_input_channels[0], out_channels, 3, padding=1),
)
# =========== TO DELETE AFTER CONVERSION ==========
transformer_depth = 1
context_dim = None
legacy = True
num_heads = -1
model_channels = resnet_input_channels[0]
channel_mult = tuple([x // model_channels for x in resnet_output_channels])
self.init_for_ldm(
in_channels,
model_channels,
channel_mult,
num_res_blocks,
dropout,
time_embed_dim,
attention_resolutions,
num_head_channels,
num_heads,
legacy,
False,
transformer_depth,
context_dim,
conv_resample,
out_channels,
)
def forward(self, sample, timesteps=None):
# 1. time step embeddings
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
t_emb = get_timestep_embedding(
timesteps, self.config.resnet_input_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0
)
emb = self.time_embed(t_emb)
# 2. pre-process sample
# sample = sample.type(self.dtype_)
sample = self.conv_in(sample)
# 3. down blocks
down_block_res_samples = (sample,)
for downsample_block in self.downsample_blocks:
sample, res_samples = downsample_block(sample, emb)
# append to tuple
down_block_res_samples += res_samples
# 4. mid block
sample = self.mid(sample, emb)
# 5. up blocks
for upsample_block in self.upsample_blocks:
# pop from tuple
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
sample = upsample_block(sample, res_samples, emb)
# 6. post-process sample
sample = self.out(sample)
return sample
...@@ -492,10 +492,12 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -492,10 +492,12 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
"out_channels": 4, "out_channels": 4,
"num_res_blocks": 2, "num_res_blocks": 2,
"attention_resolutions": (16,), "attention_resolutions": (16,),
"resnet_input_channels": [32, 32], "block_input_channels": [32, 32],
"resnet_output_channels": [32, 64], "resnet_output_channels": [32, 64],
"num_head_channels": 32, "num_head_channels": 32,
"conv_resample": True, "conv_resample": True,
"down_blocks": ("UNetResDownBlock2D", "UNetResDownBlock2D"),
"up_blocks": ("UNetResUpBlock2D", "UNetResUpBlock2D"),
} }
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment