Unverified Commit cb4b3f0b authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[OmegaConf] replace it with `yaml` (#6488)

* remove omegaconf from convert_from_ckpt.

* remove from single_file.

* change to string based ubscription.

* style

* okay

* fix: vae_param

* no . indexing.

* style

* style

* turn getattrs into explicit if/else

* style

* propagate changes to ldm_uncond.

* propagate to gligen

* propagate to if.

* fix: quotes.

* propagate to audioldm.

* propagate to audioldm2

* propagate to musicldm.

* propagate to vq_diffusion

* propagate to zero123.

* remove omegaconf from diffusers codebase.
parent 3d574b3b
import argparse import argparse
import OmegaConf
import torch import torch
import yaml
from diffusers import DDIMScheduler, LDMPipeline, UNetLDMModel, VQModel from diffusers import DDIMScheduler, LDMPipeline, UNetLDMModel, VQModel
def convert_ldm_original(checkpoint_path, config_path, output_path): def convert_ldm_original(checkpoint_path, config_path, output_path):
config = OmegaConf.load(config_path) config = yaml.safe_load(config_path)
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"] state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
keys = list(state_dict.keys()) keys = list(state_dict.keys())
...@@ -25,8 +25,8 @@ def convert_ldm_original(checkpoint_path, config_path, output_path): ...@@ -25,8 +25,8 @@ def convert_ldm_original(checkpoint_path, config_path, output_path):
if key.startswith(unet_key): if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = state_dict[key] unet_state_dict[key.replace(unet_key, "")] = state_dict[key]
vqvae_init_args = config.model.params.first_stage_config.params vqvae_init_args = config["model"]["params"]["first_stage_config"]["params"]
unet_init_args = config.model.params.unet_config.params unet_init_args = config["model"]["params"]["unet_config"]["params"]
vqvae = VQModel(**vqvae_init_args).eval() vqvae = VQModel(**vqvae_init_args).eval()
vqvae.load_state_dict(first_stage_dict) vqvae.load_state_dict(first_stage_dict)
...@@ -35,10 +35,10 @@ def convert_ldm_original(checkpoint_path, config_path, output_path): ...@@ -35,10 +35,10 @@ def convert_ldm_original(checkpoint_path, config_path, output_path):
unet.load_state_dict(unet_state_dict) unet.load_state_dict(unet_state_dict)
noise_scheduler = DDIMScheduler( noise_scheduler = DDIMScheduler(
timesteps=config.model.params.timesteps, timesteps=config["model"]["params"]["timesteps"],
beta_schedule="scaled_linear", beta_schedule="scaled_linear",
beta_start=config.model.params.linear_start, beta_start=config["model"]["params"]["linear_start"],
beta_end=config.model.params.linear_end, beta_end=config["model"]["params"]["linear_end"],
clip_sample=False, clip_sample=False,
) )
......
...@@ -2,6 +2,7 @@ import argparse ...@@ -2,6 +2,7 @@ import argparse
import re import re
import torch import torch
import yaml
from transformers import ( from transformers import (
CLIPProcessor, CLIPProcessor,
CLIPTextModel, CLIPTextModel,
...@@ -28,8 +29,6 @@ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( ...@@ -28,8 +29,6 @@ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
textenc_conversion_map, textenc_conversion_map,
textenc_pattern, textenc_pattern,
) )
from diffusers.utils import is_omegaconf_available
from diffusers.utils.import_utils import BACKENDS_MAPPING
def convert_open_clip_checkpoint(checkpoint): def convert_open_clip_checkpoint(checkpoint):
...@@ -370,52 +369,52 @@ def convert_gligen_unet_checkpoint(checkpoint, config, path=None, extract_ema=Fa ...@@ -370,52 +369,52 @@ def convert_gligen_unet_checkpoint(checkpoint, config, path=None, extract_ema=Fa
def create_vae_config(original_config, image_size: int): def create_vae_config(original_config, image_size: int):
vae_params = original_config.autoencoder.params.ddconfig vae_params = original_config["autoencoder"]["params"]["ddconfig"]
_ = original_config.autoencoder.params.embed_dim _ = original_config["autoencoder"]["params"]["embed_dim"]
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
config = { config = {
"sample_size": image_size, "sample_size": image_size,
"in_channels": vae_params.in_channels, "in_channels": vae_params["in_channels"],
"out_channels": vae_params.out_ch, "out_channels": vae_params["out_ch"],
"down_block_types": tuple(down_block_types), "down_block_types": tuple(down_block_types),
"up_block_types": tuple(up_block_types), "up_block_types": tuple(up_block_types),
"block_out_channels": tuple(block_out_channels), "block_out_channels": tuple(block_out_channels),
"latent_channels": vae_params.z_channels, "latent_channels": vae_params["z_channels"],
"layers_per_block": vae_params.num_res_blocks, "layers_per_block": vae_params["num_res_blocks"],
} }
return config return config
def create_unet_config(original_config, image_size: int, attention_type): def create_unet_config(original_config, image_size: int, attention_type):
unet_params = original_config.model.params unet_params = original_config["model"]["params"]
vae_params = original_config.autoencoder.params.ddconfig vae_params = original_config["autoencoder"]["params"]["ddconfig"]
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
down_block_types = [] down_block_types = []
resolution = 1 resolution = 1
for i in range(len(block_out_channels)): for i in range(len(block_out_channels)):
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
down_block_types.append(block_type) down_block_types.append(block_type)
if i != len(block_out_channels) - 1: if i != len(block_out_channels) - 1:
resolution *= 2 resolution *= 2
up_block_types = [] up_block_types = []
for i in range(len(block_out_channels)): for i in range(len(block_out_channels)):
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
up_block_types.append(block_type) up_block_types.append(block_type)
resolution //= 2 resolution //= 2
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
head_dim = unet_params.num_heads if "num_heads" in unet_params else None head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None
use_linear_projection = ( use_linear_projection = (
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False
) )
if use_linear_projection: if use_linear_projection:
if head_dim is None: if head_dim is None:
...@@ -423,11 +422,11 @@ def create_unet_config(original_config, image_size: int, attention_type): ...@@ -423,11 +422,11 @@ def create_unet_config(original_config, image_size: int, attention_type):
config = { config = {
"sample_size": image_size // vae_scale_factor, "sample_size": image_size // vae_scale_factor,
"in_channels": unet_params.in_channels, "in_channels": unet_params["in_channels"],
"down_block_types": tuple(down_block_types), "down_block_types": tuple(down_block_types),
"block_out_channels": tuple(block_out_channels), "block_out_channels": tuple(block_out_channels),
"layers_per_block": unet_params.num_res_blocks, "layers_per_block": unet_params["num_res_blocks"],
"cross_attention_dim": unet_params.context_dim, "cross_attention_dim": unet_params["context_dim"],
"attention_head_dim": head_dim, "attention_head_dim": head_dim,
"use_linear_projection": use_linear_projection, "use_linear_projection": use_linear_projection,
"attention_type": attention_type, "attention_type": attention_type,
...@@ -445,11 +444,6 @@ def convert_gligen_to_diffusers( ...@@ -445,11 +444,6 @@ def convert_gligen_to_diffusers(
num_in_channels: int = None, num_in_channels: int = None,
device: str = None, device: str = None,
): ):
if not is_omegaconf_available():
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
from omegaconf import OmegaConf
if device is None: if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(checkpoint_path, map_location=device) checkpoint = torch.load(checkpoint_path, map_location=device)
...@@ -461,14 +455,14 @@ def convert_gligen_to_diffusers( ...@@ -461,14 +455,14 @@ def convert_gligen_to_diffusers(
else: else:
print("global_step key not found in model") print("global_step key not found in model")
original_config = OmegaConf.load(original_config_file) original_config = yaml.safe_load(original_config_file)
if num_in_channels is not None: if num_in_channels is not None:
original_config["model"]["params"]["in_channels"] = num_in_channels original_config["model"]["params"]["in_channels"] = num_in_channels
num_train_timesteps = original_config.diffusion.params.timesteps num_train_timesteps = original_config["diffusion"]["params"]["timesteps"]
beta_start = original_config.diffusion.params.linear_start beta_start = original_config["diffusion"]["params"]["linear_start"]
beta_end = original_config.diffusion.params.linear_end beta_end = original_config["diffusion"]["params"]["linear_end"]
scheduler = DDIMScheduler( scheduler = DDIMScheduler(
beta_end=beta_end, beta_end=beta_end,
......
...@@ -4,6 +4,7 @@ import os ...@@ -4,6 +4,7 @@ import os
import numpy as np import numpy as np
import torch import torch
import yaml
from torch.nn import functional as F from torch.nn import functional as F
from transformers import CLIPConfig, CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5Tokenizer from transformers import CLIPConfig, CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5Tokenizer
...@@ -11,14 +12,6 @@ from diffusers import DDPMScheduler, IFPipeline, IFSuperResolutionPipeline, UNet ...@@ -11,14 +12,6 @@ from diffusers import DDPMScheduler, IFPipeline, IFSuperResolutionPipeline, UNet
from diffusers.pipelines.deepfloyd_if.safety_checker import IFSafetyChecker from diffusers.pipelines.deepfloyd_if.safety_checker import IFSafetyChecker
try:
from omegaconf import OmegaConf
except ImportError:
raise ImportError(
"OmegaConf is required to convert the IF checkpoints. Please install it with `pip install" " OmegaConf`."
)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -143,8 +136,8 @@ def convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safet ...@@ -143,8 +136,8 @@ def convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safet
def get_stage_1_unet(unet_config, unet_checkpoint_path): def get_stage_1_unet(unet_config, unet_checkpoint_path):
original_unet_config = OmegaConf.load(unet_config) original_unet_config = yaml.safe_load(unet_config)
original_unet_config = original_unet_config.params original_unet_config = original_unet_config["params"]
unet_diffusers_config = create_unet_diffusers_config(original_unet_config) unet_diffusers_config = create_unet_diffusers_config(original_unet_config)
...@@ -215,11 +208,11 @@ def convert_safety_checker(p_head_path, w_head_path): ...@@ -215,11 +208,11 @@ def convert_safety_checker(p_head_path, w_head_path):
def create_unet_diffusers_config(original_unet_config, class_embed_type=None): def create_unet_diffusers_config(original_unet_config, class_embed_type=None):
attention_resolutions = parse_list(original_unet_config.attention_resolutions) attention_resolutions = parse_list(original_unet_config["attention_resolutions"])
attention_resolutions = [original_unet_config.image_size // int(res) for res in attention_resolutions] attention_resolutions = [original_unet_config["image_size"] // int(res) for res in attention_resolutions]
channel_mult = parse_list(original_unet_config.channel_mult) channel_mult = parse_list(original_unet_config["channel_mult"])
block_out_channels = [original_unet_config.model_channels * mult for mult in channel_mult] block_out_channels = [original_unet_config["model_channels"] * mult for mult in channel_mult]
down_block_types = [] down_block_types = []
resolution = 1 resolution = 1
...@@ -227,7 +220,7 @@ def create_unet_diffusers_config(original_unet_config, class_embed_type=None): ...@@ -227,7 +220,7 @@ def create_unet_diffusers_config(original_unet_config, class_embed_type=None):
for i in range(len(block_out_channels)): for i in range(len(block_out_channels)):
if resolution in attention_resolutions: if resolution in attention_resolutions:
block_type = "SimpleCrossAttnDownBlock2D" block_type = "SimpleCrossAttnDownBlock2D"
elif original_unet_config.resblock_updown: elif original_unet_config["resblock_updown"]:
block_type = "ResnetDownsampleBlock2D" block_type = "ResnetDownsampleBlock2D"
else: else:
block_type = "DownBlock2D" block_type = "DownBlock2D"
...@@ -241,17 +234,17 @@ def create_unet_diffusers_config(original_unet_config, class_embed_type=None): ...@@ -241,17 +234,17 @@ def create_unet_diffusers_config(original_unet_config, class_embed_type=None):
for i in range(len(block_out_channels)): for i in range(len(block_out_channels)):
if resolution in attention_resolutions: if resolution in attention_resolutions:
block_type = "SimpleCrossAttnUpBlock2D" block_type = "SimpleCrossAttnUpBlock2D"
elif original_unet_config.resblock_updown: elif original_unet_config["resblock_updown"]:
block_type = "ResnetUpsampleBlock2D" block_type = "ResnetUpsampleBlock2D"
else: else:
block_type = "UpBlock2D" block_type = "UpBlock2D"
up_block_types.append(block_type) up_block_types.append(block_type)
resolution //= 2 resolution //= 2
head_dim = original_unet_config.num_head_channels head_dim = original_unet_config["num_head_channels"]
use_linear_projection = ( use_linear_projection = (
original_unet_config.use_linear_in_transformer original_unet_config["use_linear_in_transformer"]
if "use_linear_in_transformer" in original_unet_config if "use_linear_in_transformer" in original_unet_config
else False else False
) )
...@@ -264,27 +257,27 @@ def create_unet_diffusers_config(original_unet_config, class_embed_type=None): ...@@ -264,27 +257,27 @@ def create_unet_diffusers_config(original_unet_config, class_embed_type=None):
if class_embed_type is None: if class_embed_type is None:
if "num_classes" in original_unet_config: if "num_classes" in original_unet_config:
if original_unet_config.num_classes == "sequential": if original_unet_config["num_classes"] == "sequential":
class_embed_type = "projection" class_embed_type = "projection"
assert "adm_in_channels" in original_unet_config assert "adm_in_channels" in original_unet_config
projection_class_embeddings_input_dim = original_unet_config.adm_in_channels projection_class_embeddings_input_dim = original_unet_config["adm_in_channels"]
else: else:
raise NotImplementedError( raise NotImplementedError(
f"Unknown conditional unet num_classes config: {original_unet_config.num_classes}" f"Unknown conditional unet num_classes config: {original_unet_config['num_classes']}"
) )
config = { config = {
"sample_size": original_unet_config.image_size, "sample_size": original_unet_config["image_size"],
"in_channels": original_unet_config.in_channels, "in_channels": original_unet_config["in_channels"],
"down_block_types": tuple(down_block_types), "down_block_types": tuple(down_block_types),
"block_out_channels": tuple(block_out_channels), "block_out_channels": tuple(block_out_channels),
"layers_per_block": original_unet_config.num_res_blocks, "layers_per_block": original_unet_config["num_res_blocks"],
"cross_attention_dim": original_unet_config.encoder_channels, "cross_attention_dim": original_unet_config["encoder_channels"],
"attention_head_dim": head_dim, "attention_head_dim": head_dim,
"use_linear_projection": use_linear_projection, "use_linear_projection": use_linear_projection,
"class_embed_type": class_embed_type, "class_embed_type": class_embed_type,
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
"out_channels": original_unet_config.out_channels, "out_channels": original_unet_config["out_channels"],
"up_block_types": tuple(up_block_types), "up_block_types": tuple(up_block_types),
"upcast_attention": False, # TODO: guessing "upcast_attention": False, # TODO: guessing
"cross_attention_norm": "group_norm", "cross_attention_norm": "group_norm",
...@@ -293,11 +286,11 @@ def create_unet_diffusers_config(original_unet_config, class_embed_type=None): ...@@ -293,11 +286,11 @@ def create_unet_diffusers_config(original_unet_config, class_embed_type=None):
"act_fn": "gelu", "act_fn": "gelu",
} }
if original_unet_config.use_scale_shift_norm: if original_unet_config["use_scale_shift_norm"]:
config["resnet_time_scale_shift"] = "scale_shift" config["resnet_time_scale_shift"] = "scale_shift"
if "encoder_dim" in original_unet_config: if "encoder_dim" in original_unet_config:
config["encoder_hid_dim"] = original_unet_config.encoder_dim config["encoder_hid_dim"] = original_unet_config["encoder_dim"]
return config return config
...@@ -725,15 +718,15 @@ def parse_list(value): ...@@ -725,15 +718,15 @@ def parse_list(value):
def get_super_res_unet(unet_checkpoint_path, verify_param_count=True, sample_size=None): def get_super_res_unet(unet_checkpoint_path, verify_param_count=True, sample_size=None):
orig_path = unet_checkpoint_path orig_path = unet_checkpoint_path
original_unet_config = OmegaConf.load(os.path.join(orig_path, "config.yml")) original_unet_config = yaml.safe_load(os.path.join(orig_path, "config.yml"))
original_unet_config = original_unet_config.params original_unet_config = original_unet_config["params"]
unet_diffusers_config = superres_create_unet_diffusers_config(original_unet_config) unet_diffusers_config = superres_create_unet_diffusers_config(original_unet_config)
unet_diffusers_config["time_embedding_dim"] = original_unet_config.model_channels * int( unet_diffusers_config["time_embedding_dim"] = original_unet_config["model_channels"] * int(
original_unet_config.channel_mult.split(",")[-1] original_unet_config["channel_mult"].split(",")[-1]
) )
if original_unet_config.encoder_dim != original_unet_config.encoder_channels: if original_unet_config["encoder_dim"] != original_unet_config["encoder_channels"]:
unet_diffusers_config["encoder_hid_dim"] = original_unet_config.encoder_dim unet_diffusers_config["encoder_hid_dim"] = original_unet_config["encoder_dim"]
unet_diffusers_config["class_embed_type"] = "timestep" unet_diffusers_config["class_embed_type"] = "timestep"
unet_diffusers_config["addition_embed_type"] = "text" unet_diffusers_config["addition_embed_type"] = "text"
...@@ -742,16 +735,16 @@ def get_super_res_unet(unet_checkpoint_path, verify_param_count=True, sample_siz ...@@ -742,16 +735,16 @@ def get_super_res_unet(unet_checkpoint_path, verify_param_count=True, sample_siz
unet_diffusers_config["resnet_out_scale_factor"] = 1 / 0.7071 unet_diffusers_config["resnet_out_scale_factor"] = 1 / 0.7071
unet_diffusers_config["mid_block_scale_factor"] = 1 / 0.7071 unet_diffusers_config["mid_block_scale_factor"] = 1 / 0.7071
unet_diffusers_config["only_cross_attention"] = ( unet_diffusers_config["only_cross_attention"] = (
bool(original_unet_config.disable_self_attentions) bool(original_unet_config["disable_self_attentions"])
if ( if (
"disable_self_attentions" in original_unet_config "disable_self_attentions" in original_unet_config
and isinstance(original_unet_config.disable_self_attentions, int) and isinstance(original_unet_config["disable_self_attentions"], int)
) )
else True else True
) )
if sample_size is None: if sample_size is None:
unet_diffusers_config["sample_size"] = original_unet_config.image_size unet_diffusers_config["sample_size"] = original_unet_config["image_size"]
else: else:
# The second upscaler unet's sample size is incorrectly specified # The second upscaler unet's sample size is incorrectly specified
# in the config and is instead hardcoded in source # in the config and is instead hardcoded in source
...@@ -783,11 +776,11 @@ def get_super_res_unet(unet_checkpoint_path, verify_param_count=True, sample_siz ...@@ -783,11 +776,11 @@ def get_super_res_unet(unet_checkpoint_path, verify_param_count=True, sample_siz
def superres_create_unet_diffusers_config(original_unet_config): def superres_create_unet_diffusers_config(original_unet_config):
attention_resolutions = parse_list(original_unet_config.attention_resolutions) attention_resolutions = parse_list(original_unet_config["attention_resolutions"])
attention_resolutions = [original_unet_config.image_size // int(res) for res in attention_resolutions] attention_resolutions = [original_unet_config["image_size"] // int(res) for res in attention_resolutions]
channel_mult = parse_list(original_unet_config.channel_mult) channel_mult = parse_list(original_unet_config["channel_mult"])
block_out_channels = [original_unet_config.model_channels * mult for mult in channel_mult] block_out_channels = [original_unet_config["model_channels"] * mult for mult in channel_mult]
down_block_types = [] down_block_types = []
resolution = 1 resolution = 1
...@@ -795,7 +788,7 @@ def superres_create_unet_diffusers_config(original_unet_config): ...@@ -795,7 +788,7 @@ def superres_create_unet_diffusers_config(original_unet_config):
for i in range(len(block_out_channels)): for i in range(len(block_out_channels)):
if resolution in attention_resolutions: if resolution in attention_resolutions:
block_type = "SimpleCrossAttnDownBlock2D" block_type = "SimpleCrossAttnDownBlock2D"
elif original_unet_config.resblock_updown: elif original_unet_config["resblock_updown"]:
block_type = "ResnetDownsampleBlock2D" block_type = "ResnetDownsampleBlock2D"
else: else:
block_type = "DownBlock2D" block_type = "DownBlock2D"
...@@ -809,16 +802,16 @@ def superres_create_unet_diffusers_config(original_unet_config): ...@@ -809,16 +802,16 @@ def superres_create_unet_diffusers_config(original_unet_config):
for i in range(len(block_out_channels)): for i in range(len(block_out_channels)):
if resolution in attention_resolutions: if resolution in attention_resolutions:
block_type = "SimpleCrossAttnUpBlock2D" block_type = "SimpleCrossAttnUpBlock2D"
elif original_unet_config.resblock_updown: elif original_unet_config["resblock_updown"]:
block_type = "ResnetUpsampleBlock2D" block_type = "ResnetUpsampleBlock2D"
else: else:
block_type = "UpBlock2D" block_type = "UpBlock2D"
up_block_types.append(block_type) up_block_types.append(block_type)
resolution //= 2 resolution //= 2
head_dim = original_unet_config.num_head_channels head_dim = original_unet_config["num_head_channels"]
use_linear_projection = ( use_linear_projection = (
original_unet_config.use_linear_in_transformer original_unet_config["use_linear_in_transformer"]
if "use_linear_in_transformer" in original_unet_config if "use_linear_in_transformer" in original_unet_config
else False else False
) )
...@@ -831,26 +824,26 @@ def superres_create_unet_diffusers_config(original_unet_config): ...@@ -831,26 +824,26 @@ def superres_create_unet_diffusers_config(original_unet_config):
projection_class_embeddings_input_dim = None projection_class_embeddings_input_dim = None
if "num_classes" in original_unet_config: if "num_classes" in original_unet_config:
if original_unet_config.num_classes == "sequential": if original_unet_config["num_classes"] == "sequential":
class_embed_type = "projection" class_embed_type = "projection"
assert "adm_in_channels" in original_unet_config assert "adm_in_channels" in original_unet_config
projection_class_embeddings_input_dim = original_unet_config.adm_in_channels projection_class_embeddings_input_dim = original_unet_config["adm_in_channels"]
else: else:
raise NotImplementedError( raise NotImplementedError(
f"Unknown conditional unet num_classes config: {original_unet_config.num_classes}" f"Unknown conditional unet num_classes config: {original_unet_config['num_classes']}"
) )
config = { config = {
"in_channels": original_unet_config.in_channels, "in_channels": original_unet_config["in_channels"],
"down_block_types": tuple(down_block_types), "down_block_types": tuple(down_block_types),
"block_out_channels": tuple(block_out_channels), "block_out_channels": tuple(block_out_channels),
"layers_per_block": tuple(original_unet_config.num_res_blocks), "layers_per_block": tuple(original_unet_config["num_res_blocks"]),
"cross_attention_dim": original_unet_config.encoder_channels, "cross_attention_dim": original_unet_config["encoder_channels"],
"attention_head_dim": head_dim, "attention_head_dim": head_dim,
"use_linear_projection": use_linear_projection, "use_linear_projection": use_linear_projection,
"class_embed_type": class_embed_type, "class_embed_type": class_embed_type,
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
"out_channels": original_unet_config.out_channels, "out_channels": original_unet_config["out_channels"],
"up_block_types": tuple(up_block_types), "up_block_types": tuple(up_block_types),
"upcast_attention": False, # TODO: guessing "upcast_attention": False, # TODO: guessing
"cross_attention_norm": "group_norm", "cross_attention_norm": "group_norm",
...@@ -858,7 +851,7 @@ def superres_create_unet_diffusers_config(original_unet_config): ...@@ -858,7 +851,7 @@ def superres_create_unet_diffusers_config(original_unet_config):
"act_fn": "gelu", "act_fn": "gelu",
} }
if original_unet_config.use_scale_shift_norm: if original_unet_config["use_scale_shift_norm"]:
config["resnet_time_scale_shift"] = "scale_shift" config["resnet_time_scale_shift"] = "scale_shift"
return config return config
......
...@@ -19,6 +19,7 @@ import re ...@@ -19,6 +19,7 @@ import re
from typing import List, Union from typing import List, Union
import torch import torch
import yaml
from transformers import ( from transformers import (
AutoFeatureExtractor, AutoFeatureExtractor,
AutoTokenizer, AutoTokenizer,
...@@ -45,7 +46,7 @@ from diffusers import ( ...@@ -45,7 +46,7 @@ from diffusers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from diffusers.utils import is_omegaconf_available, is_safetensors_available from diffusers.utils import is_safetensors_available
from diffusers.utils.import_utils import BACKENDS_MAPPING from diffusers.utils.import_utils import BACKENDS_MAPPING
...@@ -212,41 +213,41 @@ def create_unet_diffusers_config(original_config, image_size: int): ...@@ -212,41 +213,41 @@ def create_unet_diffusers_config(original_config, image_size: int):
""" """
Creates a UNet config for diffusers based on the config of the original AudioLDM2 model. Creates a UNet config for diffusers based on the config of the original AudioLDM2 model.
""" """
unet_params = original_config.model.params.unet_config.params unet_params = original_config["model"]["params"]["unet_config"]["params"]
vae_params = original_config.model.params.first_stage_config.params.ddconfig vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
down_block_types = [] down_block_types = []
resolution = 1 resolution = 1
for i in range(len(block_out_channels)): for i in range(len(block_out_channels)):
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
down_block_types.append(block_type) down_block_types.append(block_type)
if i != len(block_out_channels) - 1: if i != len(block_out_channels) - 1:
resolution *= 2 resolution *= 2
up_block_types = [] up_block_types = []
for i in range(len(block_out_channels)): for i in range(len(block_out_channels)):
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
up_block_types.append(block_type) up_block_types.append(block_type)
resolution //= 2 resolution //= 2
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
cross_attention_dim = list(unet_params.context_dim) if "context_dim" in unet_params else block_out_channels cross_attention_dim = list(unet_params["context_dim"]) if "context_dim" in unet_params else block_out_channels
if len(cross_attention_dim) > 1: if len(cross_attention_dim) > 1:
# require two or more cross-attention layers per-block, each of different dimension # require two or more cross-attention layers per-block, each of different dimension
cross_attention_dim = [cross_attention_dim for _ in range(len(block_out_channels))] cross_attention_dim = [cross_attention_dim for _ in range(len(block_out_channels))]
config = { config = {
"sample_size": image_size // vae_scale_factor, "sample_size": image_size // vae_scale_factor,
"in_channels": unet_params.in_channels, "in_channels": unet_params["in_channels"],
"out_channels": unet_params.out_channels, "out_channels": unet_params["out_channels"],
"down_block_types": tuple(down_block_types), "down_block_types": tuple(down_block_types),
"up_block_types": tuple(up_block_types), "up_block_types": tuple(up_block_types),
"block_out_channels": tuple(block_out_channels), "block_out_channels": tuple(block_out_channels),
"layers_per_block": unet_params.num_res_blocks, "layers_per_block": unet_params["num_res_blocks"],
"transformer_layers_per_block": unet_params.transformer_depth, "transformer_layers_per_block": unet_params["transformer_depth"],
"cross_attention_dim": tuple(cross_attention_dim), "cross_attention_dim": tuple(cross_attention_dim),
} }
...@@ -259,24 +260,24 @@ def create_vae_diffusers_config(original_config, checkpoint, image_size: int): ...@@ -259,24 +260,24 @@ def create_vae_diffusers_config(original_config, checkpoint, image_size: int):
Creates a VAE config for diffusers based on the config of the original AudioLDM2 model. Compared to the original Creates a VAE config for diffusers based on the config of the original AudioLDM2 model. Compared to the original
Stable Diffusion conversion, this function passes a *learnt* VAE scaling factor to the diffusers VAE. Stable Diffusion conversion, this function passes a *learnt* VAE scaling factor to the diffusers VAE.
""" """
vae_params = original_config.model.params.first_stage_config.params.ddconfig vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
_ = original_config.model.params.first_stage_config.params.embed_dim _ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"]
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config.model.params else 0.18215 scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config["model"]["params"] else 0.18215
config = { config = {
"sample_size": image_size, "sample_size": image_size,
"in_channels": vae_params.in_channels, "in_channels": vae_params["in_channels"],
"out_channels": vae_params.out_ch, "out_channels": vae_params["out_ch"],
"down_block_types": tuple(down_block_types), "down_block_types": tuple(down_block_types),
"up_block_types": tuple(up_block_types), "up_block_types": tuple(up_block_types),
"block_out_channels": tuple(block_out_channels), "block_out_channels": tuple(block_out_channels),
"latent_channels": vae_params.z_channels, "latent_channels": vae_params["z_channels"],
"layers_per_block": vae_params.num_res_blocks, "layers_per_block": vae_params["num_res_blocks"],
"scaling_factor": float(scaling_factor), "scaling_factor": float(scaling_factor),
} }
return config return config
...@@ -285,9 +286,9 @@ def create_vae_diffusers_config(original_config, checkpoint, image_size: int): ...@@ -285,9 +286,9 @@ def create_vae_diffusers_config(original_config, checkpoint, image_size: int):
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_diffusers_schedular # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_diffusers_schedular
def create_diffusers_schedular(original_config): def create_diffusers_schedular(original_config):
schedular = DDIMScheduler( schedular = DDIMScheduler(
num_train_timesteps=original_config.model.params.timesteps, num_train_timesteps=original_config["model"]["params"]["timesteps"],
beta_start=original_config.model.params.linear_start, beta_start=original_config["model"]["params"]["linear_start"],
beta_end=original_config.model.params.linear_end, beta_end=original_config["model"]["params"]["linear_end"],
beta_schedule="scaled_linear", beta_schedule="scaled_linear",
) )
return schedular return schedular
...@@ -692,17 +693,17 @@ def create_transformers_vocoder_config(original_config): ...@@ -692,17 +693,17 @@ def create_transformers_vocoder_config(original_config):
""" """
Creates a config for transformers SpeechT5HifiGan based on the config of the vocoder model. Creates a config for transformers SpeechT5HifiGan based on the config of the vocoder model.
""" """
vocoder_params = original_config.model.params.vocoder_config.params vocoder_params = original_config["model"]["params"]["vocoder_config"]["params"]
config = { config = {
"model_in_dim": vocoder_params.num_mels, "model_in_dim": vocoder_params["num_mels"],
"sampling_rate": vocoder_params.sampling_rate, "sampling_rate": vocoder_params["sampling_rate"],
"upsample_initial_channel": vocoder_params.upsample_initial_channel, "upsample_initial_channel": vocoder_params["upsample_initial_channel"],
"upsample_rates": list(vocoder_params.upsample_rates), "upsample_rates": list(vocoder_params["upsample_rates"]),
"upsample_kernel_sizes": list(vocoder_params.upsample_kernel_sizes), "upsample_kernel_sizes": list(vocoder_params["upsample_kernel_sizes"]),
"resblock_kernel_sizes": list(vocoder_params.resblock_kernel_sizes), "resblock_kernel_sizes": list(vocoder_params["resblock_kernel_sizes"]),
"resblock_dilation_sizes": [ "resblock_dilation_sizes": [
list(resblock_dilation) for resblock_dilation in vocoder_params.resblock_dilation_sizes list(resblock_dilation) for resblock_dilation in vocoder_params["resblock_dilation_sizes"]
], ],
"normalize_before": False, "normalize_before": False,
} }
...@@ -876,11 +877,6 @@ def load_pipeline_from_original_AudioLDM2_ckpt( ...@@ -876,11 +877,6 @@ def load_pipeline_from_original_AudioLDM2_ckpt(
return: An AudioLDM2Pipeline object representing the passed-in `.ckpt`/`.safetensors` file. return: An AudioLDM2Pipeline object representing the passed-in `.ckpt`/`.safetensors` file.
""" """
if not is_omegaconf_available():
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
from omegaconf import OmegaConf
if from_safetensors: if from_safetensors:
if not is_safetensors_available(): if not is_safetensors_available():
raise ValueError(BACKENDS_MAPPING["safetensors"][1]) raise ValueError(BACKENDS_MAPPING["safetensors"][1])
...@@ -903,9 +899,8 @@ def load_pipeline_from_original_AudioLDM2_ckpt( ...@@ -903,9 +899,8 @@ def load_pipeline_from_original_AudioLDM2_ckpt(
if original_config_file is None: if original_config_file is None:
original_config = DEFAULT_CONFIG original_config = DEFAULT_CONFIG
original_config = OmegaConf.create(original_config)
else: else:
original_config = OmegaConf.load(original_config_file) original_config = yaml.safe_load(original_config_file)
if image_size is not None: if image_size is not None:
original_config["model"]["params"]["unet_config"]["params"]["image_size"] = image_size original_config["model"]["params"]["unet_config"]["params"]["image_size"] = image_size
...@@ -926,9 +921,9 @@ def load_pipeline_from_original_AudioLDM2_ckpt( ...@@ -926,9 +921,9 @@ def load_pipeline_from_original_AudioLDM2_ckpt(
if prediction_type is None: if prediction_type is None:
prediction_type = "epsilon" prediction_type = "epsilon"
num_train_timesteps = original_config.model.params.timesteps num_train_timesteps = original_config["model"]["params"]["timesteps"]
beta_start = original_config.model.params.linear_start beta_start = original_config["model"]["params"]["linear_start"]
beta_end = original_config.model.params.linear_end beta_end = original_config["model"]["params"]["linear_end"]
scheduler = DDIMScheduler( scheduler = DDIMScheduler(
beta_end=beta_end, beta_end=beta_end,
...@@ -1026,9 +1021,9 @@ def load_pipeline_from_original_AudioLDM2_ckpt( ...@@ -1026,9 +1021,9 @@ def load_pipeline_from_original_AudioLDM2_ckpt(
# Convert the GPT2 encoder model: AudioLDM2 uses the same configuration as the original GPT2 base model # Convert the GPT2 encoder model: AudioLDM2 uses the same configuration as the original GPT2 base model
gpt2_config = GPT2Config.from_pretrained("gpt2") gpt2_config = GPT2Config.from_pretrained("gpt2")
gpt2_model = GPT2Model(gpt2_config) gpt2_model = GPT2Model(gpt2_config)
gpt2_model.config.max_new_tokens = ( gpt2_model.config.max_new_tokens = original_config["model"]["params"]["cond_stage_config"][
original_config.model.params.cond_stage_config.crossattn_audiomae_generated.params.sequence_gen_length "crossattn_audiomae_generated"
) ]["params"]["sequence_gen_length"]
converted_gpt2_checkpoint = extract_sub_model(checkpoint, key_prefix="cond_stage_models.0.model.") converted_gpt2_checkpoint = extract_sub_model(checkpoint, key_prefix="cond_stage_models.0.model.")
gpt2_model.load_state_dict(converted_gpt2_checkpoint) gpt2_model.load_state_dict(converted_gpt2_checkpoint)
......
...@@ -18,6 +18,7 @@ import argparse ...@@ -18,6 +18,7 @@ import argparse
import re import re
import torch import torch
import yaml
from transformers import ( from transformers import (
AutoTokenizer, AutoTokenizer,
ClapTextConfig, ClapTextConfig,
...@@ -38,8 +39,6 @@ from diffusers import ( ...@@ -38,8 +39,6 @@ from diffusers import (
PNDMScheduler, PNDMScheduler,
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.utils import is_omegaconf_available
from diffusers.utils.import_utils import BACKENDS_MAPPING
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments
...@@ -215,45 +214,45 @@ def create_unet_diffusers_config(original_config, image_size: int): ...@@ -215,45 +214,45 @@ def create_unet_diffusers_config(original_config, image_size: int):
""" """
Creates a UNet config for diffusers based on the config of the original AudioLDM model. Creates a UNet config for diffusers based on the config of the original AudioLDM model.
""" """
unet_params = original_config.model.params.unet_config.params unet_params = original_config["model"]["params"]["unet_config"]["params"]
vae_params = original_config.model.params.first_stage_config.params.ddconfig vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
down_block_types = [] down_block_types = []
resolution = 1 resolution = 1
for i in range(len(block_out_channels)): for i in range(len(block_out_channels)):
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
down_block_types.append(block_type) down_block_types.append(block_type)
if i != len(block_out_channels) - 1: if i != len(block_out_channels) - 1:
resolution *= 2 resolution *= 2
up_block_types = [] up_block_types = []
for i in range(len(block_out_channels)): for i in range(len(block_out_channels)):
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
up_block_types.append(block_type) up_block_types.append(block_type)
resolution //= 2 resolution //= 2
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
cross_attention_dim = ( cross_attention_dim = (
unet_params.cross_attention_dim if "cross_attention_dim" in unet_params else block_out_channels unet_params["cross_attention_dim"] if "cross_attention_dim" in unet_params else block_out_channels
) )
class_embed_type = "simple_projection" if "extra_film_condition_dim" in unet_params else None class_embed_type = "simple_projection" if "extra_film_condition_dim" in unet_params else None
projection_class_embeddings_input_dim = ( projection_class_embeddings_input_dim = (
unet_params.extra_film_condition_dim if "extra_film_condition_dim" in unet_params else None unet_params["extra_film_condition_dim"] if "extra_film_condition_dim" in unet_params else None
) )
class_embeddings_concat = unet_params.extra_film_use_concat if "extra_film_use_concat" in unet_params else None class_embeddings_concat = unet_params["extra_film_use_concat"] if "extra_film_use_concat" in unet_params else None
config = { config = {
"sample_size": image_size // vae_scale_factor, "sample_size": image_size // vae_scale_factor,
"in_channels": unet_params.in_channels, "in_channels": unet_params["in_channels"],
"out_channels": unet_params.out_channels, "out_channels": unet_params["out_channels"],
"down_block_types": tuple(down_block_types), "down_block_types": tuple(down_block_types),
"up_block_types": tuple(up_block_types), "up_block_types": tuple(up_block_types),
"block_out_channels": tuple(block_out_channels), "block_out_channels": tuple(block_out_channels),
"layers_per_block": unet_params.num_res_blocks, "layers_per_block": unet_params["num_res_blocks"],
"cross_attention_dim": cross_attention_dim, "cross_attention_dim": cross_attention_dim,
"class_embed_type": class_embed_type, "class_embed_type": class_embed_type,
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
...@@ -269,24 +268,24 @@ def create_vae_diffusers_config(original_config, checkpoint, image_size: int): ...@@ -269,24 +268,24 @@ def create_vae_diffusers_config(original_config, checkpoint, image_size: int):
Creates a VAE config for diffusers based on the config of the original AudioLDM model. Compared to the original Creates a VAE config for diffusers based on the config of the original AudioLDM model. Compared to the original
Stable Diffusion conversion, this function passes a *learnt* VAE scaling factor to the diffusers VAE. Stable Diffusion conversion, this function passes a *learnt* VAE scaling factor to the diffusers VAE.
""" """
vae_params = original_config.model.params.first_stage_config.params.ddconfig vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
_ = original_config.model.params.first_stage_config.params.embed_dim _ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"]
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config.model.params else 0.18215 scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config["model"]["params"] else 0.18215
config = { config = {
"sample_size": image_size, "sample_size": image_size,
"in_channels": vae_params.in_channels, "in_channels": vae_params["in_channels"],
"out_channels": vae_params.out_ch, "out_channels": vae_params["out_ch"],
"down_block_types": tuple(down_block_types), "down_block_types": tuple(down_block_types),
"up_block_types": tuple(up_block_types), "up_block_types": tuple(up_block_types),
"block_out_channels": tuple(block_out_channels), "block_out_channels": tuple(block_out_channels),
"latent_channels": vae_params.z_channels, "latent_channels": vae_params["z_channels"],
"layers_per_block": vae_params.num_res_blocks, "layers_per_block": vae_params["num_res_blocks"],
"scaling_factor": float(scaling_factor), "scaling_factor": float(scaling_factor),
} }
return config return config
...@@ -295,9 +294,9 @@ def create_vae_diffusers_config(original_config, checkpoint, image_size: int): ...@@ -295,9 +294,9 @@ def create_vae_diffusers_config(original_config, checkpoint, image_size: int):
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_diffusers_schedular # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_diffusers_schedular
def create_diffusers_schedular(original_config): def create_diffusers_schedular(original_config):
schedular = DDIMScheduler( schedular = DDIMScheduler(
num_train_timesteps=original_config.model.params.timesteps, num_train_timesteps=original_config["model"]["params"]["timesteps"],
beta_start=original_config.model.params.linear_start, beta_start=original_config["model"]["params"]["linear_start"],
beta_end=original_config.model.params.linear_end, beta_end=original_config["model"]["params"]["linear_end"],
beta_schedule="scaled_linear", beta_schedule="scaled_linear",
) )
return schedular return schedular
...@@ -668,17 +667,17 @@ def create_transformers_vocoder_config(original_config): ...@@ -668,17 +667,17 @@ def create_transformers_vocoder_config(original_config):
""" """
Creates a config for transformers SpeechT5HifiGan based on the config of the vocoder model. Creates a config for transformers SpeechT5HifiGan based on the config of the vocoder model.
""" """
vocoder_params = original_config.model.params.vocoder_config.params vocoder_params = original_config["model"]["params"]["vocoder_config"]["params"]
config = { config = {
"model_in_dim": vocoder_params.num_mels, "model_in_dim": vocoder_params["num_mels"],
"sampling_rate": vocoder_params.sampling_rate, "sampling_rate": vocoder_params["sampling_rate"],
"upsample_initial_channel": vocoder_params.upsample_initial_channel, "upsample_initial_channel": vocoder_params["upsample_initial_channel"],
"upsample_rates": list(vocoder_params.upsample_rates), "upsample_rates": list(vocoder_params["upsample_rates"]),
"upsample_kernel_sizes": list(vocoder_params.upsample_kernel_sizes), "upsample_kernel_sizes": list(vocoder_params["upsample_kernel_sizes"]),
"resblock_kernel_sizes": list(vocoder_params.resblock_kernel_sizes), "resblock_kernel_sizes": list(vocoder_params["resblock_kernel_sizes"]),
"resblock_dilation_sizes": [ "resblock_dilation_sizes": [
list(resblock_dilation) for resblock_dilation in vocoder_params.resblock_dilation_sizes list(resblock_dilation) for resblock_dilation in vocoder_params["resblock_dilation_sizes"]
], ],
"normalize_before": False, "normalize_before": False,
} }
...@@ -818,11 +817,6 @@ def load_pipeline_from_original_audioldm_ckpt( ...@@ -818,11 +817,6 @@ def load_pipeline_from_original_audioldm_ckpt(
return: An AudioLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file. return: An AudioLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
""" """
if not is_omegaconf_available():
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
from omegaconf import OmegaConf
if from_safetensors: if from_safetensors:
from safetensors import safe_open from safetensors import safe_open
...@@ -842,9 +836,8 @@ def load_pipeline_from_original_audioldm_ckpt( ...@@ -842,9 +836,8 @@ def load_pipeline_from_original_audioldm_ckpt(
if original_config_file is None: if original_config_file is None:
original_config = DEFAULT_CONFIG original_config = DEFAULT_CONFIG
original_config = OmegaConf.create(original_config)
else: else:
original_config = OmegaConf.load(original_config_file) original_config = yaml.safe_load(original_config_file)
if num_in_channels is not None: if num_in_channels is not None:
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
...@@ -868,9 +861,9 @@ def load_pipeline_from_original_audioldm_ckpt( ...@@ -868,9 +861,9 @@ def load_pipeline_from_original_audioldm_ckpt(
if image_size is None: if image_size is None:
image_size = 512 image_size = 512
num_train_timesteps = original_config.model.params.timesteps num_train_timesteps = original_config["model"]["params"]["timesteps"]
beta_start = original_config.model.params.linear_start beta_start = original_config["model"]["params"]["linear_start"]
beta_end = original_config.model.params.linear_end beta_end = original_config["model"]["params"]["linear_end"]
scheduler = DDIMScheduler( scheduler = DDIMScheduler(
beta_end=beta_end, beta_end=beta_end,
......
...@@ -18,6 +18,7 @@ import argparse ...@@ -18,6 +18,7 @@ import argparse
import re import re
import torch import torch
import yaml
from transformers import ( from transformers import (
AutoFeatureExtractor, AutoFeatureExtractor,
AutoTokenizer, AutoTokenizer,
...@@ -39,8 +40,6 @@ from diffusers import ( ...@@ -39,8 +40,6 @@ from diffusers import (
PNDMScheduler, PNDMScheduler,
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.utils import is_omegaconf_available
from diffusers.utils.import_utils import BACKENDS_MAPPING
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments
...@@ -212,45 +211,45 @@ def create_unet_diffusers_config(original_config, image_size: int): ...@@ -212,45 +211,45 @@ def create_unet_diffusers_config(original_config, image_size: int):
""" """
Creates a UNet config for diffusers based on the config of the original MusicLDM model. Creates a UNet config for diffusers based on the config of the original MusicLDM model.
""" """
unet_params = original_config.model.params.unet_config.params unet_params = original_config["model"]["params"]["unet_config"]["params"]
vae_params = original_config.model.params.first_stage_config.params.ddconfig vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
down_block_types = [] down_block_types = []
resolution = 1 resolution = 1
for i in range(len(block_out_channels)): for i in range(len(block_out_channels)):
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
down_block_types.append(block_type) down_block_types.append(block_type)
if i != len(block_out_channels) - 1: if i != len(block_out_channels) - 1:
resolution *= 2 resolution *= 2
up_block_types = [] up_block_types = []
for i in range(len(block_out_channels)): for i in range(len(block_out_channels)):
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
up_block_types.append(block_type) up_block_types.append(block_type)
resolution //= 2 resolution //= 2
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
cross_attention_dim = ( cross_attention_dim = (
unet_params.cross_attention_dim if "cross_attention_dim" in unet_params else block_out_channels unet_params["cross_attention_dim"] if "cross_attention_dim" in unet_params else block_out_channels
) )
class_embed_type = "simple_projection" if "extra_film_condition_dim" in unet_params else None class_embed_type = "simple_projection" if "extra_film_condition_dim" in unet_params else None
projection_class_embeddings_input_dim = ( projection_class_embeddings_input_dim = (
unet_params.extra_film_condition_dim if "extra_film_condition_dim" in unet_params else None unet_params["extra_film_condition_dim"] if "extra_film_condition_dim" in unet_params else None
) )
class_embeddings_concat = unet_params.extra_film_use_concat if "extra_film_use_concat" in unet_params else None class_embeddings_concat = unet_params["extra_film_use_concat"] if "extra_film_use_concat" in unet_params else None
config = { config = {
"sample_size": image_size // vae_scale_factor, "sample_size": image_size // vae_scale_factor,
"in_channels": unet_params.in_channels, "in_channels": unet_params["in_channels"],
"out_channels": unet_params.out_channels, "out_channels": unet_params["out_channels"],
"down_block_types": tuple(down_block_types), "down_block_types": tuple(down_block_types),
"up_block_types": tuple(up_block_types), "up_block_types": tuple(up_block_types),
"block_out_channels": tuple(block_out_channels), "block_out_channels": tuple(block_out_channels),
"layers_per_block": unet_params.num_res_blocks, "layers_per_block": unet_params["num_res_blocks"],
"cross_attention_dim": cross_attention_dim, "cross_attention_dim": cross_attention_dim,
"class_embed_type": class_embed_type, "class_embed_type": class_embed_type,
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
...@@ -266,24 +265,24 @@ def create_vae_diffusers_config(original_config, checkpoint, image_size: int): ...@@ -266,24 +265,24 @@ def create_vae_diffusers_config(original_config, checkpoint, image_size: int):
Creates a VAE config for diffusers based on the config of the original MusicLDM model. Compared to the original Creates a VAE config for diffusers based on the config of the original MusicLDM model. Compared to the original
Stable Diffusion conversion, this function passes a *learnt* VAE scaling factor to the diffusers VAE. Stable Diffusion conversion, this function passes a *learnt* VAE scaling factor to the diffusers VAE.
""" """
vae_params = original_config.model.params.first_stage_config.params.ddconfig vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
_ = original_config.model.params.first_stage_config.params.embed_dim _ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"]
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config.model.params else 0.18215 scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config["model"]["params"] else 0.18215
config = { config = {
"sample_size": image_size, "sample_size": image_size,
"in_channels": vae_params.in_channels, "in_channels": vae_params["in_channels"],
"out_channels": vae_params.out_ch, "out_channels": vae_params["out_ch"],
"down_block_types": tuple(down_block_types), "down_block_types": tuple(down_block_types),
"up_block_types": tuple(up_block_types), "up_block_types": tuple(up_block_types),
"block_out_channels": tuple(block_out_channels), "block_out_channels": tuple(block_out_channels),
"latent_channels": vae_params.z_channels, "latent_channels": vae_params["z_channels"],
"layers_per_block": vae_params.num_res_blocks, "layers_per_block": vae_params["num_res_blocks"],
"scaling_factor": float(scaling_factor), "scaling_factor": float(scaling_factor),
} }
return config return config
...@@ -292,9 +291,9 @@ def create_vae_diffusers_config(original_config, checkpoint, image_size: int): ...@@ -292,9 +291,9 @@ def create_vae_diffusers_config(original_config, checkpoint, image_size: int):
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_diffusers_schedular # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_diffusers_schedular
def create_diffusers_schedular(original_config): def create_diffusers_schedular(original_config):
schedular = DDIMScheduler( schedular = DDIMScheduler(
num_train_timesteps=original_config.model.params.timesteps, num_train_timesteps=original_config["model"]["params"]["timesteps"],
beta_start=original_config.model.params.linear_start, beta_start=original_config["model"]["params"]["linear_start"],
beta_end=original_config.model.params.linear_end, beta_end=original_config["model"]["params"]["linear_end"],
beta_schedule="scaled_linear", beta_schedule="scaled_linear",
) )
return schedular return schedular
...@@ -674,17 +673,17 @@ def create_transformers_vocoder_config(original_config): ...@@ -674,17 +673,17 @@ def create_transformers_vocoder_config(original_config):
""" """
Creates a config for transformers SpeechT5HifiGan based on the config of the vocoder model. Creates a config for transformers SpeechT5HifiGan based on the config of the vocoder model.
""" """
vocoder_params = original_config.model.params.vocoder_config.params vocoder_params = original_config["model"]["params"]["vocoder_config"]["params"]
config = { config = {
"model_in_dim": vocoder_params.num_mels, "model_in_dim": vocoder_params["num_mels"],
"sampling_rate": vocoder_params.sampling_rate, "sampling_rate": vocoder_params["sampling_rate"],
"upsample_initial_channel": vocoder_params.upsample_initial_channel, "upsample_initial_channel": vocoder_params["upsample_initial_channel"],
"upsample_rates": list(vocoder_params.upsample_rates), "upsample_rates": list(vocoder_params["upsample_rates"]),
"upsample_kernel_sizes": list(vocoder_params.upsample_kernel_sizes), "upsample_kernel_sizes": list(vocoder_params["upsample_kernel_sizes"]),
"resblock_kernel_sizes": list(vocoder_params.resblock_kernel_sizes), "resblock_kernel_sizes": list(vocoder_params["resblock_kernel_sizes"]),
"resblock_dilation_sizes": [ "resblock_dilation_sizes": [
list(resblock_dilation) for resblock_dilation in vocoder_params.resblock_dilation_sizes list(resblock_dilation) for resblock_dilation in vocoder_params["resblock_dilation_sizes"]
], ],
"normalize_before": False, "normalize_before": False,
} }
...@@ -823,12 +822,6 @@ def load_pipeline_from_original_MusicLDM_ckpt( ...@@ -823,12 +822,6 @@ def load_pipeline_from_original_MusicLDM_ckpt(
If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch. If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
return: An MusicLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file. return: An MusicLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
""" """
if not is_omegaconf_available():
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
from omegaconf import OmegaConf
if from_safetensors: if from_safetensors:
from safetensors import safe_open from safetensors import safe_open
...@@ -848,9 +841,8 @@ def load_pipeline_from_original_MusicLDM_ckpt( ...@@ -848,9 +841,8 @@ def load_pipeline_from_original_MusicLDM_ckpt(
if original_config_file is None: if original_config_file is None:
original_config = DEFAULT_CONFIG original_config = DEFAULT_CONFIG
original_config = OmegaConf.create(original_config)
else: else:
original_config = OmegaConf.load(original_config_file) original_config = yaml.safe_load(original_config_file)
if num_in_channels is not None: if num_in_channels is not None:
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
...@@ -874,9 +866,9 @@ def load_pipeline_from_original_MusicLDM_ckpt( ...@@ -874,9 +866,9 @@ def load_pipeline_from_original_MusicLDM_ckpt(
if image_size is None: if image_size is None:
image_size = 512 image_size = 512
num_train_timesteps = original_config.model.params.timesteps num_train_timesteps = original_config["model"]["params"]["timesteps"]
beta_start = original_config.model.params.linear_start beta_start = original_config["model"]["params"]["linear_start"]
beta_end = original_config.model.params.linear_end beta_end = original_config["model"]["params"]["linear_end"]
scheduler = DDIMScheduler( scheduler = DDIMScheduler(
beta_end=beta_end, beta_end=beta_end,
......
...@@ -3,7 +3,7 @@ import io ...@@ -3,7 +3,7 @@ import io
import requests import requests
import torch import torch
from omegaconf import OmegaConf import yaml
from diffusers import AutoencoderKL from diffusers import AutoencoderKL
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
...@@ -126,7 +126,7 @@ def vae_pt_to_vae_diffuser( ...@@ -126,7 +126,7 @@ def vae_pt_to_vae_diffuser(
) )
io_obj = io.BytesIO(r.content) io_obj = io.BytesIO(r.content)
original_config = OmegaConf.load(io_obj) original_config = yaml.safe_load(io_obj)
image_size = 512 image_size = 512
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
if checkpoint_path.endswith("safetensors"): if checkpoint_path.endswith("safetensors"):
......
...@@ -45,51 +45,45 @@ from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionSchedu ...@@ -45,51 +45,45 @@ from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionSchedu
from diffusers.pipelines.vq_diffusion.pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings from diffusers.pipelines.vq_diffusion.pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings
try:
from omegaconf import OmegaConf
except ImportError:
raise ImportError(
"OmegaConf is required to convert the VQ Diffusion checkpoints. Please install it with `pip install"
" OmegaConf`."
)
# vqvae model # vqvae model
PORTED_VQVAES = ["image_synthesis.modeling.codecs.image_codec.patch_vqgan.PatchVQGAN"] PORTED_VQVAES = ["image_synthesis.modeling.codecs.image_codec.patch_vqgan.PatchVQGAN"]
def vqvae_model_from_original_config(original_config): def vqvae_model_from_original_config(original_config):
assert original_config.target in PORTED_VQVAES, f"{original_config.target} has not yet been ported to diffusers." assert (
original_config["target"] in PORTED_VQVAES
), f"{original_config['target']} has not yet been ported to diffusers."
original_config = original_config.params original_config = original_config["params"]
original_encoder_config = original_config.encoder_config.params original_encoder_config = original_config["encoder_config"]["params"]
original_decoder_config = original_config.decoder_config.params original_decoder_config = original_config["decoder_config"]["params"]
in_channels = original_encoder_config.in_channels in_channels = original_encoder_config["in_channels"]
out_channels = original_decoder_config.out_ch out_channels = original_decoder_config["out_ch"]
down_block_types = get_down_block_types(original_encoder_config) down_block_types = get_down_block_types(original_encoder_config)
up_block_types = get_up_block_types(original_decoder_config) up_block_types = get_up_block_types(original_decoder_config)
assert original_encoder_config.ch == original_decoder_config.ch assert original_encoder_config["ch"] == original_decoder_config["ch"]
assert original_encoder_config.ch_mult == original_decoder_config.ch_mult assert original_encoder_config["ch_mult"] == original_decoder_config["ch_mult"]
block_out_channels = tuple( block_out_channels = tuple(
[original_encoder_config.ch * a_ch_mult for a_ch_mult in original_encoder_config.ch_mult] [original_encoder_config["ch"] * a_ch_mult for a_ch_mult in original_encoder_config["ch_mult"]]
) )
assert original_encoder_config.num_res_blocks == original_decoder_config.num_res_blocks assert original_encoder_config["num_res_blocks"] == original_decoder_config["num_res_blocks"]
layers_per_block = original_encoder_config.num_res_blocks layers_per_block = original_encoder_config["num_res_blocks"]
assert original_encoder_config.z_channels == original_decoder_config.z_channels assert original_encoder_config["z_channels"] == original_decoder_config["z_channels"]
latent_channels = original_encoder_config.z_channels latent_channels = original_encoder_config["z_channels"]
num_vq_embeddings = original_config.n_embed num_vq_embeddings = original_config["n_embed"]
# Hard coded value for ResnetBlock.GoupNorm(num_groups) in VQ-diffusion # Hard coded value for ResnetBlock.GoupNorm(num_groups) in VQ-diffusion
norm_num_groups = 32 norm_num_groups = 32
e_dim = original_config.embed_dim e_dim = original_config["embed_dim"]
model = VQModel( model = VQModel(
in_channels=in_channels, in_channels=in_channels,
...@@ -108,9 +102,9 @@ def vqvae_model_from_original_config(original_config): ...@@ -108,9 +102,9 @@ def vqvae_model_from_original_config(original_config):
def get_down_block_types(original_encoder_config): def get_down_block_types(original_encoder_config):
attn_resolutions = coerce_attn_resolutions(original_encoder_config.attn_resolutions) attn_resolutions = coerce_attn_resolutions(original_encoder_config["attn_resolutions"])
num_resolutions = len(original_encoder_config.ch_mult) num_resolutions = len(original_encoder_config["ch_mult"])
resolution = coerce_resolution(original_encoder_config.resolution) resolution = coerce_resolution(original_encoder_config["resolution"])
curr_res = resolution curr_res = resolution
down_block_types = [] down_block_types = []
...@@ -129,9 +123,9 @@ def get_down_block_types(original_encoder_config): ...@@ -129,9 +123,9 @@ def get_down_block_types(original_encoder_config):
def get_up_block_types(original_decoder_config): def get_up_block_types(original_decoder_config):
attn_resolutions = coerce_attn_resolutions(original_decoder_config.attn_resolutions) attn_resolutions = coerce_attn_resolutions(original_decoder_config["attn_resolutions"])
num_resolutions = len(original_decoder_config.ch_mult) num_resolutions = len(original_decoder_config["ch_mult"])
resolution = coerce_resolution(original_decoder_config.resolution) resolution = coerce_resolution(original_decoder_config["resolution"])
curr_res = [r // 2 ** (num_resolutions - 1) for r in resolution] curr_res = [r // 2 ** (num_resolutions - 1) for r in resolution]
up_block_types = [] up_block_types = []
...@@ -150,7 +144,7 @@ def get_up_block_types(original_decoder_config): ...@@ -150,7 +144,7 @@ def get_up_block_types(original_decoder_config):
def coerce_attn_resolutions(attn_resolutions): def coerce_attn_resolutions(attn_resolutions):
attn_resolutions = OmegaConf.to_object(attn_resolutions) attn_resolutions = list(attn_resolutions)
attn_resolutions_ = [] attn_resolutions_ = []
for ar in attn_resolutions: for ar in attn_resolutions:
if isinstance(ar, (list, tuple)): if isinstance(ar, (list, tuple)):
...@@ -161,7 +155,6 @@ def coerce_attn_resolutions(attn_resolutions): ...@@ -161,7 +155,6 @@ def coerce_attn_resolutions(attn_resolutions):
def coerce_resolution(resolution): def coerce_resolution(resolution):
resolution = OmegaConf.to_object(resolution)
if isinstance(resolution, int): if isinstance(resolution, int):
resolution = [resolution, resolution] # H, W resolution = [resolution, resolution] # H, W
elif isinstance(resolution, (tuple, list)): elif isinstance(resolution, (tuple, list)):
...@@ -472,18 +465,18 @@ def transformer_model_from_original_config( ...@@ -472,18 +465,18 @@ def transformer_model_from_original_config(
original_diffusion_config, original_transformer_config, original_content_embedding_config original_diffusion_config, original_transformer_config, original_content_embedding_config
): ):
assert ( assert (
original_diffusion_config.target in PORTED_DIFFUSIONS original_diffusion_config["target"] in PORTED_DIFFUSIONS
), f"{original_diffusion_config.target} has not yet been ported to diffusers." ), f"{original_diffusion_config['target']} has not yet been ported to diffusers."
assert ( assert (
original_transformer_config.target in PORTED_TRANSFORMERS original_transformer_config["target"] in PORTED_TRANSFORMERS
), f"{original_transformer_config.target} has not yet been ported to diffusers." ), f"{original_transformer_config['target']} has not yet been ported to diffusers."
assert ( assert (
original_content_embedding_config.target in PORTED_CONTENT_EMBEDDINGS original_content_embedding_config["target"] in PORTED_CONTENT_EMBEDDINGS
), f"{original_content_embedding_config.target} has not yet been ported to diffusers." ), f"{original_content_embedding_config['target']} has not yet been ported to diffusers."
original_diffusion_config = original_diffusion_config.params original_diffusion_config = original_diffusion_config["params"]
original_transformer_config = original_transformer_config.params original_transformer_config = original_transformer_config["params"]
original_content_embedding_config = original_content_embedding_config.params original_content_embedding_config = original_content_embedding_config["params"]
inner_dim = original_transformer_config["n_embd"] inner_dim = original_transformer_config["n_embd"]
...@@ -689,13 +682,11 @@ def transformer_feedforward_to_diffusers_checkpoint(checkpoint, *, diffusers_fee ...@@ -689,13 +682,11 @@ def transformer_feedforward_to_diffusers_checkpoint(checkpoint, *, diffusers_fee
def read_config_file(filename): def read_config_file(filename):
# The yaml file contains annotations that certain values should # The yaml file contains annotations that certain values should
# loaded as tuples. By default, OmegaConf will panic when reading # loaded as tuples.
# these. Instead, we can manually read the yaml with the FullLoader and then
# construct the OmegaConf object.
with open(filename) as f: with open(filename) as f:
original_config = yaml.load(f, FullLoader) original_config = yaml.load(f, FullLoader)
return OmegaConf.create(original_config) return original_config
# We take separate arguments for the vqvae because the ITHQ vqvae config file # We take separate arguments for the vqvae because the ITHQ vqvae config file
...@@ -792,9 +783,9 @@ if __name__ == "__main__": ...@@ -792,9 +783,9 @@ if __name__ == "__main__":
original_config = read_config_file(args.original_config_file).model original_config = read_config_file(args.original_config_file).model
diffusion_config = original_config.params.diffusion_config diffusion_config = original_config["params"]["diffusion_config"]
transformer_config = original_config.params.diffusion_config.params.transformer_config transformer_config = original_config["params"]["diffusion_config"]["params"]["transformer_config"]
content_embedding_config = original_config.params.diffusion_config.params.content_emb_config content_embedding_config = original_config["params"]["diffusion_config"]["params"]["content_emb_config"]
pre_checkpoint = torch.load(args.checkpoint_path, map_location=checkpoint_map_location) pre_checkpoint = torch.load(args.checkpoint_path, map_location=checkpoint_map_location)
...@@ -831,7 +822,7 @@ if __name__ == "__main__": ...@@ -831,7 +822,7 @@ if __name__ == "__main__":
# The learned embeddings are stored on the transformer in the original VQ-diffusion. We store them on a separate # The learned embeddings are stored on the transformer in the original VQ-diffusion. We store them on a separate
# model, so we pull them off the checkpoint before the checkpoint is deleted. # model, so we pull them off the checkpoint before the checkpoint is deleted.
learnable_classifier_free_sampling_embeddings = diffusion_config.params.learnable_cf learnable_classifier_free_sampling_embeddings = diffusion_config["params"].learnable_cf
if learnable_classifier_free_sampling_embeddings: if learnable_classifier_free_sampling_embeddings:
learned_classifier_free_sampling_embeddings_embeddings = checkpoint["transformer.empty_text_embed"] learned_classifier_free_sampling_embeddings_embeddings = checkpoint["transformer.empty_text_embed"]
......
...@@ -14,6 +14,7 @@ $ python convert_zero123_to_diffusers.py \ ...@@ -14,6 +14,7 @@ $ python convert_zero123_to_diffusers.py \
import argparse import argparse
import torch import torch
import yaml
from accelerate import init_empty_weights from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device from accelerate.utils import set_module_tensor_to_device
from pipeline_zero1to3 import CCProjection, Zero1to3StableDiffusionPipeline from pipeline_zero1to3 import CCProjection, Zero1to3StableDiffusionPipeline
...@@ -38,51 +39,54 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa ...@@ -38,51 +39,54 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
Creates a config for the diffusers based on the config of the LDM model. Creates a config for the diffusers based on the config of the LDM model.
""" """
if controlnet: if controlnet:
unet_params = original_config.model.params.control_stage_config.params unet_params = original_config["model"]["params"]["control_stage_config"]["params"]
else: else:
if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None: if (
unet_params = original_config.model.params.unet_config.params "unet_config" in original_config["model"]["params"]
and original_config["model"]["params"]["unet_config"] is not None
):
unet_params = original_config["model"]["params"]["unet_config"]["params"]
else: else:
unet_params = original_config.model.params.network_config.params unet_params = original_config["model"]["params"]["network_config"]["params"]
vae_params = original_config.model.params.first_stage_config.params.ddconfig vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
down_block_types = [] down_block_types = []
resolution = 1 resolution = 1
for i in range(len(block_out_channels)): for i in range(len(block_out_channels)):
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
down_block_types.append(block_type) down_block_types.append(block_type)
if i != len(block_out_channels) - 1: if i != len(block_out_channels) - 1:
resolution *= 2 resolution *= 2
up_block_types = [] up_block_types = []
for i in range(len(block_out_channels)): for i in range(len(block_out_channels)):
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
up_block_types.append(block_type) up_block_types.append(block_type)
resolution //= 2 resolution //= 2
if unet_params.transformer_depth is not None: if unet_params["transformer_depth"] is not None:
transformer_layers_per_block = ( transformer_layers_per_block = (
unet_params.transformer_depth unet_params["transformer_depth"]
if isinstance(unet_params.transformer_depth, int) if isinstance(unet_params["transformer_depth"], int)
else list(unet_params.transformer_depth) else list(unet_params["transformer_depth"])
) )
else: else:
transformer_layers_per_block = 1 transformer_layers_per_block = 1
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
head_dim = unet_params.num_heads if "num_heads" in unet_params else None head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None
use_linear_projection = ( use_linear_projection = (
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False
) )
if use_linear_projection: if use_linear_projection:
# stable diffusion 2-base-512 and 2-768 # stable diffusion 2-base-512 and 2-768
if head_dim is None: if head_dim is None:
head_dim_mult = unet_params.model_channels // unet_params.num_head_channels head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"]
head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)] head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])]
class_embed_type = None class_embed_type = None
addition_embed_type = None addition_embed_type = None
...@@ -90,13 +94,15 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa ...@@ -90,13 +94,15 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
projection_class_embeddings_input_dim = None projection_class_embeddings_input_dim = None
context_dim = None context_dim = None
if unet_params.context_dim is not None: if unet_params["context_dim"] is not None:
context_dim = ( context_dim = (
unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0] unet_params["context_dim"]
if isinstance(unet_params["context_dim"], int)
else unet_params["context_dim"][0]
) )
if "num_classes" in unet_params: if "num_classes" in unet_params:
if unet_params.num_classes == "sequential": if unet_params["num_classes"] == "sequential":
if context_dim in [2048, 1280]: if context_dim in [2048, 1280]:
# SDXL # SDXL
addition_embed_type = "text_time" addition_embed_type = "text_time"
...@@ -104,16 +110,16 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa ...@@ -104,16 +110,16 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
else: else:
class_embed_type = "projection" class_embed_type = "projection"
assert "adm_in_channels" in unet_params assert "adm_in_channels" in unet_params
projection_class_embeddings_input_dim = unet_params.adm_in_channels projection_class_embeddings_input_dim = unet_params["adm_in_channels"]
else: else:
raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}") raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params["num_classes"]}")
config = { config = {
"sample_size": image_size // vae_scale_factor, "sample_size": image_size // vae_scale_factor,
"in_channels": unet_params.in_channels, "in_channels": unet_params["in_channels"],
"down_block_types": tuple(down_block_types), "down_block_types": tuple(down_block_types),
"block_out_channels": tuple(block_out_channels), "block_out_channels": tuple(block_out_channels),
"layers_per_block": unet_params.num_res_blocks, "layers_per_block": unet_params["num_res_blocks"],
"cross_attention_dim": context_dim, "cross_attention_dim": context_dim,
"attention_head_dim": head_dim, "attention_head_dim": head_dim,
"use_linear_projection": use_linear_projection, "use_linear_projection": use_linear_projection,
...@@ -125,9 +131,9 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa ...@@ -125,9 +131,9 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
} }
if controlnet: if controlnet:
config["conditioning_channels"] = unet_params.hint_channels config["conditioning_channels"] = unet_params["hint_channels"]
else: else:
config["out_channels"] = unet_params.out_channels config["out_channels"] = unet_params["out_channels"]
config["up_block_types"] = tuple(up_block_types) config["up_block_types"] = tuple(up_block_types)
return config return config
...@@ -487,22 +493,22 @@ def create_vae_diffusers_config(original_config, image_size: int): ...@@ -487,22 +493,22 @@ def create_vae_diffusers_config(original_config, image_size: int):
""" """
Creates a config for the diffusers based on the config of the LDM model. Creates a config for the diffusers based on the config of the LDM model.
""" """
vae_params = original_config.model.params.first_stage_config.params.ddconfig vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
_ = original_config.model.params.first_stage_config.params.embed_dim _ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"]
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
config = { config = {
"sample_size": image_size, "sample_size": image_size,
"in_channels": vae_params.in_channels, "in_channels": vae_params["in_channels"],
"out_channels": vae_params.out_ch, "out_channels": vae_params["out_ch"],
"down_block_types": tuple(down_block_types), "down_block_types": tuple(down_block_types),
"up_block_types": tuple(up_block_types), "up_block_types": tuple(up_block_types),
"block_out_channels": tuple(block_out_channels), "block_out_channels": tuple(block_out_channels),
"latent_channels": vae_params.z_channels, "latent_channels": vae_params["z_channels"],
"layers_per_block": vae_params.num_res_blocks, "layers_per_block": vae_params["num_res_blocks"],
} }
return config return config
...@@ -679,18 +685,16 @@ def convert_from_original_zero123_ckpt(checkpoint_path, original_config_file, ex ...@@ -679,18 +685,16 @@ def convert_from_original_zero123_ckpt(checkpoint_path, original_config_file, ex
del ckpt del ckpt
torch.cuda.empty_cache() torch.cuda.empty_cache()
from omegaconf import OmegaConf original_config = yaml.safe_load(original_config_file)
original_config["model"]["params"]["cond_stage_config"]["target"].split(".")[-1]
original_config = OmegaConf.load(original_config_file)
original_config.model.params.cond_stage_config.target.split(".")[-1]
num_in_channels = 8 num_in_channels = 8
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
prediction_type = "epsilon" prediction_type = "epsilon"
image_size = 256 image_size = 256
num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000 num_train_timesteps = getattr(original_config["model"]["params"], "timesteps", None) or 1000
beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02 beta_start = getattr(original_config["model"]["params"], "linear_start", None) or 0.02
beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085 beta_end = getattr(original_config["model"]["params"], "linear_end", None) or 0.085
scheduler = DDIMScheduler( scheduler = DDIMScheduler(
beta_end=beta_end, beta_end=beta_end,
beta_schedule="scaled_linear", beta_schedule="scaled_linear",
...@@ -721,10 +725,10 @@ def convert_from_original_zero123_ckpt(checkpoint_path, original_config_file, ex ...@@ -721,10 +725,10 @@ def convert_from_original_zero123_ckpt(checkpoint_path, original_config_file, ex
if ( if (
"model" in original_config "model" in original_config
and "params" in original_config.model and "params" in original_config["model"]
and "scale_factor" in original_config.model.params and "scale_factor" in original_config["model"]["params"]
): ):
vae_scaling_factor = original_config.model.params.scale_factor vae_scaling_factor = original_config["model"]["params"]["scale_factor"]
else: else:
vae_scaling_factor = 0.18215 # default SD scaling factor vae_scaling_factor = 0.18215 # default SD scaling factor
......
...@@ -110,7 +110,6 @@ _deps = [ ...@@ -110,7 +110,6 @@ _deps = [
"note_seq", "note_seq",
"librosa", "librosa",
"numpy", "numpy",
"omegaconf",
"parameterized", "parameterized",
"peft>=0.6.0", "peft>=0.6.0",
"protobuf>=3.20.3,<4", "protobuf>=3.20.3,<4",
...@@ -213,7 +212,6 @@ extras["test"] = deps_list( ...@@ -213,7 +212,6 @@ extras["test"] = deps_list(
"invisible-watermark", "invisible-watermark",
"k-diffusion", "k-diffusion",
"librosa", "librosa",
"omegaconf",
"parameterized", "parameterized",
"pytest", "pytest",
"pytest-timeout", "pytest-timeout",
......
...@@ -22,7 +22,6 @@ deps = { ...@@ -22,7 +22,6 @@ deps = {
"note_seq": "note_seq", "note_seq": "note_seq",
"librosa": "librosa", "librosa": "librosa",
"numpy": "numpy", "numpy": "numpy",
"omegaconf": "omegaconf",
"parameterized": "parameterized", "parameterized": "parameterized",
"peft": "peft>=0.6.0", "peft": "peft>=0.6.0",
"protobuf": "protobuf>=3.20.3,<4", "protobuf": "protobuf>=3.20.3,<4",
......
...@@ -17,17 +17,11 @@ from pathlib import Path ...@@ -17,17 +17,11 @@ from pathlib import Path
import requests import requests
import torch import torch
import yaml
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from huggingface_hub.utils import validate_hf_hub_args from huggingface_hub.utils import validate_hf_hub_args
from ..utils import ( from ..utils import deprecate, is_accelerate_available, is_transformers_available, logging
deprecate,
is_accelerate_available,
is_omegaconf_available,
is_transformers_available,
logging,
)
from ..utils.import_utils import BACKENDS_MAPPING
if is_transformers_available(): if is_transformers_available():
...@@ -370,11 +364,6 @@ class FromOriginalVAEMixin: ...@@ -370,11 +364,6 @@ class FromOriginalVAEMixin:
model = AutoencoderKL.from_single_file(url) model = AutoencoderKL.from_single_file(url)
``` ```
""" """
if not is_omegaconf_available():
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
from omegaconf import OmegaConf
from ..models import AutoencoderKL from ..models import AutoencoderKL
# import here to avoid circular dependency # import here to avoid circular dependency
...@@ -452,7 +441,7 @@ class FromOriginalVAEMixin: ...@@ -452,7 +441,7 @@ class FromOriginalVAEMixin:
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
config_file = BytesIO(requests.get(config_url).content) config_file = BytesIO(requests.get(config_url).content)
original_config = OmegaConf.load(config_file) original_config = yaml.safe_load(config_file)
# default to sd-v1-5 # default to sd-v1-5
image_size = image_size or 512 image_size = image_size or 512
...@@ -463,10 +452,10 @@ class FromOriginalVAEMixin: ...@@ -463,10 +452,10 @@ class FromOriginalVAEMixin:
if scaling_factor is None: if scaling_factor is None:
if ( if (
"model" in original_config "model" in original_config
and "params" in original_config.model and "params" in original_config["model"]
and "scale_factor" in original_config.model.params and "scale_factor" in original_config["model"]["params"]
): ):
vae_scaling_factor = original_config.model.params.scale_factor vae_scaling_factor = original_config["model"]["params"]["scale_factor"]
else: else:
vae_scaling_factor = 0.18215 # default SD scaling factor vae_scaling_factor = 0.18215 # default SD scaling factor
......
...@@ -21,6 +21,7 @@ from typing import Dict, Optional, Union ...@@ -21,6 +21,7 @@ from typing import Dict, Optional, Union
import requests import requests
import torch import torch
import yaml
from transformers import ( from transformers import (
AutoFeatureExtractor, AutoFeatureExtractor,
BertTokenizerFast, BertTokenizerFast,
...@@ -50,8 +51,7 @@ from ...schedulers import ( ...@@ -50,8 +51,7 @@ from ...schedulers import (
PNDMScheduler, PNDMScheduler,
UnCLIPScheduler, UnCLIPScheduler,
) )
from ...utils import is_accelerate_available, is_omegaconf_available, logging from ...utils import is_accelerate_available, logging
from ...utils.import_utils import BACKENDS_MAPPING
from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
from ..paint_by_example import PaintByExampleImageEncoder from ..paint_by_example import PaintByExampleImageEncoder
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
...@@ -237,51 +237,54 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa ...@@ -237,51 +237,54 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
Creates a config for the diffusers based on the config of the LDM model. Creates a config for the diffusers based on the config of the LDM model.
""" """
if controlnet: if controlnet:
unet_params = original_config.model.params.control_stage_config.params unet_params = original_config["model"]["params"]["control_stage_config"]["params"]
else: else:
if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None: if (
unet_params = original_config.model.params.unet_config.params "unet_config" in original_config["model"]["params"]
and original_config["model"]["params"]["unet_config"] is not None
):
unet_params = original_config["model"]["params"]["unet_config"]["params"]
else: else:
unet_params = original_config.model.params.network_config.params unet_params = original_config["model"]["params"]["network_config"]["params"]
vae_params = original_config.model.params.first_stage_config.params.ddconfig vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
down_block_types = [] down_block_types = []
resolution = 1 resolution = 1
for i in range(len(block_out_channels)): for i in range(len(block_out_channels)):
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
down_block_types.append(block_type) down_block_types.append(block_type)
if i != len(block_out_channels) - 1: if i != len(block_out_channels) - 1:
resolution *= 2 resolution *= 2
up_block_types = [] up_block_types = []
for i in range(len(block_out_channels)): for i in range(len(block_out_channels)):
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
up_block_types.append(block_type) up_block_types.append(block_type)
resolution //= 2 resolution //= 2
if unet_params.transformer_depth is not None: if unet_params["transformer_depth"] is not None:
transformer_layers_per_block = ( transformer_layers_per_block = (
unet_params.transformer_depth unet_params["transformer_depth"]
if isinstance(unet_params.transformer_depth, int) if isinstance(unet_params["transformer_depth"], int)
else list(unet_params.transformer_depth) else list(unet_params["transformer_depth"])
) )
else: else:
transformer_layers_per_block = 1 transformer_layers_per_block = 1
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
head_dim = unet_params.num_heads if "num_heads" in unet_params else None head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None
use_linear_projection = ( use_linear_projection = (
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False
) )
if use_linear_projection: if use_linear_projection:
# stable diffusion 2-base-512 and 2-768 # stable diffusion 2-base-512 and 2-768
if head_dim is None: if head_dim is None:
head_dim_mult = unet_params.model_channels // unet_params.num_head_channels head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"]
head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)] head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])]
class_embed_type = None class_embed_type = None
addition_embed_type = None addition_embed_type = None
...@@ -289,13 +292,15 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa ...@@ -289,13 +292,15 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
projection_class_embeddings_input_dim = None projection_class_embeddings_input_dim = None
context_dim = None context_dim = None
if unet_params.context_dim is not None: if unet_params["context_dim"] is not None:
context_dim = ( context_dim = (
unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0] unet_params["context_dim"]
if isinstance(unet_params["context_dim"], int)
else unet_params["context_dim"][0]
) )
if "num_classes" in unet_params: if "num_classes" in unet_params:
if unet_params.num_classes == "sequential": if unet_params["num_classes"] == "sequential":
if context_dim in [2048, 1280]: if context_dim in [2048, 1280]:
# SDXL # SDXL
addition_embed_type = "text_time" addition_embed_type = "text_time"
...@@ -303,14 +308,14 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa ...@@ -303,14 +308,14 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
else: else:
class_embed_type = "projection" class_embed_type = "projection"
assert "adm_in_channels" in unet_params assert "adm_in_channels" in unet_params
projection_class_embeddings_input_dim = unet_params.adm_in_channels projection_class_embeddings_input_dim = unet_params["adm_in_channels"]
config = { config = {
"sample_size": image_size // vae_scale_factor, "sample_size": image_size // vae_scale_factor,
"in_channels": unet_params.in_channels, "in_channels": unet_params["in_channels"],
"down_block_types": tuple(down_block_types), "down_block_types": tuple(down_block_types),
"block_out_channels": tuple(block_out_channels), "block_out_channels": tuple(block_out_channels),
"layers_per_block": unet_params.num_res_blocks, "layers_per_block": unet_params["num_res_blocks"],
"cross_attention_dim": context_dim, "cross_attention_dim": context_dim,
"attention_head_dim": head_dim, "attention_head_dim": head_dim,
"use_linear_projection": use_linear_projection, "use_linear_projection": use_linear_projection,
...@@ -322,15 +327,15 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa ...@@ -322,15 +327,15 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
} }
if "disable_self_attentions" in unet_params: if "disable_self_attentions" in unet_params:
config["only_cross_attention"] = unet_params.disable_self_attentions config["only_cross_attention"] = unet_params["disable_self_attentions"]
if "num_classes" in unet_params and isinstance(unet_params.num_classes, int): if "num_classes" in unet_params and isinstance(unet_params["num_classes"], int):
config["num_class_embeds"] = unet_params.num_classes config["num_class_embeds"] = unet_params["num_classes"]
if controlnet: if controlnet:
config["conditioning_channels"] = unet_params.hint_channels config["conditioning_channels"] = unet_params["hint_channels"]
else: else:
config["out_channels"] = unet_params.out_channels config["out_channels"] = unet_params["out_channels"]
config["up_block_types"] = tuple(up_block_types) config["up_block_types"] = tuple(up_block_types)
return config return config
...@@ -340,38 +345,38 @@ def create_vae_diffusers_config(original_config, image_size: int): ...@@ -340,38 +345,38 @@ def create_vae_diffusers_config(original_config, image_size: int):
""" """
Creates a config for the diffusers based on the config of the LDM model. Creates a config for the diffusers based on the config of the LDM model.
""" """
vae_params = original_config.model.params.first_stage_config.params.ddconfig vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
_ = original_config.model.params.first_stage_config.params.embed_dim _ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"]
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
config = { config = {
"sample_size": image_size, "sample_size": image_size,
"in_channels": vae_params.in_channels, "in_channels": vae_params["in_channels"],
"out_channels": vae_params.out_ch, "out_channels": vae_params["out_ch"],
"down_block_types": tuple(down_block_types), "down_block_types": tuple(down_block_types),
"up_block_types": tuple(up_block_types), "up_block_types": tuple(up_block_types),
"block_out_channels": tuple(block_out_channels), "block_out_channels": tuple(block_out_channels),
"latent_channels": vae_params.z_channels, "latent_channels": vae_params["z_channels"],
"layers_per_block": vae_params.num_res_blocks, "layers_per_block": vae_params["num_res_blocks"],
} }
return config return config
def create_diffusers_schedular(original_config): def create_diffusers_schedular(original_config):
schedular = DDIMScheduler( schedular = DDIMScheduler(
num_train_timesteps=original_config.model.params.timesteps, num_train_timesteps=original_config["model"]["params"]["timesteps"],
beta_start=original_config.model.params.linear_start, beta_start=original_config["model"]["params"]["linear_start"],
beta_end=original_config.model.params.linear_end, beta_end=original_config["model"]["params"]["linear_end"],
beta_schedule="scaled_linear", beta_schedule="scaled_linear",
) )
return schedular return schedular
def create_ldm_bert_config(original_config): def create_ldm_bert_config(original_config):
bert_params = original_config.model.params.cond_stage_config.params bert_params = original_config["model"]["params"]["cond_stage_config"]["params"]
config = LDMBertConfig( config = LDMBertConfig(
d_model=bert_params.n_embed, d_model=bert_params.n_embed,
encoder_layers=bert_params.n_layer, encoder_layers=bert_params.n_layer,
...@@ -1006,9 +1011,9 @@ def stable_unclip_image_encoder(original_config, local_files_only=False): ...@@ -1006,9 +1011,9 @@ def stable_unclip_image_encoder(original_config, local_files_only=False):
encoders. encoders.
""" """
image_embedder_config = original_config.model.params.embedder_config image_embedder_config = original_config["model"]["params"]["embedder_config"]
sd_clip_image_embedder_class = image_embedder_config.target sd_clip_image_embedder_class = image_embedder_config["target"]
sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1] sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1]
if sd_clip_image_embedder_class == "ClipImageEmbedder": if sd_clip_image_embedder_class == "ClipImageEmbedder":
...@@ -1047,8 +1052,8 @@ def stable_unclip_image_noising_components( ...@@ -1047,8 +1052,8 @@ def stable_unclip_image_noising_components(
If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided. If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided.
""" """
noise_aug_config = original_config.model.params.noise_aug_config noise_aug_config = original_config["model"]["params"]["noise_aug_config"]
noise_aug_class = noise_aug_config.target noise_aug_class = noise_aug_config["target"]
noise_aug_class = noise_aug_class.split(".")[-1] noise_aug_class = noise_aug_class.split(".")[-1]
if noise_aug_class == "CLIPEmbeddingNoiseAugmentation": if noise_aug_class == "CLIPEmbeddingNoiseAugmentation":
...@@ -1245,11 +1250,6 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1245,11 +1250,6 @@ def download_from_original_stable_diffusion_ckpt(
if prediction_type == "v-prediction": if prediction_type == "v-prediction":
prediction_type = "v_prediction" prediction_type = "v_prediction"
if not is_omegaconf_available():
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
from omegaconf import OmegaConf
if isinstance(checkpoint_path_or_dict, str): if isinstance(checkpoint_path_or_dict, str):
if from_safetensors: if from_safetensors:
from safetensors.torch import load_file as safe_load from safetensors.torch import load_file as safe_load
...@@ -1318,18 +1318,18 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1318,18 +1318,18 @@ def download_from_original_stable_diffusion_ckpt(
if config_url is not None: if config_url is not None:
original_config_file = BytesIO(requests.get(config_url).content) original_config_file = BytesIO(requests.get(config_url).content)
original_config = OmegaConf.load(original_config_file) original_config = yaml.safe_load(original_config_file)
# Convert the text model. # Convert the text model.
if ( if (
model_type is None model_type is None
and "cond_stage_config" in original_config.model.params and "cond_stage_config" in original_config["model"]["params"]
and original_config.model.params.cond_stage_config is not None and original_config["model"]["params"]["cond_stage_config"] is not None
): ):
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] model_type = original_config["model"]["params"]["cond_stage_config"]["target"].split(".")[-1]
logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}") logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}")
elif model_type is None and original_config.model.params.network_config is not None: elif model_type is None and original_config["model"]["params"]["network_config"] is not None:
if original_config.model.params.network_config.params.context_dim == 2048: if original_config["model"]["params"]["network_config"]["params"]["context_dim"] == 2048:
model_type = "SDXL" model_type = "SDXL"
else: else:
model_type = "SDXL-Refiner" model_type = "SDXL-Refiner"
...@@ -1354,7 +1354,7 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1354,7 +1354,7 @@ def download_from_original_stable_diffusion_ckpt(
elif num_in_channels is None: elif num_in_channels is None:
num_in_channels = 4 num_in_channels = 4
if "unet_config" in original_config.model.params: if "unet_config" in original_config["model"]["params"]:
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
if ( if (
...@@ -1375,13 +1375,16 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1375,13 +1375,16 @@ def download_from_original_stable_diffusion_ckpt(
if image_size is None: if image_size is None:
image_size = 512 image_size = 512
if controlnet is None and "control_stage_config" in original_config.model.params: if controlnet is None and "control_stage_config" in original_config["model"]["params"]:
path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else "" path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
controlnet = convert_controlnet_checkpoint( controlnet = convert_controlnet_checkpoint(
checkpoint, original_config, path, image_size, upcast_attention, extract_ema checkpoint, original_config, path, image_size, upcast_attention, extract_ema
) )
num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000 if "timesteps" in original_config["model"]["params"]:
num_train_timesteps = original_config["model"]["params"]["timesteps"]
else:
num_train_timesteps = 1000
if model_type in ["SDXL", "SDXL-Refiner"]: if model_type in ["SDXL", "SDXL-Refiner"]:
scheduler_dict = { scheduler_dict = {
...@@ -1400,8 +1403,15 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1400,8 +1403,15 @@ def download_from_original_stable_diffusion_ckpt(
scheduler = EulerDiscreteScheduler.from_config(scheduler_dict) scheduler = EulerDiscreteScheduler.from_config(scheduler_dict)
scheduler_type = "euler" scheduler_type = "euler"
else: else:
beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02 if "linear_start" in original_config["model"]["params"]:
beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085 beta_start = original_config["model"]["params"]["linear_start"]
else:
beta_start = 0.02
if "linear_end" in original_config["model"]["params"]:
beta_end = original_config["model"]["params"]["linear_end"]
else:
beta_end = 0.085
scheduler = DDIMScheduler( scheduler = DDIMScheduler(
beta_end=beta_end, beta_end=beta_end,
beta_schedule="scaled_linear", beta_schedule="scaled_linear",
...@@ -1435,7 +1445,7 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1435,7 +1445,7 @@ def download_from_original_stable_diffusion_ckpt(
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
if pipeline_class == StableDiffusionUpscalePipeline: if pipeline_class == StableDiffusionUpscalePipeline:
image_size = original_config.model.params.unet_config.params.image_size image_size = original_config["model"]["params"]["unet_config"]["params"]["image_size"]
# Convert the UNet2DConditionModel model. # Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(original_config, image_size=image_size) unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
...@@ -1464,10 +1474,10 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1464,10 +1474,10 @@ def download_from_original_stable_diffusion_ckpt(
if ( if (
"model" in original_config "model" in original_config
and "params" in original_config.model and "params" in original_config["model"]
and "scale_factor" in original_config.model.params and "scale_factor" in original_config["model"]["params"]
): ):
vae_scaling_factor = original_config.model.params.scale_factor vae_scaling_factor = original_config["model"]["params"]["scale_factor"]
else: else:
vae_scaling_factor = 0.18215 # default SD scaling factor vae_scaling_factor = 0.18215 # default SD scaling factor
...@@ -1803,11 +1813,6 @@ def download_controlnet_from_original_ckpt( ...@@ -1803,11 +1813,6 @@ def download_controlnet_from_original_ckpt(
use_linear_projection: Optional[bool] = None, use_linear_projection: Optional[bool] = None,
cross_attention_dim: Optional[bool] = None, cross_attention_dim: Optional[bool] = None,
) -> DiffusionPipeline: ) -> DiffusionPipeline:
if not is_omegaconf_available():
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
from omegaconf import OmegaConf
if from_safetensors: if from_safetensors:
from safetensors import safe_open from safetensors import safe_open
...@@ -1827,12 +1832,12 @@ def download_controlnet_from_original_ckpt( ...@@ -1827,12 +1832,12 @@ def download_controlnet_from_original_ckpt(
while "state_dict" in checkpoint: while "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"] checkpoint = checkpoint["state_dict"]
original_config = OmegaConf.load(original_config_file) original_config = yaml.safe_load(original_config_file)
if num_in_channels is not None: if num_in_channels is not None:
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
if "control_stage_config" not in original_config.model.params: if "control_stage_config" not in original_config["model"]["params"]:
raise ValueError("`control_stage_config` not present in original config") raise ValueError("`control_stage_config` not present in original config")
controlnet = convert_controlnet_checkpoint( controlnet = convert_controlnet_checkpoint(
......
...@@ -66,7 +66,6 @@ from .import_utils import ( ...@@ -66,7 +66,6 @@ from .import_utils import (
is_k_diffusion_version, is_k_diffusion_version,
is_librosa_available, is_librosa_available,
is_note_seq_available, is_note_seq_available,
is_omegaconf_available,
is_onnx_available, is_onnx_available,
is_peft_available, is_peft_available,
is_scipy_available, is_scipy_available,
......
...@@ -223,12 +223,6 @@ try: ...@@ -223,12 +223,6 @@ try:
except importlib_metadata.PackageNotFoundError: except importlib_metadata.PackageNotFoundError:
_wandb_available = False _wandb_available = False
_omegaconf_available = importlib.util.find_spec("omegaconf") is not None
try:
_omegaconf_version = importlib_metadata.version("omegaconf")
logger.debug(f"Successfully imported omegaconf version {_omegaconf_version}")
except importlib_metadata.PackageNotFoundError:
_omegaconf_available = False
_tensorboard_available = importlib.util.find_spec("tensorboard") _tensorboard_available = importlib.util.find_spec("tensorboard")
try: try:
...@@ -345,10 +339,6 @@ def is_wandb_available(): ...@@ -345,10 +339,6 @@ def is_wandb_available():
return _wandb_available return _wandb_available
def is_omegaconf_available():
return _omegaconf_available
def is_tensorboard_available(): def is_tensorboard_available():
return _tensorboard_available return _tensorboard_available
...@@ -449,12 +439,6 @@ WANDB_IMPORT_ERROR = """ ...@@ -449,12 +439,6 @@ WANDB_IMPORT_ERROR = """
install wandb` install wandb`
""" """
# docstyle-ignore
OMEGACONF_IMPORT_ERROR = """
{0} requires the omegaconf library but it was not found in your environment. You can install it with pip: `pip
install omegaconf`
"""
# docstyle-ignore # docstyle-ignore
TENSORBOARD_IMPORT_ERROR = """ TENSORBOARD_IMPORT_ERROR = """
{0} requires the tensorboard library but it was not found in your environment. You can install it with pip: `pip {0} requires the tensorboard library but it was not found in your environment. You can install it with pip: `pip
...@@ -506,7 +490,6 @@ BACKENDS_MAPPING = OrderedDict( ...@@ -506,7 +490,6 @@ BACKENDS_MAPPING = OrderedDict(
("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)), ("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)),
("note_seq", (is_note_seq_available, NOTE_SEQ_IMPORT_ERROR)), ("note_seq", (is_note_seq_available, NOTE_SEQ_IMPORT_ERROR)),
("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)), ("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)),
("omegaconf", (is_omegaconf_available, OMEGACONF_IMPORT_ERROR)),
("tensorboard", (is_tensorboard_available, TENSORBOARD_IMPORT_ERROR)), ("tensorboard", (is_tensorboard_available, TENSORBOARD_IMPORT_ERROR)),
("compel", (is_compel_available, COMPEL_IMPORT_ERROR)), ("compel", (is_compel_available, COMPEL_IMPORT_ERROR)),
("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)), ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment