Unverified Commit 9c3820d0 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Big Model Renaming (#109)

* up

* change model name

* renaming

* more changes

* up

* up

* up

* save checkpoint

* finish api / naming

* finish config renaming

* rename all weights

* finish really
parent 13e37cab
...@@ -84,7 +84,7 @@ For more examples see [schedulers](https://github.com/huggingface/diffusers/tree ...@@ -84,7 +84,7 @@ For more examples see [schedulers](https://github.com/huggingface/diffusers/tree
```python ```python
import torch import torch
from diffusers import UNetUnconditionalModel, DDIMScheduler from diffusers import UNet2DModel, DDIMScheduler
import PIL.Image import PIL.Image
import numpy as np import numpy as np
import tqdm import tqdm
...@@ -93,7 +93,7 @@ torch_device = "cuda" if torch.cuda.is_available() else "cpu" ...@@ -93,7 +93,7 @@ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
# 1. Load models # 1. Load models
scheduler = DDIMScheduler.from_config("fusing/ddpm-celeba-hq", tensor_format="pt") scheduler = DDIMScheduler.from_config("fusing/ddpm-celeba-hq", tensor_format="pt")
unet = UNetUnconditionalModel.from_pretrained("fusing/ddpm-celeba-hq", ddpm=True).to(torch_device) unet = UNet2DModel.from_pretrained("fusing/ddpm-celeba-hq", ddpm=True).to(torch_device)
# 2. Sample gaussian noise # 2. Sample gaussian noise
generator = torch.manual_seed(23) generator = torch.manual_seed(23)
......
from diffusers import UNetUnconditionalModel, DDPMScheduler, DDPMPipeline from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline
import argparse import argparse
import json import json
import torch import torch
...@@ -80,7 +80,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s ...@@ -80,7 +80,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
continue continue
new_path = new_path.replace('down.', 'downsample_blocks.') new_path = new_path.replace('down.', 'downsample_blocks.')
new_path = new_path.replace('up.', 'upsample_blocks.') new_path = new_path.replace('up.', 'up_blocks.')
if additional_replacements is not None: if additional_replacements is not None:
for replacement in additional_replacements: for replacement in additional_replacements:
...@@ -114,8 +114,8 @@ def convert_ddpm_checkpoint(checkpoint, config): ...@@ -114,8 +114,8 @@ def convert_ddpm_checkpoint(checkpoint, config):
num_downsample_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'down' in layer}) num_downsample_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'down' in layer})
downsample_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_downsample_blocks)} downsample_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_downsample_blocks)}
num_upsample_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'up' in layer}) num_up_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'up' in layer})
upsample_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_upsample_blocks)} up_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)}
for i in range(num_downsample_blocks): for i in range(num_downsample_blocks):
block_id = (i - 1) // (config['num_res_blocks'] + 1) block_id = (i - 1) // (config['num_res_blocks'] + 1)
...@@ -164,34 +164,34 @@ def convert_ddpm_checkpoint(checkpoint, config): ...@@ -164,34 +164,34 @@ def convert_ddpm_checkpoint(checkpoint, config):
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'attn_1', 'new': 'attentions.0'} {'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'attn_1', 'new': 'attentions.0'}
]) ])
for i in range(num_upsample_blocks): for i in range(num_up_blocks):
block_id = num_upsample_blocks - 1 - i block_id = num_up_blocks - 1 - i
if any('upsample' in layer for layer in upsample_blocks[i]): if any('upsample' in layer for layer in up_blocks[i]):
new_checkpoint[f'upsample_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'up.{i}.upsample.conv.weight'] new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'up.{i}.upsample.conv.weight']
new_checkpoint[f'upsample_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'up.{i}.upsample.conv.bias'] new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'up.{i}.upsample.conv.bias']
if any('block' in layer for layer in upsample_blocks[i]): if any('block' in layer for layer in up_blocks[i]):
num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in upsample_blocks[i] if 'block' in layer}) num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in up_blocks[i] if 'block' in layer})
blocks = {layer_id: [key for key in upsample_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)} blocks = {layer_id: [key for key in up_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
if num_blocks > 0: if num_blocks > 0:
for j in range(config['num_res_blocks'] + 1): for j in range(config['num_res_blocks'] + 1):
replace_indices = {'old': f'upsample_blocks.{i}', 'new': f'upsample_blocks.{block_id}'} replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
paths = renew_resnet_paths(blocks[j]) paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices]) assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
if any('attn' in layer for layer in upsample_blocks[i]): if any('attn' in layer for layer in up_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in upsample_blocks[i] if 'attn' in layer}) num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in up_blocks[i] if 'attn' in layer})
attns = {layer_id: [key for key in upsample_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)} attns = {layer_id: [key for key in up_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
if num_attn > 0: if num_attn > 0:
for j in range(config['num_res_blocks'] + 1): for j in range(config['num_res_blocks'] + 1):
replace_indices = {'old': f'upsample_blocks.{i}', 'new': f'upsample_blocks.{block_id}'} replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
paths = renew_attention_paths(attns[j]) paths = renew_attention_paths(attns[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices]) assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
new_checkpoint = {k.replace('mid_new_2', 'mid'): v for k, v in new_checkpoint.items()} new_checkpoint = {k.replace('mid_new_2', 'mid_block'): v for k, v in new_checkpoint.items()}
return new_checkpoint return new_checkpoint
...@@ -225,7 +225,7 @@ if __name__ == "__main__": ...@@ -225,7 +225,7 @@ if __name__ == "__main__":
if "ddpm" in config: if "ddpm" in config:
del config["ddpm"] del config["ddpm"]
model = UNetUnconditionalModel(**config) model = UNet2DModel(**config)
model.load_state_dict(converted_checkpoint) model.load_state_dict(converted_checkpoint)
scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1])) scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import argparse import argparse
import json import json
import torch import torch
from diffusers import VQModel, DDPMScheduler, UNetUnconditionalModel, LatentDiffusionUncondPipeline from diffusers import VQModel, DDPMScheduler, UNet2DModel, LatentDiffusionUncondPipeline
def shave_segments(path, n_shave_prefix_segments=1): def shave_segments(path, n_shave_prefix_segments=1):
...@@ -207,14 +207,14 @@ def convert_ldm_checkpoint(checkpoint, config): ...@@ -207,14 +207,14 @@ def convert_ldm_checkpoint(checkpoint, config):
attentions_paths = renew_attention_paths(attentions) attentions_paths = renew_attention_paths(attentions)
to_split = { to_split = {
'middle_block.1.qkv.bias': { 'middle_block.1.qkv.bias': {
'key': 'mid.attentions.0.key.bias', 'key': 'mid_block.attentions.0.key.bias',
'query': 'mid.attentions.0.query.bias', 'query': 'mid_block.attentions.0.query.bias',
'value': 'mid.attentions.0.value.bias', 'value': 'mid_block.attentions.0.value.bias',
}, },
'middle_block.1.qkv.weight': { 'middle_block.1.qkv.weight': {
'key': 'mid.attentions.0.key.weight', 'key': 'mid_block.attentions.0.key.weight',
'query': 'mid.attentions.0.query.weight', 'query': 'mid_block.attentions.0.query.weight',
'value': 'mid.attentions.0.value.weight', 'value': 'mid_block.attentions.0.value.weight',
}, },
} }
assign_to_checkpoint(attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split, config=config) assign_to_checkpoint(attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split, config=config)
...@@ -239,13 +239,13 @@ def convert_ldm_checkpoint(checkpoint, config): ...@@ -239,13 +239,13 @@ def convert_ldm_checkpoint(checkpoint, config):
resnet_0_paths = renew_resnet_paths(resnets) resnet_0_paths = renew_resnet_paths(resnets)
paths = renew_resnet_paths(resnets) paths = renew_resnet_paths(resnets)
meta_path = {'old': f'output_blocks.{i}.0', 'new': f'upsample_blocks.{block_id}.resnets.{layer_in_block_id}'} meta_path = {'old': f'output_blocks.{i}.0', 'new': f'up_blocks.{block_id}.resnets.{layer_in_block_id}'}
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path], config=config) assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path], config=config)
if ['conv.weight', 'conv.bias'] in output_block_list.values(): if ['conv.weight', 'conv.bias'] in output_block_list.values():
index = list(output_block_list.values()).index(['conv.weight', 'conv.bias']) index = list(output_block_list.values()).index(['conv.weight', 'conv.bias'])
new_checkpoint[f'upsample_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'output_blocks.{i}.{index}.conv.weight'] new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'output_blocks.{i}.{index}.conv.weight']
new_checkpoint[f'upsample_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'output_blocks.{i}.{index}.conv.bias'] new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'output_blocks.{i}.{index}.conv.bias']
# Clear attentions as they have been attributed above. # Clear attentions as they have been attributed above.
if len(attentions) == 2: if len(attentions) == 2:
...@@ -255,18 +255,18 @@ def convert_ldm_checkpoint(checkpoint, config): ...@@ -255,18 +255,18 @@ def convert_ldm_checkpoint(checkpoint, config):
paths = renew_attention_paths(attentions) paths = renew_attention_paths(attentions)
meta_path = { meta_path = {
'old': f'output_blocks.{i}.1', 'old': f'output_blocks.{i}.1',
'new': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}' 'new': f'up_blocks.{block_id}.attentions.{layer_in_block_id}'
} }
to_split = { to_split = {
f'output_blocks.{i}.1.qkv.bias': { f'output_blocks.{i}.1.qkv.bias': {
'key': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias', 'key': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias',
'query': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias', 'query': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias',
'value': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias', 'value': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias',
}, },
f'output_blocks.{i}.1.qkv.weight': { f'output_blocks.{i}.1.qkv.weight': {
'key': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight', 'key': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight',
'query': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight', 'query': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight',
'value': f'upsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight', 'value': f'up_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight',
}, },
} }
assign_to_checkpoint( assign_to_checkpoint(
...@@ -281,7 +281,7 @@ def convert_ldm_checkpoint(checkpoint, config): ...@@ -281,7 +281,7 @@ def convert_ldm_checkpoint(checkpoint, config):
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
for path in resnet_0_paths: for path in resnet_0_paths:
old_path = '.'.join(['output_blocks', str(i), path['old']]) old_path = '.'.join(['output_blocks', str(i), path['old']])
new_path = '.'.join(['upsample_blocks', str(block_id), 'resnets', str(layer_in_block_id), path['new']]) new_path = '.'.join(['up_blocks', str(block_id), 'resnets', str(layer_in_block_id), path['new']])
new_checkpoint[new_path] = checkpoint[old_path] new_checkpoint[new_path] = checkpoint[old_path]
...@@ -319,7 +319,7 @@ if __name__ == "__main__": ...@@ -319,7 +319,7 @@ if __name__ == "__main__":
if "ldm" in config: if "ldm" in config:
del config["ldm"] del config["ldm"]
model = UNetUnconditionalModel(**config) model = UNet2DModel(**config)
model.load_state_dict(converted_checkpoint) model.load_state_dict(converted_checkpoint)
try: try:
......
...@@ -17,16 +17,16 @@ ...@@ -17,16 +17,16 @@
import argparse import argparse
import json import json
import torch import torch
from diffusers import UNetUnconditionalModel from diffusers import UNet2DModel
def convert_ncsnpp_checkpoint(checkpoint, config): def convert_ncsnpp_checkpoint(checkpoint, config):
""" """
Takes a state dict and the path to Takes a state dict and the path to
""" """
new_model_architecture = UNetUnconditionalModel(**config) new_model_architecture = UNet2DModel(**config)
new_model_architecture.time_steps.W.data = checkpoint["all_modules.0.W"].data new_model_architecture.time_proj.W.data = checkpoint["all_modules.0.W"].data
new_model_architecture.time_steps.weight.data = checkpoint["all_modules.0.W"].data new_model_architecture.time_proj.weight.data = checkpoint["all_modules.0.W"].data
new_model_architecture.time_embedding.linear_1.weight.data = checkpoint["all_modules.1.weight"].data new_model_architecture.time_embedding.linear_1.weight.data = checkpoint["all_modules.1.weight"].data
new_model_architecture.time_embedding.linear_1.bias.data = checkpoint["all_modules.1.bias"].data new_model_architecture.time_embedding.linear_1.bias.data = checkpoint["all_modules.1.bias"].data
...@@ -92,14 +92,14 @@ def convert_ncsnpp_checkpoint(checkpoint, config): ...@@ -92,14 +92,14 @@ def convert_ncsnpp_checkpoint(checkpoint, config):
block.skip_conv.bias.data = checkpoint[f"all_modules.{module_index}.Conv_0.bias"].data block.skip_conv.bias.data = checkpoint[f"all_modules.{module_index}.Conv_0.bias"].data
module_index += 1 module_index += 1
set_resnet_weights(new_model_architecture.mid.resnets[0], checkpoint, module_index) set_resnet_weights(new_model_architecture.mid_block.resnets[0], checkpoint, module_index)
module_index += 1 module_index += 1
set_attention_weights(new_model_architecture.mid.attentions[0], checkpoint, module_index) set_attention_weights(new_model_architecture.mid_block.attentions[0], checkpoint, module_index)
module_index += 1 module_index += 1
set_resnet_weights(new_model_architecture.mid.resnets[1], checkpoint, module_index) set_resnet_weights(new_model_architecture.mid_block.resnets[1], checkpoint, module_index)
module_index += 1 module_index += 1
for i, block in enumerate(new_model_architecture.upsample_blocks): for i, block in enumerate(new_model_architecture.up_blocks):
has_attentions = hasattr(block, "attentions") has_attentions = hasattr(block, "attentions")
for j in range(len(block.resnets)): for j in range(len(block.resnets)):
set_resnet_weights(block.resnets[j], checkpoint, module_index) set_resnet_weights(block.resnets[j], checkpoint, module_index)
...@@ -134,7 +134,7 @@ if __name__ == "__main__": ...@@ -134,7 +134,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--checkpoint_path", "--checkpoint_path",
default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_model.pt", default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_pytorch_model.bin",
type=str, type=str,
required=False, required=False,
help="Path to the checkpoint to convert.", help="Path to the checkpoint to convert.",
...@@ -171,7 +171,7 @@ if __name__ == "__main__": ...@@ -171,7 +171,7 @@ if __name__ == "__main__":
if "sde" in config: if "sde" in config:
del config["sde"] del config["sde"]
model = UNetUnconditionalModel(**config) model = UNet2DModel(**config)
model.load_state_dict(converted_checkpoint) model.load_state_dict(converted_checkpoint)
try: try:
......
from huggingface_hub import HfApi from huggingface_hub import HfApi
from transformers.file_utils import has_file from transformers.file_utils import has_file
from diffusers import UNetUnconditionalModel from diffusers import UNet2DModel
import random import random
import torch import torch
api = HfApi() api = HfApi()
...@@ -70,19 +70,22 @@ results["google_ddpm_ema_cat_256"] = torch.tensor([-1.4574, -2.0569, -0.0473, -0 ...@@ -70,19 +70,22 @@ results["google_ddpm_ema_cat_256"] = torch.tensor([-1.4574, -2.0569, -0.0473, -0
models = api.list_models(filter="diffusers") models = api.list_models(filter="diffusers")
for mod in models: for mod in models:
if "google" in mod.author or mod.modelId == "CompVis/ldm-celebahq-256": if "google" in mod.author or mod.modelId == "CompVis/ldm-celebahq-256":
local_checkpoint = "/home/patrick/google_checkpoints/" + mod.modelId.split("/")[-1]
if mod.modelId == "CompVis/ldm-celebahq-256" or not has_file(mod.modelId, "config.json"):
model = UNetUnconditionalModel.from_pretrained(mod.modelId, subfolder = "unet") print(f"Started running {mod.modelId}!!!")
if mod.modelId.startswith("CompVis"):
model = UNet2DModel.from_pretrained(local_checkpoint, subfolder = "unet")
else: else:
model = UNetUnconditionalModel.from_pretrained(mod.modelId) model = UNet2DModel.from_pretrained(local_checkpoint)
torch.manual_seed(0) torch.manual_seed(0)
random.seed(0) random.seed(0)
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size) noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
time_step = torch.tensor([10] * noise.shape[0]) time_step = torch.tensor([10] * noise.shape[0])
with torch.no_grad(): with torch.no_grad():
logits = model(noise, time_step)['sample'] logits = model(noise, time_step)['sample']
torch.allclose(logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3) assert torch.allclose(logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3)
print(f"{mod.modelId} has passed succesfully!!!") print(f"{mod.modelId} has passed succesfully!!!")
...@@ -7,7 +7,7 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode ...@@ -7,7 +7,7 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode
__version__ = "0.0.4" __version__ = "0.0.4"
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .models import AutoencoderKL, UNetConditionalModel, UNetUnconditionalModel, VQModel from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIMPipeline, DDPMPipeline, LatentDiffusionUncondPipeline, PNDMPipeline, ScoreSdeVePipeline from .pipelines import DDIMPipeline, DDPMPipeline, LatentDiffusionUncondPipeline, PNDMPipeline, ScoreSdeVePipeline
from .schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, SchedulerMixin, ScoreSdeVeScheduler from .schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, SchedulerMixin, ScoreSdeVeScheduler
......
...@@ -161,10 +161,10 @@ class ConfigMixin: ...@@ -161,10 +161,10 @@ class ConfigMixin:
except RepositoryNotFoundError: except RepositoryNotFoundError:
raise EnvironmentError( raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed" f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
" on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token" " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
" having permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and" " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
" pass `use_auth_token=True`." " login` and pass `use_auth_token=True`."
) )
except RevisionNotFoundError: except RevisionNotFoundError:
raise EnvironmentError( raise EnvironmentError(
......
...@@ -34,7 +34,7 @@ from .utils import ( ...@@ -34,7 +34,7 @@ from .utils import (
) )
WEIGHTS_NAME = "diffusion_model.pt" WEIGHTS_NAME = "diffusion_pytorch_model.bin"
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -147,7 +147,7 @@ class ModelMixin(torch.nn.Module): ...@@ -147,7 +147,7 @@ class ModelMixin(torch.nn.Module):
models, `pixel_values` for vision models and `input_values` for speech models). models, `pixel_values` for vision models and `input_values` for speech models).
""" """
config_name = CONFIG_NAME config_name = CONFIG_NAME
_automatically_saved_args = ["_diffusers_version", "_class_name", "name_or_path"] _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -341,7 +341,7 @@ class ModelMixin(torch.nn.Module): ...@@ -341,7 +341,7 @@ class ModelMixin(torch.nn.Module):
subfolder=subfolder, subfolder=subfolder,
**kwargs, **kwargs,
) )
model.register_to_config(name_or_path=pretrained_model_name_or_path) model.register_to_config(_name_or_path=pretrained_model_name_or_path)
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# Load model # Load model
pretrained_model_name_or_path = str(pretrained_model_name_or_path) pretrained_model_name_or_path = str(pretrained_model_name_or_path)
...@@ -497,46 +497,45 @@ class ModelMixin(torch.nn.Module): ...@@ -497,46 +497,45 @@ class ModelMixin(torch.nn.Module):
) )
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
if False: if len(unexpected_keys) > 0:
if len(unexpected_keys) > 0: logger.warning(
logger.warning( f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
" or with another architecture (e.g. initializing a BertForSequenceClassification model from a" " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
" BertForPreTraining model).\n- This IS NOT expected if you are initializing" f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" " identical (initializing a BertForSequenceClassification model from a"
" identical (initializing a BertForSequenceClassification model from a" " BertForSequenceClassification model)."
" BertForSequenceClassification model)." )
) else:
else: logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") if len(missing_keys) > 0:
if len(missing_keys) > 0: logger.warning(
logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
" TRAIN this model on a down-stream task to be able to use it for predictions and inference." )
) elif len(mismatched_keys) == 0:
elif len(mismatched_keys) == 0: logger.info(
logger.info( f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" " without further training."
" without further training." )
) if len(mismatched_keys) > 0:
if len(mismatched_keys) > 0: mismatched_warning = "\n".join(
mismatched_warning = "\n".join( [
[ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" for key, shape1, shape2 in mismatched_keys
for key, shape1, shape2 in mismatched_keys ]
] )
) logger.warning(
logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" " able to use it for predictions and inference."
" able to use it for predictions and inference." )
)
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
......
...@@ -16,6 +16,6 @@ ...@@ -16,6 +16,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .unet_conditional import UNetConditionalModel from .unet_2d import UNet2DModel
from .unet_unconditional import UNetUnconditionalModel from .unet_2d_condition import UNet2DConditionModel
from .vae import AutoencoderKL, VQModel from .vae import AutoencoderKL, VQModel
...@@ -17,7 +17,6 @@ class AttentionBlockNew(nn.Module): ...@@ -17,7 +17,6 @@ class AttentionBlockNew(nn.Module):
def __init__( def __init__(
self, self,
channels, channels,
num_heads=1,
num_head_channels=None, num_head_channels=None,
num_groups=32, num_groups=32,
rescale_output_factor=1.0, rescale_output_factor=1.0,
...@@ -25,14 +24,8 @@ class AttentionBlockNew(nn.Module): ...@@ -25,14 +24,8 @@ class AttentionBlockNew(nn.Module):
): ):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
if num_head_channels is None:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
self.num_head_size = num_head_channels self.num_head_size = num_head_channels
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True) self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
......
...@@ -78,12 +78,11 @@ class Downsample2D(nn.Module): ...@@ -78,12 +78,11 @@ class Downsample2D(nn.Module):
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv": if name == "conv":
self.Conv2d_0 = conv
self.conv = conv self.conv = conv
elif name == "Conv2d_0": elif name == "Conv2d_0":
self.Conv2d_0 = conv
self.conv = conv self.conv = conv
else: else:
self.op = conv
self.conv = conv self.conv = conv
def forward(self, x): def forward(self, x):
......
...@@ -9,143 +9,113 @@ from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps ...@@ -9,143 +9,113 @@ from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
class UNetUnconditionalModel(ModelMixin, ConfigMixin): class UNet2DModel(ModelMixin, ConfigMixin):
"""
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param
num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample
rates at which
attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x
downsampling, attention will be used.
:param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param
conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this
model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention
heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks
for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially
increased efficiency.
"""
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
image_size=None, sample_size=None,
in_channels=None, in_channels=3,
out_channels=None, out_channels=3,
num_res_blocks=None, center_input_sample=False,
dropout=0,
block_channels=(224, 448, 672, 896),
down_blocks=(
"UNetResDownBlock2D",
"UNetResAttnDownBlock2D",
"UNetResAttnDownBlock2D",
"UNetResAttnDownBlock2D",
),
downsample_padding=1,
up_blocks=("UNetResAttnUpBlock2D", "UNetResAttnUpBlock2D", "UNetResAttnUpBlock2D", "UNetResUpBlock2D"),
resnet_act_fn="silu",
resnet_eps=1e-5,
conv_resample=True,
num_head_channels=32,
flip_sin_to_cos=True,
downscale_freq_shift=0,
time_embedding_type="positional", time_embedding_type="positional",
freq_shift=0,
flip_sin_to_cos=True,
down_block_types=("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
block_out_channels=(224, 448, 672, 896),
layers_per_block=2,
mid_block_scale_factor=1, mid_block_scale_factor=1,
center_input_sample=False, downsample_padding=1,
resnet_num_groups=32, act_fn="silu",
attention_head_dim=8,
norm_num_groups=32,
norm_eps=1e-5,
): ):
super().__init__() super().__init__()
self.image_size = image_size
time_embed_dim = block_channels[0] * 4 self.sample_size = sample_size
time_embed_dim = block_out_channels[0] * 4
# input # input
self.conv_in = nn.Conv2d(in_channels, block_channels[0], kernel_size=3, padding=(1, 1)) self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
# time # time
if time_embedding_type == "fourier": if time_embedding_type == "fourier":
self.time_steps = GaussianFourierProjection(embedding_size=block_channels[0], scale=16) self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
timestep_input_dim = 2 * block_channels[0] timestep_input_dim = 2 * block_out_channels[0]
elif time_embedding_type == "positional": elif time_embedding_type == "positional":
self.time_steps = Timesteps(block_channels[0], flip_sin_to_cos, downscale_freq_shift) self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_channels[0] timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
self.downsample_blocks = nn.ModuleList([]) self.down_blocks = nn.ModuleList([])
self.mid = None self.mid_block = None
self.upsample_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([])
# down # down
output_channel = block_channels[0] output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_blocks): for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel input_channel = output_channel
output_channel = block_channels[i] output_channel = block_out_channels[i]
is_final_block = i == len(block_channels) - 1 is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block( down_block = get_down_block(
down_block_type, down_block_type,
num_layers=num_res_blocks, num_layers=layers_per_block,
in_channels=input_channel, in_channels=input_channel,
out_channels=output_channel, out_channels=output_channel,
temb_channels=time_embed_dim, temb_channels=time_embed_dim,
add_downsample=not is_final_block, add_downsample=not is_final_block,
resnet_eps=resnet_eps, resnet_eps=norm_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=act_fn,
attn_num_head_channels=num_head_channels, attn_num_head_channels=attention_head_dim,
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
) )
self.downsample_blocks.append(down_block) self.down_blocks.append(down_block)
# mid # mid
self.mid = UNetMidBlock2D( self.mid_block = UNetMidBlock2D(
in_channels=block_channels[-1], in_channels=block_out_channels[-1],
dropout=dropout,
temb_channels=time_embed_dim, temb_channels=time_embed_dim,
resnet_eps=resnet_eps, resnet_eps=norm_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor, output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift="default", resnet_time_scale_shift="default",
attn_num_head_channels=num_head_channels, attn_num_head_channels=attention_head_dim,
resnet_groups=resnet_num_groups, resnet_groups=norm_num_groups,
) )
# up # up
reversed_block_channels = list(reversed(block_channels)) reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_channels[0] output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_blocks): for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel prev_output_channel = output_channel
output_channel = reversed_block_channels[i] output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_channels[min(i + 1, len(block_channels) - 1)] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
is_final_block = i == len(block_channels) - 1 is_final_block = i == len(block_out_channels) - 1
up_block = get_up_block( up_block = get_up_block(
up_block_type, up_block_type,
num_layers=num_res_blocks + 1, num_layers=layers_per_block + 1,
in_channels=input_channel, in_channels=input_channel,
out_channels=output_channel, out_channels=output_channel,
prev_output_channel=prev_output_channel, prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim, temb_channels=time_embed_dim,
add_upsample=not is_final_block, add_upsample=not is_final_block,
resnet_eps=resnet_eps, resnet_eps=norm_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=act_fn,
attn_num_head_channels=num_head_channels, attn_num_head_channels=attention_head_dim,
) )
self.upsample_blocks.append(up_block) self.up_blocks.append(up_block)
prev_output_channel = output_channel prev_output_channel = output_channel
# out # out
num_groups_out = resnet_num_groups if resnet_num_groups is not None else min(block_channels[0] // 4, 32) num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
self.conv_norm_out = nn.GroupNorm(num_channels=block_channels[0], num_groups=num_groups_out, eps=resnet_eps) self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
self.conv_act = nn.SiLU() self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_channels[0], out_channels, 3, padding=1) self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
def forward( def forward(
self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int] self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int]
...@@ -162,7 +132,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -162,7 +132,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device) timesteps = timesteps[None].to(sample.device)
t_emb = self.time_steps(timesteps) t_emb = self.time_proj(timesteps)
emb = self.time_embedding(t_emb) emb = self.time_embedding(t_emb)
# 2. pre-process # 2. pre-process
...@@ -171,7 +141,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -171,7 +141,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
# 3. down # 3. down
down_block_res_samples = (sample,) down_block_res_samples = (sample,)
for downsample_block in self.downsample_blocks: for downsample_block in self.down_blocks:
if hasattr(downsample_block, "skip_conv"): if hasattr(downsample_block, "skip_conv"):
sample, res_samples, skip_sample = downsample_block( sample, res_samples, skip_sample = downsample_block(
hidden_states=sample, temb=emb, skip_sample=skip_sample hidden_states=sample, temb=emb, skip_sample=skip_sample
...@@ -182,11 +152,11 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -182,11 +152,11 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
down_block_res_samples += res_samples down_block_res_samples += res_samples
# 4. mid # 4. mid
sample = self.mid(sample, emb) sample = self.mid_block(sample, emb)
# 5. up # 5. up
skip_sample = None skip_sample = None
for upsample_block in self.upsample_blocks: for upsample_block in self.up_blocks:
res_samples = down_block_res_samples[-len(upsample_block.resnets) :] res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
......
...@@ -9,142 +9,107 @@ from .embeddings import TimestepEmbedding, Timesteps ...@@ -9,142 +9,107 @@ from .embeddings import TimestepEmbedding, Timesteps
from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
class UNetConditionalModel(ModelMixin, ConfigMixin): class UNet2DConditionModel(ModelMixin, ConfigMixin):
"""
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param
num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample
rates at which
attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x
downsampling, attention will be used.
:param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param
conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this
model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention
heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks
for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially
increased efficiency.
"""
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
image_size=None, sample_size=None,
in_channels=4, in_channels=4,
out_channels=4, out_channels=4,
num_res_blocks=2, center_input_sample=False,
dropout=0,
block_channels=(320, 640, 1280, 1280),
down_blocks=(
"UNetResCrossAttnDownBlock2D",
"UNetResCrossAttnDownBlock2D",
"UNetResCrossAttnDownBlock2D",
"UNetResDownBlock2D",
),
downsample_padding=1,
up_blocks=(
"UNetResUpBlock2D",
"UNetResCrossAttnUpBlock2D",
"UNetResCrossAttnUpBlock2D",
"UNetResCrossAttnUpBlock2D",
),
resnet_act_fn="silu",
resnet_eps=1e-5,
conv_resample=True,
num_head_channels=8,
flip_sin_to_cos=True, flip_sin_to_cos=True,
downscale_freq_shift=0, freq_shift=0,
down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"),
up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
block_out_channels=(320, 640, 1280, 1280),
layers_per_block=2,
downsample_padding=1,
mid_block_scale_factor=1, mid_block_scale_factor=1,
center_input_sample=False, act_fn="silu",
resnet_num_groups=30, norm_num_groups=32,
norm_eps=1e-5,
attention_head_dim=8,
): ):
super().__init__() super().__init__()
self.image_size = image_size
time_embed_dim = block_channels[0] * 4 self.sample_size = sample_size
time_embed_dim = block_out_channels[0] * 4
# input # input
self.conv_in = nn.Conv2d(in_channels, block_channels[0], kernel_size=3, padding=(1, 1)) self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
# time # time
self.time_steps = Timesteps(block_channels[0], flip_sin_to_cos, downscale_freq_shift) self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_channels[0] timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
self.downsample_blocks = nn.ModuleList([]) self.down_blocks = nn.ModuleList([])
self.mid = None self.mid_block = None
self.upsample_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([])
# down # down
output_channel = block_channels[0] output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_blocks): for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel input_channel = output_channel
output_channel = block_channels[i] output_channel = block_out_channels[i]
is_final_block = i == len(block_channels) - 1 is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block( down_block = get_down_block(
down_block_type, down_block_type,
num_layers=num_res_blocks, num_layers=layers_per_block,
in_channels=input_channel, in_channels=input_channel,
out_channels=output_channel, out_channels=output_channel,
temb_channels=time_embed_dim, temb_channels=time_embed_dim,
add_downsample=not is_final_block, add_downsample=not is_final_block,
resnet_eps=resnet_eps, resnet_eps=norm_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=act_fn,
attn_num_head_channels=num_head_channels, attn_num_head_channels=attention_head_dim,
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
) )
self.downsample_blocks.append(down_block) self.down_blocks.append(down_block)
# mid # mid
self.mid = UNetMidBlock2DCrossAttn( self.mid_block = UNetMidBlock2DCrossAttn(
in_channels=block_channels[-1], in_channels=block_out_channels[-1],
dropout=dropout,
temb_channels=time_embed_dim, temb_channels=time_embed_dim,
resnet_eps=resnet_eps, resnet_eps=norm_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor, output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift="default", resnet_time_scale_shift="default",
attn_num_head_channels=num_head_channels, attn_num_head_channels=attention_head_dim,
resnet_groups=resnet_num_groups, resnet_groups=norm_num_groups,
) )
# up # up
reversed_block_channels = list(reversed(block_channels)) reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_channels[0] output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_blocks): for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel prev_output_channel = output_channel
output_channel = reversed_block_channels[i] output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_channels[min(i + 1, len(block_channels) - 1)] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
is_final_block = i == len(block_channels) - 1 is_final_block = i == len(block_out_channels) - 1
up_block = get_up_block( up_block = get_up_block(
up_block_type, up_block_type,
num_layers=num_res_blocks + 1, num_layers=layers_per_block + 1,
in_channels=input_channel, in_channels=input_channel,
out_channels=output_channel, out_channels=output_channel,
prev_output_channel=prev_output_channel, prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim, temb_channels=time_embed_dim,
add_upsample=not is_final_block, add_upsample=not is_final_block,
resnet_eps=resnet_eps, resnet_eps=norm_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=act_fn,
attn_num_head_channels=num_head_channels, attn_num_head_channels=attention_head_dim,
) )
self.upsample_blocks.append(up_block) self.up_blocks.append(up_block)
prev_output_channel = output_channel prev_output_channel = output_channel
# out # out
self.conv_norm_out = nn.GroupNorm(num_channels=block_channels[0], num_groups=resnet_num_groups, eps=resnet_eps) self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
self.conv_act = nn.SiLU() self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_channels[0], out_channels, 3, padding=1) self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
def forward( def forward(
self, self,
...@@ -164,7 +129,7 @@ class UNetConditionalModel(ModelMixin, ConfigMixin): ...@@ -164,7 +129,7 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device) timesteps = timesteps[None].to(sample.device)
t_emb = self.time_steps(timesteps) t_emb = self.time_proj(timesteps)
emb = self.time_embedding(t_emb) emb = self.time_embedding(t_emb)
# 2. pre-process # 2. pre-process
...@@ -172,7 +137,7 @@ class UNetConditionalModel(ModelMixin, ConfigMixin): ...@@ -172,7 +137,7 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
# 3. down # 3. down
down_block_res_samples = (sample,) down_block_res_samples = (sample,)
for downsample_block in self.downsample_blocks: for downsample_block in self.down_blocks:
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None: if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
sample, res_samples = downsample_block( sample, res_samples = downsample_block(
...@@ -184,10 +149,10 @@ class UNetConditionalModel(ModelMixin, ConfigMixin): ...@@ -184,10 +149,10 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
down_block_res_samples += res_samples down_block_res_samples += res_samples
# 4. mid # 4. mid
sample = self.mid(sample, emb, encoder_hidden_states=encoder_hidden_states) sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
# 5. up # 5. up
for upsample_block in self.upsample_blocks: for upsample_block in self.up_blocks:
res_samples = down_block_res_samples[-len(upsample_block.resnets) :] res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
......
...@@ -33,8 +33,9 @@ def get_down_block( ...@@ -33,8 +33,9 @@ def get_down_block(
attn_num_head_channels, attn_num_head_channels,
downsample_padding=None, downsample_padding=None,
): ):
if down_block_type == "UNetResDownBlock2D": down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
return UNetResDownBlock2D( if down_block_type == "DownBlock2D":
return DownBlock2D(
num_layers=num_layers, num_layers=num_layers,
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
...@@ -44,8 +45,8 @@ def get_down_block( ...@@ -44,8 +45,8 @@ def get_down_block(
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
) )
elif down_block_type == "UNetResAttnDownBlock2D": elif down_block_type == "AttnDownBlock2D":
return UNetResAttnDownBlock2D( return AttnDownBlock2D(
num_layers=num_layers, num_layers=num_layers,
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
...@@ -56,8 +57,8 @@ def get_down_block( ...@@ -56,8 +57,8 @@ def get_down_block(
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels, attn_num_head_channels=attn_num_head_channels,
) )
elif down_block_type == "UNetResCrossAttnDownBlock2D": elif down_block_type == "CrossAttnDownBlock2D":
return UNetResCrossAttnDownBlock2D( return CrossAttnDownBlock2D(
num_layers=num_layers, num_layers=num_layers,
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
...@@ -68,8 +69,8 @@ def get_down_block( ...@@ -68,8 +69,8 @@ def get_down_block(
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels, attn_num_head_channels=attn_num_head_channels,
) )
elif down_block_type == "UNetResSkipDownBlock2D": elif down_block_type == "SkipDownBlock2D":
return UNetResSkipDownBlock2D( return SkipDownBlock2D(
num_layers=num_layers, num_layers=num_layers,
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
...@@ -79,8 +80,8 @@ def get_down_block( ...@@ -79,8 +80,8 @@ def get_down_block(
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
) )
elif down_block_type == "UNetResAttnSkipDownBlock2D": elif down_block_type == "AttnSkipDownBlock2D":
return UNetResAttnSkipDownBlock2D( return AttnSkipDownBlock2D(
num_layers=num_layers, num_layers=num_layers,
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
...@@ -105,8 +106,9 @@ def get_up_block( ...@@ -105,8 +106,9 @@ def get_up_block(
resnet_act_fn, resnet_act_fn,
attn_num_head_channels, attn_num_head_channels,
): ):
if up_block_type == "UNetResUpBlock2D": up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
return UNetResUpBlock2D( if up_block_type == "UpBlock2D":
return UpBlock2D(
num_layers=num_layers, num_layers=num_layers,
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
...@@ -116,8 +118,8 @@ def get_up_block( ...@@ -116,8 +118,8 @@ def get_up_block(
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
) )
elif up_block_type == "UNetResCrossAttnUpBlock2D": elif up_block_type == "CrossAttnUpBlock2D":
return UNetResCrossAttnUpBlock2D( return CrossAttnUpBlock2D(
num_layers=num_layers, num_layers=num_layers,
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
...@@ -128,8 +130,8 @@ def get_up_block( ...@@ -128,8 +130,8 @@ def get_up_block(
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
attn_num_head_channels=attn_num_head_channels, attn_num_head_channels=attn_num_head_channels,
) )
elif up_block_type == "UNetResAttnUpBlock2D": elif up_block_type == "AttnUpBlock2D":
return UNetResAttnUpBlock2D( return AttnUpBlock2D(
num_layers=num_layers, num_layers=num_layers,
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
...@@ -140,8 +142,8 @@ def get_up_block( ...@@ -140,8 +142,8 @@ def get_up_block(
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
attn_num_head_channels=attn_num_head_channels, attn_num_head_channels=attn_num_head_channels,
) )
elif up_block_type == "UNetResSkipUpBlock2D": elif up_block_type == "SkipUpBlock2D":
return UNetResSkipUpBlock2D( return SkipUpBlock2D(
num_layers=num_layers, num_layers=num_layers,
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
...@@ -151,8 +153,8 @@ def get_up_block( ...@@ -151,8 +153,8 @@ def get_up_block(
resnet_eps=resnet_eps, resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
) )
elif up_block_type == "UNetResAttnSkipUpBlock2D": elif up_block_type == "AttnSkipUpBlock2D":
return UNetResAttnSkipUpBlock2D( return AttnSkipUpBlock2D(
num_layers=num_layers, num_layers=num_layers,
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
...@@ -322,7 +324,7 @@ class UNetMidBlock2DCrossAttn(nn.Module): ...@@ -322,7 +324,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
return hidden_states return hidden_states
class UNetResAttnDownBlock2D(nn.Module): class AttnDownBlock2D(nn.Module):
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
...@@ -403,7 +405,7 @@ class UNetResAttnDownBlock2D(nn.Module): ...@@ -403,7 +405,7 @@ class UNetResAttnDownBlock2D(nn.Module):
return hidden_states, output_states return hidden_states, output_states
class UNetResCrossAttnDownBlock2D(nn.Module): class CrossAttnDownBlock2D(nn.Module):
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
...@@ -485,7 +487,7 @@ class UNetResCrossAttnDownBlock2D(nn.Module): ...@@ -485,7 +487,7 @@ class UNetResCrossAttnDownBlock2D(nn.Module):
return hidden_states, output_states return hidden_states, output_states
class UNetResDownBlock2D(nn.Module): class DownBlock2D(nn.Module):
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
...@@ -551,7 +553,7 @@ class UNetResDownBlock2D(nn.Module): ...@@ -551,7 +553,7 @@ class UNetResDownBlock2D(nn.Module):
return hidden_states, output_states return hidden_states, output_states
class UNetResAttnSkipDownBlock2D(nn.Module): class AttnSkipDownBlock2D(nn.Module):
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
...@@ -644,7 +646,7 @@ class UNetResAttnSkipDownBlock2D(nn.Module): ...@@ -644,7 +646,7 @@ class UNetResAttnSkipDownBlock2D(nn.Module):
return hidden_states, output_states, skip_sample return hidden_states, output_states, skip_sample
class UNetResSkipDownBlock2D(nn.Module): class SkipDownBlock2D(nn.Module):
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
...@@ -723,7 +725,7 @@ class UNetResSkipDownBlock2D(nn.Module): ...@@ -723,7 +725,7 @@ class UNetResSkipDownBlock2D(nn.Module):
return hidden_states, output_states, skip_sample return hidden_states, output_states, skip_sample
class UNetResAttnUpBlock2D(nn.Module): class AttnUpBlock2D(nn.Module):
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
...@@ -801,7 +803,7 @@ class UNetResAttnUpBlock2D(nn.Module): ...@@ -801,7 +803,7 @@ class UNetResAttnUpBlock2D(nn.Module):
return hidden_states return hidden_states
class UNetResCrossAttnUpBlock2D(nn.Module): class CrossAttnUpBlock2D(nn.Module):
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
...@@ -881,7 +883,7 @@ class UNetResCrossAttnUpBlock2D(nn.Module): ...@@ -881,7 +883,7 @@ class UNetResCrossAttnUpBlock2D(nn.Module):
return hidden_states return hidden_states
class UNetResUpBlock2D(nn.Module): class UpBlock2D(nn.Module):
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
...@@ -944,7 +946,7 @@ class UNetResUpBlock2D(nn.Module): ...@@ -944,7 +946,7 @@ class UNetResUpBlock2D(nn.Module):
return hidden_states return hidden_states
class UNetResAttnSkipUpBlock2D(nn.Module): class AttnSkipUpBlock2D(nn.Module):
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
...@@ -1055,7 +1057,7 @@ class UNetResAttnSkipUpBlock2D(nn.Module): ...@@ -1055,7 +1057,7 @@ class UNetResAttnSkipUpBlock2D(nn.Module):
return hidden_states, skip_sample return hidden_states, skip_sample
class UNetResSkipUpBlock2D(nn.Module): class SkipUpBlock2D(nn.Module):
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
......
...@@ -25,7 +25,7 @@ from .configuration_utils import ConfigMixin ...@@ -25,7 +25,7 @@ from .configuration_utils import ConfigMixin
from .utils import DIFFUSERS_CACHE, logging from .utils import DIFFUSERS_CACHE, logging
INDEX_FILE = "diffusion_model.pt" INDEX_FILE = "diffusion_pytorch_model.bin"
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -28,7 +28,9 @@ class DDIMPipeline(DiffusionPipeline): ...@@ -28,7 +28,9 @@ class DDIMPipeline(DiffusionPipeline):
self.register_modules(unet=unet, scheduler=scheduler) self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad() @torch.no_grad()
def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil"): def __call__(
self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil"
):
# eta corresponds to η in paper and should be between [0, 1] # eta corresponds to η in paper and should be between [0, 1]
if torch_device is None: if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
...@@ -37,7 +39,7 @@ class DDIMPipeline(DiffusionPipeline): ...@@ -37,7 +39,7 @@ class DDIMPipeline(DiffusionPipeline):
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
image = torch.randn( image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size), (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator, generator=generator,
) )
image = image.to(torch_device) image = image.to(torch_device)
......
...@@ -36,7 +36,7 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -36,7 +36,7 @@ class DDPMPipeline(DiffusionPipeline):
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
image = torch.randn( image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size), (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator, generator=generator,
) )
image = image.to(torch_device) image = image.to(torch_device)
......
...@@ -52,7 +52,7 @@ class LatentDiffusionPipeline(DiffusionPipeline): ...@@ -52,7 +52,7 @@ class LatentDiffusionPipeline(DiffusionPipeline):
text_embeddings = self.bert(text_input.input_ids.to(torch_device)) text_embeddings = self.bert(text_input.input_ids.to(torch_device))
latents = torch.randn( latents = torch.randn(
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size), (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator, generator=generator,
) )
latents = latents.to(torch_device) latents = latents.to(torch_device)
......
...@@ -24,7 +24,7 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline): ...@@ -24,7 +24,7 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
self.vqvae.to(torch_device) self.vqvae.to(torch_device)
latents = torch.randn( latents = torch.randn(
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size), (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator, generator=generator,
) )
latents = latents.to(torch_device) latents = latents.to(torch_device)
......
...@@ -38,7 +38,7 @@ class PNDMPipeline(DiffusionPipeline): ...@@ -38,7 +38,7 @@ class PNDMPipeline(DiffusionPipeline):
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
image = torch.randn( image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size), (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator, generator=generator,
) )
image = image.to(torch_device) image = image.to(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment