Unverified Commit 3100bc96 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Vae and AutoencoderKL] Final clean of LDM checkpoints (#137)

* [Vae and AutoencoderKL clean]

* save intermediate finished work

* more progress

* more progress

* finish modeling code

* save intermediate

* finish

* Correct tests
parent e05f03ae
from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline, VQModel, AutoencoderKL
import argparse import argparse
import json import json
import torch import torch
...@@ -64,7 +64,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s ...@@ -64,7 +64,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 num_heads = old_tensor.shape[0] // config.get("num_head_channels", 1) // 3
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
query, key, value = old_tensor.split(channels // num_heads, dim=1) query, key, value = old_tensor.split(channels // num_heads, dim=1)
...@@ -79,7 +79,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s ...@@ -79,7 +79,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
if attention_paths_to_split is not None and new_path in attention_paths_to_split: if attention_paths_to_split is not None and new_path in attention_paths_to_split:
continue continue
new_path = new_path.replace('down.', 'downsample_blocks.') new_path = new_path.replace('down.', 'down_blocks.')
new_path = new_path.replace('up.', 'up_blocks.') new_path = new_path.replace('up.', 'up_blocks.')
if additional_replacements is not None: if additional_replacements is not None:
...@@ -111,36 +111,36 @@ def convert_ddpm_checkpoint(checkpoint, config): ...@@ -111,36 +111,36 @@ def convert_ddpm_checkpoint(checkpoint, config):
new_checkpoint['conv_out.weight'] = checkpoint['conv_out.weight'] new_checkpoint['conv_out.weight'] = checkpoint['conv_out.weight']
new_checkpoint['conv_out.bias'] = checkpoint['conv_out.bias'] new_checkpoint['conv_out.bias'] = checkpoint['conv_out.bias']
num_downsample_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'down' in layer}) num_down_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'down' in layer})
downsample_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_downsample_blocks)} down_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_down_blocks)}
num_up_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'up' in layer}) num_up_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'up' in layer})
up_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)} up_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)}
for i in range(num_downsample_blocks): for i in range(num_down_blocks):
block_id = (i - 1) // (config['num_res_blocks'] + 1) block_id = (i - 1) // (config['layers_per_block'] + 1)
if any('downsample' in layer for layer in downsample_blocks[i]): if any('downsample' in layer for layer in down_blocks[i]):
new_checkpoint[f'downsample_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'down.{i}.downsample.conv.weight'] new_checkpoint[f'down_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'down.{i}.downsample.op.weight']
new_checkpoint[f'downsample_blocks.{i}.downsamplers.0.conv.bias'] = checkpoint[f'down.{i}.downsample.conv.bias'] new_checkpoint[f'down_blocks.{i}.downsamplers.0.conv.bias'] = checkpoint[f'down.{i}.downsample.op.bias']
new_checkpoint[f'downsample_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight'] # new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
new_checkpoint[f'downsample_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias'] # new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
if any('block' in layer for layer in downsample_blocks[i]): if any('block' in layer for layer in down_blocks[i]):
num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in downsample_blocks[i] if 'block' in layer}) num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in down_blocks[i] if 'block' in layer})
blocks = {layer_id: [key for key in downsample_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)} blocks = {layer_id: [key for key in down_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
if num_blocks > 0: if num_blocks > 0:
for j in range(config['num_res_blocks']): for j in range(config['layers_per_block']):
paths = renew_resnet_paths(blocks[j]) paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint) assign_to_checkpoint(paths, new_checkpoint, checkpoint)
if any('attn' in layer for layer in downsample_blocks[i]): if any('attn' in layer for layer in down_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in downsample_blocks[i] if 'attn' in layer}) num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in down_blocks[i] if 'attn' in layer})
attns = {layer_id: [key for key in downsample_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)} attns = {layer_id: [key for key in down_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
if num_attn > 0: if num_attn > 0:
for j in range(config['num_res_blocks']): for j in range(config['layers_per_block']):
paths = renew_attention_paths(attns[j]) paths = renew_attention_paths(attns[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config) assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
...@@ -176,7 +176,7 @@ def convert_ddpm_checkpoint(checkpoint, config): ...@@ -176,7 +176,7 @@ def convert_ddpm_checkpoint(checkpoint, config):
blocks = {layer_id: [key for key in up_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)} blocks = {layer_id: [key for key in up_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
if num_blocks > 0: if num_blocks > 0:
for j in range(config['num_res_blocks'] + 1): for j in range(config['layers_per_block'] + 1):
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'} replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
paths = renew_resnet_paths(blocks[j]) paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices]) assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
...@@ -186,7 +186,7 @@ def convert_ddpm_checkpoint(checkpoint, config): ...@@ -186,7 +186,7 @@ def convert_ddpm_checkpoint(checkpoint, config):
attns = {layer_id: [key for key in up_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)} attns = {layer_id: [key for key in up_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
if num_attn > 0: if num_attn > 0:
for j in range(config['num_res_blocks'] + 1): for j in range(config['layers_per_block'] + 1):
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'} replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
paths = renew_attention_paths(attns[j]) paths = renew_attention_paths(attns[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices]) assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
...@@ -195,6 +195,117 @@ def convert_ddpm_checkpoint(checkpoint, config): ...@@ -195,6 +195,117 @@ def convert_ddpm_checkpoint(checkpoint, config):
return new_checkpoint return new_checkpoint
def convert_vq_autoenc_checkpoint(checkpoint, config):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
new_checkpoint = {}
new_checkpoint['encoder.conv_norm_out.weight'] = checkpoint['encoder.norm_out.weight']
new_checkpoint['encoder.conv_norm_out.bias'] = checkpoint['encoder.norm_out.bias']
new_checkpoint['encoder.conv_in.weight'] = checkpoint['encoder.conv_in.weight']
new_checkpoint['encoder.conv_in.bias'] = checkpoint['encoder.conv_in.bias']
new_checkpoint['encoder.conv_out.weight'] = checkpoint['encoder.conv_out.weight']
new_checkpoint['encoder.conv_out.bias'] = checkpoint['encoder.conv_out.bias']
new_checkpoint['decoder.conv_norm_out.weight'] = checkpoint['decoder.norm_out.weight']
new_checkpoint['decoder.conv_norm_out.bias'] = checkpoint['decoder.norm_out.bias']
new_checkpoint['decoder.conv_in.weight'] = checkpoint['decoder.conv_in.weight']
new_checkpoint['decoder.conv_in.bias'] = checkpoint['decoder.conv_in.bias']
new_checkpoint['decoder.conv_out.weight'] = checkpoint['decoder.conv_out.weight']
new_checkpoint['decoder.conv_out.bias'] = checkpoint['decoder.conv_out.bias']
num_down_blocks = len({'.'.join(layer.split('.')[:3]) for layer in checkpoint if 'down' in layer})
down_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_down_blocks)}
num_up_blocks = len({'.'.join(layer.split('.')[:3]) for layer in checkpoint if 'up' in layer})
up_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)}
for i in range(num_down_blocks):
block_id = (i - 1) // (config['layers_per_block'] + 1)
if any('downsample' in layer for layer in down_blocks[i]):
new_checkpoint[f'encoder.down_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'encoder.down.{i}.downsample.conv.weight']
new_checkpoint[f'encoder.down_blocks.{i}.downsamplers.0.conv.bias'] = checkpoint[f'encoder.down.{i}.downsample.conv.bias']
if any('block' in layer for layer in down_blocks[i]):
num_blocks = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in down_blocks[i] if 'block' in layer})
blocks = {layer_id: [key for key in down_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
if num_blocks > 0:
for j in range(config['layers_per_block']):
paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint)
if any('attn' in layer for layer in down_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in down_blocks[i] if 'attn' in layer})
attns = {layer_id: [key for key in down_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
if num_attn > 0:
for j in range(config['layers_per_block']):
paths = renew_attention_paths(attns[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
mid_block_1_layers = [key for key in checkpoint if "mid.block_1" in key]
mid_block_2_layers = [key for key in checkpoint if "mid.block_2" in key]
mid_attn_1_layers = [key for key in checkpoint if "mid.attn_1" in key]
# Mid new 2
paths = renew_resnet_paths(mid_block_1_layers)
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'block_1', 'new': 'resnets.0'}
])
paths = renew_resnet_paths(mid_block_2_layers)
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'block_2', 'new': 'resnets.1'}
])
paths = renew_attention_paths(mid_attn_1_layers, in_mid=True)
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'attn_1', 'new': 'attentions.0'}
])
for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i
if any('upsample' in layer for layer in up_blocks[i]):
new_checkpoint[f'decoder.up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'decoder.up.{i}.upsample.conv.weight']
new_checkpoint[f'decoder.up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'decoder.up.{i}.upsample.conv.bias']
if any('block' in layer for layer in up_blocks[i]):
num_blocks = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in up_blocks[i] if 'block' in layer})
blocks = {layer_id: [key for key in up_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
if num_blocks > 0:
for j in range(config['layers_per_block'] + 1):
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
if any('attn' in layer for layer in up_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in up_blocks[i] if 'attn' in layer})
attns = {layer_id: [key for key in up_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
if num_attn > 0:
for j in range(config['layers_per_block'] + 1):
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
paths = renew_attention_paths(attns[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
new_checkpoint = {k.replace('mid_new_2', 'mid_block'): v for k, v in new_checkpoint.items()}
new_checkpoint["quant_conv.weight"] = checkpoint["quant_conv.weight"]
new_checkpoint["quant_conv.bias"] = checkpoint["quant_conv.bias"]
if "quantize.embedding.weight" in checkpoint:
new_checkpoint["quantize.embedding.weight"] = checkpoint["quantize.embedding.weight"]
new_checkpoint["post_quant_conv.weight"] = checkpoint["post_quant_conv.weight"]
new_checkpoint["post_quant_conv.bias"] = checkpoint["post_quant_conv.bias"]
return new_checkpoint
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -220,15 +331,29 @@ if __name__ == "__main__": ...@@ -220,15 +331,29 @@ if __name__ == "__main__":
with open(args.config_file) as f: with open(args.config_file) as f:
config = json.loads(f.read()) config = json.loads(f.read())
converted_checkpoint = convert_ddpm_checkpoint(checkpoint, config) # unet case
key_prefix_set = set(key.split(".")[0] for key in checkpoint.keys())
if "encoder" in key_prefix_set and "decoder" in key_prefix_set:
converted_checkpoint = convert_vq_autoenc_checkpoint(checkpoint, config)
else:
converted_checkpoint = convert_ddpm_checkpoint(checkpoint, config)
if "ddpm" in config: if "ddpm" in config:
del config["ddpm"] del config["ddpm"]
model = UNet2DModel(**config) if config["_class_name"] == "VQModel":
model.load_state_dict(converted_checkpoint) model = VQModel(**config)
model.load_state_dict(converted_checkpoint)
model.save_pretrained(args.dump_path)
elif config["_class_name"] == "AutoencoderKL":
model = AutoencoderKL(**config)
model.load_state_dict(converted_checkpoint)
model.save_pretrained(args.dump_path)
else:
model = UNet2DModel(**config)
model.load_state_dict(converted_checkpoint)
scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1])) scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
pipe = DDPMPipeline(unet=model, scheduler=scheduler) pipe = DDPMPipeline(unet=model, scheduler=scheduler)
pipe.save_pretrained(args.dump_path) pipe.save_pretrained(args.dump_path)
...@@ -288,7 +288,10 @@ class ResnetBlock(nn.Module): ...@@ -288,7 +288,10 @@ class ResnetBlock(nn.Module):
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) if temb_channels is not None:
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
else:
self.time_emb_proj = None
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(dropout) self.dropout = torch.nn.Dropout(dropout)
...@@ -364,8 +367,9 @@ class ResnetBlock(nn.Module): ...@@ -364,8 +367,9 @@ class ResnetBlock(nn.Module):
self.conv1.weight.data = resnet.conv1.weight.data self.conv1.weight.data = resnet.conv1.weight.data
self.conv1.bias.data = resnet.conv1.bias.data self.conv1.bias.data = resnet.conv1.bias.data
self.time_emb_proj.weight.data = resnet.temb_proj.weight.data if self.time_emb_proj is not None:
self.time_emb_proj.bias.data = resnet.temb_proj.bias.data self.time_emb_proj.weight.data = resnet.temb_proj.weight.data
self.time_emb_proj.bias.data = resnet.temb_proj.bias.data
self.norm2.weight.data = resnet.norm2.weight.data self.norm2.weight.data = resnet.norm2.weight.data
self.norm2.bias.data = resnet.norm2.bias.data self.norm2.bias.data = resnet.norm2.bias.data
......
...@@ -92,6 +92,16 @@ def get_down_block( ...@@ -92,6 +92,16 @@ def get_down_block(
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels, attn_num_head_channels=attn_num_head_channels,
) )
elif down_block_type == "DownEncoderBlock2D":
return DownEncoderBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
downsample_padding=downsample_padding,
)
def get_up_block( def get_up_block(
...@@ -165,6 +175,15 @@ def get_up_block( ...@@ -165,6 +175,15 @@ def get_up_block(
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
attn_num_head_channels=attn_num_head_channels, attn_num_head_channels=attn_num_head_channels,
) )
elif up_block_type == "UpDecoderBlock2D":
return UpDecoderBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
)
raise ValueError(f"{up_block_type} does not exist.") raise ValueError(f"{up_block_type} does not exist.")
...@@ -553,6 +572,66 @@ class DownBlock2D(nn.Module): ...@@ -553,6 +572,66 @@ class DownBlock2D(nn.Module):
return hidden_states, output_states return hidden_states, output_states
class DownEncoderBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_downsample=True,
downsample_padding=1,
):
super().__init__()
resnets = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=None,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if add_downsample:
self.downsamplers = nn.ModuleList(
[
Downsample2D(
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
else:
self.downsamplers = None
def forward(self, hidden_states):
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=None)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
return hidden_states
class AttnSkipDownBlock2D(nn.Module): class AttnSkipDownBlock2D(nn.Module):
def __init__( def __init__(
self, self,
...@@ -946,6 +1025,60 @@ class UpBlock2D(nn.Module): ...@@ -946,6 +1025,60 @@ class UpBlock2D(nn.Module):
return hidden_states return hidden_states
class UpDecoderBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_upsample=True,
):
super().__init__()
resnets = []
for i in range(num_layers):
input_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock(
in_channels=input_channels,
out_channels=out_channels,
temb_channels=None,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None
def forward(self, hidden_states):
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=None)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class AttnSkipUpBlock2D(nn.Module): class AttnSkipUpBlock2D(nn.Module):
def __init__( def __init__(
self, self,
......
...@@ -4,221 +4,164 @@ import torch.nn as nn ...@@ -4,221 +4,164 @@ import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__( def __init__(
self, self,
*, in_channels=3,
ch, out_channels=3,
ch_mult=(1, 2, 4, 8), down_block_types=("DownEncoderBlock2D",),
num_res_blocks, block_out_channels=(64,),
attn_resolutions, layers_per_block=2,
dropout=0.0, act_fn="silu",
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True, double_z=True,
**ignore_kwargs,
): ):
super().__init__() super().__init__()
self.ch = ch self.layers_per_block = layers_per_block
self.temb_ch = 0
self.num_resolutions = len(ch_mult) self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
self.num_res_blocks = num_res_blocks
self.resolution = resolution self.mid_block = None
self.in_channels = in_channels self.down_blocks = nn.ModuleList([])
# downsampling # down
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
curr_res = resolution input_channel = output_channel
in_ch_mult = (1,) + tuple(ch_mult) output_channel = block_out_channels[i]
self.down = nn.ModuleList() is_final_block = i == len(block_out_channels) - 1
for i_level in range(self.num_resolutions):
block = nn.ModuleList() down_block = get_down_block(
attn = nn.ModuleList() down_block_type,
block_in = ch * in_ch_mult[i_level] num_layers=self.layers_per_block,
block_out = ch * ch_mult[i_level] in_channels=input_channel,
for i_block in range(self.num_res_blocks): out_channels=output_channel,
block.append( add_downsample=not is_final_block,
ResnetBlock2D( resnet_eps=1e-6,
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout resnet_act_fn=act_fn,
) attn_num_head_channels=None,
) temb_channels=None,
block_in = block_out )
if curr_res in attn_resolutions: self.down_blocks.append(down_block)
attn.append(AttentionBlock(block_in, overwrite_qkv=True))
down = nn.Module() # mid
down.block = block self.mid_block = UNetMidBlock2D(
down.attn = attn in_channels=block_out_channels[-1],
if i_level != self.num_resolutions - 1: resnet_eps=1e-6,
down.downsample = Downsample2D(block_in, use_conv=resamp_with_conv, padding=0) resnet_act_fn=act_fn,
curr_res = curr_res // 2 output_scale_factor=1,
self.down.append(down) resnet_time_scale_shift="default",
attn_num_head_channels=None,
# middle resnet_groups=32,
self.mid = nn.Module() temb_channels=None,
self.mid.block_1 = ResnetBlock2D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True)
self.mid.block_2 = ResnetBlock2D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
) )
# end # out
self.norm_out = Normalize(block_in) num_groups_out = 32
self.conv_out = torch.nn.Conv2d( self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups_out, eps=1e-6)
block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1 self.conv_act = nn.SiLU()
)
conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
def forward(self, x): def forward(self, x):
# assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) sample = x
sample = self.conv_in(sample)
# timestep embedding
temb = None # down
for down_block in self.down_blocks:
# downsampling sample = down_block(sample)
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle # middle
h = hs[-1] sample = self.mid_block(sample)
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h) # post-process
h = self.mid.block_2(h, temb) sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
# end sample = self.conv_out(sample)
h = self.norm_out(h)
h = nonlinearity(h) return sample
h = self.conv_out(h)
return h
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__( def __init__(
self, self,
*, in_channels=3,
ch, out_channels=3,
out_ch, up_block_types=("UpDecoderBlock2D",),
ch_mult=(1, 2, 4, 8), block_out_channels=(64,),
num_res_blocks, layers_per_block=2,
attn_resolutions, act_fn="silu",
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
**ignorekwargs,
): ):
super().__init__() super().__init__()
self.ch = ch self.layers_per_block = layers_per_block
self.temb_ch = 0
self.num_resolutions = len(ch_mult) self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
self.num_res_blocks = num_res_blocks
self.resolution = resolution self.mid_block = None
self.in_channels = in_channels self.up_blocks = nn.ModuleList([])
self.give_pre_end = give_pre_end
# mid
# compute in_ch_mult, block_in and curr_res at lowest res self.mid_block = UNetMidBlock2D(
block_in = ch * ch_mult[self.num_resolutions - 1] in_channels=block_out_channels[-1],
curr_res = resolution // 2 ** (self.num_resolutions - 1) resnet_eps=1e-6,
self.z_shape = (1, z_channels, curr_res, curr_res) resnet_act_fn=act_fn,
# print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) output_scale_factor=1,
resnet_time_scale_shift="default",
# z to block_in attn_num_head_channels=None,
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) resnet_groups=32,
temb_channels=None,
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock2D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True)
self.mid.block_2 = ResnetBlock2D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
) )
# upsampling # up
self.up = nn.ModuleList() reversed_block_out_channels = list(reversed(block_out_channels))
for i_level in reversed(range(self.num_resolutions)): output_channel = reversed_block_out_channels[0]
block = nn.ModuleList() for i, up_block_type in enumerate(up_block_types):
attn = nn.ModuleList() prev_output_channel = output_channel
block_out = ch * ch_mult[i_level] output_channel = reversed_block_out_channels[i]
for i_block in range(self.num_res_blocks + 1):
block.append( is_final_block = i == len(block_out_channels) - 1
ResnetBlock2D(
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout up_block = get_up_block(
) up_block_type,
) num_layers=self.layers_per_block + 1,
block_in = block_out in_channels=prev_output_channel,
if curr_res in attn_resolutions: out_channels=output_channel,
attn.append(AttentionBlock(block_in, overwrite_qkv=True)) prev_output_channel=None,
up = nn.Module() add_upsample=not is_final_block,
up.block = block resnet_eps=1e-6,
up.attn = attn resnet_act_fn=act_fn,
if i_level != 0: attn_num_head_channels=None,
up.upsample = Upsample2D(block_in, use_conv=resamp_with_conv) temb_channels=None,
curr_res = curr_res * 2 )
self.up.insert(0, up) # prepend to get consistent order self.up_blocks.append(up_block)
prev_output_channel = output_channel
# end
self.norm_out = Normalize(block_in) # out
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) num_groups_out = 32
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
def forward(self, z): def forward(self, z):
# assert z.shape[1:] == self.z_shape[1:] sample = z
self.last_z_shape = z.shape sample = self.conv_in(sample)
# timestep embedding # middle
temb = None sample = self.mid_block(sample)
# z to block_in # up
h = self.conv_in(z) for up_block in self.up_blocks:
sample = up_block(sample)
# middle # post-process
h = self.mid.block_1(h, temb) sample = self.conv_norm_out(sample)
h = self.mid.attn_1(h) sample = self.conv_act(sample)
h = self.mid.block_2(h, temb) sample = self.conv_out(sample)
# upsampling return sample
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class VectorQuantizer(nn.Module): class VectorQuantizer(nn.Module):
...@@ -383,57 +326,44 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -383,57 +326,44 @@ class VQModel(ModelMixin, ConfigMixin):
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
ch, in_channels=3,
out_ch, out_channels=3,
num_res_blocks, down_block_types=("DownEncoderBlock2D",),
attn_resolutions, up_block_types=("UpDecoderBlock2D",),
in_channels, block_out_channels=(64,),
resolution, layers_per_block=1,
z_channels, act_fn="silu",
n_embed, latent_channels=3,
embed_dim, sample_size=32,
remap=None, num_vq_embeddings=256,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
ch_mult=(1, 2, 4, 8),
dropout=0.0,
double_z=True,
resamp_with_conv=True,
give_pre_end=False,
): ):
super().__init__() super().__init__()
# pass init params to Encoder # pass init params to Encoder
self.encoder = Encoder( self.encoder = Encoder(
ch=ch,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolutions,
in_channels=in_channels, in_channels=in_channels,
resolution=resolution, out_channels=latent_channels,
z_channels=z_channels, down_block_types=down_block_types,
ch_mult=ch_mult, block_out_channels=block_out_channels,
dropout=dropout, layers_per_block=layers_per_block,
resamp_with_conv=resamp_with_conv, act_fn=act_fn,
double_z=double_z, double_z=False,
give_pre_end=give_pre_end,
) )
self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1) self.quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, remap=remap, sane_index_shape=sane_index_shape) self.quantize = VectorQuantizer(
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1) num_vq_embeddings, latent_channels, beta=0.25, remap=None, sane_index_shape=False
)
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
# pass init params to Decoder # pass init params to Decoder
self.decoder = Decoder( self.decoder = Decoder(
ch=ch, in_channels=latent_channels,
out_ch=out_ch, out_channels=out_channels,
num_res_blocks=num_res_blocks, up_block_types=up_block_types,
attn_resolutions=attn_resolutions, block_out_channels=block_out_channels,
in_channels=in_channels, layers_per_block=layers_per_block,
resolution=resolution, act_fn=act_fn,
z_channels=z_channels,
ch_mult=ch_mult,
dropout=dropout,
resamp_with_conv=resamp_with_conv,
give_pre_end=give_pre_end,
) )
def encode(self, x): def encode(self, x):
...@@ -462,57 +392,41 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -462,57 +392,41 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
ch, in_channels=3,
out_ch, out_channels=3,
num_res_blocks, down_block_types=("DownEncoderBlock2D",),
attn_resolutions, up_block_types=("UpDecoderBlock2D",),
in_channels, block_out_channels=(64,),
resolution, layers_per_block=1,
z_channels, act_fn="silu",
embed_dim, latent_channels=4,
remap=None, sample_size=32,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
ch_mult=(1, 2, 4, 8),
dropout=0.0,
double_z=True,
resamp_with_conv=True,
give_pre_end=False,
): ):
super().__init__() super().__init__()
# pass init params to Encoder # pass init params to Encoder
self.encoder = Encoder( self.encoder = Encoder(
ch=ch,
out_ch=out_ch,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolutions,
in_channels=in_channels, in_channels=in_channels,
resolution=resolution, out_channels=latent_channels,
z_channels=z_channels, down_block_types=down_block_types,
ch_mult=ch_mult, block_out_channels=block_out_channels,
dropout=dropout, layers_per_block=layers_per_block,
resamp_with_conv=resamp_with_conv, act_fn=act_fn,
double_z=double_z, double_z=True,
give_pre_end=give_pre_end,
) )
# pass init params to Decoder # pass init params to Decoder
self.decoder = Decoder( self.decoder = Decoder(
ch=ch, in_channels=latent_channels,
out_ch=out_ch, out_channels=out_channels,
num_res_blocks=num_res_blocks, up_block_types=up_block_types,
attn_resolutions=attn_resolutions, block_out_channels=block_out_channels,
in_channels=in_channels, layers_per_block=layers_per_block,
resolution=resolution, act_fn=act_fn,
z_channels=z_channels,
ch_mult=ch_mult,
dropout=dropout,
resamp_with_conv=resamp_with_conv,
give_pre_end=give_pre_end,
) )
self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * embed_dim, 1) self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1) self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
def encode(self, x): def encode(self, x):
h = self.encoder(x) h = self.encoder(x)
......
...@@ -28,8 +28,8 @@ def enable_full_determinism(seed: int): ...@@ -28,8 +28,8 @@ def enable_full_determinism(seed: int):
def set_seed(seed: int): def set_seed(seed: int):
""" """
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
Args: Args:
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
seed (`int`): The seed to set. seed (`int`): The seed to set.
""" """
random.seed(seed) random.seed(seed)
......
...@@ -555,18 +555,12 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -555,18 +555,12 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
def prepare_init_args_and_inputs_for_common(self): def prepare_init_args_and_inputs_for_common(self):
init_dict = { init_dict = {
"ch": 64, "block_out_channels": [64],
"out_ch": 3,
"num_res_blocks": 1,
"in_channels": 3, "in_channels": 3,
"attn_resolutions": [], "out_channels": 3,
"resolution": 32, "down_block_types": ["DownEncoderBlock2D"],
"z_channels": 3, "up_block_types": ["UpDecoderBlock2D"],
"n_embed": 256, "latent_channels": 3,
"embed_dim": 3,
"sane_index_shape": False,
"ch_mult": (1,),
"double_z": False,
} }
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
...@@ -595,7 +589,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -595,7 +589,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed_all(0) torch.cuda.manual_seed_all(0)
image = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution) image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
with torch.no_grad(): with torch.no_grad():
output = model(image) output = model(image)
...@@ -639,6 +633,14 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): ...@@ -639,6 +633,14 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
"resolution": 32, "resolution": 32,
"z_channels": 4, "z_channels": 4,
} }
init_dict = {
"block_out_channels": [64],
"in_channels": 3,
"out_channels": 3,
"down_block_types": ["DownEncoderBlock2D"],
"up_block_types": ["UpDecoderBlock2D"],
"latent_channels": 4,
}
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
...@@ -666,13 +668,13 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): ...@@ -666,13 +668,13 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed_all(0) torch.cuda.manual_seed_all(0)
image = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution) image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
with torch.no_grad(): with torch.no_grad():
output = model(image, sample_posterior=True) output = model(image, sample_posterior=True)
output_slice = output[0, -1, -3:, -3:].flatten() output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off # fmt: off
expected_output_slice = torch.tensor([-0.0814, -0.0229, -0.1320, -0.4123, -0.0366, -0.3473, 0.0438, -0.1662, 0.1750]) expected_output_slice = torch.tensor([-0.3900, -0.2800, 0.1281, -0.4449, -0.4890, -0.0207, 0.0784, -0.1258, -0.0409])
# fmt: on # fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment