Commit 8aed37c1 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

some more refactor

parent 06c79730
...@@ -41,7 +41,6 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -41,7 +41,6 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
def init_for_ldm( def init_for_ldm(
self, self,
dims,
in_channels, in_channels,
model_channels, model_channels,
channel_mult, channel_mult,
...@@ -80,6 +79,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -80,6 +79,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
return nn.Conv3d(*args, **kwargs) return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}") raise ValueError(f"unsupported dimensions: {dims}")
dims = 2
self.input_blocks = nn.ModuleList( self.input_blocks = nn.ModuleList(
[TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))] [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
) )
...@@ -257,27 +257,14 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -257,27 +257,14 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self, self,
image_size, image_size,
in_channels, in_channels,
model_channels,
out_channels, out_channels,
num_res_blocks, num_res_blocks,
attention_resolutions, attention_resolutions,
dropout=0, dropout=0,
channel_mult=(1, 2, 4, 8), resnet_input_channels=(224, 224, 448, 672),
resnet_output_channels=(224, 448, 672, 896),
conv_resample=True, conv_resample=True,
dims=2, num_head_channels=32,
num_classes=None,
use_checkpoint=False,
use_fp16=False,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
): ):
super().__init__() super().__init__()
...@@ -285,57 +272,39 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -285,57 +272,39 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self.register_to_config( self.register_to_config(
image_size=image_size, image_size=image_size,
in_channels=in_channels, in_channels=in_channels,
model_channels=model_channels, resnet_input_channels=resnet_input_channels,
resnet_output_channels=resnet_output_channels,
out_channels=out_channels, out_channels=out_channels,
num_res_blocks=num_res_blocks, num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions, attention_resolutions=attention_resolutions,
dropout=dropout, dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample, conv_resample=conv_resample,
dims=dims,
num_classes=num_classes,
use_fp16=use_fp16,
num_heads=num_heads,
num_heads_upsample=num_heads_upsample,
num_head_channels=num_head_channels, num_head_channels=num_head_channels,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
transformer_depth=transformer_depth,
context_dim=context_dim,
n_embed=n_embed,
legacy=legacy,
) )
# To delete - replace with config values
self.image_size = image_size self.image_size = image_size
self.in_channels = in_channels self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels self.out_channels = out_channels
self.num_res_blocks = num_res_blocks self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions self.attention_resolutions = attention_resolutions
self.dropout = dropout self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.num_classes = num_classes
self.dtype_ = torch.float16 if use_fp16 else torch.float32
self.num_heads = num_heads
self.num_heads_upsample = num_heads_upsample
self.predict_codebook_ids = n_embed is not None
time_embed_dim = model_channels * 4 time_embed_dim = resnet_input_channels[0] * 4
# ======================== Input =================== # ======================== Input ===================
self.conv_in = nn.Conv2d(in_channels, model_channels, kernel_size=3, padding=(1, 1)) self.conv_in = nn.Conv2d(in_channels, resnet_input_channels[0], kernel_size=3, padding=(1, 1))
# ======================== Time ==================== # ======================== Time ====================
self.time_embed = nn.Sequential( self.time_embed = nn.Sequential(
nn.Linear(model_channels, time_embed_dim), nn.Linear(resnet_input_channels[0], time_embed_dim),
nn.SiLU(), nn.SiLU(),
nn.Linear(time_embed_dim, time_embed_dim), nn.Linear(time_embed_dim, time_embed_dim),
) )
# ======================== Down ==================== # ======================== Down ====================
input_channels = [model_channels * mult for mult in [1] + list(channel_mult[:-1])] input_channels = list(resnet_input_channels)
output_channels = [model_channels * mult for mult in channel_mult] output_channels = list(resnet_output_channels)
ds_new = 1 ds_new = 1
self.downsample_blocks = nn.ModuleList([]) self.downsample_blocks = nn.ModuleList([])
...@@ -377,14 +346,10 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -377,14 +346,10 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
temb_channels=time_embed_dim, temb_channels=time_embed_dim,
resnet_eps=1e-5, resnet_eps=1e-5,
resnet_act_fn="silu", resnet_act_fn="silu",
resnet_time_scale_shift="scale_shift" if use_scale_shift_norm else "default", resnet_time_scale_shift="default",
attn_num_head_channels=num_head_channels, attn_num_head_channels=num_head_channels,
) )
# ======================== Up =====================
# input_channels = [model_channels * mult for mult in channel_mult]
# output_channels = [model_channels * mult for mult in channel_mult]
self.upsample_blocks = nn.ModuleList([]) self.upsample_blocks = nn.ModuleList([])
for i, (input_channel, output_channel) in enumerate(zip(reversed(input_channels), reversed(output_channels))): for i, (input_channel, output_channel) in enumerate(zip(reversed(input_channels), reversed(output_channels))):
is_final_block = i == len(input_channels) - 1 is_final_block = i == len(input_channels) - 1
...@@ -419,12 +384,17 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -419,12 +384,17 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self.out = nn.Sequential( self.out = nn.Sequential(
nn.GroupNorm(num_channels=output_channels[0], num_groups=32, eps=1e-5), nn.GroupNorm(num_channels=output_channels[0], num_groups=32, eps=1e-5),
nn.SiLU(), nn.SiLU(),
nn.Conv2d(model_channels, out_channels, 3, padding=1), nn.Conv2d(resnet_input_channels[0], out_channels, 3, padding=1),
) )
# =========== TO DELETE AFTER CONVERSION ========== # =========== TO DELETE AFTER CONVERSION ==========
transformer_depth = 1
context_dim = None
legacy = True
num_heads = -1
model_channels = resnet_input_channels[0]
channel_mult = tuple([x // model_channels for x in resnet_output_channels])
self.init_for_ldm( self.init_for_ldm(
dims,
in_channels, in_channels,
model_channels, model_channels,
channel_mult, channel_mult,
...@@ -446,11 +416,13 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -446,11 +416,13 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
# 1. time step embeddings # 1. time step embeddings
if not torch.is_tensor(timesteps): if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
t_emb = get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0) t_emb = get_timestep_embedding(
timesteps, self.config.resnet_input_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0
)
emb = self.time_embed(t_emb) emb = self.time_embed(t_emb)
# 2. pre-process sample # 2. pre-process sample
sample = sample.type(self.dtype_) # sample = sample.type(self.dtype_)
sample = self.conv_in(sample) sample = self.conv_in(sample)
# 3. down blocks # 3. down blocks
......
...@@ -490,10 +490,10 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -490,10 +490,10 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
"image_size": 32, "image_size": 32,
"in_channels": 4, "in_channels": 4,
"out_channels": 4, "out_channels": 4,
"model_channels": 32,
"num_res_blocks": 2, "num_res_blocks": 2,
"attention_resolutions": (16,), "attention_resolutions": (16,),
"channel_mult": (1, 2), "resnet_input_channels": [32, 32],
"resnet_output_channels": [32, 64],
"num_head_channels": 32, "num_head_channels": 32,
"conv_resample": True, "conv_resample": True,
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment