Unverified Commit 9c3820d0 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Big Model Renaming (#109)

* up

* change model name

* renaming

* more changes

* up

* up

* up

* save checkpoint

* finish api / naming

* finish config renaming

* rename all weights

* finish really
parent 13e37cab
......@@ -84,7 +84,7 @@ For more examples see [schedulers](https://github.com/huggingface/diffusers/tree
```python
import torch
from diffusers import UNetUnconditionalModel, DDIMScheduler
from diffusers import UNet2DModel, DDIMScheduler
import PIL.Image
import numpy as np
import tqdm
......@@ -93,7 +93,7 @@ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
# 1. Load models
scheduler = DDIMScheduler.from_config("fusing/ddpm-celeba-hq", tensor_format="pt")
unet = UNetUnconditionalModel.from_pretrained("fusing/ddpm-celeba-hq", ddpm=True).to(torch_device)
unet = UNet2DModel.from_pretrained("fusing/ddpm-celeba-hq", ddpm=True).to(torch_device)
# 2. Sample gaussian noise
generator = torch.manual_seed(23)
......
from diffusers import UNetUnconditionalModel, DDPMScheduler, DDPMPipeline
from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline
import argparse
import json
import torch
......@@ -80,7 +80,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
continue
new_path = new_path.replace('down.', 'downsample_blocks.')
new_path = new_path.replace('up.', 'upsample_blocks.')
new_path = new_path.replace('up.', 'up_blocks.')
if additional_replacements is not None:
for replacement in additional_replacements:
......@@ -114,8 +114,8 @@ def convert_ddpm_checkpoint(checkpoint, config):
num_downsample_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'down' in layer})
downsample_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_downsample_blocks)}
num_upsample_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'up' in layer})
upsample_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_upsample_blocks)}
num_up_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'up' in layer})
up_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)}
for i in range(num_downsample_blocks):
block_id = (i - 1) // (config['num_res_blocks'] + 1)
......@@ -164,34 +164,34 @@ def convert_ddpm_checkpoint(checkpoint, config):
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'attn_1', 'new': 'attentions.0'}
])
for i in range(num_upsample_blocks):
block_id = num_upsample_blocks - 1 - i
for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i
if any('upsample' in layer for layer in upsample_blocks[i]):
new_checkpoint[f'upsample_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'up.{i}.upsample.conv.weight']
new_checkpoint[f'upsample_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'up.{i}.upsample.conv.bias']
if any('upsample' in layer for layer in up_blocks[i]):
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'up.{i}.upsample.conv.weight']
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'up.{i}.upsample.conv.bias']
if any('block' in layer for layer in upsample_blocks[i]):
num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in upsample_blocks[i] if 'block' in layer})
blocks = {layer_id: [key for key in upsample_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
if any('block' in layer for layer in up_blocks[i]):
num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in up_blocks[i] if 'block' in layer})
blocks = {layer_id: [key for key in up_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
if num_blocks > 0:
for j in range(config['num_res_blocks'] + 1):
replace_indices = {'old': f'upsample_blocks.{i}', 'new': f'upsample_blocks.{block_id}'}
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
if any('attn' in layer for layer in upsample_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in upsample_blocks[i] if 'attn' in layer})
attns = {layer_id: [key for key in upsample_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
if any('attn' in layer for layer in up_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in up_blocks[i] if 'attn' in layer})
attns = {layer_id: [key for key in up_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
if num_attn > 0:
for j in range(config['num_res_blocks'] + 1):
replace_indices = {'old': f'upsample_blocks.{i}', 'new': f'upsample_blocks.{block_id}'}
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
paths = renew_attention_paths(attns[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
new_checkpoint = {k.replace('mid_new_2', 'mid'): v for k, v in new_checkpoint.items()}
new_checkpoint = {k.replace('mid_new_2', 'mid_block'): v for k, v in new_checkpoint.items()}
return new_checkpoint
......@@ -225,7 +225,7 @@ if __name__ == "__main__":
if "ddpm" in config:
del config["ddpm"]
model = UNetUnconditionalModel(**config)
model = UNet2DModel(**config)
model.load_state_dict(converted_checkpoint)
scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
......
......@@ -17,7 +17,7 @@
import argparse
import json
import torch
from diffusers import VQModel, DDPMScheduler, UNetUnconditionalModel, LatentDiffusionUncondPipeline
from diffusers import VQModel, DDPMScheduler, UNet2DModel, LatentDiffusionUncondPipeline
def shave_segments(path, n_shave_prefix_segments=1):
......@@ -207,14 +207,14 @@ def convert_ldm_checkpoint(checkpoint, config):
attentions_paths = renew_attention_paths(attentions)
to_split = {
'middle_block.1.qkv.bias': {
'key': 'mid.attentions.0.key.bias',
'query': 'mid.attentions.0.query.bias',
'value': 'mid.attentions.0.value.bias',
'key': 'mid_block.attentions.0.key.bias',
'query': 'mid_block.attentions.0.query.bias',
'value': 'mid_block.attentions.0.value.bias',
},
'middle_block.1.qkv.weight': {
'key': 'mid.attentions.0.key.weight',
'query': 'mid.attentions.0.query.weight',
'value': 'mid.attentions.0.value.weight',
'key': 'mid_block.attentions.0.key.weight',
'query': 'mid_block.attentions.0.query.weight',
'value': 'mid_block.attentions.0.value.weight',
},
}
assign_to_checkpoint(attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split, config=config)
......@@ -239,13 +239,13 @@ def convert_ldm_checkpoint(checkpoint, config):
resnet_0_paths = renew_resnet_paths(resnets)
paths = renew_resnet_paths(resnets)
meta_path = {'old': f'output_blocks.{i}.0', 'new': f'upsample_blocks.{block_id}.resnets.{layer_in_block_id}'}
meta_path = {'old': f'output_blocks.{i}.0', 'new': f'up_blocks.{block_id}.resnets.{layer_in_block_id}'}
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path], config=config)
if ['conv.weight', 'conv.bias'] in output_block_list.values():
index = list(output_block_list.values()).index(['conv.weight', 'conv.bias'])
new_checkpoint[f'upsample_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'output_blocks.{i}.{index}.conv.weight']
new_checkpoint[f'upsample_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'output_blocks.{i}.{index}.conv.bias']
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'output_blocks.{i}.{index}.conv.weight']
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'output_blocks.{i}.{index}.conv.bias']
# Clear attentions as they have been attributed above.
if len(attentions) == 2:
......@@ -255,18 +255,18 @@ def convert_ldm_checkpoint(checkpoint, config):
paths = renew_attention_paths(attentions)
meta_path = {
'old': f'output_blocks.{i}.1',
'new': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}'
'new': f'up_blocks.{block_id}.attentions.{layer_in_block_id}'
}
to_split = {
f'output_blocks.{i}.1.qkv.bias': {
'key': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias',
'query': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias',
'value': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias',
'key': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias',
'query': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias',
'value': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias',
},
f'output_blocks.{i}.1.qkv.weight': {
'key': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight',
'query': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight',
'value': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight',
'key': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight',
'query': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight',
'value': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight',
},
}
assign_to_checkpoint(
......@@ -281,7 +281,7 @@ def convert_ldm_checkpoint(checkpoint, config):
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
for path in resnet_0_paths:
old_path = '.'.join(['output_blocks', str(i), path['old']])
new_path = '.'.join(['upsample_blocks', str(block_id), 'resnets', str(layer_in_block_id), path['new']])
new_path = '.'.join(['up_blocks', str(block_id), 'resnets', str(layer_in_block_id), path['new']])
new_checkpoint[new_path] = checkpoint[old_path]
......@@ -319,7 +319,7 @@ if __name__ == "__main__":
if "ldm" in config:
del config["ldm"]
model = UNetUnconditionalModel(**config)
model = UNet2DModel(**config)
model.load_state_dict(converted_checkpoint)
try:
......
......@@ -17,16 +17,16 @@
import argparse
import json
import torch
from diffusers import UNetUnconditionalModel
from diffusers import UNet2DModel
def convert_ncsnpp_checkpoint(checkpoint, config):
"""
Takes a state dict and the path to
"""
new_model_architecture = UNetUnconditionalModel(**config)
new_model_architecture.time_steps.W.data = checkpoint["all_modules.0.W"].data
new_model_architecture.time_steps.weight.data = checkpoint["all_modules.0.W"].data
new_model_architecture = UNet2DModel(**config)
new_model_architecture.time_proj.W.data = checkpoint["all_modules.0.W"].data
new_model_architecture.time_proj.weight.data = checkpoint["all_modules.0.W"].data
new_model_architecture.time_embedding.linear_1.weight.data = checkpoint["all_modules.1.weight"].data
new_model_architecture.time_embedding.linear_1.bias.data = checkpoint["all_modules.1.bias"].data
......@@ -92,14 +92,14 @@ def convert_ncsnpp_checkpoint(checkpoint, config):
block.skip_conv.bias.data = checkpoint[f"all_modules.{module_index}.Conv_0.bias"].data
module_index += 1
set_resnet_weights(new_model_architecture.mid.resnets[0], checkpoint, module_index)
set_resnet_weights(new_model_architecture.mid_block.resnets[0], checkpoint, module_index)
module_index += 1
set_attention_weights(new_model_architecture.mid.attentions[0], checkpoint, module_index)
set_attention_weights(new_model_architecture.mid_block.attentions[0], checkpoint, module_index)
module_index += 1
set_resnet_weights(new_model_architecture.mid.resnets[1], checkpoint, module_index)
set_resnet_weights(new_model_architecture.mid_block.resnets[1], checkpoint, module_index)
module_index += 1
for i, block in enumerate(new_model_architecture.upsample_blocks):
for i, block in enumerate(new_model_architecture.up_blocks):
has_attentions = hasattr(block, "attentions")
for j in range(len(block.resnets)):
set_resnet_weights(block.resnets[j], checkpoint, module_index)
......@@ -134,7 +134,7 @@ if __name__ == "__main__":
parser.add_argument(
"--checkpoint_path",
default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_model.pt",
default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_pytorch_model.bin",
type=str,
required=False,
help="Path to the checkpoint to convert.",
......@@ -171,7 +171,7 @@ if __name__ == "__main__":
if "sde" in config:
del config["sde"]
model = UNetUnconditionalModel(**config)
model = UNet2DModel(**config)
model.load_state_dict(converted_checkpoint)
try:
......
from huggingface_hub import HfApi
from transformers.file_utils import has_file
from diffusers import UNetUnconditionalModel
from diffusers import UNet2DModel
import random
import torch
api = HfApi()
......@@ -70,19 +70,22 @@ results["google_ddpm_ema_cat_256"] = torch.tensor([-1.4574, -2.0569, -0.0473, -0
models = api.list_models(filter="diffusers")
for mod in models:
if "google" in mod.author or mod.modelId == "CompVis/ldm-celebahq-256":
local_checkpoint = "/home/patrick/google_checkpoints/" + mod.modelId.split("/")[-1]
if mod.modelId == "CompVis/ldm-celebahq-256" or not has_file(mod.modelId, "config.json"):
model = UNetUnconditionalModel.from_pretrained(mod.modelId, subfolder = "unet")
print(f"Started running {mod.modelId}!!!")
if mod.modelId.startswith("CompVis"):
model = UNet2DModel.from_pretrained(local_checkpoint, subfolder = "unet")
else:
model = UNetUnconditionalModel.from_pretrained(mod.modelId)
model = UNet2DModel.from_pretrained(local_checkpoint)
torch.manual_seed(0)
random.seed(0)
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
time_step = torch.tensor([10] * noise.shape[0])
with torch.no_grad():
logits = model(noise, time_step)['sample']
torch.allclose(logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3)
assert torch.allclose(logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3)
print(f"{mod.modelId} has passed succesfully!!!")
......@@ -7,7 +7,7 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode
__version__ = "0.0.4"
from .modeling_utils import ModelMixin
from .models import AutoencoderKL, UNetConditionalModel, UNetUnconditionalModel, VQModel
from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIMPipeline, DDPMPipeline, LatentDiffusionUncondPipeline, PNDMPipeline, ScoreSdeVePipeline
from .schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, SchedulerMixin, ScoreSdeVeScheduler
......
......@@ -161,10 +161,10 @@ class ConfigMixin:
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed"
" on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token"
" having permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and"
" pass `use_auth_token=True`."
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
" listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
" token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
" login` and pass `use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
......
......@@ -34,7 +34,7 @@ from .utils import (
)
WEIGHTS_NAME = "diffusion_model.pt"
WEIGHTS_NAME = "diffusion_pytorch_model.bin"
logger = logging.get_logger(__name__)
......@@ -147,7 +147,7 @@ class ModelMixin(torch.nn.Module):
models, `pixel_values` for vision models and `input_values` for speech models).
"""
config_name = CONFIG_NAME
_automatically_saved_args = ["_diffusers_version", "_class_name", "name_or_path"]
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
def __init__(self):
super().__init__()
......@@ -341,7 +341,7 @@ class ModelMixin(torch.nn.Module):
subfolder=subfolder,
**kwargs,
)
model.register_to_config(name_or_path=pretrained_model_name_or_path)
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# Load model
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
......@@ -497,7 +497,6 @@ class ModelMixin(torch.nn.Module):
)
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
if False:
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
......
......@@ -16,6 +16,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .unet_conditional import UNetConditionalModel
from .unet_unconditional import UNetUnconditionalModel
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .vae import AutoencoderKL, VQModel
......@@ -17,7 +17,6 @@ class AttentionBlockNew(nn.Module):
def __init__(
self,
channels,
num_heads=1,
num_head_channels=None,
num_groups=32,
rescale_output_factor=1.0,
......@@ -25,14 +24,8 @@ class AttentionBlockNew(nn.Module):
):
super().__init__()
self.channels = channels
if num_head_channels is None:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
self.num_head_size = num_head_channels
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
......
......@@ -78,12 +78,11 @@ class Downsample2D(nn.Module):
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv":
self.Conv2d_0 = conv
self.conv = conv
elif name == "Conv2d_0":
self.Conv2d_0 = conv
self.conv = conv
else:
self.op = conv
self.conv = conv
def forward(self, x):
......
......@@ -9,143 +9,113 @@ from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
class UNetUnconditionalModel(ModelMixin, ConfigMixin):
"""
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param
num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample
rates at which
attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x
downsampling, attention will be used.
:param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param
conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this
model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention
heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks
for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially
increased efficiency.
"""
class UNet2DModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
image_size=None,
in_channels=None,
out_channels=None,
num_res_blocks=None,
dropout=0,
block_channels=(224, 448, 672, 896),
down_blocks=(
"UNetResDownBlock2D",
"UNetResAttnDownBlock2D",
"UNetResAttnDownBlock2D",
"UNetResAttnDownBlock2D",
),
downsample_padding=1,
up_blocks=("UNetResAttnUpBlock2D", "UNetResAttnUpBlock2D", "UNetResAttnUpBlock2D", "UNetResUpBlock2D"),
resnet_act_fn="silu",
resnet_eps=1e-5,
conv_resample=True,
num_head_channels=32,
flip_sin_to_cos=True,
downscale_freq_shift=0,
sample_size=None,
in_channels=3,
out_channels=3,
center_input_sample=False,
time_embedding_type="positional",
freq_shift=0,
flip_sin_to_cos=True,
down_block_types=("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
block_out_channels=(224, 448, 672, 896),
layers_per_block=2,
mid_block_scale_factor=1,
center_input_sample=False,
resnet_num_groups=32,
downsample_padding=1,
act_fn="silu",
attention_head_dim=8,
norm_num_groups=32,
norm_eps=1e-5,
):
super().__init__()
self.image_size = image_size
time_embed_dim = block_channels[0] * 4
self.sample_size = sample_size
time_embed_dim = block_out_channels[0] * 4
# input
self.conv_in = nn.Conv2d(in_channels, block_channels[0], kernel_size=3, padding=(1, 1))
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
# time
if time_embedding_type == "fourier":
self.time_steps = GaussianFourierProjection(embedding_size=block_channels[0], scale=16)
timestep_input_dim = 2 * block_channels[0]
self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
timestep_input_dim = 2 * block_out_channels[0]
elif time_embedding_type == "positional":
self.time_steps = Timesteps(block_channels[0], flip_sin_to_cos, downscale_freq_shift)
timestep_input_dim = block_channels[0]
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
self.downsample_blocks = nn.ModuleList([])
self.mid = None
self.upsample_blocks = nn.ModuleList([])
self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])
# down
output_channel = block_channels[0]
for i, down_block_type in enumerate(down_blocks):
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_channels[i]
is_final_block = i == len(block_channels) - 1
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=num_res_blocks,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
attn_num_head_channels=num_head_channels,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
attn_num_head_channels=attention_head_dim,
downsample_padding=downsample_padding,
)
self.downsample_blocks.append(down_block)
self.down_blocks.append(down_block)
# mid
self.mid = UNetMidBlock2D(
in_channels=block_channels[-1],
dropout=dropout,
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift="default",
attn_num_head_channels=num_head_channels,
resnet_groups=resnet_num_groups,
attn_num_head_channels=attention_head_dim,
resnet_groups=norm_num_groups,
)
# up
reversed_block_channels = list(reversed(block_channels))
output_channel = reversed_block_channels[0]
for i, up_block_type in enumerate(up_blocks):
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_channels[i]
input_channel = reversed_block_channels[min(i + 1, len(block_channels) - 1)]
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
is_final_block = i == len(block_channels) - 1
is_final_block = i == len(block_out_channels) - 1
up_block = get_up_block(
up_block_type,
num_layers=num_res_blocks + 1,
num_layers=layers_per_block + 1,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
add_upsample=not is_final_block,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
attn_num_head_channels=num_head_channels,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
attn_num_head_channels=attention_head_dim,
)
self.upsample_blocks.append(up_block)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
num_groups_out = resnet_num_groups if resnet_num_groups is not None else min(block_channels[0] // 4, 32)
self.conv_norm_out = nn.GroupNorm(num_channels=block_channels[0], num_groups=num_groups_out, eps=resnet_eps)
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_channels[0], out_channels, 3, padding=1)
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
def forward(
self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int]
......@@ -162,7 +132,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
t_emb = self.time_steps(timesteps)
t_emb = self.time_proj(timesteps)
emb = self.time_embedding(t_emb)
# 2. pre-process
......@@ -171,7 +141,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.downsample_blocks:
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "skip_conv"):
sample, res_samples, skip_sample = downsample_block(
hidden_states=sample, temb=emb, skip_sample=skip_sample
......@@ -182,11 +152,11 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
down_block_res_samples += res_samples
# 4. mid
sample = self.mid(sample, emb)
sample = self.mid_block(sample, emb)
# 5. up
skip_sample = None
for upsample_block in self.upsample_blocks:
for upsample_block in self.up_blocks:
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
......
......@@ -9,142 +9,107 @@ from .embeddings import TimestepEmbedding, Timesteps
from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
class UNetConditionalModel(ModelMixin, ConfigMixin):
"""
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param
num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample
rates at which
attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x
downsampling, attention will be used.
:param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param
conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this
model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention
heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks
for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially
increased efficiency.
"""
class UNet2DConditionModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
image_size=None,
sample_size=None,
in_channels=4,
out_channels=4,
num_res_blocks=2,
dropout=0,
block_channels=(320, 640, 1280, 1280),
down_blocks=(
"UNetResCrossAttnDownBlock2D",
"UNetResCrossAttnDownBlock2D",
"UNetResCrossAttnDownBlock2D",
"UNetResDownBlock2D",
),
downsample_padding=1,
up_blocks=(
"UNetResUpBlock2D",
"UNetResCrossAttnUpBlock2D",
"UNetResCrossAttnUpBlock2D",
"UNetResCrossAttnUpBlock2D",
),
resnet_act_fn="silu",
resnet_eps=1e-5,
conv_resample=True,
num_head_channels=8,
center_input_sample=False,
flip_sin_to_cos=True,
downscale_freq_shift=0,
freq_shift=0,
down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"),
up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
block_out_channels=(320, 640, 1280, 1280),
layers_per_block=2,
downsample_padding=1,
mid_block_scale_factor=1,
center_input_sample=False,
resnet_num_groups=30,
act_fn="silu",
norm_num_groups=32,
norm_eps=1e-5,
attention_head_dim=8,
):
super().__init__()
self.image_size = image_size
time_embed_dim = block_channels[0] * 4
self.sample_size = sample_size
time_embed_dim = block_out_channels[0] * 4
# input
self.conv_in = nn.Conv2d(in_channels, block_channels[0], kernel_size=3, padding=(1, 1))
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
# time
self.time_steps = Timesteps(block_channels[0], flip_sin_to_cos, downscale_freq_shift)
timestep_input_dim = block_channels[0]
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
self.downsample_blocks = nn.ModuleList([])
self.mid = None
self.upsample_blocks = nn.ModuleList([])
self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])
# down
output_channel = block_channels[0]
for i, down_block_type in enumerate(down_blocks):
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_channels[i]
is_final_block = i == len(block_channels) - 1
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=num_res_blocks,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
attn_num_head_channels=num_head_channels,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
attn_num_head_channels=attention_head_dim,
downsample_padding=downsample_padding,
)
self.downsample_blocks.append(down_block)
self.down_blocks.append(down_block)
# mid
self.mid = UNetMidBlock2DCrossAttn(
in_channels=block_channels[-1],
dropout=dropout,
self.mid_block = UNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift="default",
attn_num_head_channels=num_head_channels,
resnet_groups=resnet_num_groups,
attn_num_head_channels=attention_head_dim,
resnet_groups=norm_num_groups,
)
# up
reversed_block_channels = list(reversed(block_channels))
output_channel = reversed_block_channels[0]
for i, up_block_type in enumerate(up_blocks):
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_channels[i]
input_channel = reversed_block_channels[min(i + 1, len(block_channels) - 1)]
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
is_final_block = i == len(block_channels) - 1
is_final_block = i == len(block_out_channels) - 1
up_block = get_up_block(
up_block_type,
num_layers=num_res_blocks + 1,
num_layers=layers_per_block + 1,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
add_upsample=not is_final_block,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
attn_num_head_channels=num_head_channels,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
attn_num_head_channels=attention_head_dim,
)
self.upsample_blocks.append(up_block)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_channels[0], num_groups=resnet_num_groups, eps=resnet_eps)
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_channels[0], out_channels, 3, padding=1)
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
def forward(
self,
......@@ -164,7 +129,7 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
t_emb = self.time_steps(timesteps)
t_emb = self.time_proj(timesteps)
emb = self.time_embedding(t_emb)
# 2. pre-process
......@@ -172,7 +137,7 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.downsample_blocks:
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
sample, res_samples = downsample_block(
......@@ -184,10 +149,10 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
down_block_res_samples += res_samples
# 4. mid
sample = self.mid(sample, emb, encoder_hidden_states=encoder_hidden_states)
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
# 5. up
for upsample_block in self.upsample_blocks:
for upsample_block in self.up_blocks:
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
......
......@@ -33,8 +33,9 @@ def get_down_block(
attn_num_head_channels,
downsample_padding=None,
):
if down_block_type == "UNetResDownBlock2D":
return UNetResDownBlock2D(
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlock2D":
return DownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
......@@ -44,8 +45,8 @@ def get_down_block(
resnet_act_fn=resnet_act_fn,
downsample_padding=downsample_padding,
)
elif down_block_type == "UNetResAttnDownBlock2D":
return UNetResAttnDownBlock2D(
elif down_block_type == "AttnDownBlock2D":
return AttnDownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
......@@ -56,8 +57,8 @@ def get_down_block(
downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels,
)
elif down_block_type == "UNetResCrossAttnDownBlock2D":
return UNetResCrossAttnDownBlock2D(
elif down_block_type == "CrossAttnDownBlock2D":
return CrossAttnDownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
......@@ -68,8 +69,8 @@ def get_down_block(
downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels,
)
elif down_block_type == "UNetResSkipDownBlock2D":
return UNetResSkipDownBlock2D(
elif down_block_type == "SkipDownBlock2D":
return SkipDownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
......@@ -79,8 +80,8 @@ def get_down_block(
resnet_act_fn=resnet_act_fn,
downsample_padding=downsample_padding,
)
elif down_block_type == "UNetResAttnSkipDownBlock2D":
return UNetResAttnSkipDownBlock2D(
elif down_block_type == "AttnSkipDownBlock2D":
return AttnSkipDownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
......@@ -105,8 +106,9 @@ def get_up_block(
resnet_act_fn,
attn_num_head_channels,
):
if up_block_type == "UNetResUpBlock2D":
return UNetResUpBlock2D(
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlock2D":
return UpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
......@@ -116,8 +118,8 @@ def get_up_block(
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
)
elif up_block_type == "UNetResCrossAttnUpBlock2D":
return UNetResCrossAttnUpBlock2D(
elif up_block_type == "CrossAttnUpBlock2D":
return CrossAttnUpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
......@@ -128,8 +130,8 @@ def get_up_block(
resnet_act_fn=resnet_act_fn,
attn_num_head_channels=attn_num_head_channels,
)
elif up_block_type == "UNetResAttnUpBlock2D":
return UNetResAttnUpBlock2D(
elif up_block_type == "AttnUpBlock2D":
return AttnUpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
......@@ -140,8 +142,8 @@ def get_up_block(
resnet_act_fn=resnet_act_fn,
attn_num_head_channels=attn_num_head_channels,
)
elif up_block_type == "UNetResSkipUpBlock2D":
return UNetResSkipUpBlock2D(
elif up_block_type == "SkipUpBlock2D":
return SkipUpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
......@@ -151,8 +153,8 @@ def get_up_block(
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
)
elif up_block_type == "UNetResAttnSkipUpBlock2D":
return UNetResAttnSkipUpBlock2D(
elif up_block_type == "AttnSkipUpBlock2D":
return AttnSkipUpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
......@@ -322,7 +324,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
return hidden_states
class UNetResAttnDownBlock2D(nn.Module):
class AttnDownBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
......@@ -403,7 +405,7 @@ class UNetResAttnDownBlock2D(nn.Module):
return hidden_states, output_states
class UNetResCrossAttnDownBlock2D(nn.Module):
class CrossAttnDownBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
......@@ -485,7 +487,7 @@ class UNetResCrossAttnDownBlock2D(nn.Module):
return hidden_states, output_states
class UNetResDownBlock2D(nn.Module):
class DownBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
......@@ -551,7 +553,7 @@ class UNetResDownBlock2D(nn.Module):
return hidden_states, output_states
class UNetResAttnSkipDownBlock2D(nn.Module):
class AttnSkipDownBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
......@@ -644,7 +646,7 @@ class UNetResAttnSkipDownBlock2D(nn.Module):
return hidden_states, output_states, skip_sample
class UNetResSkipDownBlock2D(nn.Module):
class SkipDownBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
......@@ -723,7 +725,7 @@ class UNetResSkipDownBlock2D(nn.Module):
return hidden_states, output_states, skip_sample
class UNetResAttnUpBlock2D(nn.Module):
class AttnUpBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
......@@ -801,7 +803,7 @@ class UNetResAttnUpBlock2D(nn.Module):
return hidden_states
class UNetResCrossAttnUpBlock2D(nn.Module):
class CrossAttnUpBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
......@@ -881,7 +883,7 @@ class UNetResCrossAttnUpBlock2D(nn.Module):
return hidden_states
class UNetResUpBlock2D(nn.Module):
class UpBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
......@@ -944,7 +946,7 @@ class UNetResUpBlock2D(nn.Module):
return hidden_states
class UNetResAttnSkipUpBlock2D(nn.Module):
class AttnSkipUpBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
......@@ -1055,7 +1057,7 @@ class UNetResAttnSkipUpBlock2D(nn.Module):
return hidden_states, skip_sample
class UNetResSkipUpBlock2D(nn.Module):
class SkipUpBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
......
......@@ -25,7 +25,7 @@ from .configuration_utils import ConfigMixin
from .utils import DIFFUSERS_CACHE, logging
INDEX_FILE = "diffusion_model.pt"
INDEX_FILE = "diffusion_pytorch_model.bin"
logger = logging.get_logger(__name__)
......
......@@ -28,7 +28,9 @@ class DDIMPipeline(DiffusionPipeline):
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil"):
def __call__(
self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil"
):
# eta corresponds to η in paper and should be between [0, 1]
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
......@@ -37,7 +39,7 @@ class DDIMPipeline(DiffusionPipeline):
# Sample gaussian noise to begin loop
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
)
image = image.to(torch_device)
......
......@@ -36,7 +36,7 @@ class DDPMPipeline(DiffusionPipeline):
# Sample gaussian noise to begin loop
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
)
image = image.to(torch_device)
......
......@@ -52,7 +52,7 @@ class LatentDiffusionPipeline(DiffusionPipeline):
text_embeddings = self.bert(text_input.input_ids.to(torch_device))
latents = torch.randn(
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
)
latents = latents.to(torch_device)
......
......@@ -24,7 +24,7 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
self.vqvae.to(torch_device)
latents = torch.randn(
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
)
latents = latents.to(torch_device)
......
......@@ -38,7 +38,7 @@ class PNDMPipeline(DiffusionPipeline):
# Sample gaussian noise to begin loop
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
)
image = image.to(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment