Commit 3d1b9667 authored by wangwei990215's avatar wangwei990215
Browse files

更新diffusers文件夹

parent b6a53272
accelerate>=0.16.0
torchvision
transformers>=4.25.1
wandb
bitsandbytes
deepspeed
peft>=0.6.0
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import argparse
import logging
import math
import os
import random
import shutil
from pathlib import Path
import datasets
import numpy as np
import torch
import torch.nn.functional as F
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.state import AcceleratorState, is_initialized
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from huggingface_hub import create_repo, hf_hub_download, upload_folder
from modeling_efficient_net_encoder import EfficientNetEncoder
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
from torchvision import transforms
from tqdm import tqdm
from transformers import CLIPTextModel, PreTrainedTokenizerFast
from transformers.utils import ContextManagers
from diffusers import AutoPipelineForText2Image, DDPMWuerstchenScheduler, WuerstchenPriorPipeline
from diffusers.optimization import get_scheduler
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS, WuerstchenPrior
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
from diffusers.utils.logging import set_verbosity_error, set_verbosity_info
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0")
logger = get_logger(__name__, log_level="INFO")
DATASET_NAME_MAPPING = {
"lambdalabs/pokemon-blip-captions": ("image", "text"),
}
def save_model_card(
args,
repo_id: str,
images=None,
repo_folder=None,
):
img_str = ""
if len(images) > 0:
image_grid = make_image_grid(images, 1, len(args.validation_prompts))
image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
img_str += "![val_imgs_grid](./val_imgs_grid.png)\n"
yaml = f"""
---
license: mit
base_model: {args.pretrained_prior_model_name_or_path}
datasets:
- {args.dataset_name}
tags:
- wuerstchen
- text-to-image
- diffusers
- diffusers-training
- lora
inference: true
---
"""
model_card = f"""
# LoRA Finetuning - {repo_id}
This pipeline was finetuned from **{args.pretrained_prior_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
{img_str}
## Pipeline usage
You can use the pipeline like so:
```python
from diffusers import DiffusionPipeline
import torch
pipeline = AutoPipelineForText2Image.from_pretrained(
"{args.pretrained_decoder_model_name_or_path}", torch_dtype={args.weight_dtype}
)
# load lora weights from folder:
pipeline.prior_pipe.load_lora_weights("{repo_id}", torch_dtype={args.weight_dtype})
image = pipeline(prompt=prompt).images[0]
image.save("my_image.png")
```
## Training info
These are the key hyperparameters used during training:
* LoRA rank: {args.rank}
* Epochs: {args.num_train_epochs}
* Learning rate: {args.learning_rate}
* Batch size: {args.train_batch_size}
* Gradient accumulation steps: {args.gradient_accumulation_steps}
* Image resolution: {args.resolution}
* Mixed-precision: {args.mixed_precision}
"""
wandb_info = ""
if is_wandb_available():
wandb_run_url = None
if wandb.run is not None:
wandb_run_url = wandb.run.url
if wandb_run_url is not None:
wandb_info = f"""
More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
"""
model_card += wandb_info
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
def log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, epoch):
logger.info("Running validation... ")
pipeline = AutoPipelineForText2Image.from_pretrained(
args.pretrained_decoder_model_name_or_path,
prior=accelerator.unwrap_model(prior),
prior_text_encoder=accelerator.unwrap_model(text_encoder),
prior_tokenizer=tokenizer,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
if args.seed is None:
generator = None
else:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
images = []
for i in range(len(args.validation_prompts)):
with torch.cuda.amp.autocast():
image = pipeline(
args.validation_prompts[i],
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
generator=generator,
height=args.resolution,
width=args.resolution,
).images[0]
images.append(image)
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
elif tracker.name == "wandb":
tracker.log(
{
"validation": [
wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
for i, image in enumerate(images)
]
}
)
else:
logger.warning(f"image logging not implemented for {tracker.name}")
del pipeline
torch.cuda.empty_cache()
return images
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of finetuning Würstchen Prior.")
parser.add_argument(
"--rank",
type=int,
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument(
"--pretrained_decoder_model_name_or_path",
type=str,
default="warp-ai/wuerstchen",
required=False,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--pretrained_prior_model_name_or_path",
type=str,
default="warp-ai/wuerstchen-prior",
required=False,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--dataset_name",
type=str,
default=None,
help=(
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
" or to a folder containing files that 🤗 Datasets can understand."
),
)
parser.add_argument(
"--dataset_config_name",
type=str,
default=None,
help="The config of the Dataset, leave as None if there's only one config.",
)
parser.add_argument(
"--train_data_dir",
type=str,
default=None,
help=(
"A folder containing the training data. Folder contents must follow the structure described in"
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
),
)
parser.add_argument(
"--image_column", type=str, default="image", help="The column of the dataset containing an image."
)
parser.add_argument(
"--caption_column",
type=str,
default="text",
help="The column of the dataset containing a caption or a list of captions.",
)
parser.add_argument(
"--max_train_samples",
type=int,
default=None,
help=(
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
),
)
parser.add_argument(
"--validation_prompts",
type=str,
default=None,
nargs="+",
help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
)
parser.add_argument(
"--output_dir",
type=str,
default="wuerstchen-model-finetuned-lora",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--cache_dir",
type=str,
default=None,
help="The directory where the downloaded models and datasets will be stored.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
)
parser.add_argument("--num_train_epochs", type=int, default=100)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="learning rate",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument(
"--adam_weight_decay",
type=float,
default=0.0,
required=False,
help="weight decay_to_use",
)
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--hub_model_id",
type=str,
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument(
"--checkpointing_steps",
type=int,
default=500,
help=(
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
" training using `--resume_from_checkpoint`."
),
)
parser.add_argument(
"--checkpoints_total_limit",
type=int,
default=None,
help=("Max number of checkpoints to store."),
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=(
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser.add_argument(
"--validation_epochs",
type=int,
default=5,
help="Run validation every X epochs.",
)
parser.add_argument(
"--tracker_project_name",
type=str,
default="text2image-fine-tune",
help=(
"The `project_name` argument passed to Accelerator.init_trackers for"
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
),
)
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
# Sanity checks
if args.dataset_name is None and args.train_data_dir is None:
raise ValueError("Need either a dataset name or a training folder.")
return args
def main():
args = parse_args()
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
" Please use `huggingface-cli login` to authenticate with the Hub."
)
logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_warning()
set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# Handle the repository creation
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
# Load scheduler, effnet, tokenizer, clip_model
noise_scheduler = DDPMWuerstchenScheduler()
tokenizer = PreTrainedTokenizerFast.from_pretrained(
args.pretrained_prior_model_name_or_path, subfolder="tokenizer"
)
def deepspeed_zero_init_disabled_context_manager():
"""
returns either a context list that includes one that will disable zero.Init or an empty context list
"""
deepspeed_plugin = AcceleratorState().deepspeed_plugin if is_initialized() else None
if deepspeed_plugin is None:
return []
return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
pretrained_checkpoint_file = hf_hub_download("dome272/wuerstchen", filename="model_v2_stage_b.pt")
state_dict = torch.load(pretrained_checkpoint_file, map_location="cpu")
image_encoder = EfficientNetEncoder()
image_encoder.load_state_dict(state_dict["effnet_state_dict"])
image_encoder.eval()
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_prior_model_name_or_path, subfolder="text_encoder", torch_dtype=weight_dtype
).eval()
# Freeze text_encoder, cast to weight_dtype and image_encoder and move to device
text_encoder.requires_grad_(False)
image_encoder.requires_grad_(False)
image_encoder.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)
# load prior model, cast to weight_dtype and move to device
prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
prior.to(accelerator.device, dtype=weight_dtype)
# lora attn processor
prior_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"],
)
# Add adapter and make sure the trainable params are in float32.
prior.add_adapter(prior_lora_config)
if args.mixed_precision == "fp16":
for param in prior.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
prior_lora_layers_to_save = None
for model in models:
if isinstance(model, type(accelerator.unwrap_model(prior))):
prior_lora_layers_to_save = get_peft_model_state_dict(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
WuerstchenPriorPipeline.save_lora_weights(
output_dir,
unet_lora_layers=prior_lora_layers_to_save,
)
def load_model_hook(models, input_dir):
prior_ = None
while len(models) > 0:
model = models.pop()
if isinstance(model, type(accelerator.unwrap_model(prior))):
prior_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, network_alphas = WuerstchenPriorPipeline.lora_state_dict(input_dir)
WuerstchenPriorPipeline.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=prior_)
WuerstchenPriorPipeline.load_lora_into_text_encoder(
lora_state_dict,
network_alphas=network_alphas,
)
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
)
optimizer_cls = bnb.optim.AdamW8bit
else:
optimizer_cls = torch.optim.AdamW
params_to_optimize = list(filter(lambda p: p.requires_grad, prior.parameters()))
optimizer = optimizer_cls(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
# Get the datasets: you can either provide your own training and evaluation files (see below)
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
# download the dataset.
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
dataset = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
)
else:
data_files = {}
if args.train_data_dir is not None:
data_files["train"] = os.path.join(args.train_data_dir, "**")
dataset = load_dataset(
"imagefolder",
data_files=data_files,
cache_dir=args.cache_dir,
)
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
# Preprocessing the datasets.
# We need to tokenize inputs and targets.
column_names = dataset["train"].column_names
# Get the column names for input/target.
dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
if args.image_column is None:
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
else:
image_column = args.image_column
if image_column not in column_names:
raise ValueError(
f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
)
if args.caption_column is None:
caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
else:
caption_column = args.caption_column
if caption_column not in column_names:
raise ValueError(
f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
)
# Preprocessing the datasets.
# We need to tokenize input captions and transform the images
def tokenize_captions(examples, is_train=True):
captions = []
for caption in examples[caption_column]:
if isinstance(caption, str):
captions.append(caption)
elif isinstance(caption, (list, np.ndarray)):
# take a random caption if there are multiple
captions.append(random.choice(caption) if is_train else caption[0])
else:
raise ValueError(
f"Caption column `{caption_column}` should contain either strings or lists of strings."
)
inputs = tokenizer(
captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
)
text_input_ids = inputs.input_ids
text_mask = inputs.attention_mask.bool()
return text_input_ids, text_mask
effnet_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
transforms.CenterCrop(args.resolution),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
]
)
def preprocess_train(examples):
images = [image.convert("RGB") for image in examples[image_column]]
examples["effnet_pixel_values"] = [effnet_transforms(image) for image in images]
examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples)
return examples
with accelerator.main_process_first():
if args.max_train_samples is not None:
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
# Set the training transforms
train_dataset = dataset["train"].with_transform(preprocess_train)
def collate_fn(examples):
effnet_pixel_values = torch.stack([example["effnet_pixel_values"] for example in examples])
effnet_pixel_values = effnet_pixel_values.to(memory_format=torch.contiguous_format).float()
text_input_ids = torch.stack([example["text_input_ids"] for example in examples])
text_mask = torch.stack([example["text_mask"] for example in examples])
return {"effnet_pixel_values": effnet_pixel_values, "text_input_ids": text_input_ids, "text_mask": text_mask}
# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
shuffle=True,
collate_fn=collate_fn,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
)
prior, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
prior, optimizer, train_dataloader, lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
tracker_config = dict(vars(args))
tracker_config.pop("validation_prompts")
accelerator.init_trackers(args.tracker_project_name, tracker_config)
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps
first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
for epoch in range(first_epoch, args.num_train_epochs):
prior.train()
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue
with accelerator.accumulate(prior):
# Convert images to latent space
text_input_ids, text_mask, effnet_images = (
batch["text_input_ids"],
batch["text_mask"],
batch["effnet_pixel_values"].to(weight_dtype),
)
with torch.no_grad():
text_encoder_output = text_encoder(text_input_ids, attention_mask=text_mask)
prompt_embeds = text_encoder_output.last_hidden_state
image_embeds = image_encoder(effnet_images)
# scale
image_embeds = image_embeds.add(1.0).div(42.0)
# Sample noise that we'll add to the image_embeds
noise = torch.randn_like(image_embeds)
bsz = image_embeds.shape[0]
# Sample a random timestep for each image
timesteps = torch.rand((bsz,), device=image_embeds.device, dtype=weight_dtype)
# add noise to latent
noisy_latents = noise_scheduler.add_noise(image_embeds, noise, timesteps)
# Predict the noise residual and compute losscd
pred_noise = prior(noisy_latents, timesteps, prompt_embeds)
# vanilla loss
loss = F.mse_loss(pred_noise.float(), noise.float(), reduction="mean")
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step)
train_loss = 0.0
if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if accelerator.is_main_process:
if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, global_step)
# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
prior = accelerator.unwrap_model(prior)
prior = prior.to(torch.float32)
prior_lora_state_dict = get_peft_model_state_dict(prior)
WuerstchenPriorPipeline.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=prior_lora_state_dict,
)
# Run a final round of inference.
images = []
if args.validation_prompts is not None:
logger.info("Running inference for collecting generated images...")
pipeline = AutoPipelineForText2Image.from_pretrained(
args.pretrained_decoder_model_name_or_path,
prior_text_encoder=accelerator.unwrap_model(text_encoder),
prior_tokenizer=tokenizer,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
# load lora weights
pipeline.prior_pipe.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")
pipeline.set_progress_bar_config(disable=True)
if args.seed is None:
generator = None
else:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
for i in range(len(args.validation_prompts)):
with torch.cuda.amp.autocast():
image = pipeline(
args.validation_prompts[i],
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
generator=generator,
width=args.resolution,
height=args.resolution,
).images[0]
images.append(image)
if args.push_to_hub:
save_model_card(args, repo_id, images, repo_folder=args.output_dir)
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
accelerator.end_training()
if __name__ == "__main__":
main()
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import argparse
import logging
import math
import os
import random
import shutil
from pathlib import Path
import accelerate
import datasets
import numpy as np
import torch
import torch.nn.functional as F
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.state import AcceleratorState, is_initialized
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from huggingface_hub import create_repo, hf_hub_download, upload_folder
from modeling_efficient_net_encoder import EfficientNetEncoder
from packaging import version
from torchvision import transforms
from tqdm import tqdm
from transformers import CLIPTextModel, PreTrainedTokenizerFast
from transformers.utils import ContextManagers
from diffusers import AutoPipelineForText2Image, DDPMWuerstchenScheduler
from diffusers.optimization import get_scheduler
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS, WuerstchenPrior
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
from diffusers.utils.logging import set_verbosity_error, set_verbosity_info
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0")
logger = get_logger(__name__, log_level="INFO")
DATASET_NAME_MAPPING = {
"lambdalabs/pokemon-blip-captions": ("image", "text"),
}
def save_model_card(
args,
repo_id: str,
images=None,
repo_folder=None,
):
img_str = ""
if len(images) > 0:
image_grid = make_image_grid(images, 1, len(args.validation_prompts))
image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
img_str += "![val_imgs_grid](./val_imgs_grid.png)\n"
yaml = f"""
---
license: mit
base_model: {args.pretrained_prior_model_name_or_path}
datasets:
- {args.dataset_name}
tags:
- wuerstchen
- text-to-image
- diffusers
- diffusers-training
inference: true
---
"""
model_card = f"""
# Finetuning - {repo_id}
This pipeline was finetuned from **{args.pretrained_prior_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
{img_str}
## Pipeline usage
You can use the pipeline like so:
```python
from diffusers import DiffusionPipeline
import torch
pipe_prior = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype={args.weight_dtype})
pipe_t2i = DiffusionPipeline.from_pretrained("{args.pretrained_decoder_model_name_or_path}", torch_dtype={args.weight_dtype})
prompt = "{args.validation_prompts[0]}"
(image_embeds,) = pipe_prior(prompt).to_tuple()
image = pipe_t2i(image_embeddings=image_embeds, prompt=prompt).images[0]
image.save("my_image.png")
```
## Training info
These are the key hyperparameters used during training:
* Epochs: {args.num_train_epochs}
* Learning rate: {args.learning_rate}
* Batch size: {args.train_batch_size}
* Gradient accumulation steps: {args.gradient_accumulation_steps}
* Image resolution: {args.resolution}
* Mixed-precision: {args.mixed_precision}
"""
wandb_info = ""
if is_wandb_available():
wandb_run_url = None
if wandb.run is not None:
wandb_run_url = wandb.run.url
if wandb_run_url is not None:
wandb_info = f"""
More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
"""
model_card += wandb_info
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
def log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, epoch):
logger.info("Running validation... ")
pipeline = AutoPipelineForText2Image.from_pretrained(
args.pretrained_decoder_model_name_or_path,
prior_prior=accelerator.unwrap_model(prior),
prior_text_encoder=accelerator.unwrap_model(text_encoder),
prior_tokenizer=tokenizer,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
if args.seed is None:
generator = None
else:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
images = []
for i in range(len(args.validation_prompts)):
with torch.autocast("cuda"):
image = pipeline(
args.validation_prompts[i],
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
generator=generator,
height=args.resolution,
width=args.resolution,
).images[0]
images.append(image)
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
elif tracker.name == "wandb":
tracker.log(
{
"validation": [
wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
for i, image in enumerate(images)
]
}
)
else:
logger.warning(f"image logging not implemented for {tracker.name}")
del pipeline
torch.cuda.empty_cache()
return images
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of finetuning Würstchen Prior.")
parser.add_argument(
"--pretrained_decoder_model_name_or_path",
type=str,
default="warp-ai/wuerstchen",
required=False,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--pretrained_prior_model_name_or_path",
type=str,
default="warp-ai/wuerstchen-prior",
required=False,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--dataset_name",
type=str,
default=None,
help=(
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
" or to a folder containing files that 🤗 Datasets can understand."
),
)
parser.add_argument(
"--dataset_config_name",
type=str,
default=None,
help="The config of the Dataset, leave as None if there's only one config.",
)
parser.add_argument(
"--train_data_dir",
type=str,
default=None,
help=(
"A folder containing the training data. Folder contents must follow the structure described in"
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
),
)
parser.add_argument(
"--image_column", type=str, default="image", help="The column of the dataset containing an image."
)
parser.add_argument(
"--caption_column",
type=str,
default="text",
help="The column of the dataset containing a caption or a list of captions.",
)
parser.add_argument(
"--max_train_samples",
type=int,
default=None,
help=(
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
),
)
parser.add_argument(
"--validation_prompts",
type=str,
default=None,
nargs="+",
help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
)
parser.add_argument(
"--output_dir",
type=str,
default="wuerstchen-model-finetuned",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--cache_dir",
type=str,
default=None,
help="The directory where the downloaded models and datasets will be stored.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
)
parser.add_argument("--num_train_epochs", type=int, default=100)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="learning rate",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument(
"--adam_weight_decay",
type=float,
default=0.0,
required=False,
help="weight decay_to_use",
)
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--hub_model_id",
type=str,
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument(
"--checkpointing_steps",
type=int,
default=500,
help=(
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
" training using `--resume_from_checkpoint`."
),
)
parser.add_argument(
"--checkpoints_total_limit",
type=int,
default=None,
help=("Max number of checkpoints to store."),
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=(
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser.add_argument(
"--validation_epochs",
type=int,
default=5,
help="Run validation every X epochs.",
)
parser.add_argument(
"--tracker_project_name",
type=str,
default="text2image-fine-tune",
help=(
"The `project_name` argument passed to Accelerator.init_trackers for"
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
),
)
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
# Sanity checks
if args.dataset_name is None and args.train_data_dir is None:
raise ValueError("Need either a dataset name or a training folder.")
return args
def main():
args = parse_args()
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
" Please use `huggingface-cli login` to authenticate with the Hub."
)
logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_warning()
set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# Handle the repository creation
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
# Load scheduler, effnet, tokenizer, clip_model
noise_scheduler = DDPMWuerstchenScheduler()
tokenizer = PreTrainedTokenizerFast.from_pretrained(
args.pretrained_prior_model_name_or_path, subfolder="tokenizer"
)
def deepspeed_zero_init_disabled_context_manager():
"""
returns either a context list that includes one that will disable zero.Init or an empty context list
"""
deepspeed_plugin = AcceleratorState().deepspeed_plugin if is_initialized() else None
if deepspeed_plugin is None:
return []
return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
pretrained_checkpoint_file = hf_hub_download("dome272/wuerstchen", filename="model_v2_stage_b.pt")
state_dict = torch.load(pretrained_checkpoint_file, map_location="cpu")
image_encoder = EfficientNetEncoder()
image_encoder.load_state_dict(state_dict["effnet_state_dict"])
image_encoder.eval()
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_prior_model_name_or_path, subfolder="text_encoder", torch_dtype=weight_dtype
).eval()
# Freeze text_encoder and image_encoder
text_encoder.requires_grad_(False)
image_encoder.requires_grad_(False)
# load prior model
prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
# Create EMA for the prior
if args.use_ema:
ema_prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
ema_prior = EMAModel(ema_prior.parameters(), model_cls=WuerstchenPrior, model_config=ema_prior.config)
ema_prior.to(accelerator.device)
# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if args.use_ema:
ema_prior.save_pretrained(os.path.join(output_dir, "prior_ema"))
for i, model in enumerate(models):
model.save_pretrained(os.path.join(output_dir, "prior"))
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
def load_model_hook(models, input_dir):
if args.use_ema:
load_model = EMAModel.from_pretrained(os.path.join(input_dir, "prior_ema"), WuerstchenPrior)
ema_prior.load_state_dict(load_model.state_dict())
ema_prior.to(accelerator.device)
del load_model
for i in range(len(models)):
# pop models so that they are not loaded again
model = models.pop()
# load diffusers style into model
load_model = WuerstchenPrior.from_pretrained(input_dir, subfolder="prior")
model.register_to_config(**load_model.config)
model.load_state_dict(load_model.state_dict())
del load_model
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
if args.gradient_checkpointing:
prior.enable_gradient_checkpointing()
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
)
optimizer_cls = bnb.optim.AdamW8bit
else:
optimizer_cls = torch.optim.AdamW
optimizer = optimizer_cls(
prior.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
# Get the datasets: you can either provide your own training and evaluation files (see below)
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
# download the dataset.
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
dataset = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
)
else:
data_files = {}
if args.train_data_dir is not None:
data_files["train"] = os.path.join(args.train_data_dir, "**")
dataset = load_dataset(
"imagefolder",
data_files=data_files,
cache_dir=args.cache_dir,
)
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
# Preprocessing the datasets.
# We need to tokenize inputs and targets.
column_names = dataset["train"].column_names
# Get the column names for input/target.
dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
if args.image_column is None:
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
else:
image_column = args.image_column
if image_column not in column_names:
raise ValueError(
f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
)
if args.caption_column is None:
caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
else:
caption_column = args.caption_column
if caption_column not in column_names:
raise ValueError(
f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
)
# Preprocessing the datasets.
# We need to tokenize input captions and transform the images
def tokenize_captions(examples, is_train=True):
captions = []
for caption in examples[caption_column]:
if isinstance(caption, str):
captions.append(caption)
elif isinstance(caption, (list, np.ndarray)):
# take a random caption if there are multiple
captions.append(random.choice(caption) if is_train else caption[0])
else:
raise ValueError(
f"Caption column `{caption_column}` should contain either strings or lists of strings."
)
inputs = tokenizer(
captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
)
text_input_ids = inputs.input_ids
text_mask = inputs.attention_mask.bool()
return text_input_ids, text_mask
effnet_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
transforms.CenterCrop(args.resolution),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
]
)
def preprocess_train(examples):
images = [image.convert("RGB") for image in examples[image_column]]
examples["effnet_pixel_values"] = [effnet_transforms(image) for image in images]
examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples)
return examples
with accelerator.main_process_first():
if args.max_train_samples is not None:
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
# Set the training transforms
train_dataset = dataset["train"].with_transform(preprocess_train)
def collate_fn(examples):
effnet_pixel_values = torch.stack([example["effnet_pixel_values"] for example in examples])
effnet_pixel_values = effnet_pixel_values.to(memory_format=torch.contiguous_format).float()
text_input_ids = torch.stack([example["text_input_ids"] for example in examples])
text_mask = torch.stack([example["text_mask"] for example in examples])
return {"effnet_pixel_values": effnet_pixel_values, "text_input_ids": text_input_ids, "text_mask": text_mask}
# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
shuffle=True,
collate_fn=collate_fn,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
)
prior, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
prior, optimizer, train_dataloader, lr_scheduler
)
image_encoder.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
tracker_config = dict(vars(args))
tracker_config.pop("validation_prompts")
accelerator.init_trackers(args.tracker_project_name, tracker_config)
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps
first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
for epoch in range(first_epoch, args.num_train_epochs):
prior.train()
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue
with accelerator.accumulate(prior):
# Convert images to latent space
text_input_ids, text_mask, effnet_images = (
batch["text_input_ids"],
batch["text_mask"],
batch["effnet_pixel_values"].to(weight_dtype),
)
with torch.no_grad():
text_encoder_output = text_encoder(text_input_ids, attention_mask=text_mask)
prompt_embeds = text_encoder_output.last_hidden_state
image_embeds = image_encoder(effnet_images)
# scale
image_embeds = image_embeds.add(1.0).div(42.0)
# Sample noise that we'll add to the image_embeds
noise = torch.randn_like(image_embeds)
bsz = image_embeds.shape[0]
# Sample a random timestep for each image
timesteps = torch.rand((bsz,), device=image_embeds.device, dtype=weight_dtype)
# add noise to latent
noisy_latents = noise_scheduler.add_noise(image_embeds, noise, timesteps)
# Predict the noise residual and compute losscd
pred_noise = prior(noisy_latents, timesteps, prompt_embeds)
# vanilla loss
loss = F.mse_loss(pred_noise.float(), noise.float(), reduction="mean")
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(prior.parameters(), args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
if args.use_ema:
ema_prior.step(prior.parameters())
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step)
train_loss = 0.0
if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if accelerator.is_main_process:
if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
if args.use_ema:
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
ema_prior.store(prior.parameters())
ema_prior.copy_to(prior.parameters())
log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, global_step)
if args.use_ema:
# Switch back to the original UNet parameters.
ema_prior.restore(prior.parameters())
# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
prior = accelerator.unwrap_model(prior)
if args.use_ema:
ema_prior.copy_to(prior.parameters())
pipeline = AutoPipelineForText2Image.from_pretrained(
args.pretrained_decoder_model_name_or_path,
prior_prior=prior,
prior_text_encoder=accelerator.unwrap_model(text_encoder),
prior_tokenizer=tokenizer,
)
pipeline.prior_pipe.save_pretrained(os.path.join(args.output_dir, "prior_pipeline"))
# Run a final round of inference.
images = []
if args.validation_prompts is not None:
logger.info("Running inference for collecting generated images...")
pipeline = pipeline.to(accelerator.device, torch_dtype=weight_dtype)
pipeline.set_progress_bar_config(disable=True)
if args.seed is None:
generator = None
else:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
for i in range(len(args.validation_prompts)):
with torch.autocast("cuda"):
image = pipeline(
args.validation_prompts[i],
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
generator=generator,
width=args.resolution,
height=args.resolution,
).images[0]
images.append(image)
if args.push_to_hub:
save_model_card(args, repo_id, images, repo_folder=args.output_dir)
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
accelerator.end_training()
if __name__ == "__main__":
main()
[tool.ruff]
# Never enforce `E501` (line length violations).
ignore = ["C901", "E501", "E741", "F402", "F823"]
select = ["C", "E", "F", "I", "W"]
line-length = 119
# Ignore import violations in all `__init__.py` files.
[tool.ruff.per-file-ignores]
"__init__.py" = ["E402", "F401", "F403", "F811"]
"src/diffusers/utils/dummy_*.py" = ["F401"]
[tool.ruff.isort]
lines-after-imports = 2
known-first-party = ["diffusers"]
[tool.ruff.format]
# Like Black, use double quotes for strings.
quote-style = "double"
# Like Black, indent with spaces, rather than tabs.
indent-style = "space"
# Like Black, respect magic trailing commas.
skip-magic-trailing-comma = false
# Like Black, automatically detect the appropriate line ending.
line-ending = "auto"
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Conversion script for the LDM checkpoints. """
import argparse
import json
import os
import torch
from transformers.file_utils import has_file
from diffusers import UNet2DConditionModel, UNet2DModel
do_only_config = False
do_only_weights = True
do_only_renaming = False
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo_path",
default=None,
type=str,
required=True,
help="The config json file corresponding to the architecture.",
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
args = parser.parse_args()
config_parameters_to_change = {
"image_size": "sample_size",
"num_res_blocks": "layers_per_block",
"block_channels": "block_out_channels",
"down_blocks": "down_block_types",
"up_blocks": "up_block_types",
"downscale_freq_shift": "freq_shift",
"resnet_num_groups": "norm_num_groups",
"resnet_act_fn": "act_fn",
"resnet_eps": "norm_eps",
"num_head_channels": "attention_head_dim",
}
key_parameters_to_change = {
"time_steps": "time_proj",
"mid": "mid_block",
"downsample_blocks": "down_blocks",
"upsample_blocks": "up_blocks",
}
subfolder = "" if has_file(args.repo_path, "config.json") else "unet"
with open(os.path.join(args.repo_path, subfolder, "config.json"), "r", encoding="utf-8") as reader:
text = reader.read()
config = json.loads(text)
if do_only_config:
for key in config_parameters_to_change.keys():
config.pop(key, None)
if has_file(args.repo_path, "config.json"):
model = UNet2DModel(**config)
else:
class_name = UNet2DConditionModel if "ldm-text2im-large-256" in args.repo_path else UNet2DModel
model = class_name(**config)
if do_only_config:
model.save_config(os.path.join(args.repo_path, subfolder))
config = dict(model.config)
if do_only_renaming:
for key, value in config_parameters_to_change.items():
if key in config:
config[value] = config[key]
del config[key]
config["down_block_types"] = [k.replace("UNetRes", "") for k in config["down_block_types"]]
config["up_block_types"] = [k.replace("UNetRes", "") for k in config["up_block_types"]]
if do_only_weights:
state_dict = torch.load(os.path.join(args.repo_path, subfolder, "diffusion_pytorch_model.bin"))
new_state_dict = {}
for param_key, param_value in state_dict.items():
if param_key.endswith(".op.bias") or param_key.endswith(".op.weight"):
continue
has_changed = False
for key, new_key in key_parameters_to_change.items():
if not has_changed and param_key.split(".")[0] == key:
new_state_dict[".".join([new_key] + param_key.split(".")[1:])] = param_value
has_changed = True
if not has_changed:
new_state_dict[param_key] = param_value
model.load_state_dict(new_state_dict)
model.save_pretrained(os.path.join(args.repo_path, subfolder))
import argparse
import torch
import yaml
from diffusers import DDIMScheduler, LDMPipeline, UNetLDMModel, VQModel
def convert_ldm_original(checkpoint_path, config_path, output_path):
config = yaml.safe_load(config_path)
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
keys = list(state_dict.keys())
# extract state_dict for VQVAE
first_stage_dict = {}
first_stage_key = "first_stage_model."
for key in keys:
if key.startswith(first_stage_key):
first_stage_dict[key.replace(first_stage_key, "")] = state_dict[key]
# extract state_dict for UNetLDM
unet_state_dict = {}
unet_key = "model.diffusion_model."
for key in keys:
if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = state_dict[key]
vqvae_init_args = config["model"]["params"]["first_stage_config"]["params"]
unet_init_args = config["model"]["params"]["unet_config"]["params"]
vqvae = VQModel(**vqvae_init_args).eval()
vqvae.load_state_dict(first_stage_dict)
unet = UNetLDMModel(**unet_init_args).eval()
unet.load_state_dict(unet_state_dict)
noise_scheduler = DDIMScheduler(
timesteps=config["model"]["params"]["timesteps"],
beta_schedule="scaled_linear",
beta_start=config["model"]["params"]["linear_start"],
beta_end=config["model"]["params"]["linear_end"],
clip_sample=False,
)
pipeline = LDMPipeline(vqvae, unet, noise_scheduler)
pipeline.save_pretrained(output_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", type=str, required=True)
parser.add_argument("--config_path", type=str, required=True)
parser.add_argument("--output_path", type=str, required=True)
args = parser.parse_args()
convert_ldm_original(args.checkpoint_path, args.config_path, args.output_path)
import inspect
import os
from argparse import ArgumentParser
import numpy as np
import torch
from muse import MaskGiTUViT, VQGANModel
from muse import PipelineMuse as OldPipelineMuse
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import VQModel
from diffusers.models.attention_processor import AttnProcessor
from diffusers.models.unets.uvit_2d import UVit2DModel
from diffusers.pipelines.amused.pipeline_amused import AmusedPipeline
from diffusers.schedulers import AmusedScheduler
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
torch.use_deterministic_algorithms(True)
# Enable CUDNN deterministic mode
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = False
device = "cuda"
def main():
args = ArgumentParser()
args.add_argument("--model_256", action="store_true")
args.add_argument("--write_to", type=str, required=False, default=None)
args.add_argument("--transformer_path", type=str, required=False, default=None)
args = args.parse_args()
transformer_path = args.transformer_path
subfolder = "transformer"
if transformer_path is None:
if args.model_256:
transformer_path = "openMUSE/muse-256"
else:
transformer_path = (
"../research-run-512-checkpoints/research-run-512-with-downsample-checkpoint-554000/unwrapped_model/"
)
subfolder = None
old_transformer = MaskGiTUViT.from_pretrained(transformer_path, subfolder=subfolder)
old_transformer.to(device)
old_vae = VQGANModel.from_pretrained("openMUSE/muse-512", subfolder="vae")
old_vae.to(device)
vqvae = make_vqvae(old_vae)
tokenizer = CLIPTokenizer.from_pretrained("openMUSE/muse-512", subfolder="text_encoder")
text_encoder = CLIPTextModelWithProjection.from_pretrained("openMUSE/muse-512", subfolder="text_encoder")
text_encoder.to(device)
transformer = make_transformer(old_transformer, args.model_256)
scheduler = AmusedScheduler(mask_token_id=old_transformer.config.mask_token_id)
new_pipe = AmusedPipeline(
vqvae=vqvae, tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, scheduler=scheduler
)
old_pipe = OldPipelineMuse(
vae=old_vae, transformer=old_transformer, text_encoder=text_encoder, tokenizer=tokenizer
)
old_pipe.to(device)
if args.model_256:
transformer_seq_len = 256
orig_size = (256, 256)
else:
transformer_seq_len = 1024
orig_size = (512, 512)
old_out = old_pipe(
"dog",
generator=torch.Generator(device).manual_seed(0),
transformer_seq_len=transformer_seq_len,
orig_size=orig_size,
timesteps=12,
)[0]
new_out = new_pipe("dog", generator=torch.Generator(device).manual_seed(0)).images[0]
old_out = np.array(old_out)
new_out = np.array(new_out)
diff = np.abs(old_out.astype(np.float64) - new_out.astype(np.float64))
# assert diff diff.sum() == 0
print("skipping pipeline full equivalence check")
print(f"max diff: {diff.max()}, diff.sum() / diff.size {diff.sum() / diff.size}")
if args.model_256:
assert diff.max() <= 3
assert diff.sum() / diff.size < 0.7
else:
assert diff.max() <= 1
assert diff.sum() / diff.size < 0.4
if args.write_to is not None:
new_pipe.save_pretrained(args.write_to)
def make_transformer(old_transformer, model_256):
args = dict(old_transformer.config)
force_down_up_sample = args["force_down_up_sample"]
signature = inspect.signature(UVit2DModel.__init__)
args_ = {
"downsample": force_down_up_sample,
"upsample": force_down_up_sample,
"block_out_channels": args["block_out_channels"][0],
"sample_size": 16 if model_256 else 32,
}
for s in list(signature.parameters.keys()):
if s in ["self", "downsample", "upsample", "sample_size", "block_out_channels"]:
continue
args_[s] = args[s]
new_transformer = UVit2DModel(**args_)
new_transformer.to(device)
new_transformer.set_attn_processor(AttnProcessor())
state_dict = old_transformer.state_dict()
state_dict["cond_embed.linear_1.weight"] = state_dict.pop("cond_embed.0.weight")
state_dict["cond_embed.linear_2.weight"] = state_dict.pop("cond_embed.2.weight")
for i in range(22):
state_dict[f"transformer_layers.{i}.norm1.norm.weight"] = state_dict.pop(
f"transformer_layers.{i}.attn_layer_norm.weight"
)
state_dict[f"transformer_layers.{i}.norm1.linear.weight"] = state_dict.pop(
f"transformer_layers.{i}.self_attn_adaLN_modulation.mapper.weight"
)
state_dict[f"transformer_layers.{i}.attn1.to_q.weight"] = state_dict.pop(
f"transformer_layers.{i}.attention.query.weight"
)
state_dict[f"transformer_layers.{i}.attn1.to_k.weight"] = state_dict.pop(
f"transformer_layers.{i}.attention.key.weight"
)
state_dict[f"transformer_layers.{i}.attn1.to_v.weight"] = state_dict.pop(
f"transformer_layers.{i}.attention.value.weight"
)
state_dict[f"transformer_layers.{i}.attn1.to_out.0.weight"] = state_dict.pop(
f"transformer_layers.{i}.attention.out.weight"
)
state_dict[f"transformer_layers.{i}.norm2.norm.weight"] = state_dict.pop(
f"transformer_layers.{i}.crossattn_layer_norm.weight"
)
state_dict[f"transformer_layers.{i}.norm2.linear.weight"] = state_dict.pop(
f"transformer_layers.{i}.cross_attn_adaLN_modulation.mapper.weight"
)
state_dict[f"transformer_layers.{i}.attn2.to_q.weight"] = state_dict.pop(
f"transformer_layers.{i}.crossattention.query.weight"
)
state_dict[f"transformer_layers.{i}.attn2.to_k.weight"] = state_dict.pop(
f"transformer_layers.{i}.crossattention.key.weight"
)
state_dict[f"transformer_layers.{i}.attn2.to_v.weight"] = state_dict.pop(
f"transformer_layers.{i}.crossattention.value.weight"
)
state_dict[f"transformer_layers.{i}.attn2.to_out.0.weight"] = state_dict.pop(
f"transformer_layers.{i}.crossattention.out.weight"
)
state_dict[f"transformer_layers.{i}.norm3.norm.weight"] = state_dict.pop(
f"transformer_layers.{i}.ffn.pre_mlp_layer_norm.weight"
)
state_dict[f"transformer_layers.{i}.norm3.linear.weight"] = state_dict.pop(
f"transformer_layers.{i}.ffn.adaLN_modulation.mapper.weight"
)
wi_0_weight = state_dict.pop(f"transformer_layers.{i}.ffn.wi_0.weight")
wi_1_weight = state_dict.pop(f"transformer_layers.{i}.ffn.wi_1.weight")
proj_weight = torch.concat([wi_1_weight, wi_0_weight], dim=0)
state_dict[f"transformer_layers.{i}.ff.net.0.proj.weight"] = proj_weight
state_dict[f"transformer_layers.{i}.ff.net.2.weight"] = state_dict.pop(f"transformer_layers.{i}.ffn.wo.weight")
if force_down_up_sample:
state_dict["down_block.downsample.norm.weight"] = state_dict.pop("down_blocks.0.downsample.0.norm.weight")
state_dict["down_block.downsample.conv.weight"] = state_dict.pop("down_blocks.0.downsample.1.weight")
state_dict["up_block.upsample.norm.weight"] = state_dict.pop("up_blocks.0.upsample.0.norm.weight")
state_dict["up_block.upsample.conv.weight"] = state_dict.pop("up_blocks.0.upsample.1.weight")
state_dict["mlm_layer.layer_norm.weight"] = state_dict.pop("mlm_layer.layer_norm.norm.weight")
for i in range(3):
state_dict[f"down_block.res_blocks.{i}.norm.weight"] = state_dict.pop(
f"down_blocks.0.res_blocks.{i}.norm.norm.weight"
)
state_dict[f"down_block.res_blocks.{i}.channelwise_linear_1.weight"] = state_dict.pop(
f"down_blocks.0.res_blocks.{i}.channelwise.0.weight"
)
state_dict[f"down_block.res_blocks.{i}.channelwise_norm.gamma"] = state_dict.pop(
f"down_blocks.0.res_blocks.{i}.channelwise.2.gamma"
)
state_dict[f"down_block.res_blocks.{i}.channelwise_norm.beta"] = state_dict.pop(
f"down_blocks.0.res_blocks.{i}.channelwise.2.beta"
)
state_dict[f"down_block.res_blocks.{i}.channelwise_linear_2.weight"] = state_dict.pop(
f"down_blocks.0.res_blocks.{i}.channelwise.4.weight"
)
state_dict[f"down_block.res_blocks.{i}.cond_embeds_mapper.weight"] = state_dict.pop(
f"down_blocks.0.res_blocks.{i}.adaLN_modulation.mapper.weight"
)
state_dict[f"down_block.attention_blocks.{i}.norm1.weight"] = state_dict.pop(
f"down_blocks.0.attention_blocks.{i}.attn_layer_norm.weight"
)
state_dict[f"down_block.attention_blocks.{i}.attn1.to_q.weight"] = state_dict.pop(
f"down_blocks.0.attention_blocks.{i}.attention.query.weight"
)
state_dict[f"down_block.attention_blocks.{i}.attn1.to_k.weight"] = state_dict.pop(
f"down_blocks.0.attention_blocks.{i}.attention.key.weight"
)
state_dict[f"down_block.attention_blocks.{i}.attn1.to_v.weight"] = state_dict.pop(
f"down_blocks.0.attention_blocks.{i}.attention.value.weight"
)
state_dict[f"down_block.attention_blocks.{i}.attn1.to_out.0.weight"] = state_dict.pop(
f"down_blocks.0.attention_blocks.{i}.attention.out.weight"
)
state_dict[f"down_block.attention_blocks.{i}.norm2.weight"] = state_dict.pop(
f"down_blocks.0.attention_blocks.{i}.crossattn_layer_norm.weight"
)
state_dict[f"down_block.attention_blocks.{i}.attn2.to_q.weight"] = state_dict.pop(
f"down_blocks.0.attention_blocks.{i}.crossattention.query.weight"
)
state_dict[f"down_block.attention_blocks.{i}.attn2.to_k.weight"] = state_dict.pop(
f"down_blocks.0.attention_blocks.{i}.crossattention.key.weight"
)
state_dict[f"down_block.attention_blocks.{i}.attn2.to_v.weight"] = state_dict.pop(
f"down_blocks.0.attention_blocks.{i}.crossattention.value.weight"
)
state_dict[f"down_block.attention_blocks.{i}.attn2.to_out.0.weight"] = state_dict.pop(
f"down_blocks.0.attention_blocks.{i}.crossattention.out.weight"
)
state_dict[f"up_block.res_blocks.{i}.norm.weight"] = state_dict.pop(
f"up_blocks.0.res_blocks.{i}.norm.norm.weight"
)
state_dict[f"up_block.res_blocks.{i}.channelwise_linear_1.weight"] = state_dict.pop(
f"up_blocks.0.res_blocks.{i}.channelwise.0.weight"
)
state_dict[f"up_block.res_blocks.{i}.channelwise_norm.gamma"] = state_dict.pop(
f"up_blocks.0.res_blocks.{i}.channelwise.2.gamma"
)
state_dict[f"up_block.res_blocks.{i}.channelwise_norm.beta"] = state_dict.pop(
f"up_blocks.0.res_blocks.{i}.channelwise.2.beta"
)
state_dict[f"up_block.res_blocks.{i}.channelwise_linear_2.weight"] = state_dict.pop(
f"up_blocks.0.res_blocks.{i}.channelwise.4.weight"
)
state_dict[f"up_block.res_blocks.{i}.cond_embeds_mapper.weight"] = state_dict.pop(
f"up_blocks.0.res_blocks.{i}.adaLN_modulation.mapper.weight"
)
state_dict[f"up_block.attention_blocks.{i}.norm1.weight"] = state_dict.pop(
f"up_blocks.0.attention_blocks.{i}.attn_layer_norm.weight"
)
state_dict[f"up_block.attention_blocks.{i}.attn1.to_q.weight"] = state_dict.pop(
f"up_blocks.0.attention_blocks.{i}.attention.query.weight"
)
state_dict[f"up_block.attention_blocks.{i}.attn1.to_k.weight"] = state_dict.pop(
f"up_blocks.0.attention_blocks.{i}.attention.key.weight"
)
state_dict[f"up_block.attention_blocks.{i}.attn1.to_v.weight"] = state_dict.pop(
f"up_blocks.0.attention_blocks.{i}.attention.value.weight"
)
state_dict[f"up_block.attention_blocks.{i}.attn1.to_out.0.weight"] = state_dict.pop(
f"up_blocks.0.attention_blocks.{i}.attention.out.weight"
)
state_dict[f"up_block.attention_blocks.{i}.norm2.weight"] = state_dict.pop(
f"up_blocks.0.attention_blocks.{i}.crossattn_layer_norm.weight"
)
state_dict[f"up_block.attention_blocks.{i}.attn2.to_q.weight"] = state_dict.pop(
f"up_blocks.0.attention_blocks.{i}.crossattention.query.weight"
)
state_dict[f"up_block.attention_blocks.{i}.attn2.to_k.weight"] = state_dict.pop(
f"up_blocks.0.attention_blocks.{i}.crossattention.key.weight"
)
state_dict[f"up_block.attention_blocks.{i}.attn2.to_v.weight"] = state_dict.pop(
f"up_blocks.0.attention_blocks.{i}.crossattention.value.weight"
)
state_dict[f"up_block.attention_blocks.{i}.attn2.to_out.0.weight"] = state_dict.pop(
f"up_blocks.0.attention_blocks.{i}.crossattention.out.weight"
)
for key in list(state_dict.keys()):
if key.startswith("up_blocks.0"):
key_ = "up_block." + ".".join(key.split(".")[2:])
state_dict[key_] = state_dict.pop(key)
if key.startswith("down_blocks.0"):
key_ = "down_block." + ".".join(key.split(".")[2:])
state_dict[key_] = state_dict.pop(key)
new_transformer.load_state_dict(state_dict)
input_ids = torch.randint(0, 10, (1, 32, 32), device=old_transformer.device)
encoder_hidden_states = torch.randn((1, 77, 768), device=old_transformer.device)
cond_embeds = torch.randn((1, 768), device=old_transformer.device)
micro_conds = torch.tensor([[512, 512, 0, 0, 6]], dtype=torch.float32, device=old_transformer.device)
old_out = old_transformer(input_ids.reshape(1, -1), encoder_hidden_states, cond_embeds, micro_conds)
old_out = old_out.reshape(1, 32, 32, 8192).permute(0, 3, 1, 2)
new_out = new_transformer(input_ids, encoder_hidden_states, cond_embeds, micro_conds)
# NOTE: these differences are solely due to using the geglu block that has a single linear layer of
# double output dimension instead of two different linear layers
max_diff = (old_out - new_out).abs().max()
total_diff = (old_out - new_out).abs().sum()
print(f"Transformer max_diff: {max_diff} total_diff: {total_diff}")
assert max_diff < 0.01
assert total_diff < 1500
return new_transformer
def make_vqvae(old_vae):
new_vae = VQModel(
act_fn="silu",
block_out_channels=[128, 256, 256, 512, 768],
down_block_types=[
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
],
in_channels=3,
latent_channels=64,
layers_per_block=2,
norm_num_groups=32,
num_vq_embeddings=8192,
out_channels=3,
sample_size=32,
up_block_types=[
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D",
],
mid_block_add_attention=False,
lookup_from_codebook=True,
)
new_vae.to(device)
# fmt: off
new_state_dict = {}
old_state_dict = old_vae.state_dict()
new_state_dict["encoder.conv_in.weight"] = old_state_dict.pop("encoder.conv_in.weight")
new_state_dict["encoder.conv_in.bias"] = old_state_dict.pop("encoder.conv_in.bias")
convert_vae_block_state_dict(old_state_dict, "encoder.down.0", new_state_dict, "encoder.down_blocks.0")
convert_vae_block_state_dict(old_state_dict, "encoder.down.1", new_state_dict, "encoder.down_blocks.1")
convert_vae_block_state_dict(old_state_dict, "encoder.down.2", new_state_dict, "encoder.down_blocks.2")
convert_vae_block_state_dict(old_state_dict, "encoder.down.3", new_state_dict, "encoder.down_blocks.3")
convert_vae_block_state_dict(old_state_dict, "encoder.down.4", new_state_dict, "encoder.down_blocks.4")
new_state_dict["encoder.mid_block.resnets.0.norm1.weight"] = old_state_dict.pop("encoder.mid.block_1.norm1.weight")
new_state_dict["encoder.mid_block.resnets.0.norm1.bias"] = old_state_dict.pop("encoder.mid.block_1.norm1.bias")
new_state_dict["encoder.mid_block.resnets.0.conv1.weight"] = old_state_dict.pop("encoder.mid.block_1.conv1.weight")
new_state_dict["encoder.mid_block.resnets.0.conv1.bias"] = old_state_dict.pop("encoder.mid.block_1.conv1.bias")
new_state_dict["encoder.mid_block.resnets.0.norm2.weight"] = old_state_dict.pop("encoder.mid.block_1.norm2.weight")
new_state_dict["encoder.mid_block.resnets.0.norm2.bias"] = old_state_dict.pop("encoder.mid.block_1.norm2.bias")
new_state_dict["encoder.mid_block.resnets.0.conv2.weight"] = old_state_dict.pop("encoder.mid.block_1.conv2.weight")
new_state_dict["encoder.mid_block.resnets.0.conv2.bias"] = old_state_dict.pop("encoder.mid.block_1.conv2.bias")
new_state_dict["encoder.mid_block.resnets.1.norm1.weight"] = old_state_dict.pop("encoder.mid.block_2.norm1.weight")
new_state_dict["encoder.mid_block.resnets.1.norm1.bias"] = old_state_dict.pop("encoder.mid.block_2.norm1.bias")
new_state_dict["encoder.mid_block.resnets.1.conv1.weight"] = old_state_dict.pop("encoder.mid.block_2.conv1.weight")
new_state_dict["encoder.mid_block.resnets.1.conv1.bias"] = old_state_dict.pop("encoder.mid.block_2.conv1.bias")
new_state_dict["encoder.mid_block.resnets.1.norm2.weight"] = old_state_dict.pop("encoder.mid.block_2.norm2.weight")
new_state_dict["encoder.mid_block.resnets.1.norm2.bias"] = old_state_dict.pop("encoder.mid.block_2.norm2.bias")
new_state_dict["encoder.mid_block.resnets.1.conv2.weight"] = old_state_dict.pop("encoder.mid.block_2.conv2.weight")
new_state_dict["encoder.mid_block.resnets.1.conv2.bias"] = old_state_dict.pop("encoder.mid.block_2.conv2.bias")
new_state_dict["encoder.conv_norm_out.weight"] = old_state_dict.pop("encoder.norm_out.weight")
new_state_dict["encoder.conv_norm_out.bias"] = old_state_dict.pop("encoder.norm_out.bias")
new_state_dict["encoder.conv_out.weight"] = old_state_dict.pop("encoder.conv_out.weight")
new_state_dict["encoder.conv_out.bias"] = old_state_dict.pop("encoder.conv_out.bias")
new_state_dict["quant_conv.weight"] = old_state_dict.pop("quant_conv.weight")
new_state_dict["quant_conv.bias"] = old_state_dict.pop("quant_conv.bias")
new_state_dict["quantize.embedding.weight"] = old_state_dict.pop("quantize.embedding.weight")
new_state_dict["post_quant_conv.weight"] = old_state_dict.pop("post_quant_conv.weight")
new_state_dict["post_quant_conv.bias"] = old_state_dict.pop("post_quant_conv.bias")
new_state_dict["decoder.conv_in.weight"] = old_state_dict.pop("decoder.conv_in.weight")
new_state_dict["decoder.conv_in.bias"] = old_state_dict.pop("decoder.conv_in.bias")
new_state_dict["decoder.mid_block.resnets.0.norm1.weight"] = old_state_dict.pop("decoder.mid.block_1.norm1.weight")
new_state_dict["decoder.mid_block.resnets.0.norm1.bias"] = old_state_dict.pop("decoder.mid.block_1.norm1.bias")
new_state_dict["decoder.mid_block.resnets.0.conv1.weight"] = old_state_dict.pop("decoder.mid.block_1.conv1.weight")
new_state_dict["decoder.mid_block.resnets.0.conv1.bias"] = old_state_dict.pop("decoder.mid.block_1.conv1.bias")
new_state_dict["decoder.mid_block.resnets.0.norm2.weight"] = old_state_dict.pop("decoder.mid.block_1.norm2.weight")
new_state_dict["decoder.mid_block.resnets.0.norm2.bias"] = old_state_dict.pop("decoder.mid.block_1.norm2.bias")
new_state_dict["decoder.mid_block.resnets.0.conv2.weight"] = old_state_dict.pop("decoder.mid.block_1.conv2.weight")
new_state_dict["decoder.mid_block.resnets.0.conv2.bias"] = old_state_dict.pop("decoder.mid.block_1.conv2.bias")
new_state_dict["decoder.mid_block.resnets.1.norm1.weight"] = old_state_dict.pop("decoder.mid.block_2.norm1.weight")
new_state_dict["decoder.mid_block.resnets.1.norm1.bias"] = old_state_dict.pop("decoder.mid.block_2.norm1.bias")
new_state_dict["decoder.mid_block.resnets.1.conv1.weight"] = old_state_dict.pop("decoder.mid.block_2.conv1.weight")
new_state_dict["decoder.mid_block.resnets.1.conv1.bias"] = old_state_dict.pop("decoder.mid.block_2.conv1.bias")
new_state_dict["decoder.mid_block.resnets.1.norm2.weight"] = old_state_dict.pop("decoder.mid.block_2.norm2.weight")
new_state_dict["decoder.mid_block.resnets.1.norm2.bias"] = old_state_dict.pop("decoder.mid.block_2.norm2.bias")
new_state_dict["decoder.mid_block.resnets.1.conv2.weight"] = old_state_dict.pop("decoder.mid.block_2.conv2.weight")
new_state_dict["decoder.mid_block.resnets.1.conv2.bias"] = old_state_dict.pop("decoder.mid.block_2.conv2.bias")
convert_vae_block_state_dict(old_state_dict, "decoder.up.0", new_state_dict, "decoder.up_blocks.4")
convert_vae_block_state_dict(old_state_dict, "decoder.up.1", new_state_dict, "decoder.up_blocks.3")
convert_vae_block_state_dict(old_state_dict, "decoder.up.2", new_state_dict, "decoder.up_blocks.2")
convert_vae_block_state_dict(old_state_dict, "decoder.up.3", new_state_dict, "decoder.up_blocks.1")
convert_vae_block_state_dict(old_state_dict, "decoder.up.4", new_state_dict, "decoder.up_blocks.0")
new_state_dict["decoder.conv_norm_out.weight"] = old_state_dict.pop("decoder.norm_out.weight")
new_state_dict["decoder.conv_norm_out.bias"] = old_state_dict.pop("decoder.norm_out.bias")
new_state_dict["decoder.conv_out.weight"] = old_state_dict.pop("decoder.conv_out.weight")
new_state_dict["decoder.conv_out.bias"] = old_state_dict.pop("decoder.conv_out.bias")
# fmt: on
assert len(old_state_dict.keys()) == 0
new_vae.load_state_dict(new_state_dict)
input = torch.randn((1, 3, 512, 512), device=device)
input = input.clamp(-1, 1)
old_encoder_output = old_vae.quant_conv(old_vae.encoder(input))
new_encoder_output = new_vae.quant_conv(new_vae.encoder(input))
assert (old_encoder_output == new_encoder_output).all()
old_decoder_output = old_vae.decoder(old_vae.post_quant_conv(old_encoder_output))
new_decoder_output = new_vae.decoder(new_vae.post_quant_conv(new_encoder_output))
# assert (old_decoder_output == new_decoder_output).all()
print("kipping vae decoder equivalence check")
print(f"vae decoder diff {(old_decoder_output - new_decoder_output).float().abs().sum()}")
old_output = old_vae(input)[0]
new_output = new_vae(input)[0]
# assert (old_output == new_output).all()
print("skipping full vae equivalence check")
print(f"vae full diff { (old_output - new_output).float().abs().sum()}")
return new_vae
def convert_vae_block_state_dict(old_state_dict, prefix_from, new_state_dict, prefix_to):
# fmt: off
new_state_dict[f"{prefix_to}.resnets.0.norm1.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.norm1.weight")
new_state_dict[f"{prefix_to}.resnets.0.norm1.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.norm1.bias")
new_state_dict[f"{prefix_to}.resnets.0.conv1.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.conv1.weight")
new_state_dict[f"{prefix_to}.resnets.0.conv1.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.conv1.bias")
new_state_dict[f"{prefix_to}.resnets.0.norm2.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.norm2.weight")
new_state_dict[f"{prefix_to}.resnets.0.norm2.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.norm2.bias")
new_state_dict[f"{prefix_to}.resnets.0.conv2.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.conv2.weight")
new_state_dict[f"{prefix_to}.resnets.0.conv2.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.conv2.bias")
if f"{prefix_from}.block.0.nin_shortcut.weight" in old_state_dict:
new_state_dict[f"{prefix_to}.resnets.0.conv_shortcut.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.nin_shortcut.weight")
new_state_dict[f"{prefix_to}.resnets.0.conv_shortcut.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.nin_shortcut.bias")
new_state_dict[f"{prefix_to}.resnets.1.norm1.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.norm1.weight")
new_state_dict[f"{prefix_to}.resnets.1.norm1.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.norm1.bias")
new_state_dict[f"{prefix_to}.resnets.1.conv1.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.conv1.weight")
new_state_dict[f"{prefix_to}.resnets.1.conv1.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.conv1.bias")
new_state_dict[f"{prefix_to}.resnets.1.norm2.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.norm2.weight")
new_state_dict[f"{prefix_to}.resnets.1.norm2.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.norm2.bias")
new_state_dict[f"{prefix_to}.resnets.1.conv2.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.conv2.weight")
new_state_dict[f"{prefix_to}.resnets.1.conv2.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.conv2.bias")
if f"{prefix_from}.downsample.conv.weight" in old_state_dict:
new_state_dict[f"{prefix_to}.downsamplers.0.conv.weight"] = old_state_dict.pop(f"{prefix_from}.downsample.conv.weight")
new_state_dict[f"{prefix_to}.downsamplers.0.conv.bias"] = old_state_dict.pop(f"{prefix_from}.downsample.conv.bias")
if f"{prefix_from}.upsample.conv.weight" in old_state_dict:
new_state_dict[f"{prefix_to}.upsamplers.0.conv.weight"] = old_state_dict.pop(f"{prefix_from}.upsample.conv.weight")
new_state_dict[f"{prefix_to}.upsamplers.0.conv.bias"] = old_state_dict.pop(f"{prefix_from}.upsample.conv.bias")
if f"{prefix_from}.block.2.norm1.weight" in old_state_dict:
new_state_dict[f"{prefix_to}.resnets.2.norm1.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.norm1.weight")
new_state_dict[f"{prefix_to}.resnets.2.norm1.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.norm1.bias")
new_state_dict[f"{prefix_to}.resnets.2.conv1.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.conv1.weight")
new_state_dict[f"{prefix_to}.resnets.2.conv1.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.conv1.bias")
new_state_dict[f"{prefix_to}.resnets.2.norm2.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.norm2.weight")
new_state_dict[f"{prefix_to}.resnets.2.norm2.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.norm2.bias")
new_state_dict[f"{prefix_to}.resnets.2.conv2.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.conv2.weight")
new_state_dict[f"{prefix_to}.resnets.2.conv2.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.conv2.bias")
# fmt: on
if __name__ == "__main__":
main()
import argparse
import torch
from safetensors.torch import save_file
def convert_motion_module(original_state_dict):
converted_state_dict = {}
for k, v in original_state_dict.items():
if "pos_encoder" in k:
continue
else:
converted_state_dict[
k.replace(".norms.0", ".norm1")
.replace(".norms.1", ".norm2")
.replace(".ff_norm", ".norm3")
.replace(".attention_blocks.0", ".attn1")
.replace(".attention_blocks.1", ".attn2")
.replace(".temporal_transformer", "")
] = v
return converted_state_dict
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--output_path", type=str, required=True)
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
state_dict = torch.load(args.ckpt_path, map_location="cpu")
if "state_dict" in state_dict.keys():
state_dict = state_dict["state_dict"]
conv_state_dict = convert_motion_module(state_dict)
# convert to new format
output_dict = {}
for module_name, params in conv_state_dict.items():
if type(params) is not torch.Tensor:
continue
output_dict.update({f"unet.{module_name}": params})
save_file(output_dict, f"{args.output_path}/diffusion_pytorch_model.safetensors")
import argparse
import torch
from diffusers import MotionAdapter
def convert_motion_module(original_state_dict):
converted_state_dict = {}
for k, v in original_state_dict.items():
if "pos_encoder" in k:
continue
else:
converted_state_dict[
k.replace(".norms.0", ".norm1")
.replace(".norms.1", ".norm2")
.replace(".ff_norm", ".norm3")
.replace(".attention_blocks.0", ".attn1")
.replace(".attention_blocks.1", ".attn2")
.replace(".temporal_transformer", "")
] = v
return converted_state_dict
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--output_path", type=str, required=True)
parser.add_argument("--use_motion_mid_block", action="store_true")
parser.add_argument("--motion_max_seq_length", type=int, default=32)
parser.add_argument("--save_fp16", action="store_true")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
state_dict = torch.load(args.ckpt_path, map_location="cpu")
if "state_dict" in state_dict.keys():
state_dict = state_dict["state_dict"]
conv_state_dict = convert_motion_module(state_dict)
adapter = MotionAdapter(
use_motion_mid_block=args.use_motion_mid_block, motion_max_seq_length=args.motion_max_seq_length
)
# skip loading position embeddings
adapter.load_state_dict(conv_state_dict, strict=False)
adapter.save_pretrained(args.output_path)
if args.save_fp16:
adapter.to(torch.float16).save_pretrained(args.output_path, variant="fp16")
import argparse
import time
from pathlib import Path
from typing import Any, Dict, Literal
import torch
from diffusers import AsymmetricAutoencoderKL
ASYMMETRIC_AUTOENCODER_KL_x_1_5_CONFIG = {
"in_channels": 3,
"out_channels": 3,
"down_block_types": [
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
],
"down_block_out_channels": [128, 256, 512, 512],
"layers_per_down_block": 2,
"up_block_types": [
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D",
],
"up_block_out_channels": [192, 384, 768, 768],
"layers_per_up_block": 3,
"act_fn": "silu",
"latent_channels": 4,
"norm_num_groups": 32,
"sample_size": 256,
"scaling_factor": 0.18215,
}
ASYMMETRIC_AUTOENCODER_KL_x_2_CONFIG = {
"in_channels": 3,
"out_channels": 3,
"down_block_types": [
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
],
"down_block_out_channels": [128, 256, 512, 512],
"layers_per_down_block": 2,
"up_block_types": [
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D",
],
"up_block_out_channels": [256, 512, 1024, 1024],
"layers_per_up_block": 5,
"act_fn": "silu",
"latent_channels": 4,
"norm_num_groups": 32,
"sample_size": 256,
"scaling_factor": 0.18215,
}
def convert_asymmetric_autoencoder_kl_state_dict(original_state_dict: Dict[str, Any]) -> Dict[str, Any]:
converted_state_dict = {}
for k, v in original_state_dict.items():
if k.startswith("encoder."):
converted_state_dict[
k.replace("encoder.down.", "encoder.down_blocks.")
.replace("encoder.mid.", "encoder.mid_block.")
.replace("encoder.norm_out.", "encoder.conv_norm_out.")
.replace(".downsample.", ".downsamplers.0.")
.replace(".nin_shortcut.", ".conv_shortcut.")
.replace(".block.", ".resnets.")
.replace(".block_1.", ".resnets.0.")
.replace(".block_2.", ".resnets.1.")
.replace(".attn_1.k.", ".attentions.0.to_k.")
.replace(".attn_1.q.", ".attentions.0.to_q.")
.replace(".attn_1.v.", ".attentions.0.to_v.")
.replace(".attn_1.proj_out.", ".attentions.0.to_out.0.")
.replace(".attn_1.norm.", ".attentions.0.group_norm.")
] = v
elif k.startswith("decoder.") and "up_layers" not in k:
converted_state_dict[
k.replace("decoder.encoder.", "decoder.condition_encoder.")
.replace(".norm_out.", ".conv_norm_out.")
.replace(".up.0.", ".up_blocks.3.")
.replace(".up.1.", ".up_blocks.2.")
.replace(".up.2.", ".up_blocks.1.")
.replace(".up.3.", ".up_blocks.0.")
.replace(".block.", ".resnets.")
.replace("mid", "mid_block")
.replace(".0.upsample.", ".0.upsamplers.0.")
.replace(".1.upsample.", ".1.upsamplers.0.")
.replace(".2.upsample.", ".2.upsamplers.0.")
.replace(".nin_shortcut.", ".conv_shortcut.")
.replace(".block_1.", ".resnets.0.")
.replace(".block_2.", ".resnets.1.")
.replace(".attn_1.k.", ".attentions.0.to_k.")
.replace(".attn_1.q.", ".attentions.0.to_q.")
.replace(".attn_1.v.", ".attentions.0.to_v.")
.replace(".attn_1.proj_out.", ".attentions.0.to_out.0.")
.replace(".attn_1.norm.", ".attentions.0.group_norm.")
] = v
elif k.startswith("quant_conv."):
converted_state_dict[k] = v
elif k.startswith("post_quant_conv."):
converted_state_dict[k] = v
else:
print(f" skipping key `{k}`")
# fix weights shape
for k, v in converted_state_dict.items():
if (
(k.startswith("encoder.mid_block.attentions.0") or k.startswith("decoder.mid_block.attentions.0"))
and k.endswith("weight")
and ("to_q" in k or "to_k" in k or "to_v" in k or "to_out" in k)
):
converted_state_dict[k] = converted_state_dict[k][:, :, 0, 0]
return converted_state_dict
def get_asymmetric_autoencoder_kl_from_original_checkpoint(
scale: Literal["1.5", "2"], original_checkpoint_path: str, map_location: torch.device
) -> AsymmetricAutoencoderKL:
print("Loading original state_dict")
original_state_dict = torch.load(original_checkpoint_path, map_location=map_location)
original_state_dict = original_state_dict["state_dict"]
print("Converting state_dict")
converted_state_dict = convert_asymmetric_autoencoder_kl_state_dict(original_state_dict)
kwargs = ASYMMETRIC_AUTOENCODER_KL_x_1_5_CONFIG if scale == "1.5" else ASYMMETRIC_AUTOENCODER_KL_x_2_CONFIG
print("Initializing AsymmetricAutoencoderKL model")
asymmetric_autoencoder_kl = AsymmetricAutoencoderKL(**kwargs)
print("Loading weight from converted state_dict")
asymmetric_autoencoder_kl.load_state_dict(converted_state_dict)
asymmetric_autoencoder_kl.eval()
print("AsymmetricAutoencoderKL successfully initialized")
return asymmetric_autoencoder_kl
if __name__ == "__main__":
start = time.time()
parser = argparse.ArgumentParser()
parser.add_argument(
"--scale",
default=None,
type=str,
required=True,
help="Asymmetric VQGAN scale: `1.5` or `2`",
)
parser.add_argument(
"--original_checkpoint_path",
default=None,
type=str,
required=True,
help="Path to the original Asymmetric VQGAN checkpoint",
)
parser.add_argument(
"--output_path",
default=None,
type=str,
required=True,
help="Path to save pretrained AsymmetricAutoencoderKL model",
)
parser.add_argument(
"--map_location",
default="cpu",
type=str,
required=False,
help="The device passed to `map_location` when loading the checkpoint",
)
args = parser.parse_args()
assert args.scale in ["1.5", "2"], f"{args.scale} should be `1.5` of `2`"
assert Path(args.original_checkpoint_path).is_file()
asymmetric_autoencoder_kl = get_asymmetric_autoencoder_kl_from_original_checkpoint(
scale=args.scale,
original_checkpoint_path=args.original_checkpoint_path,
map_location=torch.device(args.map_location),
)
print("Saving pretrained AsymmetricAutoencoderKL")
asymmetric_autoencoder_kl.save_pretrained(args.output_path)
print(f"Done in {time.time() - start:.2f} seconds")
"""
This script requires you to build `LAVIS` from source, since the pip version doesn't have BLIP Diffusion. Follow instructions here: https://github.com/salesforce/LAVIS/tree/main.
"""
import argparse
import os
import tempfile
import torch
from lavis.models import load_model_and_preprocess
from transformers import CLIPTokenizer
from transformers.models.blip_2.configuration_blip_2 import Blip2Config
from diffusers import (
AutoencoderKL,
PNDMScheduler,
UNet2DConditionModel,
)
from diffusers.pipelines import BlipDiffusionPipeline
from diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor
from diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel
from diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
BLIP2_CONFIG = {
"vision_config": {
"hidden_size": 1024,
"num_hidden_layers": 23,
"num_attention_heads": 16,
"image_size": 224,
"patch_size": 14,
"intermediate_size": 4096,
"hidden_act": "quick_gelu",
},
"qformer_config": {
"cross_attention_frequency": 1,
"encoder_hidden_size": 1024,
"vocab_size": 30523,
},
"num_query_tokens": 16,
}
blip2config = Blip2Config(**BLIP2_CONFIG)
def qformer_model_from_original_config():
qformer = Blip2QFormerModel(blip2config)
return qformer
def embeddings_from_original_checkpoint(model, diffuser_embeddings_prefix, original_embeddings_prefix):
embeddings = {}
embeddings.update(
{
f"{diffuser_embeddings_prefix}.word_embeddings.weight": model[
f"{original_embeddings_prefix}.word_embeddings.weight"
]
}
)
embeddings.update(
{
f"{diffuser_embeddings_prefix}.position_embeddings.weight": model[
f"{original_embeddings_prefix}.position_embeddings.weight"
]
}
)
embeddings.update(
{f"{diffuser_embeddings_prefix}.LayerNorm.weight": model[f"{original_embeddings_prefix}.LayerNorm.weight"]}
)
embeddings.update(
{f"{diffuser_embeddings_prefix}.LayerNorm.bias": model[f"{original_embeddings_prefix}.LayerNorm.bias"]}
)
return embeddings
def proj_layer_from_original_checkpoint(model, diffuser_proj_prefix, original_proj_prefix):
proj_layer = {}
proj_layer.update({f"{diffuser_proj_prefix}.dense1.weight": model[f"{original_proj_prefix}.dense1.weight"]})
proj_layer.update({f"{diffuser_proj_prefix}.dense1.bias": model[f"{original_proj_prefix}.dense1.bias"]})
proj_layer.update({f"{diffuser_proj_prefix}.dense2.weight": model[f"{original_proj_prefix}.dense2.weight"]})
proj_layer.update({f"{diffuser_proj_prefix}.dense2.bias": model[f"{original_proj_prefix}.dense2.bias"]})
proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.weight": model[f"{original_proj_prefix}.LayerNorm.weight"]})
proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.bias": model[f"{original_proj_prefix}.LayerNorm.bias"]})
return proj_layer
def attention_from_original_checkpoint(model, diffuser_attention_prefix, original_attention_prefix):
attention = {}
attention.update(
{
f"{diffuser_attention_prefix}.attention.query.weight": model[
f"{original_attention_prefix}.self.query.weight"
]
}
)
attention.update(
{f"{diffuser_attention_prefix}.attention.query.bias": model[f"{original_attention_prefix}.self.query.bias"]}
)
attention.update(
{f"{diffuser_attention_prefix}.attention.key.weight": model[f"{original_attention_prefix}.self.key.weight"]}
)
attention.update(
{f"{diffuser_attention_prefix}.attention.key.bias": model[f"{original_attention_prefix}.self.key.bias"]}
)
attention.update(
{
f"{diffuser_attention_prefix}.attention.value.weight": model[
f"{original_attention_prefix}.self.value.weight"
]
}
)
attention.update(
{f"{diffuser_attention_prefix}.attention.value.bias": model[f"{original_attention_prefix}.self.value.bias"]}
)
attention.update(
{f"{diffuser_attention_prefix}.output.dense.weight": model[f"{original_attention_prefix}.output.dense.weight"]}
)
attention.update(
{f"{diffuser_attention_prefix}.output.dense.bias": model[f"{original_attention_prefix}.output.dense.bias"]}
)
attention.update(
{
f"{diffuser_attention_prefix}.output.LayerNorm.weight": model[
f"{original_attention_prefix}.output.LayerNorm.weight"
]
}
)
attention.update(
{
f"{diffuser_attention_prefix}.output.LayerNorm.bias": model[
f"{original_attention_prefix}.output.LayerNorm.bias"
]
}
)
return attention
def output_layers_from_original_checkpoint(model, diffuser_output_prefix, original_output_prefix):
output_layers = {}
output_layers.update({f"{diffuser_output_prefix}.dense.weight": model[f"{original_output_prefix}.dense.weight"]})
output_layers.update({f"{diffuser_output_prefix}.dense.bias": model[f"{original_output_prefix}.dense.bias"]})
output_layers.update(
{f"{diffuser_output_prefix}.LayerNorm.weight": model[f"{original_output_prefix}.LayerNorm.weight"]}
)
output_layers.update(
{f"{diffuser_output_prefix}.LayerNorm.bias": model[f"{original_output_prefix}.LayerNorm.bias"]}
)
return output_layers
def encoder_from_original_checkpoint(model, diffuser_encoder_prefix, original_encoder_prefix):
encoder = {}
for i in range(blip2config.qformer_config.num_hidden_layers):
encoder.update(
attention_from_original_checkpoint(
model, f"{diffuser_encoder_prefix}.{i}.attention", f"{original_encoder_prefix}.{i}.attention"
)
)
encoder.update(
attention_from_original_checkpoint(
model, f"{diffuser_encoder_prefix}.{i}.crossattention", f"{original_encoder_prefix}.{i}.crossattention"
)
)
encoder.update(
{
f"{diffuser_encoder_prefix}.{i}.intermediate.dense.weight": model[
f"{original_encoder_prefix}.{i}.intermediate.dense.weight"
]
}
)
encoder.update(
{
f"{diffuser_encoder_prefix}.{i}.intermediate.dense.bias": model[
f"{original_encoder_prefix}.{i}.intermediate.dense.bias"
]
}
)
encoder.update(
{
f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.weight": model[
f"{original_encoder_prefix}.{i}.intermediate_query.dense.weight"
]
}
)
encoder.update(
{
f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.bias": model[
f"{original_encoder_prefix}.{i}.intermediate_query.dense.bias"
]
}
)
encoder.update(
output_layers_from_original_checkpoint(
model, f"{diffuser_encoder_prefix}.{i}.output", f"{original_encoder_prefix}.{i}.output"
)
)
encoder.update(
output_layers_from_original_checkpoint(
model, f"{diffuser_encoder_prefix}.{i}.output_query", f"{original_encoder_prefix}.{i}.output_query"
)
)
return encoder
def visual_encoder_layer_from_original_checkpoint(model, diffuser_prefix, original_prefix):
visual_encoder_layer = {}
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.weight": model[f"{original_prefix}.ln_1.weight"]})
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.bias": model[f"{original_prefix}.ln_1.bias"]})
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.weight": model[f"{original_prefix}.ln_2.weight"]})
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.bias": model[f"{original_prefix}.ln_2.bias"]})
visual_encoder_layer.update(
{f"{diffuser_prefix}.self_attn.qkv.weight": model[f"{original_prefix}.attn.in_proj_weight"]}
)
visual_encoder_layer.update(
{f"{diffuser_prefix}.self_attn.qkv.bias": model[f"{original_prefix}.attn.in_proj_bias"]}
)
visual_encoder_layer.update(
{f"{diffuser_prefix}.self_attn.projection.weight": model[f"{original_prefix}.attn.out_proj.weight"]}
)
visual_encoder_layer.update(
{f"{diffuser_prefix}.self_attn.projection.bias": model[f"{original_prefix}.attn.out_proj.bias"]}
)
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.weight": model[f"{original_prefix}.mlp.c_fc.weight"]})
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.bias": model[f"{original_prefix}.mlp.c_fc.bias"]})
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.weight": model[f"{original_prefix}.mlp.c_proj.weight"]})
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.bias": model[f"{original_prefix}.mlp.c_proj.bias"]})
return visual_encoder_layer
def visual_encoder_from_original_checkpoint(model, diffuser_prefix, original_prefix):
visual_encoder = {}
visual_encoder.update(
{
f"{diffuser_prefix}.embeddings.class_embedding": model[f"{original_prefix}.class_embedding"]
.unsqueeze(0)
.unsqueeze(0)
}
)
visual_encoder.update(
{
f"{diffuser_prefix}.embeddings.position_embedding": model[
f"{original_prefix}.positional_embedding"
].unsqueeze(0)
}
)
visual_encoder.update(
{f"{diffuser_prefix}.embeddings.patch_embedding.weight": model[f"{original_prefix}.conv1.weight"]}
)
visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.weight": model[f"{original_prefix}.ln_pre.weight"]})
visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.bias": model[f"{original_prefix}.ln_pre.bias"]})
for i in range(blip2config.vision_config.num_hidden_layers):
visual_encoder.update(
visual_encoder_layer_from_original_checkpoint(
model, f"{diffuser_prefix}.encoder.layers.{i}", f"{original_prefix}.transformer.resblocks.{i}"
)
)
visual_encoder.update({f"{diffuser_prefix}.post_layernorm.weight": model["blip.ln_vision.weight"]})
visual_encoder.update({f"{diffuser_prefix}.post_layernorm.bias": model["blip.ln_vision.bias"]})
return visual_encoder
def qformer_original_checkpoint_to_diffusers_checkpoint(model):
qformer_checkpoint = {}
qformer_checkpoint.update(embeddings_from_original_checkpoint(model, "embeddings", "blip.Qformer.bert.embeddings"))
qformer_checkpoint.update({"query_tokens": model["blip.query_tokens"]})
qformer_checkpoint.update(proj_layer_from_original_checkpoint(model, "proj_layer", "proj_layer"))
qformer_checkpoint.update(
encoder_from_original_checkpoint(model, "encoder.layer", "blip.Qformer.bert.encoder.layer")
)
qformer_checkpoint.update(visual_encoder_from_original_checkpoint(model, "visual_encoder", "blip.visual_encoder"))
return qformer_checkpoint
def get_qformer(model):
print("loading qformer")
qformer = qformer_model_from_original_config()
qformer_diffusers_checkpoint = qformer_original_checkpoint_to_diffusers_checkpoint(model)
load_checkpoint_to_model(qformer_diffusers_checkpoint, qformer)
print("done loading qformer")
return qformer
def load_checkpoint_to_model(checkpoint, model):
with tempfile.NamedTemporaryFile(delete=False) as file:
torch.save(checkpoint, file.name)
del checkpoint
model.load_state_dict(torch.load(file.name), strict=False)
os.remove(file.name)
def save_blip_diffusion_model(model, args):
qformer = get_qformer(model)
qformer.eval()
text_encoder = ContextCLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
vae.eval()
text_encoder.eval()
scheduler = PNDMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
set_alpha_to_one=False,
skip_prk_steps=True,
)
tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
image_processor = BlipImageProcessor()
blip_diffusion = BlipDiffusionPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
unet=unet,
scheduler=scheduler,
qformer=qformer,
image_processor=image_processor,
)
blip_diffusion.save_pretrained(args.checkpoint_path)
def main(args):
model, _, _ = load_model_and_preprocess("blip_diffusion", "base", device="cpu", is_eval=True)
save_blip_diffusion_model(model.state_dict(), args)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
args = parser.parse_args()
main(args)
import math
import os
import urllib
import warnings
from argparse import ArgumentParser
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub.utils import insecure_hashlib
from safetensors.torch import load_file as stl
from tqdm import tqdm
from diffusers import AutoencoderKL, ConsistencyDecoderVAE, DiffusionPipeline, StableDiffusionPipeline, UNet2DModel
from diffusers.models.autoencoders.vae import Encoder
from diffusers.models.embeddings import TimestepEmbedding
from diffusers.models.unets.unet_2d_blocks import ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, UNetMidBlock2D
args = ArgumentParser()
args.add_argument("--save_pretrained", required=False, default=None, type=str)
args.add_argument("--test_image", required=True, type=str)
args = args.parse_args()
def _extract_into_tensor(arr, timesteps, broadcast_shape):
# from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py#L895 """
res = arr[timesteps].float()
dims_to_append = len(broadcast_shape) - len(res.shape)
return res[(...,) + (None,) * dims_to_append]
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
# from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py#L45
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return torch.tensor(betas)
def _download(url: str, root: str):
os.makedirs(root, exist_ok=True)
filename = os.path.basename(url)
expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, filename)
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
if insecure_hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
return download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
if insecure_hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
return download_target
class ConsistencyDecoder:
def __init__(self, device="cuda:0", download_root=os.path.expanduser("~/.cache/clip")):
self.n_distilled_steps = 64
download_target = _download(
"https://openaipublic.azureedge.net/diff-vae/c9cebd3132dd9c42936d803e33424145a748843c8f716c0814838bdc8a2fe7cb/decoder.pt",
download_root,
)
self.ckpt = torch.jit.load(download_target).to(device)
self.device = device
sigma_data = 0.5
betas = betas_for_alpha_bar(1024, lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2).to(device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
sigmas = torch.sqrt(1.0 / alphas_cumprod - 1)
self.c_skip = sqrt_recip_alphas_cumprod * sigma_data**2 / (sigmas**2 + sigma_data**2)
self.c_out = sigmas * sigma_data / (sigmas**2 + sigma_data**2) ** 0.5
self.c_in = sqrt_recip_alphas_cumprod / (sigmas**2 + sigma_data**2) ** 0.5
@staticmethod
def round_timesteps(timesteps, total_timesteps, n_distilled_steps, truncate_start=True):
with torch.no_grad():
space = torch.div(total_timesteps, n_distilled_steps, rounding_mode="floor")
rounded_timesteps = (torch.div(timesteps, space, rounding_mode="floor") + 1) * space
if truncate_start:
rounded_timesteps[rounded_timesteps == total_timesteps] -= space
else:
rounded_timesteps[rounded_timesteps == total_timesteps] -= space
rounded_timesteps[rounded_timesteps == 0] += space
return rounded_timesteps
@staticmethod
def ldm_transform_latent(z, extra_scale_factor=1):
channel_means = [0.38862467, 0.02253063, 0.07381133, -0.0171294]
channel_stds = [0.9654121, 1.0440036, 0.76147926, 0.77022034]
if len(z.shape) != 4:
raise ValueError()
z = z * 0.18215
channels = [z[:, i] for i in range(z.shape[1])]
channels = [extra_scale_factor * (c - channel_means[i]) / channel_stds[i] for i, c in enumerate(channels)]
return torch.stack(channels, dim=1)
@torch.no_grad()
def __call__(
self,
features: torch.Tensor,
schedule=[1.0, 0.5],
generator=None,
):
features = self.ldm_transform_latent(features)
ts = self.round_timesteps(
torch.arange(0, 1024),
1024,
self.n_distilled_steps,
truncate_start=False,
)
shape = (
features.size(0),
3,
8 * features.size(2),
8 * features.size(3),
)
x_start = torch.zeros(shape, device=features.device, dtype=features.dtype)
schedule_timesteps = [int((1024 - 1) * s) for s in schedule]
for i in schedule_timesteps:
t = ts[i].item()
t_ = torch.tensor([t] * features.shape[0]).to(self.device)
# noise = torch.randn_like(x_start)
noise = torch.randn(x_start.shape, dtype=x_start.dtype, generator=generator).to(device=x_start.device)
x_start = (
_extract_into_tensor(self.sqrt_alphas_cumprod, t_, x_start.shape) * x_start
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t_, x_start.shape) * noise
)
c_in = _extract_into_tensor(self.c_in, t_, x_start.shape)
import torch.nn.functional as F
from diffusers import UNet2DModel
if isinstance(self.ckpt, UNet2DModel):
input = torch.concat([c_in * x_start, F.upsample_nearest(features, scale_factor=8)], dim=1)
model_output = self.ckpt(input, t_).sample
else:
model_output = self.ckpt(c_in * x_start, t_, features=features)
B, C = x_start.shape[:2]
model_output, _ = torch.split(model_output, C, dim=1)
pred_xstart = (
_extract_into_tensor(self.c_out, t_, x_start.shape) * model_output
+ _extract_into_tensor(self.c_skip, t_, x_start.shape) * x_start
).clamp(-1, 1)
x_start = pred_xstart
return x_start
def save_image(image, name):
import numpy as np
from PIL import Image
image = image[0].cpu().numpy()
image = (image + 1.0) * 127.5
image = image.clip(0, 255).astype(np.uint8)
image = Image.fromarray(image.transpose(1, 2, 0))
image.save(name)
def load_image(uri, size=None, center_crop=False):
import numpy as np
from PIL import Image
image = Image.open(uri)
if center_crop:
image = image.crop(
(
(image.width - min(image.width, image.height)) // 2,
(image.height - min(image.width, image.height)) // 2,
(image.width + min(image.width, image.height)) // 2,
(image.height + min(image.width, image.height)) // 2,
)
)
if size is not None:
image = image.resize(size)
image = torch.tensor(np.array(image).transpose(2, 0, 1)).unsqueeze(0).float()
image = image / 127.5 - 1.0
return image
class TimestepEmbedding_(nn.Module):
def __init__(self, n_time=1024, n_emb=320, n_out=1280) -> None:
super().__init__()
self.emb = nn.Embedding(n_time, n_emb)
self.f_1 = nn.Linear(n_emb, n_out)
self.f_2 = nn.Linear(n_out, n_out)
def forward(self, x) -> torch.Tensor:
x = self.emb(x)
x = self.f_1(x)
x = F.silu(x)
return self.f_2(x)
class ImageEmbedding(nn.Module):
def __init__(self, in_channels=7, out_channels=320) -> None:
super().__init__()
self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, x) -> torch.Tensor:
return self.f(x)
class ImageUnembedding(nn.Module):
def __init__(self, in_channels=320, out_channels=6) -> None:
super().__init__()
self.gn = nn.GroupNorm(32, in_channels)
self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, x) -> torch.Tensor:
return self.f(F.silu(self.gn(x)))
class ConvResblock(nn.Module):
def __init__(self, in_features=320, out_features=320) -> None:
super().__init__()
self.f_t = nn.Linear(1280, out_features * 2)
self.gn_1 = nn.GroupNorm(32, in_features)
self.f_1 = nn.Conv2d(in_features, out_features, kernel_size=3, padding=1)
self.gn_2 = nn.GroupNorm(32, out_features)
self.f_2 = nn.Conv2d(out_features, out_features, kernel_size=3, padding=1)
skip_conv = in_features != out_features
self.f_s = nn.Conv2d(in_features, out_features, kernel_size=1, padding=0) if skip_conv else nn.Identity()
def forward(self, x, t):
x_skip = x
t = self.f_t(F.silu(t))
t = t.chunk(2, dim=1)
t_1 = t[0].unsqueeze(dim=2).unsqueeze(dim=3) + 1
t_2 = t[1].unsqueeze(dim=2).unsqueeze(dim=3)
gn_1 = F.silu(self.gn_1(x))
f_1 = self.f_1(gn_1)
gn_2 = self.gn_2(f_1)
return self.f_s(x_skip) + self.f_2(F.silu(gn_2 * t_1 + t_2))
# Also ConvResblock
class Downsample(nn.Module):
def __init__(self, in_channels=320) -> None:
super().__init__()
self.f_t = nn.Linear(1280, in_channels * 2)
self.gn_1 = nn.GroupNorm(32, in_channels)
self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.gn_2 = nn.GroupNorm(32, in_channels)
self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
def forward(self, x, t) -> torch.Tensor:
x_skip = x
t = self.f_t(F.silu(t))
t_1, t_2 = t.chunk(2, dim=1)
t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1
t_2 = t_2.unsqueeze(2).unsqueeze(3)
gn_1 = F.silu(self.gn_1(x))
avg_pool2d = F.avg_pool2d(gn_1, kernel_size=(2, 2), stride=None)
f_1 = self.f_1(avg_pool2d)
gn_2 = self.gn_2(f_1)
f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2)))
return f_2 + F.avg_pool2d(x_skip, kernel_size=(2, 2), stride=None)
# Also ConvResblock
class Upsample(nn.Module):
def __init__(self, in_channels=1024) -> None:
super().__init__()
self.f_t = nn.Linear(1280, in_channels * 2)
self.gn_1 = nn.GroupNorm(32, in_channels)
self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.gn_2 = nn.GroupNorm(32, in_channels)
self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
def forward(self, x, t) -> torch.Tensor:
x_skip = x
t = self.f_t(F.silu(t))
t_1, t_2 = t.chunk(2, dim=1)
t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1
t_2 = t_2.unsqueeze(2).unsqueeze(3)
gn_1 = F.silu(self.gn_1(x))
upsample = F.upsample_nearest(gn_1, scale_factor=2)
f_1 = self.f_1(upsample)
gn_2 = self.gn_2(f_1)
f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2)))
return f_2 + F.upsample_nearest(x_skip, scale_factor=2)
class ConvUNetVAE(nn.Module):
def __init__(self) -> None:
super().__init__()
self.embed_image = ImageEmbedding()
self.embed_time = TimestepEmbedding_()
down_0 = nn.ModuleList(
[
ConvResblock(320, 320),
ConvResblock(320, 320),
ConvResblock(320, 320),
Downsample(320),
]
)
down_1 = nn.ModuleList(
[
ConvResblock(320, 640),
ConvResblock(640, 640),
ConvResblock(640, 640),
Downsample(640),
]
)
down_2 = nn.ModuleList(
[
ConvResblock(640, 1024),
ConvResblock(1024, 1024),
ConvResblock(1024, 1024),
Downsample(1024),
]
)
down_3 = nn.ModuleList(
[
ConvResblock(1024, 1024),
ConvResblock(1024, 1024),
ConvResblock(1024, 1024),
]
)
self.down = nn.ModuleList(
[
down_0,
down_1,
down_2,
down_3,
]
)
self.mid = nn.ModuleList(
[
ConvResblock(1024, 1024),
ConvResblock(1024, 1024),
]
)
up_3 = nn.ModuleList(
[
ConvResblock(1024 * 2, 1024),
ConvResblock(1024 * 2, 1024),
ConvResblock(1024 * 2, 1024),
ConvResblock(1024 * 2, 1024),
Upsample(1024),
]
)
up_2 = nn.ModuleList(
[
ConvResblock(1024 * 2, 1024),
ConvResblock(1024 * 2, 1024),
ConvResblock(1024 * 2, 1024),
ConvResblock(1024 + 640, 1024),
Upsample(1024),
]
)
up_1 = nn.ModuleList(
[
ConvResblock(1024 + 640, 640),
ConvResblock(640 * 2, 640),
ConvResblock(640 * 2, 640),
ConvResblock(320 + 640, 640),
Upsample(640),
]
)
up_0 = nn.ModuleList(
[
ConvResblock(320 + 640, 320),
ConvResblock(320 * 2, 320),
ConvResblock(320 * 2, 320),
ConvResblock(320 * 2, 320),
]
)
self.up = nn.ModuleList(
[
up_0,
up_1,
up_2,
up_3,
]
)
self.output = ImageUnembedding()
def forward(self, x, t, features) -> torch.Tensor:
converted = hasattr(self, "converted") and self.converted
x = torch.cat([x, F.upsample_nearest(features, scale_factor=8)], dim=1)
if converted:
t = self.time_embedding(self.time_proj(t))
else:
t = self.embed_time(t)
x = self.embed_image(x)
skips = [x]
for i, down in enumerate(self.down):
if converted and i in [0, 1, 2, 3]:
x, skips_ = down(x, t)
for skip in skips_:
skips.append(skip)
else:
for block in down:
x = block(x, t)
skips.append(x)
print(x.float().abs().sum())
if converted:
x = self.mid(x, t)
else:
for i in range(2):
x = self.mid[i](x, t)
print(x.float().abs().sum())
for i, up in enumerate(self.up[::-1]):
if converted and i in [0, 1, 2, 3]:
skip_4 = skips.pop()
skip_3 = skips.pop()
skip_2 = skips.pop()
skip_1 = skips.pop()
skips_ = (skip_1, skip_2, skip_3, skip_4)
x = up(x, skips_, t)
else:
for block in up:
if isinstance(block, ConvResblock):
x = torch.concat([x, skips.pop()], dim=1)
x = block(x, t)
return self.output(x)
def rename_state_dict_key(k):
k = k.replace("blocks.", "")
for i in range(5):
k = k.replace(f"down_{i}_", f"down.{i}.")
k = k.replace(f"conv_{i}.", f"{i}.")
k = k.replace(f"up_{i}_", f"up.{i}.")
k = k.replace(f"mid_{i}", f"mid.{i}")
k = k.replace("upsamp.", "4.")
k = k.replace("downsamp.", "3.")
k = k.replace("f_t.w", "f_t.weight").replace("f_t.b", "f_t.bias")
k = k.replace("f_1.w", "f_1.weight").replace("f_1.b", "f_1.bias")
k = k.replace("f_2.w", "f_2.weight").replace("f_2.b", "f_2.bias")
k = k.replace("f_s.w", "f_s.weight").replace("f_s.b", "f_s.bias")
k = k.replace("f.w", "f.weight").replace("f.b", "f.bias")
k = k.replace("gn_1.g", "gn_1.weight").replace("gn_1.b", "gn_1.bias")
k = k.replace("gn_2.g", "gn_2.weight").replace("gn_2.b", "gn_2.bias")
k = k.replace("gn.g", "gn.weight").replace("gn.b", "gn.bias")
return k
def rename_state_dict(sd, embedding):
sd = {rename_state_dict_key(k): v for k, v in sd.items()}
sd["embed_time.emb.weight"] = embedding["weight"]
return sd
# encode with stable diffusion vae
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipe.vae.cuda()
# construct original decoder with jitted model
decoder_consistency = ConsistencyDecoder(device="cuda:0")
# construct UNet code, overwrite the decoder with conv_unet_vae
model = ConvUNetVAE()
model.load_state_dict(
rename_state_dict(
stl("consistency_decoder.safetensors"),
stl("embedding.safetensors"),
)
)
model = model.cuda()
decoder_consistency.ckpt = model
image = load_image(args.test_image, size=(256, 256), center_crop=True)
latent = pipe.vae.encode(image.half().cuda()).latent_dist.sample()
# decode with gan
sample_gan = pipe.vae.decode(latent).sample.detach()
save_image(sample_gan, "gan.png")
# decode with conv_unet_vae
sample_consistency_orig = decoder_consistency(latent, generator=torch.Generator("cpu").manual_seed(0))
save_image(sample_consistency_orig, "con_orig.png")
########### conversion
print("CONVERSION")
print("DOWN BLOCK ONE")
block_one_sd_orig = model.down[0].state_dict()
block_one_sd_new = {}
for i in range(3):
block_one_sd_new[f"resnets.{i}.norm1.weight"] = block_one_sd_orig.pop(f"{i}.gn_1.weight")
block_one_sd_new[f"resnets.{i}.norm1.bias"] = block_one_sd_orig.pop(f"{i}.gn_1.bias")
block_one_sd_new[f"resnets.{i}.conv1.weight"] = block_one_sd_orig.pop(f"{i}.f_1.weight")
block_one_sd_new[f"resnets.{i}.conv1.bias"] = block_one_sd_orig.pop(f"{i}.f_1.bias")
block_one_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_one_sd_orig.pop(f"{i}.f_t.weight")
block_one_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_one_sd_orig.pop(f"{i}.f_t.bias")
block_one_sd_new[f"resnets.{i}.norm2.weight"] = block_one_sd_orig.pop(f"{i}.gn_2.weight")
block_one_sd_new[f"resnets.{i}.norm2.bias"] = block_one_sd_orig.pop(f"{i}.gn_2.bias")
block_one_sd_new[f"resnets.{i}.conv2.weight"] = block_one_sd_orig.pop(f"{i}.f_2.weight")
block_one_sd_new[f"resnets.{i}.conv2.bias"] = block_one_sd_orig.pop(f"{i}.f_2.bias")
block_one_sd_new["downsamplers.0.norm1.weight"] = block_one_sd_orig.pop("3.gn_1.weight")
block_one_sd_new["downsamplers.0.norm1.bias"] = block_one_sd_orig.pop("3.gn_1.bias")
block_one_sd_new["downsamplers.0.conv1.weight"] = block_one_sd_orig.pop("3.f_1.weight")
block_one_sd_new["downsamplers.0.conv1.bias"] = block_one_sd_orig.pop("3.f_1.bias")
block_one_sd_new["downsamplers.0.time_emb_proj.weight"] = block_one_sd_orig.pop("3.f_t.weight")
block_one_sd_new["downsamplers.0.time_emb_proj.bias"] = block_one_sd_orig.pop("3.f_t.bias")
block_one_sd_new["downsamplers.0.norm2.weight"] = block_one_sd_orig.pop("3.gn_2.weight")
block_one_sd_new["downsamplers.0.norm2.bias"] = block_one_sd_orig.pop("3.gn_2.bias")
block_one_sd_new["downsamplers.0.conv2.weight"] = block_one_sd_orig.pop("3.f_2.weight")
block_one_sd_new["downsamplers.0.conv2.bias"] = block_one_sd_orig.pop("3.f_2.bias")
assert len(block_one_sd_orig) == 0
block_one = ResnetDownsampleBlock2D(
in_channels=320,
out_channels=320,
temb_channels=1280,
num_layers=3,
add_downsample=True,
resnet_time_scale_shift="scale_shift",
resnet_eps=1e-5,
)
block_one.load_state_dict(block_one_sd_new)
print("DOWN BLOCK TWO")
block_two_sd_orig = model.down[1].state_dict()
block_two_sd_new = {}
for i in range(3):
block_two_sd_new[f"resnets.{i}.norm1.weight"] = block_two_sd_orig.pop(f"{i}.gn_1.weight")
block_two_sd_new[f"resnets.{i}.norm1.bias"] = block_two_sd_orig.pop(f"{i}.gn_1.bias")
block_two_sd_new[f"resnets.{i}.conv1.weight"] = block_two_sd_orig.pop(f"{i}.f_1.weight")
block_two_sd_new[f"resnets.{i}.conv1.bias"] = block_two_sd_orig.pop(f"{i}.f_1.bias")
block_two_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_two_sd_orig.pop(f"{i}.f_t.weight")
block_two_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_two_sd_orig.pop(f"{i}.f_t.bias")
block_two_sd_new[f"resnets.{i}.norm2.weight"] = block_two_sd_orig.pop(f"{i}.gn_2.weight")
block_two_sd_new[f"resnets.{i}.norm2.bias"] = block_two_sd_orig.pop(f"{i}.gn_2.bias")
block_two_sd_new[f"resnets.{i}.conv2.weight"] = block_two_sd_orig.pop(f"{i}.f_2.weight")
block_two_sd_new[f"resnets.{i}.conv2.bias"] = block_two_sd_orig.pop(f"{i}.f_2.bias")
if i == 0:
block_two_sd_new[f"resnets.{i}.conv_shortcut.weight"] = block_two_sd_orig.pop(f"{i}.f_s.weight")
block_two_sd_new[f"resnets.{i}.conv_shortcut.bias"] = block_two_sd_orig.pop(f"{i}.f_s.bias")
block_two_sd_new["downsamplers.0.norm1.weight"] = block_two_sd_orig.pop("3.gn_1.weight")
block_two_sd_new["downsamplers.0.norm1.bias"] = block_two_sd_orig.pop("3.gn_1.bias")
block_two_sd_new["downsamplers.0.conv1.weight"] = block_two_sd_orig.pop("3.f_1.weight")
block_two_sd_new["downsamplers.0.conv1.bias"] = block_two_sd_orig.pop("3.f_1.bias")
block_two_sd_new["downsamplers.0.time_emb_proj.weight"] = block_two_sd_orig.pop("3.f_t.weight")
block_two_sd_new["downsamplers.0.time_emb_proj.bias"] = block_two_sd_orig.pop("3.f_t.bias")
block_two_sd_new["downsamplers.0.norm2.weight"] = block_two_sd_orig.pop("3.gn_2.weight")
block_two_sd_new["downsamplers.0.norm2.bias"] = block_two_sd_orig.pop("3.gn_2.bias")
block_two_sd_new["downsamplers.0.conv2.weight"] = block_two_sd_orig.pop("3.f_2.weight")
block_two_sd_new["downsamplers.0.conv2.bias"] = block_two_sd_orig.pop("3.f_2.bias")
assert len(block_two_sd_orig) == 0
block_two = ResnetDownsampleBlock2D(
in_channels=320,
out_channels=640,
temb_channels=1280,
num_layers=3,
add_downsample=True,
resnet_time_scale_shift="scale_shift",
resnet_eps=1e-5,
)
block_two.load_state_dict(block_two_sd_new)
print("DOWN BLOCK THREE")
block_three_sd_orig = model.down[2].state_dict()
block_three_sd_new = {}
for i in range(3):
block_three_sd_new[f"resnets.{i}.norm1.weight"] = block_three_sd_orig.pop(f"{i}.gn_1.weight")
block_three_sd_new[f"resnets.{i}.norm1.bias"] = block_three_sd_orig.pop(f"{i}.gn_1.bias")
block_three_sd_new[f"resnets.{i}.conv1.weight"] = block_three_sd_orig.pop(f"{i}.f_1.weight")
block_three_sd_new[f"resnets.{i}.conv1.bias"] = block_three_sd_orig.pop(f"{i}.f_1.bias")
block_three_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_three_sd_orig.pop(f"{i}.f_t.weight")
block_three_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_three_sd_orig.pop(f"{i}.f_t.bias")
block_three_sd_new[f"resnets.{i}.norm2.weight"] = block_three_sd_orig.pop(f"{i}.gn_2.weight")
block_three_sd_new[f"resnets.{i}.norm2.bias"] = block_three_sd_orig.pop(f"{i}.gn_2.bias")
block_three_sd_new[f"resnets.{i}.conv2.weight"] = block_three_sd_orig.pop(f"{i}.f_2.weight")
block_three_sd_new[f"resnets.{i}.conv2.bias"] = block_three_sd_orig.pop(f"{i}.f_2.bias")
if i == 0:
block_three_sd_new[f"resnets.{i}.conv_shortcut.weight"] = block_three_sd_orig.pop(f"{i}.f_s.weight")
block_three_sd_new[f"resnets.{i}.conv_shortcut.bias"] = block_three_sd_orig.pop(f"{i}.f_s.bias")
block_three_sd_new["downsamplers.0.norm1.weight"] = block_three_sd_orig.pop("3.gn_1.weight")
block_three_sd_new["downsamplers.0.norm1.bias"] = block_three_sd_orig.pop("3.gn_1.bias")
block_three_sd_new["downsamplers.0.conv1.weight"] = block_three_sd_orig.pop("3.f_1.weight")
block_three_sd_new["downsamplers.0.conv1.bias"] = block_three_sd_orig.pop("3.f_1.bias")
block_three_sd_new["downsamplers.0.time_emb_proj.weight"] = block_three_sd_orig.pop("3.f_t.weight")
block_three_sd_new["downsamplers.0.time_emb_proj.bias"] = block_three_sd_orig.pop("3.f_t.bias")
block_three_sd_new["downsamplers.0.norm2.weight"] = block_three_sd_orig.pop("3.gn_2.weight")
block_three_sd_new["downsamplers.0.norm2.bias"] = block_three_sd_orig.pop("3.gn_2.bias")
block_three_sd_new["downsamplers.0.conv2.weight"] = block_three_sd_orig.pop("3.f_2.weight")
block_three_sd_new["downsamplers.0.conv2.bias"] = block_three_sd_orig.pop("3.f_2.bias")
assert len(block_three_sd_orig) == 0
block_three = ResnetDownsampleBlock2D(
in_channels=640,
out_channels=1024,
temb_channels=1280,
num_layers=3,
add_downsample=True,
resnet_time_scale_shift="scale_shift",
resnet_eps=1e-5,
)
block_three.load_state_dict(block_three_sd_new)
print("DOWN BLOCK FOUR")
block_four_sd_orig = model.down[3].state_dict()
block_four_sd_new = {}
for i in range(3):
block_four_sd_new[f"resnets.{i}.norm1.weight"] = block_four_sd_orig.pop(f"{i}.gn_1.weight")
block_four_sd_new[f"resnets.{i}.norm1.bias"] = block_four_sd_orig.pop(f"{i}.gn_1.bias")
block_four_sd_new[f"resnets.{i}.conv1.weight"] = block_four_sd_orig.pop(f"{i}.f_1.weight")
block_four_sd_new[f"resnets.{i}.conv1.bias"] = block_four_sd_orig.pop(f"{i}.f_1.bias")
block_four_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_four_sd_orig.pop(f"{i}.f_t.weight")
block_four_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_four_sd_orig.pop(f"{i}.f_t.bias")
block_four_sd_new[f"resnets.{i}.norm2.weight"] = block_four_sd_orig.pop(f"{i}.gn_2.weight")
block_four_sd_new[f"resnets.{i}.norm2.bias"] = block_four_sd_orig.pop(f"{i}.gn_2.bias")
block_four_sd_new[f"resnets.{i}.conv2.weight"] = block_four_sd_orig.pop(f"{i}.f_2.weight")
block_four_sd_new[f"resnets.{i}.conv2.bias"] = block_four_sd_orig.pop(f"{i}.f_2.bias")
assert len(block_four_sd_orig) == 0
block_four = ResnetDownsampleBlock2D(
in_channels=1024,
out_channels=1024,
temb_channels=1280,
num_layers=3,
add_downsample=False,
resnet_time_scale_shift="scale_shift",
resnet_eps=1e-5,
)
block_four.load_state_dict(block_four_sd_new)
print("MID BLOCK 1")
mid_block_one_sd_orig = model.mid.state_dict()
mid_block_one_sd_new = {}
for i in range(2):
mid_block_one_sd_new[f"resnets.{i}.norm1.weight"] = mid_block_one_sd_orig.pop(f"{i}.gn_1.weight")
mid_block_one_sd_new[f"resnets.{i}.norm1.bias"] = mid_block_one_sd_orig.pop(f"{i}.gn_1.bias")
mid_block_one_sd_new[f"resnets.{i}.conv1.weight"] = mid_block_one_sd_orig.pop(f"{i}.f_1.weight")
mid_block_one_sd_new[f"resnets.{i}.conv1.bias"] = mid_block_one_sd_orig.pop(f"{i}.f_1.bias")
mid_block_one_sd_new[f"resnets.{i}.time_emb_proj.weight"] = mid_block_one_sd_orig.pop(f"{i}.f_t.weight")
mid_block_one_sd_new[f"resnets.{i}.time_emb_proj.bias"] = mid_block_one_sd_orig.pop(f"{i}.f_t.bias")
mid_block_one_sd_new[f"resnets.{i}.norm2.weight"] = mid_block_one_sd_orig.pop(f"{i}.gn_2.weight")
mid_block_one_sd_new[f"resnets.{i}.norm2.bias"] = mid_block_one_sd_orig.pop(f"{i}.gn_2.bias")
mid_block_one_sd_new[f"resnets.{i}.conv2.weight"] = mid_block_one_sd_orig.pop(f"{i}.f_2.weight")
mid_block_one_sd_new[f"resnets.{i}.conv2.bias"] = mid_block_one_sd_orig.pop(f"{i}.f_2.bias")
assert len(mid_block_one_sd_orig) == 0
mid_block_one = UNetMidBlock2D(
in_channels=1024,
temb_channels=1280,
num_layers=1,
resnet_time_scale_shift="scale_shift",
resnet_eps=1e-5,
add_attention=False,
)
mid_block_one.load_state_dict(mid_block_one_sd_new)
print("UP BLOCK ONE")
up_block_one_sd_orig = model.up[-1].state_dict()
up_block_one_sd_new = {}
for i in range(4):
up_block_one_sd_new[f"resnets.{i}.norm1.weight"] = up_block_one_sd_orig.pop(f"{i}.gn_1.weight")
up_block_one_sd_new[f"resnets.{i}.norm1.bias"] = up_block_one_sd_orig.pop(f"{i}.gn_1.bias")
up_block_one_sd_new[f"resnets.{i}.conv1.weight"] = up_block_one_sd_orig.pop(f"{i}.f_1.weight")
up_block_one_sd_new[f"resnets.{i}.conv1.bias"] = up_block_one_sd_orig.pop(f"{i}.f_1.bias")
up_block_one_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_one_sd_orig.pop(f"{i}.f_t.weight")
up_block_one_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_one_sd_orig.pop(f"{i}.f_t.bias")
up_block_one_sd_new[f"resnets.{i}.norm2.weight"] = up_block_one_sd_orig.pop(f"{i}.gn_2.weight")
up_block_one_sd_new[f"resnets.{i}.norm2.bias"] = up_block_one_sd_orig.pop(f"{i}.gn_2.bias")
up_block_one_sd_new[f"resnets.{i}.conv2.weight"] = up_block_one_sd_orig.pop(f"{i}.f_2.weight")
up_block_one_sd_new[f"resnets.{i}.conv2.bias"] = up_block_one_sd_orig.pop(f"{i}.f_2.bias")
up_block_one_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_one_sd_orig.pop(f"{i}.f_s.weight")
up_block_one_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_one_sd_orig.pop(f"{i}.f_s.bias")
up_block_one_sd_new["upsamplers.0.norm1.weight"] = up_block_one_sd_orig.pop("4.gn_1.weight")
up_block_one_sd_new["upsamplers.0.norm1.bias"] = up_block_one_sd_orig.pop("4.gn_1.bias")
up_block_one_sd_new["upsamplers.0.conv1.weight"] = up_block_one_sd_orig.pop("4.f_1.weight")
up_block_one_sd_new["upsamplers.0.conv1.bias"] = up_block_one_sd_orig.pop("4.f_1.bias")
up_block_one_sd_new["upsamplers.0.time_emb_proj.weight"] = up_block_one_sd_orig.pop("4.f_t.weight")
up_block_one_sd_new["upsamplers.0.time_emb_proj.bias"] = up_block_one_sd_orig.pop("4.f_t.bias")
up_block_one_sd_new["upsamplers.0.norm2.weight"] = up_block_one_sd_orig.pop("4.gn_2.weight")
up_block_one_sd_new["upsamplers.0.norm2.bias"] = up_block_one_sd_orig.pop("4.gn_2.bias")
up_block_one_sd_new["upsamplers.0.conv2.weight"] = up_block_one_sd_orig.pop("4.f_2.weight")
up_block_one_sd_new["upsamplers.0.conv2.bias"] = up_block_one_sd_orig.pop("4.f_2.bias")
assert len(up_block_one_sd_orig) == 0
up_block_one = ResnetUpsampleBlock2D(
in_channels=1024,
prev_output_channel=1024,
out_channels=1024,
temb_channels=1280,
num_layers=4,
add_upsample=True,
resnet_time_scale_shift="scale_shift",
resnet_eps=1e-5,
)
up_block_one.load_state_dict(up_block_one_sd_new)
print("UP BLOCK TWO")
up_block_two_sd_orig = model.up[-2].state_dict()
up_block_two_sd_new = {}
for i in range(4):
up_block_two_sd_new[f"resnets.{i}.norm1.weight"] = up_block_two_sd_orig.pop(f"{i}.gn_1.weight")
up_block_two_sd_new[f"resnets.{i}.norm1.bias"] = up_block_two_sd_orig.pop(f"{i}.gn_1.bias")
up_block_two_sd_new[f"resnets.{i}.conv1.weight"] = up_block_two_sd_orig.pop(f"{i}.f_1.weight")
up_block_two_sd_new[f"resnets.{i}.conv1.bias"] = up_block_two_sd_orig.pop(f"{i}.f_1.bias")
up_block_two_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_two_sd_orig.pop(f"{i}.f_t.weight")
up_block_two_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_two_sd_orig.pop(f"{i}.f_t.bias")
up_block_two_sd_new[f"resnets.{i}.norm2.weight"] = up_block_two_sd_orig.pop(f"{i}.gn_2.weight")
up_block_two_sd_new[f"resnets.{i}.norm2.bias"] = up_block_two_sd_orig.pop(f"{i}.gn_2.bias")
up_block_two_sd_new[f"resnets.{i}.conv2.weight"] = up_block_two_sd_orig.pop(f"{i}.f_2.weight")
up_block_two_sd_new[f"resnets.{i}.conv2.bias"] = up_block_two_sd_orig.pop(f"{i}.f_2.bias")
up_block_two_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_two_sd_orig.pop(f"{i}.f_s.weight")
up_block_two_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_two_sd_orig.pop(f"{i}.f_s.bias")
up_block_two_sd_new["upsamplers.0.norm1.weight"] = up_block_two_sd_orig.pop("4.gn_1.weight")
up_block_two_sd_new["upsamplers.0.norm1.bias"] = up_block_two_sd_orig.pop("4.gn_1.bias")
up_block_two_sd_new["upsamplers.0.conv1.weight"] = up_block_two_sd_orig.pop("4.f_1.weight")
up_block_two_sd_new["upsamplers.0.conv1.bias"] = up_block_two_sd_orig.pop("4.f_1.bias")
up_block_two_sd_new["upsamplers.0.time_emb_proj.weight"] = up_block_two_sd_orig.pop("4.f_t.weight")
up_block_two_sd_new["upsamplers.0.time_emb_proj.bias"] = up_block_two_sd_orig.pop("4.f_t.bias")
up_block_two_sd_new["upsamplers.0.norm2.weight"] = up_block_two_sd_orig.pop("4.gn_2.weight")
up_block_two_sd_new["upsamplers.0.norm2.bias"] = up_block_two_sd_orig.pop("4.gn_2.bias")
up_block_two_sd_new["upsamplers.0.conv2.weight"] = up_block_two_sd_orig.pop("4.f_2.weight")
up_block_two_sd_new["upsamplers.0.conv2.bias"] = up_block_two_sd_orig.pop("4.f_2.bias")
assert len(up_block_two_sd_orig) == 0
up_block_two = ResnetUpsampleBlock2D(
in_channels=640,
prev_output_channel=1024,
out_channels=1024,
temb_channels=1280,
num_layers=4,
add_upsample=True,
resnet_time_scale_shift="scale_shift",
resnet_eps=1e-5,
)
up_block_two.load_state_dict(up_block_two_sd_new)
print("UP BLOCK THREE")
up_block_three_sd_orig = model.up[-3].state_dict()
up_block_three_sd_new = {}
for i in range(4):
up_block_three_sd_new[f"resnets.{i}.norm1.weight"] = up_block_three_sd_orig.pop(f"{i}.gn_1.weight")
up_block_three_sd_new[f"resnets.{i}.norm1.bias"] = up_block_three_sd_orig.pop(f"{i}.gn_1.bias")
up_block_three_sd_new[f"resnets.{i}.conv1.weight"] = up_block_three_sd_orig.pop(f"{i}.f_1.weight")
up_block_three_sd_new[f"resnets.{i}.conv1.bias"] = up_block_three_sd_orig.pop(f"{i}.f_1.bias")
up_block_three_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_three_sd_orig.pop(f"{i}.f_t.weight")
up_block_three_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_three_sd_orig.pop(f"{i}.f_t.bias")
up_block_three_sd_new[f"resnets.{i}.norm2.weight"] = up_block_three_sd_orig.pop(f"{i}.gn_2.weight")
up_block_three_sd_new[f"resnets.{i}.norm2.bias"] = up_block_three_sd_orig.pop(f"{i}.gn_2.bias")
up_block_three_sd_new[f"resnets.{i}.conv2.weight"] = up_block_three_sd_orig.pop(f"{i}.f_2.weight")
up_block_three_sd_new[f"resnets.{i}.conv2.bias"] = up_block_three_sd_orig.pop(f"{i}.f_2.bias")
up_block_three_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_three_sd_orig.pop(f"{i}.f_s.weight")
up_block_three_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_three_sd_orig.pop(f"{i}.f_s.bias")
up_block_three_sd_new["upsamplers.0.norm1.weight"] = up_block_three_sd_orig.pop("4.gn_1.weight")
up_block_three_sd_new["upsamplers.0.norm1.bias"] = up_block_three_sd_orig.pop("4.gn_1.bias")
up_block_three_sd_new["upsamplers.0.conv1.weight"] = up_block_three_sd_orig.pop("4.f_1.weight")
up_block_three_sd_new["upsamplers.0.conv1.bias"] = up_block_three_sd_orig.pop("4.f_1.bias")
up_block_three_sd_new["upsamplers.0.time_emb_proj.weight"] = up_block_three_sd_orig.pop("4.f_t.weight")
up_block_three_sd_new["upsamplers.0.time_emb_proj.bias"] = up_block_three_sd_orig.pop("4.f_t.bias")
up_block_three_sd_new["upsamplers.0.norm2.weight"] = up_block_three_sd_orig.pop("4.gn_2.weight")
up_block_three_sd_new["upsamplers.0.norm2.bias"] = up_block_three_sd_orig.pop("4.gn_2.bias")
up_block_three_sd_new["upsamplers.0.conv2.weight"] = up_block_three_sd_orig.pop("4.f_2.weight")
up_block_three_sd_new["upsamplers.0.conv2.bias"] = up_block_three_sd_orig.pop("4.f_2.bias")
assert len(up_block_three_sd_orig) == 0
up_block_three = ResnetUpsampleBlock2D(
in_channels=320,
prev_output_channel=1024,
out_channels=640,
temb_channels=1280,
num_layers=4,
add_upsample=True,
resnet_time_scale_shift="scale_shift",
resnet_eps=1e-5,
)
up_block_three.load_state_dict(up_block_three_sd_new)
print("UP BLOCK FOUR")
up_block_four_sd_orig = model.up[-4].state_dict()
up_block_four_sd_new = {}
for i in range(4):
up_block_four_sd_new[f"resnets.{i}.norm1.weight"] = up_block_four_sd_orig.pop(f"{i}.gn_1.weight")
up_block_four_sd_new[f"resnets.{i}.norm1.bias"] = up_block_four_sd_orig.pop(f"{i}.gn_1.bias")
up_block_four_sd_new[f"resnets.{i}.conv1.weight"] = up_block_four_sd_orig.pop(f"{i}.f_1.weight")
up_block_four_sd_new[f"resnets.{i}.conv1.bias"] = up_block_four_sd_orig.pop(f"{i}.f_1.bias")
up_block_four_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_four_sd_orig.pop(f"{i}.f_t.weight")
up_block_four_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_four_sd_orig.pop(f"{i}.f_t.bias")
up_block_four_sd_new[f"resnets.{i}.norm2.weight"] = up_block_four_sd_orig.pop(f"{i}.gn_2.weight")
up_block_four_sd_new[f"resnets.{i}.norm2.bias"] = up_block_four_sd_orig.pop(f"{i}.gn_2.bias")
up_block_four_sd_new[f"resnets.{i}.conv2.weight"] = up_block_four_sd_orig.pop(f"{i}.f_2.weight")
up_block_four_sd_new[f"resnets.{i}.conv2.bias"] = up_block_four_sd_orig.pop(f"{i}.f_2.bias")
up_block_four_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_four_sd_orig.pop(f"{i}.f_s.weight")
up_block_four_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_four_sd_orig.pop(f"{i}.f_s.bias")
assert len(up_block_four_sd_orig) == 0
up_block_four = ResnetUpsampleBlock2D(
in_channels=320,
prev_output_channel=640,
out_channels=320,
temb_channels=1280,
num_layers=4,
add_upsample=False,
resnet_time_scale_shift="scale_shift",
resnet_eps=1e-5,
)
up_block_four.load_state_dict(up_block_four_sd_new)
print("initial projection (conv_in)")
conv_in_sd_orig = model.embed_image.state_dict()
conv_in_sd_new = {}
conv_in_sd_new["weight"] = conv_in_sd_orig.pop("f.weight")
conv_in_sd_new["bias"] = conv_in_sd_orig.pop("f.bias")
assert len(conv_in_sd_orig) == 0
block_out_channels = [320, 640, 1024, 1024]
in_channels = 7
conv_in_kernel = 3
conv_in_padding = (conv_in_kernel - 1) // 2
conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding)
conv_in.load_state_dict(conv_in_sd_new)
print("out projection (conv_out) (conv_norm_out)")
out_channels = 6
norm_num_groups = 32
norm_eps = 1e-5
act_fn = "silu"
conv_out_kernel = 3
conv_out_padding = (conv_out_kernel - 1) // 2
conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
# uses torch.functional in orig
# conv_act = get_activation(act_fn)
conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding)
conv_norm_out.load_state_dict(model.output.gn.state_dict())
conv_out.load_state_dict(model.output.f.state_dict())
print("timestep projection (time_proj) (time_embedding)")
f1_sd = model.embed_time.f_1.state_dict()
f2_sd = model.embed_time.f_2.state_dict()
time_embedding_sd = {
"linear_1.weight": f1_sd.pop("weight"),
"linear_1.bias": f1_sd.pop("bias"),
"linear_2.weight": f2_sd.pop("weight"),
"linear_2.bias": f2_sd.pop("bias"),
}
assert len(f1_sd) == 0
assert len(f2_sd) == 0
time_embedding_type = "learned"
num_train_timesteps = 1024
time_embedding_dim = 1280
time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0])
timestep_input_dim = block_out_channels[0]
time_embedding = TimestepEmbedding(timestep_input_dim, time_embedding_dim)
time_proj.load_state_dict(model.embed_time.emb.state_dict())
time_embedding.load_state_dict(time_embedding_sd)
print("CONVERT")
time_embedding.to("cuda")
time_proj.to("cuda")
conv_in.to("cuda")
block_one.to("cuda")
block_two.to("cuda")
block_three.to("cuda")
block_four.to("cuda")
mid_block_one.to("cuda")
up_block_one.to("cuda")
up_block_two.to("cuda")
up_block_three.to("cuda")
up_block_four.to("cuda")
conv_norm_out.to("cuda")
conv_out.to("cuda")
model.time_proj = time_proj
model.time_embedding = time_embedding
model.embed_image = conv_in
model.down[0] = block_one
model.down[1] = block_two
model.down[2] = block_three
model.down[3] = block_four
model.mid = mid_block_one
model.up[-1] = up_block_one
model.up[-2] = up_block_two
model.up[-3] = up_block_three
model.up[-4] = up_block_four
model.output.gn = conv_norm_out
model.output.f = conv_out
model.converted = True
sample_consistency_new = decoder_consistency(latent, generator=torch.Generator("cpu").manual_seed(0))
save_image(sample_consistency_new, "con_new.png")
assert (sample_consistency_orig == sample_consistency_new).all()
print("making unet")
unet = UNet2DModel(
in_channels=in_channels,
out_channels=out_channels,
down_block_types=(
"ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D",
),
up_block_types=(
"ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D",
),
block_out_channels=block_out_channels,
layers_per_block=3,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
resnet_time_scale_shift="scale_shift",
time_embedding_type="learned",
num_train_timesteps=num_train_timesteps,
add_attention=False,
)
unet_state_dict = {}
def add_state_dict(prefix, mod):
for k, v in mod.state_dict().items():
unet_state_dict[f"{prefix}.{k}"] = v
add_state_dict("conv_in", conv_in)
add_state_dict("time_proj", time_proj)
add_state_dict("time_embedding", time_embedding)
add_state_dict("down_blocks.0", block_one)
add_state_dict("down_blocks.1", block_two)
add_state_dict("down_blocks.2", block_three)
add_state_dict("down_blocks.3", block_four)
add_state_dict("mid_block", mid_block_one)
add_state_dict("up_blocks.0", up_block_one)
add_state_dict("up_blocks.1", up_block_two)
add_state_dict("up_blocks.2", up_block_three)
add_state_dict("up_blocks.3", up_block_four)
add_state_dict("conv_norm_out", conv_norm_out)
add_state_dict("conv_out", conv_out)
unet.load_state_dict(unet_state_dict)
print("running with diffusers unet")
unet.to("cuda")
decoder_consistency.ckpt = unet
sample_consistency_new_2 = decoder_consistency(latent, generator=torch.Generator("cpu").manual_seed(0))
save_image(sample_consistency_new_2, "con_new_2.png")
assert (sample_consistency_orig == sample_consistency_new_2).all()
print("running with diffusers model")
Encoder.old_constructor = Encoder.__init__
def new_constructor(self, **kwargs):
self.old_constructor(**kwargs)
self.constructor_arguments = kwargs
Encoder.__init__ = new_constructor
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
consistency_vae = ConsistencyDecoderVAE(
encoder_args=vae.encoder.constructor_arguments,
decoder_args=unet.config,
scaling_factor=vae.config.scaling_factor,
block_out_channels=vae.config.block_out_channels,
latent_channels=vae.config.latent_channels,
)
consistency_vae.encoder.load_state_dict(vae.encoder.state_dict())
consistency_vae.quant_conv.load_state_dict(vae.quant_conv.state_dict())
consistency_vae.decoder_unet.load_state_dict(unet.state_dict())
consistency_vae.to(dtype=torch.float16, device="cuda")
sample_consistency_new_3 = consistency_vae.decode(
0.18215 * latent, generator=torch.Generator("cpu").manual_seed(0)
).sample
print("max difference")
print((sample_consistency_orig - sample_consistency_new_3).abs().max())
print("total difference")
print((sample_consistency_orig - sample_consistency_new_3).abs().sum())
# assert (sample_consistency_orig == sample_consistency_new_3).all()
print("running with diffusers pipeline")
pipe = DiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", vae=consistency_vae, torch_dtype=torch.float16
)
pipe.to("cuda")
pipe("horse", generator=torch.Generator("cpu").manual_seed(0)).images[0].save("horse.png")
if args.save_pretrained is not None:
consistency_vae.save_pretrained(args.save_pretrained)
import argparse
import os
import torch
from diffusers import (
CMStochasticIterativeScheduler,
ConsistencyModelPipeline,
UNet2DModel,
)
TEST_UNET_CONFIG = {
"sample_size": 32,
"in_channels": 3,
"out_channels": 3,
"layers_per_block": 2,
"num_class_embeds": 1000,
"block_out_channels": [32, 64],
"attention_head_dim": 8,
"down_block_types": [
"ResnetDownsampleBlock2D",
"AttnDownBlock2D",
],
"up_block_types": [
"AttnUpBlock2D",
"ResnetUpsampleBlock2D",
],
"resnet_time_scale_shift": "scale_shift",
"attn_norm_num_groups": 32,
"upsample_type": "resnet",
"downsample_type": "resnet",
}
IMAGENET_64_UNET_CONFIG = {
"sample_size": 64,
"in_channels": 3,
"out_channels": 3,
"layers_per_block": 3,
"num_class_embeds": 1000,
"block_out_channels": [192, 192 * 2, 192 * 3, 192 * 4],
"attention_head_dim": 64,
"down_block_types": [
"ResnetDownsampleBlock2D",
"AttnDownBlock2D",
"AttnDownBlock2D",
"AttnDownBlock2D",
],
"up_block_types": [
"AttnUpBlock2D",
"AttnUpBlock2D",
"AttnUpBlock2D",
"ResnetUpsampleBlock2D",
],
"resnet_time_scale_shift": "scale_shift",
"attn_norm_num_groups": 32,
"upsample_type": "resnet",
"downsample_type": "resnet",
}
LSUN_256_UNET_CONFIG = {
"sample_size": 256,
"in_channels": 3,
"out_channels": 3,
"layers_per_block": 2,
"num_class_embeds": None,
"block_out_channels": [256, 256, 256 * 2, 256 * 2, 256 * 4, 256 * 4],
"attention_head_dim": 64,
"down_block_types": [
"ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D",
"AttnDownBlock2D",
"AttnDownBlock2D",
"AttnDownBlock2D",
],
"up_block_types": [
"AttnUpBlock2D",
"AttnUpBlock2D",
"AttnUpBlock2D",
"ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D",
],
"resnet_time_scale_shift": "default",
"upsample_type": "resnet",
"downsample_type": "resnet",
}
CD_SCHEDULER_CONFIG = {
"num_train_timesteps": 40,
"sigma_min": 0.002,
"sigma_max": 80.0,
}
CT_IMAGENET_64_SCHEDULER_CONFIG = {
"num_train_timesteps": 201,
"sigma_min": 0.002,
"sigma_max": 80.0,
}
CT_LSUN_256_SCHEDULER_CONFIG = {
"num_train_timesteps": 151,
"sigma_min": 0.002,
"sigma_max": 80.0,
}
def str2bool(v):
"""
https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
"""
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("boolean value expected")
def convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=False):
new_checkpoint[f"{new_prefix}.norm1.weight"] = checkpoint[f"{old_prefix}.in_layers.0.weight"]
new_checkpoint[f"{new_prefix}.norm1.bias"] = checkpoint[f"{old_prefix}.in_layers.0.bias"]
new_checkpoint[f"{new_prefix}.conv1.weight"] = checkpoint[f"{old_prefix}.in_layers.2.weight"]
new_checkpoint[f"{new_prefix}.conv1.bias"] = checkpoint[f"{old_prefix}.in_layers.2.bias"]
new_checkpoint[f"{new_prefix}.time_emb_proj.weight"] = checkpoint[f"{old_prefix}.emb_layers.1.weight"]
new_checkpoint[f"{new_prefix}.time_emb_proj.bias"] = checkpoint[f"{old_prefix}.emb_layers.1.bias"]
new_checkpoint[f"{new_prefix}.norm2.weight"] = checkpoint[f"{old_prefix}.out_layers.0.weight"]
new_checkpoint[f"{new_prefix}.norm2.bias"] = checkpoint[f"{old_prefix}.out_layers.0.bias"]
new_checkpoint[f"{new_prefix}.conv2.weight"] = checkpoint[f"{old_prefix}.out_layers.3.weight"]
new_checkpoint[f"{new_prefix}.conv2.bias"] = checkpoint[f"{old_prefix}.out_layers.3.bias"]
if has_skip:
new_checkpoint[f"{new_prefix}.conv_shortcut.weight"] = checkpoint[f"{old_prefix}.skip_connection.weight"]
new_checkpoint[f"{new_prefix}.conv_shortcut.bias"] = checkpoint[f"{old_prefix}.skip_connection.bias"]
return new_checkpoint
def convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_dim=None):
weight_q, weight_k, weight_v = checkpoint[f"{old_prefix}.qkv.weight"].chunk(3, dim=0)
bias_q, bias_k, bias_v = checkpoint[f"{old_prefix}.qkv.bias"].chunk(3, dim=0)
new_checkpoint[f"{new_prefix}.group_norm.weight"] = checkpoint[f"{old_prefix}.norm.weight"]
new_checkpoint[f"{new_prefix}.group_norm.bias"] = checkpoint[f"{old_prefix}.norm.bias"]
new_checkpoint[f"{new_prefix}.to_q.weight"] = weight_q.squeeze(-1).squeeze(-1)
new_checkpoint[f"{new_prefix}.to_q.bias"] = bias_q.squeeze(-1).squeeze(-1)
new_checkpoint[f"{new_prefix}.to_k.weight"] = weight_k.squeeze(-1).squeeze(-1)
new_checkpoint[f"{new_prefix}.to_k.bias"] = bias_k.squeeze(-1).squeeze(-1)
new_checkpoint[f"{new_prefix}.to_v.weight"] = weight_v.squeeze(-1).squeeze(-1)
new_checkpoint[f"{new_prefix}.to_v.bias"] = bias_v.squeeze(-1).squeeze(-1)
new_checkpoint[f"{new_prefix}.to_out.0.weight"] = (
checkpoint[f"{old_prefix}.proj_out.weight"].squeeze(-1).squeeze(-1)
)
new_checkpoint[f"{new_prefix}.to_out.0.bias"] = checkpoint[f"{old_prefix}.proj_out.bias"].squeeze(-1).squeeze(-1)
return new_checkpoint
def con_pt_to_diffuser(checkpoint_path: str, unet_config):
checkpoint = torch.load(checkpoint_path, map_location="cpu")
new_checkpoint = {}
new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["time_embed.0.weight"]
new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["time_embed.0.bias"]
new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["time_embed.2.weight"]
new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["time_embed.2.bias"]
if unet_config["num_class_embeds"] is not None:
new_checkpoint["class_embedding.weight"] = checkpoint["label_emb.weight"]
new_checkpoint["conv_in.weight"] = checkpoint["input_blocks.0.0.weight"]
new_checkpoint["conv_in.bias"] = checkpoint["input_blocks.0.0.bias"]
down_block_types = unet_config["down_block_types"]
layers_per_block = unet_config["layers_per_block"]
attention_head_dim = unet_config["attention_head_dim"]
channels_list = unet_config["block_out_channels"]
current_layer = 1
prev_channels = channels_list[0]
for i, layer_type in enumerate(down_block_types):
current_channels = channels_list[i]
downsample_block_has_skip = current_channels != prev_channels
if layer_type == "ResnetDownsampleBlock2D":
for j in range(layers_per_block):
new_prefix = f"down_blocks.{i}.resnets.{j}"
old_prefix = f"input_blocks.{current_layer}.0"
has_skip = True if j == 0 and downsample_block_has_skip else False
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=has_skip)
current_layer += 1
elif layer_type == "AttnDownBlock2D":
for j in range(layers_per_block):
new_prefix = f"down_blocks.{i}.resnets.{j}"
old_prefix = f"input_blocks.{current_layer}.0"
has_skip = True if j == 0 and downsample_block_has_skip else False
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=has_skip)
new_prefix = f"down_blocks.{i}.attentions.{j}"
old_prefix = f"input_blocks.{current_layer}.1"
new_checkpoint = convert_attention(
checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim
)
current_layer += 1
if i != len(down_block_types) - 1:
new_prefix = f"down_blocks.{i}.downsamplers.0"
old_prefix = f"input_blocks.{current_layer}.0"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
current_layer += 1
prev_channels = current_channels
# hardcoded the mid-block for now
new_prefix = "mid_block.resnets.0"
old_prefix = "middle_block.0"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
new_prefix = "mid_block.attentions.0"
old_prefix = "middle_block.1"
new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim)
new_prefix = "mid_block.resnets.1"
old_prefix = "middle_block.2"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
current_layer = 0
up_block_types = unet_config["up_block_types"]
for i, layer_type in enumerate(up_block_types):
if layer_type == "ResnetUpsampleBlock2D":
for j in range(layers_per_block + 1):
new_prefix = f"up_blocks.{i}.resnets.{j}"
old_prefix = f"output_blocks.{current_layer}.0"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True)
current_layer += 1
if i != len(up_block_types) - 1:
new_prefix = f"up_blocks.{i}.upsamplers.0"
old_prefix = f"output_blocks.{current_layer-1}.1"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
elif layer_type == "AttnUpBlock2D":
for j in range(layers_per_block + 1):
new_prefix = f"up_blocks.{i}.resnets.{j}"
old_prefix = f"output_blocks.{current_layer}.0"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True)
new_prefix = f"up_blocks.{i}.attentions.{j}"
old_prefix = f"output_blocks.{current_layer}.1"
new_checkpoint = convert_attention(
checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim
)
current_layer += 1
if i != len(up_block_types) - 1:
new_prefix = f"up_blocks.{i}.upsamplers.0"
old_prefix = f"output_blocks.{current_layer-1}.2"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"]
new_checkpoint["conv_norm_out.bias"] = checkpoint["out.0.bias"]
new_checkpoint["conv_out.weight"] = checkpoint["out.2.weight"]
new_checkpoint["conv_out.bias"] = checkpoint["out.2.bias"]
return new_checkpoint
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--unet_path", default=None, type=str, required=True, help="Path to the unet.pt to convert.")
parser.add_argument(
"--dump_path", default=None, type=str, required=True, help="Path to output the converted UNet model."
)
parser.add_argument("--class_cond", default=True, type=str, help="Whether the model is class-conditional.")
args = parser.parse_args()
args.class_cond = str2bool(args.class_cond)
ckpt_name = os.path.basename(args.unet_path)
print(f"Checkpoint: {ckpt_name}")
# Get U-Net config
if "imagenet64" in ckpt_name:
unet_config = IMAGENET_64_UNET_CONFIG
elif "256" in ckpt_name and (("bedroom" in ckpt_name) or ("cat" in ckpt_name)):
unet_config = LSUN_256_UNET_CONFIG
elif "test" in ckpt_name:
unet_config = TEST_UNET_CONFIG
else:
raise ValueError(f"Checkpoint type {ckpt_name} is not currently supported.")
if not args.class_cond:
unet_config["num_class_embeds"] = None
converted_unet_ckpt = con_pt_to_diffuser(args.unet_path, unet_config)
image_unet = UNet2DModel(**unet_config)
image_unet.load_state_dict(converted_unet_ckpt)
# Get scheduler config
if "cd" in ckpt_name or "test" in ckpt_name:
scheduler_config = CD_SCHEDULER_CONFIG
elif "ct" in ckpt_name and "imagenet64" in ckpt_name:
scheduler_config = CT_IMAGENET_64_SCHEDULER_CONFIG
elif "ct" in ckpt_name and "256" in ckpt_name and (("bedroom" in ckpt_name) or ("cat" in ckpt_name)):
scheduler_config = CT_LSUN_256_SCHEDULER_CONFIG
else:
raise ValueError(f"Checkpoint type {ckpt_name} is not currently supported.")
cm_scheduler = CMStochasticIterativeScheduler(**scheduler_config)
consistency_model = ConsistencyModelPipeline(unet=image_unet, scheduler=cm_scheduler)
consistency_model.save_pretrained(args.dump_path)
#!/usr/bin/env python3
import argparse
import math
import os
from copy import deepcopy
import requests
import torch
from audio_diffusion.models import DiffusionAttnUnet1D
from diffusion import sampling
from torch import nn
from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel
MODELS_MAP = {
"gwf-440k": {
"url": "https://model-server.zqevans2.workers.dev/gwf-440k.ckpt",
"sample_rate": 48000,
"sample_size": 65536,
},
"jmann-small-190k": {
"url": "https://model-server.zqevans2.workers.dev/jmann-small-190k.ckpt",
"sample_rate": 48000,
"sample_size": 65536,
},
"jmann-large-580k": {
"url": "https://model-server.zqevans2.workers.dev/jmann-large-580k.ckpt",
"sample_rate": 48000,
"sample_size": 131072,
},
"maestro-uncond-150k": {
"url": "https://model-server.zqevans2.workers.dev/maestro-uncond-150k.ckpt",
"sample_rate": 16000,
"sample_size": 65536,
},
"unlocked-uncond-250k": {
"url": "https://model-server.zqevans2.workers.dev/unlocked-uncond-250k.ckpt",
"sample_rate": 16000,
"sample_size": 65536,
},
"honk-140k": {
"url": "https://model-server.zqevans2.workers.dev/honk-140k.ckpt",
"sample_rate": 16000,
"sample_size": 65536,
},
}
def alpha_sigma_to_t(alpha, sigma):
"""Returns a timestep, given the scaling factors for the clean image and for
the noise."""
return torch.atan2(sigma, alpha) / math.pi * 2
def get_crash_schedule(t):
sigma = torch.sin(t * math.pi / 2) ** 2
alpha = (1 - sigma**2) ** 0.5
return alpha_sigma_to_t(alpha, sigma)
class Object(object):
pass
class DiffusionUncond(nn.Module):
def __init__(self, global_args):
super().__init__()
self.diffusion = DiffusionAttnUnet1D(global_args, n_attn_layers=4)
self.diffusion_ema = deepcopy(self.diffusion)
self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
def download(model_name):
url = MODELS_MAP[model_name]["url"]
r = requests.get(url, stream=True)
local_filename = f"./{model_name}.ckpt"
with open(local_filename, "wb") as fp:
for chunk in r.iter_content(chunk_size=8192):
fp.write(chunk)
return local_filename
DOWN_NUM_TO_LAYER = {
"1": "resnets.0",
"2": "attentions.0",
"3": "resnets.1",
"4": "attentions.1",
"5": "resnets.2",
"6": "attentions.2",
}
UP_NUM_TO_LAYER = {
"8": "resnets.0",
"9": "attentions.0",
"10": "resnets.1",
"11": "attentions.1",
"12": "resnets.2",
"13": "attentions.2",
}
MID_NUM_TO_LAYER = {
"1": "resnets.0",
"2": "attentions.0",
"3": "resnets.1",
"4": "attentions.1",
"5": "resnets.2",
"6": "attentions.2",
"8": "resnets.3",
"9": "attentions.3",
"10": "resnets.4",
"11": "attentions.4",
"12": "resnets.5",
"13": "attentions.5",
}
DEPTH_0_TO_LAYER = {
"0": "resnets.0",
"1": "resnets.1",
"2": "resnets.2",
"4": "resnets.0",
"5": "resnets.1",
"6": "resnets.2",
}
RES_CONV_MAP = {
"skip": "conv_skip",
"main.0": "conv_1",
"main.1": "group_norm_1",
"main.3": "conv_2",
"main.4": "group_norm_2",
}
ATTN_MAP = {
"norm": "group_norm",
"qkv_proj": ["query", "key", "value"],
"out_proj": ["proj_attn"],
}
def convert_resconv_naming(name):
if name.startswith("skip"):
return name.replace("skip", RES_CONV_MAP["skip"])
# name has to be of format main.{digit}
if not name.startswith("main."):
raise ValueError(f"ResConvBlock error with {name}")
return name.replace(name[:6], RES_CONV_MAP[name[:6]])
def convert_attn_naming(name):
for key, value in ATTN_MAP.items():
if name.startswith(key) and not isinstance(value, list):
return name.replace(key, value)
elif name.startswith(key):
return [name.replace(key, v) for v in value]
raise ValueError(f"Attn error with {name}")
def rename(input_string, max_depth=13):
string = input_string
if string.split(".")[0] == "timestep_embed":
return string.replace("timestep_embed", "time_proj")
depth = 0
if string.startswith("net.3."):
depth += 1
string = string[6:]
elif string.startswith("net."):
string = string[4:]
while string.startswith("main.7."):
depth += 1
string = string[7:]
if string.startswith("main."):
string = string[5:]
# mid block
if string[:2].isdigit():
layer_num = string[:2]
string_left = string[2:]
else:
layer_num = string[0]
string_left = string[1:]
if depth == max_depth:
new_layer = MID_NUM_TO_LAYER[layer_num]
prefix = "mid_block"
elif depth > 0 and int(layer_num) < 7:
new_layer = DOWN_NUM_TO_LAYER[layer_num]
prefix = f"down_blocks.{depth}"
elif depth > 0 and int(layer_num) > 7:
new_layer = UP_NUM_TO_LAYER[layer_num]
prefix = f"up_blocks.{max_depth - depth - 1}"
elif depth == 0:
new_layer = DEPTH_0_TO_LAYER[layer_num]
prefix = f"up_blocks.{max_depth - 1}" if int(layer_num) > 3 else "down_blocks.0"
if not string_left.startswith("."):
raise ValueError(f"Naming error with {input_string} and string_left: {string_left}.")
string_left = string_left[1:]
if "resnets" in new_layer:
string_left = convert_resconv_naming(string_left)
elif "attentions" in new_layer:
new_string_left = convert_attn_naming(string_left)
string_left = new_string_left
if not isinstance(string_left, list):
new_string = prefix + "." + new_layer + "." + string_left
else:
new_string = [prefix + "." + new_layer + "." + s for s in string_left]
return new_string
def rename_orig_weights(state_dict):
new_state_dict = {}
for k, v in state_dict.items():
if k.endswith("kernel"):
# up- and downsample layers, don't have trainable weights
continue
new_k = rename(k)
# check if we need to transform from Conv => Linear for attention
if isinstance(new_k, list):
new_state_dict = transform_conv_attns(new_state_dict, new_k, v)
else:
new_state_dict[new_k] = v
return new_state_dict
def transform_conv_attns(new_state_dict, new_k, v):
if len(new_k) == 1:
if len(v.shape) == 3:
# weight
new_state_dict[new_k[0]] = v[:, :, 0]
else:
# bias
new_state_dict[new_k[0]] = v
else:
# qkv matrices
trippled_shape = v.shape[0]
single_shape = trippled_shape // 3
for i in range(3):
if len(v.shape) == 3:
new_state_dict[new_k[i]] = v[i * single_shape : (i + 1) * single_shape, :, 0]
else:
new_state_dict[new_k[i]] = v[i * single_shape : (i + 1) * single_shape]
return new_state_dict
def main(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = args.model_path.split("/")[-1].split(".")[0]
if not os.path.isfile(args.model_path):
assert (
model_name == args.model_path
), f"Make sure to provide one of the official model names {MODELS_MAP.keys()}"
args.model_path = download(model_name)
sample_rate = MODELS_MAP[model_name]["sample_rate"]
sample_size = MODELS_MAP[model_name]["sample_size"]
config = Object()
config.sample_size = sample_size
config.sample_rate = sample_rate
config.latent_dim = 0
diffusers_model = UNet1DModel(sample_size=sample_size, sample_rate=sample_rate)
diffusers_state_dict = diffusers_model.state_dict()
orig_model = DiffusionUncond(config)
orig_model.load_state_dict(torch.load(args.model_path, map_location=device)["state_dict"])
orig_model = orig_model.diffusion_ema.eval()
orig_model_state_dict = orig_model.state_dict()
renamed_state_dict = rename_orig_weights(orig_model_state_dict)
renamed_minus_diffusers = set(renamed_state_dict.keys()) - set(diffusers_state_dict.keys())
diffusers_minus_renamed = set(diffusers_state_dict.keys()) - set(renamed_state_dict.keys())
assert len(renamed_minus_diffusers) == 0, f"Problem with {renamed_minus_diffusers}"
assert all(k.endswith("kernel") for k in list(diffusers_minus_renamed)), f"Problem with {diffusers_minus_renamed}"
for key, value in renamed_state_dict.items():
assert (
diffusers_state_dict[key].squeeze().shape == value.squeeze().shape
), f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}"
if key == "time_proj.weight":
value = value.squeeze()
diffusers_state_dict[key] = value
diffusers_model.load_state_dict(diffusers_state_dict)
steps = 100
seed = 33
diffusers_scheduler = IPNDMScheduler(num_train_timesteps=steps)
generator = torch.manual_seed(seed)
noise = torch.randn([1, 2, config.sample_size], generator=generator).to(device)
t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
step_list = get_crash_schedule(t)
pipe = DanceDiffusionPipeline(unet=diffusers_model, scheduler=diffusers_scheduler)
generator = torch.manual_seed(33)
audio = pipe(num_inference_steps=steps, generator=generator).audios
generated = sampling.iplms_sample(orig_model, noise, step_list, {})
generated = generated.clamp(-1, 1)
diff_sum = (generated - audio).abs().sum()
diff_max = (generated - audio).abs().max()
if args.save:
pipe.save_pretrained(args.checkpoint_path)
print("Diff sum", diff_sum)
print("Diff max", diff_max)
assert diff_max < 1e-3, f"Diff max: {diff_max} is too much :-/"
print(f"Conversion for {model_name} successful!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
parser.add_argument(
"--save", default=True, type=bool, required=False, help="Whether to save the converted model or not."
)
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
args = parser.parse_args()
main(args)
import argparse
import json
import torch
from diffusers import AutoencoderKL, DDPMPipeline, DDPMScheduler, UNet2DModel, VQModel
def shave_segments(path, n_shave_prefix_segments=1):
"""
Removes segments. Positive values shave the first segments, negative shave the last segments.
"""
if n_shave_prefix_segments >= 0:
return ".".join(path.split(".")[n_shave_prefix_segments:])
else:
return ".".join(path.split(".")[:n_shave_prefix_segments])
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
mapping = []
for old_item in old_list:
new_item = old_item
new_item = new_item.replace("block.", "resnets.")
new_item = new_item.replace("conv_shorcut", "conv1")
new_item = new_item.replace("in_shortcut", "conv_shortcut")
new_item = new_item.replace("temb_proj", "time_emb_proj")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def renew_attention_paths(old_list, n_shave_prefix_segments=0, in_mid=False):
mapping = []
for old_item in old_list:
new_item = old_item
# In `model.mid`, the layer is called `attn`.
if not in_mid:
new_item = new_item.replace("attn", "attentions")
new_item = new_item.replace(".k.", ".key.")
new_item = new_item.replace(".v.", ".value.")
new_item = new_item.replace(".q.", ".query.")
new_item = new_item.replace("proj_out", "proj_attn")
new_item = new_item.replace("norm", "group_norm")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def assign_to_checkpoint(
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
):
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
if attention_paths_to_split is not None:
if config is None:
raise ValueError("Please specify the config if setting 'attention_paths_to_split' to 'True'.")
for path, path_map in attention_paths_to_split.items():
old_tensor = old_checkpoint[path]
channels = old_tensor.shape[0] // 3
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
num_heads = old_tensor.shape[0] // config.get("num_head_channels", 1) // 3
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
query, key, value = old_tensor.split(channels // num_heads, dim=1)
checkpoint[path_map["query"]] = query.reshape(target_shape).squeeze()
checkpoint[path_map["key"]] = key.reshape(target_shape).squeeze()
checkpoint[path_map["value"]] = value.reshape(target_shape).squeeze()
for path in paths:
new_path = path["new"]
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
continue
new_path = new_path.replace("down.", "down_blocks.")
new_path = new_path.replace("up.", "up_blocks.")
if additional_replacements is not None:
for replacement in additional_replacements:
new_path = new_path.replace(replacement["old"], replacement["new"])
if "attentions" in new_path:
checkpoint[new_path] = old_checkpoint[path["old"]].squeeze()
else:
checkpoint[new_path] = old_checkpoint[path["old"]]
def convert_ddpm_checkpoint(checkpoint, config):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
new_checkpoint = {}
new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["temb.dense.0.weight"]
new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["temb.dense.0.bias"]
new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["temb.dense.1.weight"]
new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["temb.dense.1.bias"]
new_checkpoint["conv_norm_out.weight"] = checkpoint["norm_out.weight"]
new_checkpoint["conv_norm_out.bias"] = checkpoint["norm_out.bias"]
new_checkpoint["conv_in.weight"] = checkpoint["conv_in.weight"]
new_checkpoint["conv_in.bias"] = checkpoint["conv_in.bias"]
new_checkpoint["conv_out.weight"] = checkpoint["conv_out.weight"]
new_checkpoint["conv_out.bias"] = checkpoint["conv_out.bias"]
num_down_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "down" in layer})
down_blocks = {
layer_id: [key for key in checkpoint if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
}
num_up_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "up" in layer})
up_blocks = {layer_id: [key for key in checkpoint if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
for i in range(num_down_blocks):
block_id = (i - 1) // (config["layers_per_block"] + 1)
if any("downsample" in layer for layer in down_blocks[i]):
new_checkpoint[f"down_blocks.{i}.downsamplers.0.conv.weight"] = checkpoint[
f"down.{i}.downsample.op.weight"
]
new_checkpoint[f"down_blocks.{i}.downsamplers.0.conv.bias"] = checkpoint[f"down.{i}.downsample.op.bias"]
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
if any("block" in layer for layer in down_blocks[i]):
num_blocks = len(
{".".join(shave_segments(layer, 2).split(".")[:2]) for layer in down_blocks[i] if "block" in layer}
)
blocks = {
layer_id: [key for key in down_blocks[i] if f"block.{layer_id}" in key]
for layer_id in range(num_blocks)
}
if num_blocks > 0:
for j in range(config["layers_per_block"]):
paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint)
if any("attn" in layer for layer in down_blocks[i]):
num_attn = len(
{".".join(shave_segments(layer, 2).split(".")[:2]) for layer in down_blocks[i] if "attn" in layer}
)
attns = {
layer_id: [key for key in down_blocks[i] if f"attn.{layer_id}" in key]
for layer_id in range(num_blocks)
}
if num_attn > 0:
for j in range(config["layers_per_block"]):
paths = renew_attention_paths(attns[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
mid_block_1_layers = [key for key in checkpoint if "mid.block_1" in key]
mid_block_2_layers = [key for key in checkpoint if "mid.block_2" in key]
mid_attn_1_layers = [key for key in checkpoint if "mid.attn_1" in key]
# Mid new 2
paths = renew_resnet_paths(mid_block_1_layers)
assign_to_checkpoint(
paths,
new_checkpoint,
checkpoint,
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_1", "new": "resnets.0"}],
)
paths = renew_resnet_paths(mid_block_2_layers)
assign_to_checkpoint(
paths,
new_checkpoint,
checkpoint,
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_2", "new": "resnets.1"}],
)
paths = renew_attention_paths(mid_attn_1_layers, in_mid=True)
assign_to_checkpoint(
paths,
new_checkpoint,
checkpoint,
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "attn_1", "new": "attentions.0"}],
)
for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i
if any("upsample" in layer for layer in up_blocks[i]):
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[
f"up.{i}.upsample.conv.weight"
]
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[f"up.{i}.upsample.conv.bias"]
if any("block" in layer for layer in up_blocks[i]):
num_blocks = len(
{".".join(shave_segments(layer, 2).split(".")[:2]) for layer in up_blocks[i] if "block" in layer}
)
blocks = {
layer_id: [key for key in up_blocks[i] if f"block.{layer_id}" in key] for layer_id in range(num_blocks)
}
if num_blocks > 0:
for j in range(config["layers_per_block"] + 1):
replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
if any("attn" in layer for layer in up_blocks[i]):
num_attn = len(
{".".join(shave_segments(layer, 2).split(".")[:2]) for layer in up_blocks[i] if "attn" in layer}
)
attns = {
layer_id: [key for key in up_blocks[i] if f"attn.{layer_id}" in key] for layer_id in range(num_blocks)
}
if num_attn > 0:
for j in range(config["layers_per_block"] + 1):
replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
paths = renew_attention_paths(attns[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
new_checkpoint = {k.replace("mid_new_2", "mid_block"): v for k, v in new_checkpoint.items()}
return new_checkpoint
def convert_vq_autoenc_checkpoint(checkpoint, config):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
new_checkpoint = {}
new_checkpoint["encoder.conv_norm_out.weight"] = checkpoint["encoder.norm_out.weight"]
new_checkpoint["encoder.conv_norm_out.bias"] = checkpoint["encoder.norm_out.bias"]
new_checkpoint["encoder.conv_in.weight"] = checkpoint["encoder.conv_in.weight"]
new_checkpoint["encoder.conv_in.bias"] = checkpoint["encoder.conv_in.bias"]
new_checkpoint["encoder.conv_out.weight"] = checkpoint["encoder.conv_out.weight"]
new_checkpoint["encoder.conv_out.bias"] = checkpoint["encoder.conv_out.bias"]
new_checkpoint["decoder.conv_norm_out.weight"] = checkpoint["decoder.norm_out.weight"]
new_checkpoint["decoder.conv_norm_out.bias"] = checkpoint["decoder.norm_out.bias"]
new_checkpoint["decoder.conv_in.weight"] = checkpoint["decoder.conv_in.weight"]
new_checkpoint["decoder.conv_in.bias"] = checkpoint["decoder.conv_in.bias"]
new_checkpoint["decoder.conv_out.weight"] = checkpoint["decoder.conv_out.weight"]
new_checkpoint["decoder.conv_out.bias"] = checkpoint["decoder.conv_out.bias"]
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in checkpoint if "down" in layer})
down_blocks = {
layer_id: [key for key in checkpoint if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
}
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in checkpoint if "up" in layer})
up_blocks = {layer_id: [key for key in checkpoint if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
for i in range(num_down_blocks):
block_id = (i - 1) // (config["layers_per_block"] + 1)
if any("downsample" in layer for layer in down_blocks[i]):
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = checkpoint[
f"encoder.down.{i}.downsample.conv.weight"
]
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = checkpoint[
f"encoder.down.{i}.downsample.conv.bias"
]
if any("block" in layer for layer in down_blocks[i]):
num_blocks = len(
{".".join(shave_segments(layer, 3).split(".")[:3]) for layer in down_blocks[i] if "block" in layer}
)
blocks = {
layer_id: [key for key in down_blocks[i] if f"block.{layer_id}" in key]
for layer_id in range(num_blocks)
}
if num_blocks > 0:
for j in range(config["layers_per_block"]):
paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint)
if any("attn" in layer for layer in down_blocks[i]):
num_attn = len(
{".".join(shave_segments(layer, 3).split(".")[:3]) for layer in down_blocks[i] if "attn" in layer}
)
attns = {
layer_id: [key for key in down_blocks[i] if f"attn.{layer_id}" in key]
for layer_id in range(num_blocks)
}
if num_attn > 0:
for j in range(config["layers_per_block"]):
paths = renew_attention_paths(attns[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
mid_block_1_layers = [key for key in checkpoint if "mid.block_1" in key]
mid_block_2_layers = [key for key in checkpoint if "mid.block_2" in key]
mid_attn_1_layers = [key for key in checkpoint if "mid.attn_1" in key]
# Mid new 2
paths = renew_resnet_paths(mid_block_1_layers)
assign_to_checkpoint(
paths,
new_checkpoint,
checkpoint,
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_1", "new": "resnets.0"}],
)
paths = renew_resnet_paths(mid_block_2_layers)
assign_to_checkpoint(
paths,
new_checkpoint,
checkpoint,
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_2", "new": "resnets.1"}],
)
paths = renew_attention_paths(mid_attn_1_layers, in_mid=True)
assign_to_checkpoint(
paths,
new_checkpoint,
checkpoint,
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "attn_1", "new": "attentions.0"}],
)
for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i
if any("upsample" in layer for layer in up_blocks[i]):
new_checkpoint[f"decoder.up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[
f"decoder.up.{i}.upsample.conv.weight"
]
new_checkpoint[f"decoder.up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[
f"decoder.up.{i}.upsample.conv.bias"
]
if any("block" in layer for layer in up_blocks[i]):
num_blocks = len(
{".".join(shave_segments(layer, 3).split(".")[:3]) for layer in up_blocks[i] if "block" in layer}
)
blocks = {
layer_id: [key for key in up_blocks[i] if f"block.{layer_id}" in key] for layer_id in range(num_blocks)
}
if num_blocks > 0:
for j in range(config["layers_per_block"] + 1):
replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
if any("attn" in layer for layer in up_blocks[i]):
num_attn = len(
{".".join(shave_segments(layer, 3).split(".")[:3]) for layer in up_blocks[i] if "attn" in layer}
)
attns = {
layer_id: [key for key in up_blocks[i] if f"attn.{layer_id}" in key] for layer_id in range(num_blocks)
}
if num_attn > 0:
for j in range(config["layers_per_block"] + 1):
replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
paths = renew_attention_paths(attns[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
new_checkpoint = {k.replace("mid_new_2", "mid_block"): v for k, v in new_checkpoint.items()}
new_checkpoint["quant_conv.weight"] = checkpoint["quant_conv.weight"]
new_checkpoint["quant_conv.bias"] = checkpoint["quant_conv.bias"]
if "quantize.embedding.weight" in checkpoint:
new_checkpoint["quantize.embedding.weight"] = checkpoint["quantize.embedding.weight"]
new_checkpoint["post_quant_conv.weight"] = checkpoint["post_quant_conv.weight"]
new_checkpoint["post_quant_conv.bias"] = checkpoint["post_quant_conv.bias"]
return new_checkpoint
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
)
parser.add_argument(
"--config_file",
default=None,
type=str,
required=True,
help="The config json file corresponding to the architecture.",
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
args = parser.parse_args()
checkpoint = torch.load(args.checkpoint_path)
with open(args.config_file) as f:
config = json.loads(f.read())
# unet case
key_prefix_set = {key.split(".")[0] for key in checkpoint.keys()}
if "encoder" in key_prefix_set and "decoder" in key_prefix_set:
converted_checkpoint = convert_vq_autoenc_checkpoint(checkpoint, config)
else:
converted_checkpoint = convert_ddpm_checkpoint(checkpoint, config)
if "ddpm" in config:
del config["ddpm"]
if config["_class_name"] == "VQModel":
model = VQModel(**config)
model.load_state_dict(converted_checkpoint)
model.save_pretrained(args.dump_path)
elif config["_class_name"] == "AutoencoderKL":
model = AutoencoderKL(**config)
model.load_state_dict(converted_checkpoint)
model.save_pretrained(args.dump_path)
else:
model = UNet2DModel(**config)
model.load_state_dict(converted_checkpoint)
scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
pipe = DDPMPipeline(unet=model, scheduler=scheduler)
pipe.save_pretrained(args.dump_path)
# Script for converting a Hugging Face Diffusers trained SDXL LoRAs to Kohya format
# This means that you can input your diffusers-trained LoRAs and
# Get the output to work with WebUIs such as AUTOMATIC1111, ComfyUI, SD.Next and others.
# To get started you can find some cool `diffusers` trained LoRAs such as this cute Corgy
# https://huggingface.co/ignasbud/corgy_dog_LoRA/, download its `pytorch_lora_weights.safetensors` file
# and run the script:
# python convert_diffusers_sdxl_lora_to_webui.py --input_lora pytorch_lora_weights.safetensors --output_lora corgy.safetensors
# now you can use corgy.safetensors in your WebUI of choice!
# To train your own, here are some diffusers training scripts and utils that you can use and then convert:
# LoRA Ease - no code SDXL Dreambooth LoRA trainer: https://huggingface.co/spaces/multimodalart/lora-ease
# Dreambooth Advanced Training Script - state of the art techniques such as pivotal tuning and prodigy optimizer:
# - Script: https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
# - Colab (only on Pro): https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_Dreambooth_LoRA_advanced_example.ipynb
# Canonical diffusers training scripts:
# - Script: https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_sdxl.py
# - Colab (runs on free tier): https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb
import argparse
import os
from safetensors.torch import load_file, save_file
from diffusers.utils import convert_all_state_dict_to_peft, convert_state_dict_to_kohya
def convert_and_save(input_lora, output_lora=None):
if output_lora is None:
base_name = os.path.splitext(input_lora)[0]
output_lora = f"{base_name}_webui.safetensors"
diffusers_state_dict = load_file(input_lora)
peft_state_dict = convert_all_state_dict_to_peft(diffusers_state_dict)
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
save_file(kohya_state_dict, output_lora)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert LoRA model to PEFT and then to Kohya format.")
parser.add_argument(
"--input_lora",
type=str,
required=True,
help="Path to the input LoRA model file in the diffusers format.",
)
parser.add_argument(
"--output_lora",
type=str,
required=False,
help="Path for the converted LoRA (safetensors format for AUTOMATIC1111, ComfyUI, etc.). Optional, defaults to input name with a _webui suffix.",
)
args = parser.parse_args()
convert_and_save(args.input_lora, args.output_lora)
# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
# *Only* converts the UNet, VAE, and Text Encoder.
# Does not convert optimizer state or any other thing.
import argparse
import os.path as osp
import re
import torch
from safetensors.torch import load_file, save_file
# =================#
# UNet Conversion #
# =================#
unet_conversion_map = [
# (stable-diffusion, HF Diffusers)
("time_embed.0.weight", "time_embedding.linear_1.weight"),
("time_embed.0.bias", "time_embedding.linear_1.bias"),
("time_embed.2.weight", "time_embedding.linear_2.weight"),
("time_embed.2.bias", "time_embedding.linear_2.bias"),
("input_blocks.0.0.weight", "conv_in.weight"),
("input_blocks.0.0.bias", "conv_in.bias"),
("out.0.weight", "conv_norm_out.weight"),
("out.0.bias", "conv_norm_out.bias"),
("out.2.weight", "conv_out.weight"),
("out.2.bias", "conv_out.bias"),
# the following are for sdxl
("label_emb.0.0.weight", "add_embedding.linear_1.weight"),
("label_emb.0.0.bias", "add_embedding.linear_1.bias"),
("label_emb.0.2.weight", "add_embedding.linear_2.weight"),
("label_emb.0.2.bias", "add_embedding.linear_2.bias"),
]
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0", "norm1"),
("in_layers.2", "conv1"),
("out_layers.0", "norm2"),
("out_layers.3", "conv2"),
("emb_layers.1", "time_emb_proj"),
("skip_connection", "conv_shortcut"),
]
unet_conversion_map_layer = []
# hardcoded number of downblocks and resnets/attentions...
# would need smarter logic for other networks.
for i in range(3):
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i > 0:
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(4):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
if i < 2:
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
unet_conversion_map_layer.append(("output_blocks.2.2.conv.", "output_blocks.2.1.conv."))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
def convert_unet_state_dict(unet_state_dict):
# buyer beware: this is a *brittle* function,
# and correct output requires that all of these pieces interact in
# the exact order in which I have arranged them.
mapping = {k: k for k in unet_state_dict.keys()}
for sd_name, hf_name in unet_conversion_map:
mapping[hf_name] = sd_name
for k, v in mapping.items():
if "resnets" in k:
for sd_part, hf_part in unet_conversion_map_resnet:
v = v.replace(hf_part, sd_part)
mapping[k] = v
for k, v in mapping.items():
for sd_part, hf_part in unet_conversion_map_layer:
v = v.replace(hf_part, sd_part)
mapping[k] = v
new_state_dict = {sd_name: unet_state_dict[hf_name] for hf_name, sd_name in mapping.items()}
return new_state_dict
# ================#
# VAE Conversion #
# ================#
vae_conversion_map = [
# (stable-diffusion, HF Diffusers)
("nin_shortcut", "conv_shortcut"),
("norm_out", "conv_norm_out"),
("mid.attn_1.", "mid_block.attentions.0."),
]
for i in range(4):
# down_blocks have two resnets
for j in range(2):
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
sd_down_prefix = f"encoder.down.{i}.block.{j}."
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
if i < 3:
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
sd_downsample_prefix = f"down.{i}.downsample."
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"up.{3-i}.upsample."
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
# up_blocks have three resnets
# also, up blocks in hf are numbered in reverse from sd
for j in range(3):
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
# this part accounts for mid blocks in both the encoder and the decoder
for i in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{i}."
sd_mid_res_prefix = f"mid.block_{i+1}."
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
vae_conversion_map_attn = [
# (stable-diffusion, HF Diffusers)
("norm.", "group_norm."),
# the following are for SDXL
("q.", "to_q."),
("k.", "to_k."),
("v.", "to_v."),
("proj_out.", "to_out.0."),
]
def reshape_weight_for_sd(w):
# convert HF linear weights to SD conv2d weights
if not w.ndim == 1:
return w.reshape(*w.shape, 1, 1)
else:
return w
def convert_vae_state_dict(vae_state_dict):
mapping = {k: k for k in vae_state_dict.keys()}
for k, v in mapping.items():
for sd_part, hf_part in vae_conversion_map:
v = v.replace(hf_part, sd_part)
mapping[k] = v
for k, v in mapping.items():
if "attentions" in k:
for sd_part, hf_part in vae_conversion_map_attn:
v = v.replace(hf_part, sd_part)
mapping[k] = v
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
weights_to_convert = ["q", "k", "v", "proj_out"]
for k, v in new_state_dict.items():
for weight_name in weights_to_convert:
if f"mid.attn_1.{weight_name}.weight" in k:
print(f"Reshaping {k} for SD format")
new_state_dict[k] = reshape_weight_for_sd(v)
return new_state_dict
# =========================#
# Text Encoder Conversion #
# =========================#
textenc_conversion_lst = [
# (stable-diffusion, HF Diffusers)
("transformer.resblocks.", "text_model.encoder.layers."),
("ln_1", "layer_norm1"),
("ln_2", "layer_norm2"),
(".c_fc.", ".fc1."),
(".c_proj.", ".fc2."),
(".attn", ".self_attn"),
("ln_final.", "text_model.final_layer_norm."),
("token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
("positional_embedding", "text_model.embeddings.position_embedding.weight"),
]
protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
textenc_pattern = re.compile("|".join(protected.keys()))
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
code2idx = {"q": 0, "k": 1, "v": 2}
def convert_openclip_text_enc_state_dict(text_enc_dict):
new_state_dict = {}
capture_qkv_weight = {}
capture_qkv_bias = {}
for k, v in text_enc_dict.items():
if (
k.endswith(".self_attn.q_proj.weight")
or k.endswith(".self_attn.k_proj.weight")
or k.endswith(".self_attn.v_proj.weight")
):
k_pre = k[: -len(".q_proj.weight")]
k_code = k[-len("q_proj.weight")]
if k_pre not in capture_qkv_weight:
capture_qkv_weight[k_pre] = [None, None, None]
capture_qkv_weight[k_pre][code2idx[k_code]] = v
continue
if (
k.endswith(".self_attn.q_proj.bias")
or k.endswith(".self_attn.k_proj.bias")
or k.endswith(".self_attn.v_proj.bias")
):
k_pre = k[: -len(".q_proj.bias")]
k_code = k[-len("q_proj.bias")]
if k_pre not in capture_qkv_bias:
capture_qkv_bias[k_pre] = [None, None, None]
capture_qkv_bias[k_pre][code2idx[k_code]] = v
continue
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
new_state_dict[relabelled_key] = v
for k_pre, tensors in capture_qkv_weight.items():
if None in tensors:
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
for k_pre, tensors in capture_qkv_bias.items():
if None in tensors:
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
return new_state_dict
def convert_openai_text_enc_state_dict(text_enc_dict):
return text_enc_dict
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
parser.add_argument(
"--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt."
)
args = parser.parse_args()
assert args.model_path is not None, "Must provide a model path!"
assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
# Path for safetensors
unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors")
vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.safetensors")
text_enc_path = osp.join(args.model_path, "text_encoder", "model.safetensors")
text_enc_2_path = osp.join(args.model_path, "text_encoder_2", "model.safetensors")
# Load models from safetensors if it exists, if it doesn't pytorch
if osp.exists(unet_path):
unet_state_dict = load_file(unet_path, device="cpu")
else:
unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
unet_state_dict = torch.load(unet_path, map_location="cpu")
if osp.exists(vae_path):
vae_state_dict = load_file(vae_path, device="cpu")
else:
vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
vae_state_dict = torch.load(vae_path, map_location="cpu")
if osp.exists(text_enc_path):
text_enc_dict = load_file(text_enc_path, device="cpu")
else:
text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
if osp.exists(text_enc_2_path):
text_enc_2_dict = load_file(text_enc_2_path, device="cpu")
else:
text_enc_2_path = osp.join(args.model_path, "text_encoder_2", "pytorch_model.bin")
text_enc_2_dict = torch.load(text_enc_2_path, map_location="cpu")
# Convert the UNet model
unet_state_dict = convert_unet_state_dict(unet_state_dict)
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
# Convert the VAE model
vae_state_dict = convert_vae_state_dict(vae_state_dict)
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
# Convert text encoder 1
text_enc_dict = convert_openai_text_enc_state_dict(text_enc_dict)
text_enc_dict = {"conditioner.embedders.0.transformer." + k: v for k, v in text_enc_dict.items()}
# Convert text encoder 2
text_enc_2_dict = convert_openclip_text_enc_state_dict(text_enc_2_dict)
text_enc_2_dict = {"conditioner.embedders.1.model." + k: v for k, v in text_enc_2_dict.items()}
# We call the `.T.contiguous()` to match what's done in
# https://github.com/huggingface/diffusers/blob/84905ca7287876b925b6bf8e9bb92fec21c78764/src/diffusers/loaders/single_file_utils.py#L1085
text_enc_2_dict["conditioner.embedders.1.model.text_projection"] = text_enc_2_dict.pop(
"conditioner.embedders.1.model.text_projection.weight"
).T.contiguous()
# Put together new checkpoint
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict, **text_enc_2_dict}
if args.half:
state_dict = {k: v.half() for k, v in state_dict.items()}
if args.use_safetensors:
save_file(state_dict, args.checkpoint_path)
else:
state_dict = {"state_dict": state_dict}
torch.save(state_dict, args.checkpoint_path)
# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
# *Only* converts the UNet, VAE, and Text Encoder.
# Does not convert optimizer state or any other thing.
import argparse
import os.path as osp
import re
import torch
from safetensors.torch import load_file, save_file
# =================#
# UNet Conversion #
# =================#
unet_conversion_map = [
# (stable-diffusion, HF Diffusers)
("time_embed.0.weight", "time_embedding.linear_1.weight"),
("time_embed.0.bias", "time_embedding.linear_1.bias"),
("time_embed.2.weight", "time_embedding.linear_2.weight"),
("time_embed.2.bias", "time_embedding.linear_2.bias"),
("input_blocks.0.0.weight", "conv_in.weight"),
("input_blocks.0.0.bias", "conv_in.bias"),
("out.0.weight", "conv_norm_out.weight"),
("out.0.bias", "conv_norm_out.bias"),
("out.2.weight", "conv_out.weight"),
("out.2.bias", "conv_out.bias"),
]
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0", "norm1"),
("in_layers.2", "conv1"),
("out_layers.0", "norm2"),
("out_layers.3", "conv2"),
("emb_layers.1", "time_emb_proj"),
("skip_connection", "conv_shortcut"),
]
unet_conversion_map_layer = []
# hardcoded number of downblocks and resnets/attentions...
# would need smarter logic for other networks.
for i in range(4):
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
if i > 0:
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
def convert_unet_state_dict(unet_state_dict):
# buyer beware: this is a *brittle* function,
# and correct output requires that all of these pieces interact in
# the exact order in which I have arranged them.
mapping = {k: k for k in unet_state_dict.keys()}
for sd_name, hf_name in unet_conversion_map:
mapping[hf_name] = sd_name
for k, v in mapping.items():
if "resnets" in k:
for sd_part, hf_part in unet_conversion_map_resnet:
v = v.replace(hf_part, sd_part)
mapping[k] = v
for k, v in mapping.items():
for sd_part, hf_part in unet_conversion_map_layer:
v = v.replace(hf_part, sd_part)
mapping[k] = v
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
return new_state_dict
# ================#
# VAE Conversion #
# ================#
vae_conversion_map = [
# (stable-diffusion, HF Diffusers)
("nin_shortcut", "conv_shortcut"),
("norm_out", "conv_norm_out"),
("mid.attn_1.", "mid_block.attentions.0."),
]
for i in range(4):
# down_blocks have two resnets
for j in range(2):
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
sd_down_prefix = f"encoder.down.{i}.block.{j}."
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
if i < 3:
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
sd_downsample_prefix = f"down.{i}.downsample."
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"up.{3-i}.upsample."
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
# up_blocks have three resnets
# also, up blocks in hf are numbered in reverse from sd
for j in range(3):
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
# this part accounts for mid blocks in both the encoder and the decoder
for i in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{i}."
sd_mid_res_prefix = f"mid.block_{i+1}."
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
vae_conversion_map_attn = [
# (stable-diffusion, HF Diffusers)
("norm.", "group_norm."),
("q.", "query."),
("k.", "key."),
("v.", "value."),
("proj_out.", "proj_attn."),
]
# This is probably not the most ideal solution, but it does work.
vae_extra_conversion_map = [
("to_q", "q"),
("to_k", "k"),
("to_v", "v"),
("to_out.0", "proj_out"),
]
def reshape_weight_for_sd(w):
# convert HF linear weights to SD conv2d weights
if not w.ndim == 1:
return w.reshape(*w.shape, 1, 1)
else:
return w
def convert_vae_state_dict(vae_state_dict):
mapping = {k: k for k in vae_state_dict.keys()}
for k, v in mapping.items():
for sd_part, hf_part in vae_conversion_map:
v = v.replace(hf_part, sd_part)
mapping[k] = v
for k, v in mapping.items():
if "attentions" in k:
for sd_part, hf_part in vae_conversion_map_attn:
v = v.replace(hf_part, sd_part)
mapping[k] = v
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
weights_to_convert = ["q", "k", "v", "proj_out"]
keys_to_rename = {}
for k, v in new_state_dict.items():
for weight_name in weights_to_convert:
if f"mid.attn_1.{weight_name}.weight" in k:
print(f"Reshaping {k} for SD format")
new_state_dict[k] = reshape_weight_for_sd(v)
for weight_name, real_weight_name in vae_extra_conversion_map:
if f"mid.attn_1.{weight_name}.weight" in k or f"mid.attn_1.{weight_name}.bias" in k:
keys_to_rename[k] = k.replace(weight_name, real_weight_name)
for k, v in keys_to_rename.items():
if k in new_state_dict:
print(f"Renaming {k} to {v}")
new_state_dict[v] = reshape_weight_for_sd(new_state_dict[k])
del new_state_dict[k]
return new_state_dict
# =========================#
# Text Encoder Conversion #
# =========================#
textenc_conversion_lst = [
# (stable-diffusion, HF Diffusers)
("resblocks.", "text_model.encoder.layers."),
("ln_1", "layer_norm1"),
("ln_2", "layer_norm2"),
(".c_fc.", ".fc1."),
(".c_proj.", ".fc2."),
(".attn", ".self_attn"),
("ln_final.", "transformer.text_model.final_layer_norm."),
("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
]
protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
textenc_pattern = re.compile("|".join(protected.keys()))
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
code2idx = {"q": 0, "k": 1, "v": 2}
def convert_text_enc_state_dict_v20(text_enc_dict):
new_state_dict = {}
capture_qkv_weight = {}
capture_qkv_bias = {}
for k, v in text_enc_dict.items():
if (
k.endswith(".self_attn.q_proj.weight")
or k.endswith(".self_attn.k_proj.weight")
or k.endswith(".self_attn.v_proj.weight")
):
k_pre = k[: -len(".q_proj.weight")]
k_code = k[-len("q_proj.weight")]
if k_pre not in capture_qkv_weight:
capture_qkv_weight[k_pre] = [None, None, None]
capture_qkv_weight[k_pre][code2idx[k_code]] = v
continue
if (
k.endswith(".self_attn.q_proj.bias")
or k.endswith(".self_attn.k_proj.bias")
or k.endswith(".self_attn.v_proj.bias")
):
k_pre = k[: -len(".q_proj.bias")]
k_code = k[-len("q_proj.bias")]
if k_pre not in capture_qkv_bias:
capture_qkv_bias[k_pre] = [None, None, None]
capture_qkv_bias[k_pre][code2idx[k_code]] = v
continue
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
new_state_dict[relabelled_key] = v
for k_pre, tensors in capture_qkv_weight.items():
if None in tensors:
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
for k_pre, tensors in capture_qkv_bias.items():
if None in tensors:
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
return new_state_dict
def convert_text_enc_state_dict(text_enc_dict):
return text_enc_dict
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
parser.add_argument(
"--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt."
)
args = parser.parse_args()
assert args.model_path is not None, "Must provide a model path!"
assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
# Path for safetensors
unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors")
vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.safetensors")
text_enc_path = osp.join(args.model_path, "text_encoder", "model.safetensors")
# Load models from safetensors if it exists, if it doesn't pytorch
if osp.exists(unet_path):
unet_state_dict = load_file(unet_path, device="cpu")
else:
unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
unet_state_dict = torch.load(unet_path, map_location="cpu")
if osp.exists(vae_path):
vae_state_dict = load_file(vae_path, device="cpu")
else:
vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
vae_state_dict = torch.load(vae_path, map_location="cpu")
if osp.exists(text_enc_path):
text_enc_dict = load_file(text_enc_path, device="cpu")
else:
text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
# Convert the UNet model
unet_state_dict = convert_unet_state_dict(unet_state_dict)
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
# Convert the VAE model
vae_state_dict = convert_vae_state_dict(vae_state_dict)
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
if is_v20_model:
# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
else:
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
# Put together new checkpoint
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
if args.half:
state_dict = {k: v.half() for k, v in state_dict.items()}
if args.use_safetensors:
save_file(state_dict, args.checkpoint_path)
else:
state_dict = {"state_dict": state_dict}
torch.save(state_dict, args.checkpoint_path)
import argparse
import os
import torch
from torchvision.datasets.utils import download_url
from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, Transformer2DModel
pretrained_models = {512: "DiT-XL-2-512x512.pt", 256: "DiT-XL-2-256x256.pt"}
def download_model(model_name):
"""
Downloads a pre-trained DiT model from the web.
"""
local_path = f"pretrained_models/{model_name}"
if not os.path.isfile(local_path):
os.makedirs("pretrained_models", exist_ok=True)
web_path = f"https://dl.fbaipublicfiles.com/DiT/models/{model_name}"
download_url(web_path, "pretrained_models")
model = torch.load(local_path, map_location=lambda storage, loc: storage)
return model
def main(args):
state_dict = download_model(pretrained_models[args.image_size])
state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"]
state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"]
state_dict.pop("x_embedder.proj.weight")
state_dict.pop("x_embedder.proj.bias")
for depth in range(28):
state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_1.weight"] = state_dict[
"t_embedder.mlp.0.weight"
]
state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_1.bias"] = state_dict[
"t_embedder.mlp.0.bias"
]
state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_2.weight"] = state_dict[
"t_embedder.mlp.2.weight"
]
state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_2.bias"] = state_dict[
"t_embedder.mlp.2.bias"
]
state_dict[f"transformer_blocks.{depth}.norm1.emb.class_embedder.embedding_table.weight"] = state_dict[
"y_embedder.embedding_table.weight"
]
state_dict[f"transformer_blocks.{depth}.norm1.linear.weight"] = state_dict[
f"blocks.{depth}.adaLN_modulation.1.weight"
]
state_dict[f"transformer_blocks.{depth}.norm1.linear.bias"] = state_dict[
f"blocks.{depth}.adaLN_modulation.1.bias"
]
q, k, v = torch.chunk(state_dict[f"blocks.{depth}.attn.qkv.weight"], 3, dim=0)
q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{depth}.attn.qkv.bias"], 3, dim=0)
state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
state_dict[f"transformer_blocks.{depth}.attn1.to_q.bias"] = q_bias
state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
state_dict[f"transformer_blocks.{depth}.attn1.to_k.bias"] = k_bias
state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
state_dict[f"transformer_blocks.{depth}.attn1.to_v.bias"] = v_bias
state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict[
f"blocks.{depth}.attn.proj.weight"
]
state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict[f"blocks.{depth}.attn.proj.bias"]
state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.weight"] = state_dict[f"blocks.{depth}.mlp.fc1.weight"]
state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.bias"] = state_dict[f"blocks.{depth}.mlp.fc1.bias"]
state_dict[f"transformer_blocks.{depth}.ff.net.2.weight"] = state_dict[f"blocks.{depth}.mlp.fc2.weight"]
state_dict[f"transformer_blocks.{depth}.ff.net.2.bias"] = state_dict[f"blocks.{depth}.mlp.fc2.bias"]
state_dict.pop(f"blocks.{depth}.attn.qkv.weight")
state_dict.pop(f"blocks.{depth}.attn.qkv.bias")
state_dict.pop(f"blocks.{depth}.attn.proj.weight")
state_dict.pop(f"blocks.{depth}.attn.proj.bias")
state_dict.pop(f"blocks.{depth}.mlp.fc1.weight")
state_dict.pop(f"blocks.{depth}.mlp.fc1.bias")
state_dict.pop(f"blocks.{depth}.mlp.fc2.weight")
state_dict.pop(f"blocks.{depth}.mlp.fc2.bias")
state_dict.pop(f"blocks.{depth}.adaLN_modulation.1.weight")
state_dict.pop(f"blocks.{depth}.adaLN_modulation.1.bias")
state_dict.pop("t_embedder.mlp.0.weight")
state_dict.pop("t_embedder.mlp.0.bias")
state_dict.pop("t_embedder.mlp.2.weight")
state_dict.pop("t_embedder.mlp.2.bias")
state_dict.pop("y_embedder.embedding_table.weight")
state_dict["proj_out_1.weight"] = state_dict["final_layer.adaLN_modulation.1.weight"]
state_dict["proj_out_1.bias"] = state_dict["final_layer.adaLN_modulation.1.bias"]
state_dict["proj_out_2.weight"] = state_dict["final_layer.linear.weight"]
state_dict["proj_out_2.bias"] = state_dict["final_layer.linear.bias"]
state_dict.pop("final_layer.linear.weight")
state_dict.pop("final_layer.linear.bias")
state_dict.pop("final_layer.adaLN_modulation.1.weight")
state_dict.pop("final_layer.adaLN_modulation.1.bias")
# DiT XL/2
transformer = Transformer2DModel(
sample_size=args.image_size // 8,
num_layers=28,
attention_head_dim=72,
in_channels=4,
out_channels=8,
patch_size=2,
attention_bias=True,
num_attention_heads=16,
activation_fn="gelu-approximate",
num_embeds_ada_norm=1000,
norm_type="ada_norm_zero",
norm_elementwise_affine=False,
)
transformer.load_state_dict(state_dict, strict=True)
scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_schedule="linear",
prediction_type="epsilon",
clip_sample=False,
)
vae = AutoencoderKL.from_pretrained(args.vae_model)
pipeline = DiTPipeline(transformer=transformer, vae=vae, scheduler=scheduler)
if args.save:
pipeline.save_pretrained(args.checkpoint_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--image_size",
default=256,
type=int,
required=False,
help="Image size of pretrained model, either 256 or 512.",
)
parser.add_argument(
"--vae_model",
default="stabilityai/sd-vae-ft-ema",
type=str,
required=False,
help="Path to pretrained VAE model, either stabilityai/sd-vae-ft-mse or stabilityai/sd-vae-ft-ema.",
)
parser.add_argument(
"--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not."
)
parser.add_argument(
"--checkpoint_path", default=None, type=str, required=True, help="Path to the output pipeline."
)
args = parser.parse_args()
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment