Commit 727428ec authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit CI/CD

parents
task_flag="dit_g2_full_1024p" # the task flag is used to identify folders.
resume_module_root=./ckpts/t2i/model/pytorch_model_distill.pt # checkpoint root for model resume
resume_ema_root=./ckpts/t2i/model/pytorch_model_ema.pt # checkpoint root for ema resume
index_file=dataset/porcelain/jsons/porcelain.json # index file for dataloader
results_dir=./log_EXP # save root for results
batch_size=1 # training batch size
image_size=1024 # training image resolution
grad_accu_steps=1 # gradient accumulation
warmup_num_steps=0 # warm-up steps
lr=0.0001 # learning rate
ckpt_every=9999999 # create a ckpt every a few steps.
ckpt_latest_every=9999999 # create a ckpt named `latest.pt` every a few steps.
ckpt_every_n_epoch=2 # create a ckpt every a few epochs.
epochs=8 # total training epochs
sh $(dirname "$0")/run_g.sh \
--task-flag ${task_flag} \
--noise-schedule scaled_linear --beta-start 0.00085 --beta-end 0.018 \
--predict-type v_prediction \
--uncond-p 0 \
--uncond-p-t5 0 \
--index-file ${index_file} \
--random-flip \
--lr ${lr} \
--batch-size ${batch_size} \
--image-size ${image_size} \
--global-seed 999 \
--grad-accu-steps ${grad_accu_steps} \
--warmup-num-steps ${warmup_num_steps} \
--use-flash-attn \
--use-fp16 \
--extra-fp16 \
--results-dir ${results_dir} \
--resume \
--resume-module-root ${resume_module_root} \
--resume-ema-root ${resume_ema_root} \
--epochs ${epochs} \
--ckpt-every ${ckpt_every} \
--ckpt-latest-every ${ckpt_latest_every} \
--ckpt-every-n-epoch ${ckpt_every_n_epoch} \
--log-every 10 \
--deepspeed \
--use-zero-stage 2 \
--gradient-checkpointing \
--cpu-offloading \
"$@"
task_flag="canny_controlnet" # the task flag is used to identify folders.
control_type=canny
resume_module_root=./ckpts/t2i/model/pytorch_model_distill.pt # checkpoint root for resume
index_file=/path/to/your/indexfile # index file for dataloader
results_dir=./log_EXP # save root for results
batch_size=1 # training batch size
image_size=1024 # training image resolution
grad_accu_steps=2 # gradient accumulation
warmup_num_steps=0 # warm-up steps
lr=0.0001 # learning rate
ckpt_every=10000 # create a ckpt every a few steps.
ckpt_latest_every=5000 # create a ckpt named `latest.pt` every a few steps.
epochs=100 # total training epochs
sh $(dirname "$0")/run_g_controlnet.sh \
--task-flag ${task_flag} \
--control-type ${control_type} \
--noise-schedule scaled_linear --beta-start 0.00085 --beta-end 0.018 \
--predict-type v_prediction \
--uncond-p 0.44 \
--uncond-p-t5 0.44 \
--index-file ${index_file} \
--random-flip \
--lr ${lr} \
--batch-size ${batch_size} \
--image-size ${image_size} \
--global-seed 999 \
--grad-accu-steps ${grad_accu_steps} \
--warmup-num-steps ${warmup_num_steps} \
--use-flash-attn \
--use-fp16 \
--results-dir ${results_dir} \
--resume \
--resume-module-root ${resume_module_root} \
--epochs ${epochs} \
--ckpt-every ${ckpt_every} \
--ckpt-latest-every ${ckpt_latest_every} \
--log-every 10 \
--deepspeed \
--deepspeed-optimizer \
--use-zero-stage 2 \
--gradient-checkpointing \
"$@"
import gc
import json
import os
import random
import sys
import time
from functools import partial
from glob import glob
import deepspeed
import numpy as np
import torch
import torch.distributed as dist
from diffusers.models import AutoencoderKL
from peft import LoraConfig, get_peft_model
from torch.utils.data import DataLoader
from transformers import BertModel, BertTokenizer, logging as tf_logging
from IndexKits.index_kits import ResolutionGroup
from IndexKits.index_kits.sampler import (
DistributedSamplerWithStartIndex,
BlockDistributedSampler,
)
from hydit.config import get_args
from hydit.constants import VAE_EMA_PATH, TEXT_ENCODER, TOKENIZER, T5_ENCODER
from hydit.data_loader.arrow_load_stream import TextImageArrowStream
from hydit.diffusion import create_diffusion
from hydit.ds_config import deepspeed_config_from_args
from hydit.lr_scheduler import WarmupLR
from hydit.modules.ema import EMA
from hydit.modules.fp16_layers import Float16Module
from hydit.modules.models import HUNYUAN_DIT_MODELS, HunYuanDiT
from hydit.modules.text_encoder import MT5Embedder
from hydit.modules.posemb_layers import init_image_posemb
from hydit.utils.tools import create_exp_folder, model_resume, get_trainable_params
import bitsandbytes as bnb
def deepspeed_initialize(args, logger, model, opt, deepspeed_config):
logger.info(f"Initialize deepspeed...")
logger.info(f" Using deepspeed optimizer")
def get_learning_rate_scheduler(warmup_min_lr, lr, warmup_num_steps, opt):
return WarmupLR(opt, warmup_min_lr, lr, warmup_num_steps)
logger.info(
f" Building scheduler with warmup_min_lr={args.warmup_min_lr}, warmup_num_steps={args.warmup_num_steps}"
)
model, opt, _, scheduler = deepspeed.initialize(
model=model,
optimizer=opt,
model_parameters=get_trainable_params(model),
config_params=deepspeed_config,
args=args,
lr_scheduler=(
partial(
get_learning_rate_scheduler,
args.warmup_min_lr,
args.lr,
args.warmup_num_steps,
)
if args.warmup_num_steps > 0
else None
),
)
return model, opt, scheduler
def save_checkpoint(
args, rank, logger, model, ema, epoch, train_steps, checkpoint_dir, by="step"
):
def save_lora_weight(checkpoint_dir, client_state, tag=f"{train_steps:07d}.pt"):
cur_ckpt_save_dir = f"{checkpoint_dir}/{tag}"
if rank == 0:
if args.use_fp16:
target_module = getattr(model.module, 'module', model.module)
target_module.save_pretrained(cur_ckpt_save_dir)
else:
model.module.save_pretrained(cur_ckpt_save_dir)
def save_model_weight(client_state, tag):
checkpoint_path = f"{checkpoint_dir}/{tag}"
try:
if args.training_parts == "lora":
save_lora_weight(checkpoint_dir, client_state, tag=tag)
else:
model.save_checkpoint(
checkpoint_dir, client_state=client_state, tag=tag
)
logger.info(f"Saved checkpoint to {checkpoint_path}")
except Exception as e:
logger.error(f"Saved failed to {checkpoint_path}. {type(e)}: {e}")
return False, ""
return True, checkpoint_path
client_state = {"steps": train_steps, "epoch": epoch, "args": args}
if ema is not None:
client_state["ema"] = ema.state_dict()
# Save model weights by epoch or step
dst_paths = []
if by == "epoch":
tag = f"e{epoch:04d}.pt"
dst_paths.append(save_model_weight(client_state, tag))
elif by == "step":
if train_steps % args.ckpt_every == 0:
tag = f"{train_steps:07d}.pt"
dst_paths.append(save_model_weight(client_state, tag))
if (
train_steps % args.ckpt_latest_every == 0
or train_steps == args.max_training_steps
):
tag = "latest.pt"
dst_paths.append(save_model_weight(client_state, tag))
elif by == "final":
tag = "final.pt"
dst_paths.append(save_model_weight(client_state, tag))
else:
raise ValueError(f"Unknown save checkpoint method: {by}")
saved = any([state for state, _ in dst_paths])
if not saved:
return False
# Maybe clear optimizer states
if not args.save_optimizer_state:
dist.barrier()
if rank == 0 and len(dst_paths) > 0:
# Delete optimizer states to avoid occupying too much disk space.
for dst_path in dst_paths:
for opt_state_path in glob(f"{dst_path}/zero_*_optim_states.pt"):
os.remove(opt_state_path)
return True
@torch.no_grad()
def prepare_model_inputs(
args, batch, device, vae, text_encoder, text_encoder_t5, freqs_cis_img
):
(
image,
text_embedding,
text_embedding_mask,
text_embedding_t5,
text_embedding_mask_t5,
kwargs,
) = batch
# clip & mT5 text embedding
text_embedding = text_embedding.to(device)
text_embedding_mask = text_embedding_mask.to(device)
encoder_hidden_states = text_encoder(
text_embedding.to(device),
attention_mask=text_embedding_mask.to(device),
)[0]
text_embedding_t5 = text_embedding_t5.to(device).squeeze(1)
text_embedding_mask_t5 = text_embedding_mask_t5.to(device).squeeze(1)
with torch.no_grad():
output_t5 = text_encoder_t5(
input_ids=text_embedding_t5,
attention_mask=(
text_embedding_mask_t5 if T5_ENCODER["attention_mask"] else None
),
output_hidden_states=True,
)
encoder_hidden_states_t5 = output_t5["hidden_states"][
T5_ENCODER["layer_index"]
].detach()
# additional condition
if args.size_cond:
image_meta_size = kwargs["image_meta_size"].to(device)
else:
image_meta_size = None
if args.use_style_cond:
style = kwargs["style"].to(device)
else:
style = None
if args.extra_fp16:
image = image.half()
# Map input images to latent space + normalize latents:
image = image.to(device)
vae_scaling_factor = vae.config.scaling_factor
latents = vae.encode(image).latent_dist.sample().mul_(vae_scaling_factor)
# positional embedding
_, _, height, width = image.shape
reso = f"{height}x{width}"
cos_cis_img, sin_cis_img = freqs_cis_img[reso]
# Model conditions
model_kwargs = dict(
encoder_hidden_states=encoder_hidden_states,
text_embedding_mask=text_embedding_mask,
encoder_hidden_states_t5=encoder_hidden_states_t5,
text_embedding_mask_t5=text_embedding_mask_t5,
image_meta_size=image_meta_size,
style=style,
cos_cis_img=cos_cis_img,
sin_cis_img=sin_cis_img,
)
return latents, model_kwargs
def main(args):
if args.training_parts == "lora":
args.use_ema = False
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
dist.init_process_group("nccl")
world_size = dist.get_world_size()
batch_size = args.batch_size
grad_accu_steps = args.grad_accu_steps
global_batch_size = world_size * batch_size * grad_accu_steps
rank = dist.get_rank()
device = rank % torch.cuda.device_count()
seed = args.global_seed * world_size + rank
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.cuda.set_device(device)
print(f"Starting rank={rank}, seed={seed}, world_size={world_size}.")
deepspeed_config = deepspeed_config_from_args(args, global_batch_size)
# Setup an experiment folder
experiment_dir, checkpoint_dir, logger = create_exp_folder(args, rank)
# Log all the arguments
logger.info(sys.argv)
logger.info(str(args))
# Save to a json file
args_dict = vars(args)
args_dict["world_size"] = world_size
with open(f"{experiment_dir}/args.json", "w") as f:
json.dump(args_dict, f, indent=4)
# Disable the message "Some weights of the model checkpoint at ... were not used when initializing BertModel."
# If needed, just comment the following line.
tf_logging.set_verbosity_error()
# ===========================================================================
# Building HYDIT
# ===========================================================================
logger.info("Building HYDIT Model.")
# ---------------------------------------------------------------------------
# Training sample base size, such as 256/512/1024. Notice that this size is
# just a base size, not necessary the actual size of training samples. Actual
# size of the training samples are correlated with `resolutions` when enabling
# multi-resolution training.
# ---------------------------------------------------------------------------
image_size = args.image_size
if len(image_size) == 1:
image_size = [image_size[0], image_size[0]]
if len(image_size) != 2:
raise ValueError(f"Invalid image size: {args.image_size}")
assert image_size[0] % 8 == 0 and image_size[1] % 8 == 0, (
"Image size must be divisible by 8 (for the VAE encoder). " f"got {image_size}"
)
latent_size = [image_size[0] // 8, image_size[1] // 8]
# initialize model by deepspeed
# assert args.deepspeed, f"Must enable deepspeed in this script: train_deepspeed.py"
if args.deepspeed:
with deepspeed.zero.Init(
data_parallel_group=torch.distributed.group.WORLD,
remote_device=None if args.remote_device == "none" else args.remote_device,
config_dict_or_path=deepspeed_config,
mpu=None,
enabled=args.zero_stage == 3,
):
model = HUNYUAN_DIT_MODELS[args.model](
args,
input_size=latent_size,
log_fn=logger.info,
)
else:
model = HUNYUAN_DIT_MODELS[args.model](
args,
input_size=latent_size,
log_fn=logger.info,
)
model.to(device)
# Multi-resolution / Single-resolution training.
if args.multireso:
resolutions = ResolutionGroup(
image_size[0],
align=16,
step=args.reso_step,
target_ratios=args.target_ratios,
).data
else:
resolutions = ResolutionGroup(
image_size[0], align=16, target_ratios=["1:1"]
).data
freqs_cis_img = init_image_posemb(
args.rope_img,
resolutions=resolutions,
patch_size=model.patch_size,
hidden_size=model.hidden_size,
num_heads=model.num_heads,
log_fn=logger.info,
rope_real=args.rope_real,
)
# Create EMA model and convert to fp16 if needed.
ema = None
if args.use_ema:
ema = EMA(args, model, device, logger)
# Setup gradient checkpointing
if args.gradient_checkpointing:
model.enable_gradient_checkpointing()
# Setup FP16 main model:
if args.use_fp16:
model = Float16Module(model, args)
logger.info(
f" Using main model with data type {'fp16' if args.use_fp16 else 'fp32'}"
)
diffusion = create_diffusion(
noise_schedule=args.noise_schedule,
predict_type=args.predict_type,
learn_sigma=args.learn_sigma,
mse_loss_weight_type=args.mse_loss_weight_type,
beta_start=args.beta_start,
beta_end=args.beta_end,
noise_offset=args.noise_offset,
)
# Setup VAE
logger.info(f" Loading vae from {VAE_EMA_PATH}")
vae = AutoencoderKL.from_pretrained(VAE_EMA_PATH)
# Setup BERT text encoder
logger.info(f" Loading Bert text encoder from {TEXT_ENCODER}")
text_encoder = BertModel.from_pretrained(TEXT_ENCODER, False, revision=None)
# Setup BERT tokenizer:
logger.info(f" Loading Bert tokenizer from {TOKENIZER}")
tokenizer = BertTokenizer.from_pretrained(TOKENIZER)
# Setup T5 text encoder
mt5_path = T5_ENCODER["MT5"]
embedder_t5 = MT5Embedder(
mt5_path, torch_dtype=T5_ENCODER["torch_dtype"], max_length=args.text_len_t5
)
tokenizer_t5 = embedder_t5.tokenizer
text_encoder_t5 = embedder_t5.model
if args.extra_fp16:
logger.info(f" Using fp16 for extra modules: vae, text_encoder")
vae = vae.half().to(device)
text_encoder = text_encoder.half().to(device)
text_encoder_t5 = text_encoder_t5.half().to(device)
else:
vae = vae.to(device)
text_encoder = text_encoder.to(device)
text_encoder_t5 = text_encoder_t5.to(device)
logger.info(
f" Optimizer parameters: lr={args.lr}, weight_decay={args.weight_decay}"
)
logger.info(" Using deepspeed optimizer")
opt = None
# ===========================================================================
# Building Dataset
# ===========================================================================
logger.info(f"Building Streaming Dataset.")
logger.info(f" Loading index file {args.index_file} (v2)")
dataset = TextImageArrowStream(
args=args,
resolution=image_size[0],
random_flip=args.random_flip,
log_fn=logger.info,
index_file=args.index_file,
multireso=args.multireso,
batch_size=batch_size,
world_size=world_size,
random_shrink_size_cond=args.random_shrink_size_cond,
merge_src_cond=args.merge_src_cond,
uncond_p=args.uncond_p,
text_ctx_len=args.text_len,
tokenizer=tokenizer,
uncond_p_t5=args.uncond_p_t5,
text_ctx_len_t5=args.text_len_t5,
tokenizer_t5=tokenizer_t5,
)
if args.multireso:
sampler = BlockDistributedSampler(
dataset,
num_replicas=world_size,
rank=rank,
seed=args.global_seed,
shuffle=False,
drop_last=True,
batch_size=batch_size,
)
else:
sampler = DistributedSamplerWithStartIndex(
dataset,
num_replicas=world_size,
rank=rank,
seed=args.global_seed,
shuffle=False,
drop_last=True,
)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
sampler=sampler,
num_workers=args.num_workers,
pin_memory=True,
drop_last=True,
)
logger.info(f" Dataset contains {len(dataset):,} images.")
logger.info(f" Index file: {args.index_file}.")
if args.multireso:
logger.info(
f" Using MultiResolutionBucketIndexV2 with step {dataset.index_manager.step} "
f"and base size {dataset.index_manager.base_size}"
)
logger.info(f"\n {dataset.index_manager.resolutions}")
# ===========================================================================
# Loading parameter
# ===========================================================================
logger.info(f"Loading parameter")
start_epoch = 0
start_epoch_step = 0
train_steps = 0
# Resume checkpoint if needed
if args.resume:
model, ema, start_epoch, start_epoch_step, train_steps = model_resume(
args, model, ema, logger, len(loader)
)
if args.training_parts == "lora":
loraconfig = LoraConfig(
r=args.rank, lora_alpha=args.rank, target_modules=args.target_modules
)
if args.use_fp16:
model.module.enable_input_requires_grad()
model.module = get_peft_model(model.module, loraconfig)
else:
model.enable_input_requires_grad()
model = get_peft_model(model, loraconfig)
logger.info(f" Training parts: {args.training_parts}")
if args.deepspeed:
model, opt, scheduler = deepspeed_initialize(
args, logger, model, opt, deepspeed_config
)
# ===========================================================================
# Training
# ===========================================================================
model.train()
if args.use_ema:
ema.eval()
print(f" Worker {rank} ready.")
dist.barrier()
iters_per_epoch = len(loader)
logger.info(
" ****************************** Running training ******************************"
)
logger.info(f" Number GPUs: {world_size}")
logger.info(f" Number training samples: {len(dataset):,}")
logger.info(
f" Number parameters: {sum(p.numel() for p in model.parameters()):,}"
)
logger.info(
f" Number trainable params: {sum(p.numel() for p in get_trainable_params(model)):,}"
)
logger.info(
" ------------------------------------------------------------------------------"
)
logger.info(f" Iters per epoch: {iters_per_epoch:,}")
logger.info(f" Batch size per device: {batch_size}")
logger.info(
f" Batch size all device: {batch_size * world_size * grad_accu_steps:,} (world_size * batch_size * grad_accu_steps)"
)
logger.info(f" Gradient Accu steps: {args.grad_accu_steps}")
logger.info(
f" Total optimization steps: {args.epochs * iters_per_epoch // grad_accu_steps:,}"
)
logger.info(f" Training epochs: {start_epoch}/{args.epochs}")
logger.info(
f" Training epoch steps: {start_epoch_step:,}/{iters_per_epoch:,}"
)
logger.info(
f" Training total steps: {train_steps:,}/{min(args.max_training_steps, args.epochs * iters_per_epoch):,}"
)
logger.info(
" ------------------------------------------------------------------------------"
)
logger.info(f" Noise schedule: {args.noise_schedule}")
logger.info(
f" Beta limits: ({args.beta_start}, {args.beta_end})"
)
logger.info(f" Learn sigma: {args.learn_sigma}")
logger.info(f" Prediction type: {args.predict_type}")
logger.info(f" Noise offset: {args.noise_offset}")
logger.info(
" ------------------------------------------------------------------------------"
)
logger.info(f" Using EMA model: {args.use_ema} ({args.ema_dtype})")
if args.use_ema:
logger.info(
f" Using EMA decay: {ema.max_value if args.use_ema else None}"
)
logger.info(
f" Using EMA warmup power: {ema.power if args.use_ema else None}"
)
logger.info(f" Using main model fp16: {args.use_fp16}")
logger.info(f" Using extra modules fp16: {args.extra_fp16}")
logger.info(
" ------------------------------------------------------------------------------"
)
logger.info(f" Experiment directory: {experiment_dir}")
logger.info(
" *******************************************************************************"
)
if args.gc_interval > 0:
gc.disable()
gc.collect()
# Variables for monitoring/logging purposes:
log_steps = 0
running_loss = 0
start_time = time.time()
# Training loop
epoch = start_epoch
while epoch < args.epochs:
# Random shuffle dataset
shuffle_seed = args.global_seed + epoch
logger.info(f" Start random shuffle with seed={shuffle_seed}")
# Makesure all processors use the same seed to shuffle dataset.
dataset.shuffle(seed=shuffle_seed, fast=True)
logger.info(f" End of random shuffle")
# Move sampler to start_index
if not args.multireso:
start_index = start_epoch_step * world_size * batch_size
if start_index != sampler.start_index:
sampler.start_index = start_index
# Reset start_epoch_step to zero, to ensure next epoch will start from the beginning.
start_epoch_step = 0
logger.info(f" Iters left this epoch: {len(loader):,}")
logger.info(f" Beginning epoch {epoch}...")
for batch in loader:
latents, model_kwargs = prepare_model_inputs(
args, batch, device, vae, text_encoder, text_encoder_t5, freqs_cis_img
)
loss_dict = diffusion.training_losses(
model=model, x_start=latents, model_kwargs=model_kwargs
)
loss = loss_dict["loss"].mean()
if args.deepspeed:
model.backward(loss)
else:
loss.backward()
last_batch_iteration = (train_steps + 1) // (
global_batch_size // (batch_size * world_size)
)
if args.deepspeed:
model.step(lr_kwargs={"last_batch_iteration": last_batch_iteration})
else:
opt.step()
scheduler.step(last_batch_iteration=last_batch_iteration)
if args.use_ema:
if args.use_fp16:
target_module = getattr(model.module, 'module', model.module)
ema.update(target_module, step=train_steps)
else:
ema.update(model.module, step=train_steps)
# ===========================================================================
# Log loss values:
# ===========================================================================
running_loss += loss.item()
log_steps += 1
train_steps += 1
if train_steps % args.log_every == 0:
# Measure training speed:
torch.cuda.synchronize()
end_time = time.time()
steps_per_sec = log_steps / (end_time - start_time)
# Reduce loss history over all processes:
avg_loss = torch.tensor(running_loss / log_steps, device=device)
dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
avg_loss = avg_loss.item() / world_size
# get lr from deepspeed fused optimizer
logger.info(
f"(step={train_steps:07d}) "
+ (
f"(update_step={train_steps // args.grad_accu_steps:07d}) "
if args.grad_accu_steps > 1
else ""
)
+ f"Train Loss: {avg_loss:.4f}, "
f"Lr: {opt.param_groups[0]['lr']:.6g}, "
f"Steps/Sec: {steps_per_sec:.2f}, "
f"Samples/Sec: {steps_per_sec * batch_size * world_size:.2f}"
)
# Reset monitoring variables:
running_loss = 0
log_steps = 0
start_time = time.time()
# collect gc:
if args.gc_interval > 0 and (train_steps % args.gc_interval == 0):
gc.collect()
if (
train_steps % args.ckpt_every == 0
or train_steps % args.ckpt_latest_every == 0
) and train_steps > 0:
save_checkpoint(
args,
rank,
logger,
model,
ema,
epoch,
train_steps,
checkpoint_dir,
by="step",
)
if train_steps >= args.max_training_steps:
logger.info(f"Breaking step loop at train_steps={train_steps}.")
break
if train_steps >= args.max_training_steps:
logger.info(f"Breaking epoch loop at epoch={epoch}.")
break
# Finish an epoch
if args.ckpt_every_n_epoch > 0 and epoch % args.ckpt_every_n_epoch == 0:
save_checkpoint(
args,
rank,
logger,
model,
ema,
epoch,
train_steps,
checkpoint_dir,
by="epoch",
)
epoch += 1
save_checkpoint(
args, rank, logger, model, ema, epoch, train_steps, checkpoint_dir, by="final"
)
dist.destroy_process_group()
if __name__ == "__main__":
# Start
main(get_args())
import gc
import json
import os
import random
import sys
import time
from functools import partial
from glob import glob
from pathlib import Path
import numpy as np
import deepspeed
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.utils.data import DataLoader
from torch.distributed.optim import ZeroRedundancyOptimizer
from torchvision.transforms import functional as TF
from diffusers.models import AutoencoderKL
from transformers import BertModel, BertTokenizer, logging as tf_logging
from hydit.config import get_args
from hydit.constants import VAE_EMA_PATH, TEXT_ENCODER, TOKENIZER, T5_ENCODER
from hydit.lr_scheduler import WarmupLR
from hydit.data_loader.arrow_load_stream import TextImageArrowStream
from hydit.diffusion import create_diffusion
from hydit.ds_config import deepspeed_config_from_args
from hydit.modules.ema import EMA
from hydit.modules.fp16_layers import Float16Module
from hydit.modules.models import HUNYUAN_DIT_MODELS
from hydit.modules.controlnet import HunYuanControlNet
from hydit.modules.posemb_layers import init_image_posemb
from hydit.utils.tools import (
create_logger,
set_seeds,
create_exp_folder,
model_resume,
get_trainable_params,
)
from IndexKits.index_kits import ResolutionGroup
from IndexKits.index_kits.sampler import (
DistributedSamplerWithStartIndex,
BlockDistributedSampler,
)
from peft import LoraConfig, get_peft_model
from hydit.annotator.dwpose import DWposeDetector
torch.optim.lr_scheduler.LRScheduler = torch.optim.lr_scheduler._LRScheduler
from transformers import pipeline
import cv2
from PIL import Image
depth_estimator = pipeline(
"depth-estimation", device="cuda:{}".format(int(os.getenv("LOCAL_RANK", "0")))
)
pose_detector = DWposeDetector()
def deepspeed_initialize(args, logger, model, opt, deepspeed_config):
logger.info(f"Initialize deepspeed...")
logger.info(f" Using deepspeed optimizer")
def get_learning_rate_scheduler(warmup_min_lr, lr, warmup_num_steps, opt):
return WarmupLR(opt, warmup_min_lr, lr, warmup_num_steps)
logger.info(
f" Building scheduler with warmup_min_lr={args.warmup_min_lr}, warmup_num_steps={args.warmup_num_steps}"
)
model, opt, _, scheduler = deepspeed.initialize(
model=model,
model_parameters=get_trainable_params(model),
config_params=deepspeed_config,
args=args,
lr_scheduler=(
partial(
get_learning_rate_scheduler,
args.warmup_min_lr,
args.lr,
args.warmup_num_steps,
)
if args.warmup_num_steps > 0
else None
),
)
return model, opt, scheduler
def save_checkpoint(args, rank, logger, model, ema, epoch, train_steps, checkpoint_dir):
def save_lora_weight(checkpoint_dir, client_state, tag=f"{train_steps:07d}.pt"):
cur_ckpt_save_dir = f"{checkpoint_dir}/{tag}"
if rank == 0:
if args.use_fp16:
model.module.module.save_pretrained(cur_ckpt_save_dir)
else:
model.module.save_pretrained(cur_ckpt_save_dir)
checkpoint_path = "[Not rank 0. Disabled output.]"
client_state = {"steps": train_steps, "epoch": epoch, "args": args}
if ema is not None:
client_state["ema"] = ema.state_dict()
dst_paths = []
if train_steps % args.ckpt_every == 0:
checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt"
try:
if args.training_parts == "lora":
save_lora_weight(
checkpoint_dir, client_state, tag=f"{train_steps:07d}.pt"
)
else:
model.save_checkpoint(
checkpoint_dir,
client_state=client_state,
tag=f"{train_steps:07d}.pt",
)
dst_paths.append(checkpoint_path)
logger.info(f"Saved checkpoint to {checkpoint_path}")
except:
logger.error(f"Saved failed to {checkpoint_path}")
if (
train_steps % args.ckpt_latest_every == 0
or train_steps == args.max_training_steps
):
save_name = "latest.pt"
checkpoint_path = f"{checkpoint_dir}/{save_name}"
try:
if args.training_parts == "lora":
save_lora_weight(checkpoint_dir, client_state, tag=f"{save_name}")
else:
model.save_checkpoint(
checkpoint_dir, client_state=client_state, tag=f"{save_name}"
)
dst_paths.append(checkpoint_path)
logger.info(f"Saved checkpoint to {checkpoint_path}")
except:
logger.error(f"Saved failed to {checkpoint_path}")
dist.barrier()
if rank == 0 and len(dst_paths) > 0:
# Delete optimizer states to avoid occupying too much disk space.
for dst_path in dst_paths:
for opt_state_path in glob(f"{dst_path}/*_00_optim_states.pt"):
os.remove(opt_state_path)
return checkpoint_path
def get_canny(np_img, low_threshold=100, high_threshold=200):
# tensor = deNormalize()
# image = tensor_to_img(tensor)
image = cv2.Canny(np_img, low_threshold, high_threshold)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
# canny_tensor = img_to_norm_tensor(canny_img)
return image
def get_depth(np_img):
pil_img = Image.fromarray(np_img)
depth = depth_estimator(pil_img)["depth"]
depth = np.array(depth)
depth = depth[:, :, None]
depth = np.concatenate([depth, depth, depth], axis=2)
return depth
def get_pose(np_img):
return pose_detector(np_img)[0]
@torch.no_grad()
def prepare_model_inputs(
args, batch, device, vae, text_encoder, text_encoder_t5, freqs_cis_img
):
(
image,
text_embedding,
text_embedding_mask,
text_embedding_t5,
text_embedding_mask_t5,
kwargs,
) = batch
# clip & mT5 text embedding
text_embedding = text_embedding.to(device)
text_embedding_mask = text_embedding_mask.to(device)
encoder_hidden_states = text_encoder(
text_embedding.to(device),
attention_mask=text_embedding_mask.to(device),
)[0]
text_embedding_t5 = text_embedding_t5.to(device).squeeze(1)
text_embedding_mask_t5 = text_embedding_mask_t5.to(device).squeeze(1)
with torch.no_grad():
output_t5 = text_encoder_t5(
input_ids=text_embedding_t5,
attention_mask=(
text_embedding_mask_t5 if T5_ENCODER["attention_mask"] else None
),
output_hidden_states=True,
)
encoder_hidden_states_t5 = output_t5["hidden_states"][
T5_ENCODER["layer_index"]
].detach()
# additional condition
if args.size_cond:
image_meta_size = kwargs["image_meta_size"].to(device)
else:
image_meta_size = None
if args.use_style_cond:
style = kwargs["style"].to(device)
else:
style = None
np_img = (
image.squeeze(0)
.add(1)
.mul(255 / 2)
.permute(1, 2, 0)
.cpu()
.numpy()
.astype("uint8")
)
if args.control_type == "canny":
condition = get_canny(np_img)
elif args.control_type == "depth":
condition = get_depth(np_img)
elif args.control_type == "pose":
condition = get_pose(np_img)
else:
raise NotImplementedError
condtion = Image.fromarray(condition)
condition = TF.to_tensor(condition)
condition = TF.normalize(condition, [0.5], [0.5])
condition = condition.unsqueeze(0).to(device)
if args.extra_fp16:
image = image.half()
# Map input images to latent space + normalize latents:
image = image.to(device)
vae_scaling_factor = vae.config.scaling_factor
latents = vae.encode(image).latent_dist.sample().mul_(vae_scaling_factor)
condition = vae.encode(condition).latent_dist.sample().mul_(vae_scaling_factor)
# positional embedding
_, _, height, width = image.shape
reso = f"{height}x{width}"
cos_cis_img, sin_cis_img = freqs_cis_img[reso]
# Model conditions
model_kwargs = dict(
encoder_hidden_states=encoder_hidden_states,
text_embedding_mask=text_embedding_mask,
encoder_hidden_states_t5=encoder_hidden_states_t5,
text_embedding_mask_t5=text_embedding_mask_t5,
image_meta_size=image_meta_size,
style=style,
cos_cis_img=cos_cis_img,
sin_cis_img=sin_cis_img,
condition=condition,
)
return latents, model_kwargs
def main(args):
args.use_ema = False # EMA usage is discouraged during ControlNet training.
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
dist.init_process_group("nccl")
world_size = dist.get_world_size()
batch_size = args.batch_size
grad_accu_steps = args.grad_accu_steps
global_batch_size = world_size * batch_size * grad_accu_steps
rank = dist.get_rank()
device = rank % torch.cuda.device_count()
seed = args.global_seed * world_size + rank
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.cuda.set_device(device)
print(f"Starting rank={rank}, seed={seed}, world_size={world_size}.")
deepspeed_config = deepspeed_config_from_args(args, global_batch_size)
# Setup an experiment folder
experiment_dir, checkpoint_dir, logger = create_exp_folder(args, rank)
# Log all the arguments
logger.info(sys.argv)
logger.info(str(args))
# Save to a json file
args_dict = vars(args)
args_dict["world_size"] = world_size
with open(f"{experiment_dir}/args.json", "w") as f:
json.dump(args_dict, f, indent=4)
# Disable the message "Some weights of the model checkpoint at ... were not used when initializing BertModel."
# If needed, just comment the following line.
tf_logging.set_verbosity_error()
# ===========================================================================
# Building HYDIT
# ===========================================================================
logger.info("Building HYDIT Model.")
# ---------------------------------------------------------------------------
# Training sample base size, such as 256/512/1024. Notice that this size is
# just a base size, not necessary the actual size of training samples. Actual
# size of the training samples are correlated with `resolutions` when enabling
# multi-resolution training.
# ---------------------------------------------------------------------------
image_size = args.image_size
if len(image_size) == 1:
image_size = [image_size[0], image_size[0]]
if len(image_size) != 2:
raise ValueError(f"Invalid image size: {args.image_size}")
assert image_size[0] % 8 == 0 and image_size[1] % 8 == 0, (
"Image size must be divisible by 8 (for the VAE encoder). " f"got {image_size}"
)
latent_size = [image_size[0] // 8, image_size[1] // 8]
# initialize model by deepspeed
assert args.deepspeed, f"Must enable deepspeed in this script: train_deepspeed.py"
with deepspeed.zero.Init(
data_parallel_group=torch.distributed.group.WORLD,
remote_device=None if args.remote_device == "none" else args.remote_device,
config_dict_or_path=deepspeed_config,
mpu=None,
enabled=args.zero_stage == 3,
):
model = HUNYUAN_DIT_MODELS[args.model](
args,
input_size=latent_size,
log_fn=logger.info,
)
controlnet = HunYuanControlNet(
args,
input_size=latent_size,
depth=40,
hidden_size=1408,
patch_size=2,
num_heads=16,
mlp_ratio=4.3637,
log_fn=logger.info,
)
# Multi-resolution / Single-resolution training.
if args.multireso:
resolutions = ResolutionGroup(
image_size[0],
align=16,
step=args.reso_step,
target_ratios=args.target_ratios,
).data
else:
resolutions = ResolutionGroup(
image_size[0], align=16, target_ratios=["1:1"]
).data
freqs_cis_img = init_image_posemb(
args.rope_img,
resolutions=resolutions,
patch_size=model.patch_size,
hidden_size=model.hidden_size,
num_heads=model.num_heads,
log_fn=logger.info,
rope_real=args.rope_real,
)
# Create EMA model and convert to fp16 if needed.
ema = None
if args.use_ema:
ema = EMA(args, model, device, logger)
# Setup gradient checkpointing
if args.gradient_checkpointing:
model.enable_gradient_checkpointing()
controlnet.enable_gradient_checkpointing()
# Setup FP16 main model:
if args.use_fp16:
model = Float16Module(model, args)
controlnet = Float16Module(controlnet, args)
logger.info(
f" Using main model with data type {'fp16' if args.use_fp16 else 'fp32'}"
)
diffusion = create_diffusion(
noise_schedule=args.noise_schedule,
predict_type=args.predict_type,
learn_sigma=args.learn_sigma,
mse_loss_weight_type=args.mse_loss_weight_type,
beta_start=args.beta_start,
beta_end=args.beta_end,
noise_offset=args.noise_offset,
)
# Setup VAE
logger.info(f" Loading vae from {VAE_EMA_PATH}")
vae = AutoencoderKL.from_pretrained(VAE_EMA_PATH)
# Setup BERT text encoder
logger.info(f" Loading Bert text encoder from {TEXT_ENCODER}")
text_encoder = BertModel.from_pretrained(TEXT_ENCODER, False, revision=None)
# Setup BERT tokenizer:
logger.info(f" Loading Bert tokenizer from {TOKENIZER}")
tokenizer = BertTokenizer.from_pretrained(TOKENIZER)
# Setup T5 text encoder
from hydit.modules.text_encoder import MT5Embedder
mt5_path = T5_ENCODER["MT5"]
embedder_t5 = MT5Embedder(
mt5_path, torch_dtype=T5_ENCODER["torch_dtype"], max_length=args.text_len_t5
)
tokenizer_t5 = embedder_t5.tokenizer
text_encoder_t5 = embedder_t5.model
if args.extra_fp16:
logger.info(f" Using fp16 for extra modules: vae, text_encoder")
vae = vae.half().to(device)
text_encoder = text_encoder.half().to(device)
text_encoder_t5 = text_encoder_t5.half().to(device)
else:
vae = vae.to(device)
text_encoder = text_encoder.to(device)
text_encoder_t5 = text_encoder_t5.to(device)
logger.info(
f" Optimizer parameters: lr={args.lr}, weight_decay={args.weight_decay}"
)
logger.info(" Using deepspeed optimizer")
opt = None
# ===========================================================================
# Building Dataset
# ===========================================================================
logger.info(f"Building Streaming Dataset.")
logger.info(f" Loading index file {args.index_file} (v2)")
dataset = TextImageArrowStream(
args=args,
resolution=image_size[0],
random_flip=args.random_flip,
log_fn=logger.info,
index_file=args.index_file,
multireso=args.multireso,
batch_size=batch_size,
world_size=world_size,
random_shrink_size_cond=args.random_shrink_size_cond,
merge_src_cond=args.merge_src_cond,
uncond_p=args.uncond_p,
text_ctx_len=args.text_len,
tokenizer=tokenizer,
uncond_p_t5=args.uncond_p_t5,
text_ctx_len_t5=args.text_len_t5,
tokenizer_t5=tokenizer_t5,
)
if args.multireso:
sampler = BlockDistributedSampler(
dataset,
num_replicas=world_size,
rank=rank,
seed=args.global_seed,
shuffle=False,
drop_last=True,
batch_size=batch_size,
)
else:
sampler = DistributedSamplerWithStartIndex(
dataset,
num_replicas=world_size,
rank=rank,
seed=args.global_seed,
shuffle=False,
drop_last=True,
)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
sampler=sampler,
num_workers=args.num_workers,
pin_memory=True,
drop_last=True,
)
logger.info(f" Dataset contains {len(dataset):,} images.")
logger.info(f" Index file: {args.index_file}.")
if args.multireso:
logger.info(
f" Using MultiResolutionBucketIndexV2 with step {dataset.index_manager.step} "
f"and base size {dataset.index_manager.base_size}"
)
logger.info(f"\n {dataset.index_manager.resolutions}")
# ===========================================================================
# Loading parameter
# ===========================================================================
logger.info(f"Loading parameter")
start_epoch = 0
start_epoch_step = 0
train_steps = 0
# Resume checkpoint if needed
if args.resume:
model, ema, start_epoch, start_epoch_step, train_steps = model_resume(
args, model, ema, logger, len(loader)
)
if args.training_parts == "lora":
loraconfig = LoraConfig(
r=args.rank, lora_alpha=args.rank, target_modules=args.target_modules
)
if args.use_fp16:
model.module = get_peft_model(model.module, loraconfig)
else:
model = get_peft_model(model, loraconfig)
logger.info(f" Training parts: {args.training_parts}")
if args.use_fp16:
controlnet.module.from_dit(model.module)
controlnet.module.set_trainable()
else:
controlnet.from_dit(model)
controlnet.set_trainable()
logger.info(f" ControlNet loaded from DIT")
controlnet, opt, scheduler = deepspeed_initialize(
args, logger, controlnet, opt, deepspeed_config
)
# ===========================================================================
# Training
# ===========================================================================
model.eval()
model.requires_grad_(False)
model = model.to(device)
if args.use_ema:
ema.eval()
print(f" Worker {rank} ready.")
dist.barrier()
iters_per_epoch = len(loader)
logger.info(
" ****************************** Running training ******************************"
)
logger.info(f" Number GPUs: {world_size}")
logger.info(f" Number training samples: {len(dataset):,}")
logger.info(
f" Number parameters: {sum(p.numel() for p in controlnet.parameters()):,}"
)
logger.info(
f" Number trainable params: {sum(p.numel() for p in get_trainable_params(controlnet)):,}"
)
logger.info(
" ------------------------------------------------------------------------------"
)
logger.info(f" Iters per epoch: {iters_per_epoch:,}")
logger.info(f" Batch size per device: {batch_size}")
logger.info(
f" Batch size all device: {batch_size * world_size * grad_accu_steps:,} (world_size * batch_size * grad_accu_steps)"
)
logger.info(f" Gradient Accu steps: {args.grad_accu_steps}")
logger.info(
f" Total optimization steps: {args.epochs * iters_per_epoch // grad_accu_steps:,}"
)
logger.info(f" Training epochs: {start_epoch}/{args.epochs}")
logger.info(
f" Training epoch steps: {start_epoch_step:,}/{iters_per_epoch:,}"
)
logger.info(
f" Training total steps: {train_steps:,}/{min(args.max_training_steps, args.epochs * iters_per_epoch):,}"
)
logger.info(
" ------------------------------------------------------------------------------"
)
logger.info(f" Noise schedule: {args.noise_schedule}")
logger.info(
f" Beta limits: ({args.beta_start}, {args.beta_end})"
)
logger.info(f" Learn sigma: {args.learn_sigma}")
logger.info(f" Prediction type: {args.predict_type}")
logger.info(f" Noise offset: {args.noise_offset}")
logger.info(
" ------------------------------------------------------------------------------"
)
logger.info(f" Using EMA model: {args.use_ema} ({args.ema_dtype})")
if args.use_ema:
logger.info(
f" Using EMA decay: {ema.max_value if args.use_ema else None}"
)
logger.info(
f" Using EMA warmup power: {ema.power if args.use_ema else None}"
)
logger.info(f" Using main model fp16: {args.use_fp16}")
logger.info(f" Using extra modules fp16: {args.extra_fp16}")
logger.info(
" ------------------------------------------------------------------------------"
)
logger.info(f" Experiment directory: {experiment_dir}")
logger.info(
" *******************************************************************************"
)
if args.gc_interval > 0:
gc.disable()
gc.collect()
# Variables for monitoring/logging purposes:
log_steps = 0
running_loss = 0
start_time = time.time()
# Training loop
for epoch in range(start_epoch, args.epochs):
logger.info(f" Start random shuffle with seed={seed}")
# Makesure all processors use the same seed to shuffle dataset.
dataset.shuffle(seed=args.global_seed + epoch, fast=True)
logger.info(f" End of random shuffle")
# Move sampler to start_index
if not args.multireso:
start_index = start_epoch_step * world_size * batch_size
if start_index != sampler.start_index:
sampler.start_index = start_index
# Reset start_epoch_step to zero, to ensure next epoch will start from the beginning.
start_epoch_step = 0
logger.info(f" Iters left this epoch: {len(loader):,}")
logger.info(f" Beginning epoch {epoch}...")
step = 0
for batch in loader:
step += 1
latents, model_kwargs = prepare_model_inputs(
args, batch, device, vae, text_encoder, text_encoder_t5, freqs_cis_img
)
loss_dict = diffusion.training_losses(
model=model,
x_start=latents,
model_kwargs=model_kwargs,
controlnet=controlnet,
)
loss = loss_dict["loss"].mean()
controlnet.backward(loss)
last_batch_iteration = (train_steps + 1) // (
global_batch_size // (batch_size * world_size)
)
controlnet.step(lr_kwargs={"last_batch_iteration": last_batch_iteration})
if args.use_ema:
if args.use_fp16:
ema.update(model.module.module, step=step)
else:
ema.update(model.module, step=step)
# ===========================================================================
# Log loss values:
# ===========================================================================
running_loss += loss.item()
log_steps += 1
train_steps += 1
if train_steps % args.log_every == 0:
# Measure training speed:
torch.cuda.synchronize()
end_time = time.time()
steps_per_sec = log_steps / (end_time - start_time)
# Reduce loss history over all processes:
avg_loss = torch.tensor(running_loss / log_steps, device=device)
dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
avg_loss = avg_loss.item() / world_size
# get lr from deepspeed fused optimizer
logger.info(
f"(step={train_steps:07d}) "
+ (
f"(update_step={train_steps // args.grad_accu_steps:07d}) "
if args.grad_accu_steps > 1
else ""
)
+ f"Train Loss: {avg_loss:.4f}, "
f"Lr: {opt.param_groups[0]['lr']:.6g}, "
f"Steps/Sec: {steps_per_sec:.2f}, "
f"Samples/Sec: {int(steps_per_sec * batch_size * world_size):d}"
)
# Reset monitoring variables:
running_loss = 0
log_steps = 0
start_time = time.time()
# collect gc:
if args.gc_interval > 0 and (step % args.gc_interval == 0):
gc.collect()
if (
train_steps % args.ckpt_every == 0
or train_steps % args.ckpt_latest_every
== 0 # or train_steps == args.max_training_steps
) and train_steps > 0:
save_checkpoint(
args,
rank,
logger,
controlnet,
ema,
epoch,
train_steps,
checkpoint_dir,
)
if train_steps >= args.max_training_steps:
logger.info(f"Breaking step loop at train_steps={train_steps}.")
break
if train_steps >= args.max_training_steps:
logger.info(f"Breaking epoch loop at epoch={epoch}.")
break
dist.destroy_process_group()
if __name__ == "__main__":
# Start
main(get_args())
import gc
import json
import os
import random
import sys
import time
from functools import partial
from glob import glob
import deepspeed
import numpy as np
import torch
import torch.distributed as dist
from diffusers.models import AutoencoderKL
from peft import LoraConfig, get_peft_model
from torch.utils.data import DataLoader
from transformers import BertModel, BertTokenizer, logging as tf_logging
from IndexKits.index_kits import ResolutionGroup
from IndexKits.index_kits.sampler import (
DistributedSamplerWithStartIndex,
BlockDistributedSampler,
)
from hydit.config import get_args
from hydit.constants import VAE_EMA_PATH, TEXT_ENCODER, TOKENIZER, T5_ENCODER
from hydit.data_loader.arrow_load_stream import TextImageArrowStream
from hydit.diffusion import create_diffusion
from hydit.ds_config import deepspeed_config_from_args
from hydit.lr_scheduler import WarmupLR
from hydit.modules.ema import EMA
from hydit.modules.fp16_layers import Float16Module
from hydit.modules.models import HUNYUAN_DIT_MODELS, HunYuanDiT
from hydit.modules.text_encoder import MT5Embedder
from hydit.modules.posemb_layers import init_image_posemb
from hydit.utils.tools import (
create_exp_folder,
model_resume,
get_trainable_params,
get_trainable_params_ipa,
)
from hydit.utils.img_clip_emb import ImgClipEmbDetector
def deepspeed_initialize(args, logger, model, opt, deepspeed_config):
logger.info(f"Initialize deepspeed...")
logger.info(f" Using deepspeed optimizer")
def get_learning_rate_scheduler(warmup_min_lr, lr, warmup_num_steps, opt):
return WarmupLR(opt, warmup_min_lr, lr, warmup_num_steps)
logger.info(
f" Building scheduler with warmup_min_lr={args.warmup_min_lr}, warmup_num_steps={args.warmup_num_steps}"
)
model, opt, _, scheduler = deepspeed.initialize(
model=model,
model_parameters=get_trainable_params_ipa(model, args, freeze_others=True),
config_params=deepspeed_config,
args=args,
lr_scheduler=(
partial(
get_learning_rate_scheduler,
args.warmup_min_lr,
args.lr,
args.warmup_num_steps,
)
if args.warmup_num_steps > 0
else None
),
)
return model, opt, scheduler
def save_checkpoint(
args, rank, logger, model, ema, epoch, train_steps, checkpoint_dir, by="step"
):
def save_lora_weight(checkpoint_dir, client_state, tag=f"{train_steps:07d}.pt"):
cur_ckpt_save_dir = f"{checkpoint_dir}/{tag}"
if rank == 0:
if args.use_fp16:
model.module.module.save_pretrained(cur_ckpt_save_dir)
else:
model.module.save_pretrained(cur_ckpt_save_dir)
def save_ipadapter_weight(
checkpoint_dir, client_state, tag=f"{train_steps:07d}.pt"
):
cur_ckpt_save_dir = f"{checkpoint_dir}/{tag}"
save_state = {}
if rank == 0:
if args.use_fp16:
for param_tensor in model.module.module.state_dict():
if "ip_adapter" in param_tensor:
save_state.update(
{
param_tensor: model.module.module.state_dict()[
param_tensor
]
}
)
torch.save(save_state, cur_ckpt_save_dir)
else:
for param_tensor in model.module.state_dict():
if "ip_adapter" in param_tensor:
save_state.update(
{param_tensor: model.module.state_dict()[param_tensor]}
)
torch.save(save_state, cur_ckpt_save_dir)
def save_model_weight(client_state, tag):
checkpoint_path = f"{checkpoint_dir}/{tag}"
try:
if args.training_parts == "lora":
save_lora_weight(checkpoint_dir, client_state, tag=tag)
elif args.training_parts == "ipadapter":
save_ipadapter_weight(checkpoint_dir, client_state, tag=tag)
else:
model.save_checkpoint(
checkpoint_dir, client_state=client_state, tag=tag
)
logger.info(f"Saved checkpoint to {checkpoint_path}")
except Exception as e:
logger.error(f"Saved failed to {checkpoint_path}. {type(e)}: {e}")
return False, ""
return True, checkpoint_path
client_state = {"steps": train_steps, "epoch": epoch, "args": args}
if ema is not None:
client_state["ema"] = ema.state_dict()
# Save model weights by epoch or step
dst_paths = []
if by == "epoch":
tag = f"e{epoch:04d}.pt"
dst_paths.append(save_model_weight(client_state, tag))
elif by == "step":
if train_steps % args.ckpt_every == 0:
tag = f"{train_steps:07d}.pt"
dst_paths.append(save_model_weight(client_state, tag))
if (
train_steps % args.ckpt_latest_every == 0
or train_steps == args.max_training_steps
):
tag = "latest.pt"
dst_paths.append(save_model_weight(client_state, tag))
elif by == "final":
tag = "final.pt"
dst_paths.append(save_model_weight(client_state, tag))
else:
raise ValueError(f"Unknown save checkpoint method: {by}")
saved = any([state for state, _ in dst_paths])
if not saved:
return False
# Maybe clear optimizer states
if not args.save_optimizer_state:
dist.barrier()
if rank == 0 and len(dst_paths) > 0:
# Delete optimizer states to avoid occupying too much disk space.
for dst_path in dst_paths:
for opt_state_path in glob(f"{dst_path}/zero_*_optim_states.pt"):
os.remove(opt_state_path)
return True
@torch.no_grad()
def prepare_model_inputs(
args, batch, device, vae, text_encoder, img_encoder, text_encoder_t5, freqs_cis_img
):
(
image,
text_embedding,
text_embedding_mask,
text_embedding_t5,
text_embedding_mask_t5,
img_for_clip_tensor,
kwargs,
) = batch
# clip & mT5 text embedding
text_embedding = text_embedding.to(device)
text_embedding_mask = text_embedding_mask.to(device)
encoder_hidden_states = text_encoder(
text_embedding.to(device),
attention_mask=text_embedding_mask.to(device),
)[0]
text_embedding_t5 = text_embedding_t5.to(device).squeeze(1)
text_embedding_mask_t5 = text_embedding_mask_t5.to(device).squeeze(1)
with torch.no_grad():
output_t5 = text_encoder_t5(
input_ids=text_embedding_t5,
attention_mask=(
text_embedding_mask_t5 if T5_ENCODER["attention_mask"] else None
),
output_hidden_states=True,
)
encoder_hidden_states_t5 = output_t5["hidden_states"][
T5_ENCODER["layer_index"]
].detach()
# additional condition
if args.size_cond:
image_meta_size = kwargs["image_meta_size"].to(device)
else:
image_meta_size = None
if args.use_style_cond:
style = kwargs["style"].to(device)
else:
style = None
if args.extra_fp16:
image = image.half()
image_meta_size = (
image_meta_size.half() if image_meta_size is not None else None
)
style = style.half() if style is not None else None
# Map input images to latent space + normalize latents:
image = image.to(device)
vae_scaling_factor = vae.config.scaling_factor
latents = vae.encode(image).latent_dist.sample().mul_(vae_scaling_factor)
# positional embedding
_, _, height, width = image.shape
reso = f"{height}x{width}"
cos_cis_img, sin_cis_img = freqs_cis_img[reso]
img_clip_embedding = img_encoder(img_for_clip_tensor.to(latents.device))
# Model conditions
model_kwargs = dict(
encoder_hidden_states=encoder_hidden_states,
text_embedding_mask=text_embedding_mask,
img_clip_embedding=img_clip_embedding,
encoder_hidden_states_t5=encoder_hidden_states_t5,
text_embedding_mask_t5=text_embedding_mask_t5,
image_meta_size=image_meta_size,
style=style,
cos_cis_img=cos_cis_img,
sin_cis_img=sin_cis_img,
t_scale=1,
i_scale=1,
)
return latents, model_kwargs
def main(args):
if args.training_parts == "lora":
args.use_ema = False
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
dist.init_process_group("nccl")
world_size = dist.get_world_size()
batch_size = args.batch_size
grad_accu_steps = args.grad_accu_steps
global_batch_size = world_size * batch_size * grad_accu_steps
rank = dist.get_rank()
device = rank % torch.cuda.device_count()
seed = args.global_seed * world_size + rank
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.cuda.set_device(device)
print(f"Starting rank={rank}, seed={seed}, world_size={world_size}.")
deepspeed_config = deepspeed_config_from_args(args, global_batch_size)
# Setup an experiment folder
experiment_dir, checkpoint_dir, logger = create_exp_folder(args, rank)
# Log all the arguments
logger.info(sys.argv)
logger.info(str(args))
# Save to a json file
args_dict = vars(args)
args_dict["world_size"] = world_size
with open(f"{experiment_dir}/args.json", "w") as f:
json.dump(args_dict, f, indent=4)
# Disable the message "Some weights of the model checkpoint at ... were not used when initializing BertModel."
# If needed, just comment the following line.
tf_logging.set_verbosity_error()
# ===========================================================================
# Building HYDIT
# ===========================================================================
logger.info("Building HYDIT Model.")
# ---------------------------------------------------------------------------
# Training sample base size, such as 256/512/1024. Notice that this size is
# just a base size, not necessary the actual size of training samples. Actual
# size of the training samples are correlated with `resolutions` when enabling
# multi-resolution training.
# ---------------------------------------------------------------------------
image_size = args.image_size
if len(image_size) == 1:
image_size = [image_size[0], image_size[0]]
if len(image_size) != 2:
raise ValueError(f"Invalid image size: {args.image_size}")
assert image_size[0] % 8 == 0 and image_size[1] % 8 == 0, (
"Image size must be divisible by 8 (for the VAE encoder). " f"got {image_size}"
)
latent_size = [image_size[0] // 8, image_size[1] // 8]
# initialize model by deepspeed
assert (
args.deepspeed
), f"Must enable deepspeed in this script: train_deepspeed_ipadapter.py"
with deepspeed.zero.Init(
data_parallel_group=torch.distributed.group.WORLD,
remote_device=None if args.remote_device == "none" else args.remote_device,
config_dict_or_path=deepspeed_config,
mpu=None,
enabled=args.zero_stage == 3,
):
model = HUNYUAN_DIT_MODELS[args.model](
args,
input_size=latent_size,
log_fn=logger.info,
)
# Multi-resolution / Single-resolution training.
if args.multireso:
resolutions = ResolutionGroup(
image_size[0],
align=16,
step=args.reso_step,
target_ratios=args.target_ratios,
).data
else:
resolutions = ResolutionGroup(
image_size[0], align=16, target_ratios=["1:1"]
).data
freqs_cis_img = init_image_posemb(
args.rope_img,
resolutions=resolutions,
patch_size=model.patch_size,
hidden_size=model.hidden_size,
num_heads=model.num_heads,
log_fn=logger.info,
rope_real=args.rope_real,
)
# Create EMA model and convert to fp16 if needed.
ema = None
if args.use_ema:
ema = EMA(args, model, device, logger)
# Setup gradient checkpointing
if args.gradient_checkpointing:
model.enable_gradient_checkpointing()
# Setup FP16 main model:
if args.use_fp16:
model = Float16Module(model, args)
logger.info(
f" Using main model with data type {'fp16' if args.use_fp16 else 'fp32'}"
)
diffusion = create_diffusion(
noise_schedule=args.noise_schedule,
predict_type=args.predict_type,
learn_sigma=args.learn_sigma,
mse_loss_weight_type=args.mse_loss_weight_type,
beta_start=args.beta_start,
beta_end=args.beta_end,
noise_offset=args.noise_offset,
)
# Setup VAE
logger.info(f" Loading vae from {VAE_EMA_PATH}")
vae = AutoencoderKL.from_pretrained(VAE_EMA_PATH)
# Setup BERT text encoder
logger.info(f" Loading Bert text encoder from {TEXT_ENCODER}")
text_encoder = BertModel.from_pretrained(TEXT_ENCODER, False, revision=None)
# Setup BERT tokenizer:
logger.info(f" Loading Bert tokenizer from {TOKENIZER}")
tokenizer = BertTokenizer.from_pretrained(TOKENIZER)
# Setup T5 text encoder
mt5_path = T5_ENCODER["MT5"]
embedder_t5 = MT5Embedder(
mt5_path, torch_dtype=T5_ENCODER["torch_dtype"], max_length=args.text_len_t5
)
tokenizer_t5 = embedder_t5.tokenizer
text_encoder_t5 = embedder_t5.model
if args.extra_fp16:
logger.info(f" Using fp16 for extra modules: vae, text_encoder")
vae = vae.half().to(device)
text_encoder = text_encoder.half().to(device)
text_encoder_t5 = text_encoder_t5.half().to(device)
else:
vae = vae.to(device)
text_encoder = text_encoder.to(device)
text_encoder_t5 = text_encoder_t5.to(device)
logger.info(
f" Optimizer parameters: lr={args.lr}, weight_decay={args.weight_decay}"
)
logger.info(" Using deepspeed optimizer")
opt = None
img_encoder = ImgClipEmbDetector()
# ===========================================================================
# Building Dataset
# ===========================================================================
logger.info(f"Building Streaming Dataset.")
logger.info(f" Loading index file {args.index_file} (v2)")
dataset = TextImageArrowStream(
args=args,
resolution=image_size[0],
random_flip=args.random_flip,
log_fn=logger.info,
index_file=args.index_file,
multireso=args.multireso,
batch_size=batch_size,
world_size=world_size,
random_shrink_size_cond=args.random_shrink_size_cond,
merge_src_cond=args.merge_src_cond,
uncond_p=args.uncond_p,
text_ctx_len=args.text_len,
tokenizer=tokenizer,
uncond_p_t5=args.uncond_p_t5,
text_ctx_len_t5=args.text_len_t5,
tokenizer_t5=tokenizer_t5,
)
if args.multireso:
sampler = BlockDistributedSampler(
dataset,
num_replicas=world_size,
rank=rank,
seed=args.global_seed,
shuffle=False,
drop_last=True,
batch_size=batch_size,
)
else:
sampler = DistributedSamplerWithStartIndex(
dataset,
num_replicas=world_size,
rank=rank,
seed=args.global_seed,
shuffle=False,
drop_last=True,
)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
sampler=sampler,
num_workers=args.num_workers,
pin_memory=True,
drop_last=True,
)
logger.info(f" Dataset contains {len(dataset):,} images.")
logger.info(f" Index file: {args.index_file}.")
if args.multireso:
logger.info(
f" Using MultiResolutionBucketIndexV2 with step {dataset.index_manager.step} "
f"and base size {dataset.index_manager.base_size}"
)
logger.info(f"\n {dataset.index_manager.resolutions}")
# ===========================================================================
# Loading parameter
# ===========================================================================
logger.info(f"Loading parameter")
start_epoch = 0
start_epoch_step = 0
train_steps = 0
# Resume checkpoint if needed
if args.resume:
model, ema, start_epoch, start_epoch_step, train_steps = model_resume(
args, model, ema, logger, len(loader)
)
if args.training_parts == "lora":
loraconfig = LoraConfig(
r=args.rank, lora_alpha=args.rank, target_modules=args.target_modules
)
if args.use_fp16:
model.module = get_peft_model(model.module, loraconfig)
else:
model = get_peft_model(model, loraconfig)
logger.info(f" Training parts: {args.training_parts}")
model, opt, scheduler = deepspeed_initialize(
args, logger, model, opt, deepspeed_config
)
# ===========================================================================
# Training
# ===========================================================================
model.train()
if args.use_ema:
ema.eval()
print(f" Worker {rank} ready.")
dist.barrier()
iters_per_epoch = len(loader)
logger.info(
" ****************************** Running training ******************************"
)
logger.info(f" Number GPUs: {world_size}")
logger.info(f" Number training samples: {len(dataset):,}")
logger.info(
f" Number parameters: {sum(p.numel() for p in model.parameters()):,}"
)
logger.info(
f" Number trainable params: {sum(p.numel() for p in get_trainable_params_ipa(model, args, freeze_others=True)):,}"
)
logger.info(
" ------------------------------------------------------------------------------"
)
logger.info(f" Iters per epoch: {iters_per_epoch:,}")
logger.info(f" Batch size per device: {batch_size}")
logger.info(
f" Batch size all device: {batch_size * world_size * grad_accu_steps:,} (world_size * batch_size * grad_accu_steps)"
)
logger.info(f" Gradient Accu steps: {args.grad_accu_steps}")
logger.info(
f" Total optimization steps: {args.epochs * iters_per_epoch // grad_accu_steps:,}"
)
logger.info(f" Training epochs: {start_epoch}/{args.epochs}")
logger.info(
f" Training epoch steps: {start_epoch_step:,}/{iters_per_epoch:,}"
)
logger.info(
f" Training total steps: {train_steps:,}/{min(args.max_training_steps, args.epochs * iters_per_epoch):,}"
)
logger.info(
" ------------------------------------------------------------------------------"
)
logger.info(f" Noise schedule: {args.noise_schedule}")
logger.info(
f" Beta limits: ({args.beta_start}, {args.beta_end})"
)
logger.info(f" Learn sigma: {args.learn_sigma}")
logger.info(f" Prediction type: {args.predict_type}")
logger.info(f" Noise offset: {args.noise_offset}")
logger.info(
" ------------------------------------------------------------------------------"
)
logger.info(f" Using EMA model: {args.use_ema} ({args.ema_dtype})")
if args.use_ema:
logger.info(
f" Using EMA decay: {ema.max_value if args.use_ema else None}"
)
logger.info(
f" Using EMA warmup power: {ema.power if args.use_ema else None}"
)
logger.info(f" Using main model fp16: {args.use_fp16}")
logger.info(f" Using extra modules fp16: {args.extra_fp16}")
logger.info(
" ------------------------------------------------------------------------------"
)
logger.info(f" Experiment directory: {experiment_dir}")
logger.info(
" *******************************************************************************"
)
if args.gc_interval > 0:
gc.disable()
gc.collect()
# Variables for monitoring/logging purposes:
log_steps = 0
running_loss = 0
start_time = time.time()
# Training loop
epoch = start_epoch
while epoch < args.epochs:
# Random shuffle dataset
shuffle_seed = args.global_seed + epoch
logger.info(f" Start random shuffle with seed={shuffle_seed}")
# Makesure all processors use the same seed to shuffle dataset.
dataset.shuffle(seed=shuffle_seed, fast=True)
logger.info(f" End of random shuffle")
# Move sampler to start_index
if not args.multireso:
start_index = start_epoch_step * world_size * batch_size
if start_index != sampler.start_index:
sampler.start_index = start_index
# Reset start_epoch_step to zero, to ensure next epoch will start from the beginning.
start_epoch_step = 0
logger.info(f" Iters left this epoch: {len(loader):,}")
logger.info(f" Beginning epoch {epoch}...")
for batch in loader:
latents, model_kwargs = prepare_model_inputs(
args,
batch,
device,
vae,
text_encoder,
img_encoder,
text_encoder_t5,
freqs_cis_img,
)
loss_dict = diffusion.training_losses(
model=model, x_start=latents, model_kwargs=model_kwargs
)
loss = loss_dict["loss"].mean()
model.backward(loss)
last_batch_iteration = (train_steps + 1) // (
global_batch_size // (batch_size * world_size)
)
model.step(lr_kwargs={"last_batch_iteration": last_batch_iteration})
if args.use_ema:
if args.use_fp16:
ema.update(model.module.module, step=train_steps)
else:
ema.update(model.module, step=train_steps)
# ===========================================================================
# Log loss values:
# ===========================================================================
running_loss += loss.item()
log_steps += 1
train_steps += 1
if train_steps % args.log_every == 0:
# Measure training speed:
torch.cuda.synchronize()
end_time = time.time()
steps_per_sec = log_steps / (end_time - start_time)
# Reduce loss history over all processes:
avg_loss = torch.tensor(running_loss / log_steps, device=device)
dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
avg_loss = avg_loss.item() / world_size
# get lr from deepspeed fused optimizer
logger.info(
f"(step={train_steps:07d}) "
+ (
f"(update_step={train_steps // args.grad_accu_steps:07d}) "
if args.grad_accu_steps > 1
else ""
)
+ f"Train Loss: {avg_loss:.4f}, "
f"Lr: {opt.param_groups[0]['lr']:.6g}, "
f"Steps/Sec: {steps_per_sec:.2f}, "
f"Samples/Sec: {int(steps_per_sec * batch_size * world_size):d}"
)
# Reset monitoring variables:
running_loss = 0
log_steps = 0
start_time = time.time()
# collect gc:
if args.gc_interval > 0 and (train_steps % args.gc_interval == 0):
gc.collect()
if (
train_steps % args.ckpt_every == 0
or train_steps % args.ckpt_latest_every == 0
) and train_steps > 0:
save_checkpoint(
args,
rank,
logger,
model,
ema,
epoch,
train_steps,
checkpoint_dir,
by="step",
)
if train_steps >= args.max_training_steps:
logger.info(f"Breaking step loop at train_steps={train_steps}.")
break
if train_steps >= args.max_training_steps:
logger.info(f"Breaking epoch loop at epoch={epoch}.")
break
# Finish an epoch
if args.ckpt_every_n_epoch > 0 and epoch % args.ckpt_every_n_epoch == 0:
save_checkpoint(
args,
rank,
logger,
model,
ema,
epoch,
train_steps,
checkpoint_dir,
by="epoch",
)
epoch += 1
save_checkpoint(
args, rank, logger, model, ema, epoch, train_steps, checkpoint_dir, by="final"
)
dist.destroy_process_group()
if __name__ == "__main__":
# Start
main(get_args())
export NCCL_IB_GID_INDEX=3
export NCCL_IB_SL=3
export NCCL_CHECK_DISABLE=1
export NCCL_P2P_DISABLE=0
export NCCL_IB_DISABLE=0
export NCCL_LL_THRESHOLD=16384
export NCCL_IB_CUDA_SUPPORT=1
export NCCL_SOCKET_IFNAME=bond1
export UCX_NET_DEVICES=bond1
export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_5,mlx5_bond_3,mlx5_bond_7,mlx5_bond_4,mlx5_bond_8,mlx5_bond_2,mlx5_bond_6
export NCCL_COLLNET_ENABLE=0
export SHARP_COLL_ENABLE_SAT=0
export NCCL_NET_GDR_LEVEL=2
export NCCL_IB_QPS_PER_CONNECTION=4
export NCCL_IB_TC=160
export NCCL_PXN_DISABLE=1
task_flag="dit_g2_full_1024p" # the task flag is used to identify folders. # checkpoint root for resume
resume=ckpts/t2i/model
results_dir=./log_EXP # save root for results
batch_size=1 # training batch size
image_size=1024 # training image resolution
grad_accu_steps=1 # gradient accumulation
warmup_num_steps=0 # warm-up steps
lr=0.0001 # learning rate
ckpt_every=100 # create a ckpt every a few steps.
ckpt_latest_every=10000 # create a ckpt named `latest.pt` every a few steps.
ckpt_every_n_epoch=2 # create a ckpt every a few epochs.
epochs=8 # total training epochs
PYTHONPATH=. \
sh $(dirname "$0")/run_g_ipadapter.sh \
--task-flag ${task_flag} \
--noise-schedule scaled_linear --beta-start 0.00085 --beta-end 0.018 \
--predict-type v_prediction \
--multireso \
--reso-step 64 \
--uncond-p 0.22 \
--uncond-p-t5 0.22\
--uncond-p-img 0.05\
--index-file \
your data path \
--random-flip \
--lr ${lr} \
--batch-size ${batch_size} \
--image-size ${image_size} \
--global-seed 999 \
--grad-accu-steps ${grad_accu_steps} \
--warmup-num-steps ${warmup_num_steps} \
--use-flash-attn \
--use-fp16 \
--extra-fp16 \
--results-dir ${results_dir} \
--resume \
--resume-module-root ckpts/t2i/model/pytorch_model_module.pt \
--epochs ${epochs} \
--ckpt-every ${ckpt_every} \
--ckpt-latest-every ${ckpt_latest_every} \
--ckpt-every-n-epoch ${ckpt_every_n_epoch} \
--log-every 10 \
--deepspeed \
--use-zero-stage 2 \
--gradient-checkpointing \
--no-strict \
--training-parts ipadapter \
--is-ipa True \
--resume-ipa Ture \
--resume-ipa-root ckpts/t2i/model/ipa.pt \
"$@"
task_flag="dit_g2_full_1024p" # the task flag is used to identify folders.
resume_module_root=./ckpts/t2i/model_v1_1/pytorch_model_distill.pt # checkpoint root for resume
index_file=dataset/porcelain/jsons/porcelain.json # index file for dataloader
results_dir=./log_EXP # save root for results
batch_size=1 # training batch size
image_size=1024 # training image resolution
grad_accu_steps=1 # gradient accumulation
warmup_num_steps=0 # warm-up steps
lr=0.0001 # learning rate
ckpt_every=9999999 # create a ckpt every a few steps.
ckpt_latest_every=9999999 # create a ckpt named `latest.pt` every a few steps.
ckpt_every_n_epoch=2 # create a ckpt every a few epochs.
epochs=8 # total training epochs
sh $(dirname "$0")/run_g.sh \
--task-flag ${task_flag} \
--noise-schedule scaled_linear --beta-start 0.00085 --beta-end 0.03 \
--predict-type v_prediction \
--uncond-p 0.44 \
--uncond-p-t5 0.44 \
--index-file ${index_file} \
--random-flip \
--lr ${lr} \
--batch-size ${batch_size} \
--image-size ${image_size} \
--global-seed 999 \
--grad-accu-steps ${grad_accu_steps} \
--warmup-num-steps ${warmup_num_steps} \
--use-flash-attn \
--use-fp16 \
--extra-fp16 \
--results-dir ${results_dir} \
--resume \
--resume-module-root ${resume_module_root} \
--epochs ${epochs} \
--ckpt-every ${ckpt_every} \
--ckpt-latest-every ${ckpt_latest_every} \
--ckpt-every-n-epoch ${ckpt_every_n_epoch} \
--log-every 10 \
--deepspeed \
--use-zero-stage 2 \
--gradient-checkpointing \
--use-style-cond \
--size-cond 1024 1024 \
"$@"
# Hunyuan CLIP 模型初始化
import sys
import tqdm
import argparse
from PIL import Image
import copy
import json
import os
import numpy as np
sys.path.append(os.getcwd())
from transformers import BertTokenizer
import torch
from torchvision.transforms import (
Compose,
ToTensor,
Normalize,
Resize,
InterpolationMode,
)
from collections import OrderedDict
from typing import Tuple, Union
from itertools import repeat
import collections.abc
import math
import logging
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.checkpoint import checkpoint
import importlib.util
# if importlib.util.find_spec('flash_attn'):
# FlashMHA = importlib.import_module('flash_attn.flash_attention').FlashMHA
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self.downsample = nn.Sequential(
OrderedDict(
[
("-1", nn.AvgPool2d(stride)),
(
"0",
nn.Conv2d(
inplanes,
planes * self.expansion,
1,
stride=1,
bias=False,
),
),
("1", nn.BatchNorm2d(planes * self.expansion)),
]
)
)
def forward(self, x: torch.Tensor):
identity = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class AttentionPool2d(nn.Module):
def __init__(
self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None
):
super().__init__()
self.positional_embedding = nn.Parameter(
torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5
)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(
2, 0, 1
) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x,
key=x,
value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat(
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False,
)
return x[0]
class ModifiedResNet(nn.Module):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
super().__init__()
self.output_dim = output_dim
self.input_resolution = input_resolution
# the 3-layer stem
self.conv1 = nn.Conv2d(
3, width // 2, kernel_size=3, stride=2, padding=1, bias=False
)
self.bn1 = nn.BatchNorm2d(width // 2)
self.conv2 = nn.Conv2d(
width // 2, width // 2, kernel_size=3, padding=1, bias=False
)
self.bn2 = nn.BatchNorm2d(width // 2)
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.avgpool = nn.AvgPool2d(2)
self.relu = nn.ReLU(inplace=True)
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(
input_resolution // 32, embed_dim, heads, output_dim
)
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
# FIXME support for non-transformer
pass
def forward(self, x):
def stem(x):
for conv, bn in [
(self.conv1, self.bn1),
(self.conv2, self.bn2),
(self.conv3, self.bn3),
]:
x = self.relu(bn(conv(x)))
x = self.avgpool(x)
return x
x = x.type(self.conv1.weight.dtype)
x = stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
return x
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
attn_mask: torch.Tensor = None,
use_flash_attention: bool = False,
):
super().__init__()
# self.attn = nn.MultiheadAttention(d_model, n_head) if not use_flash_attention else FlashMHA(d_model, n_head)
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(
OrderedDict(
[
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model)),
]
)
)
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
self.use_flash_attention = use_flash_attention
def attention(self, x: torch.Tensor):
self.attn_mask = (
self.attn_mask.to(dtype=x.dtype, device=x.device)
if self.attn_mask is not None
else None
)
if self.use_flash_attention:
# Batch first is needed for FlashAttention. See https://github.com/HazyResearch/flash-attention/issues/84 for more information.
return self.attn(x.transpose(1, 0))[0].transpose(1, 0)
else:
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(
self,
width: int,
layers: int,
heads: int,
attn_mask: torch.Tensor = None,
use_flash_attention: bool = False,
):
super().__init__()
self.width = width
self.layers = layers
self.grad_checkpointing = False
self.resblocks = nn.Sequential(
*[
ResidualAttentionBlock(width, heads, attn_mask, use_flash_attention)
for _ in range(layers)
]
)
def forward(self, x: torch.Tensor):
if self.grad_checkpointing and not torch.jit.is_scripting():
for r in self.resblocks:
x = checkpoint(r, x)
return x
return self.resblocks(x)
class VisualTransformer(nn.Module):
def __init__(
self,
input_resolution: int,
patch_size: int,
width: int,
layers: int,
heads: int,
output_dim: int,
use_flash_attention: bool = False,
):
super().__init__()
self.input_resolution = input_resolution
self.grid_size = (
self.input_resolution // patch_size,
self.input_resolution // patch_size,
)
self.output_dim = output_dim
self.conv1 = nn.Conv2d(
in_channels=3,
out_channels=width,
kernel_size=patch_size,
stride=patch_size,
bias=False,
)
scale = width**-0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(
scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)
)
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(
width, layers, heads, use_flash_attention=use_flash_attention
)
self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.transformer.grad_checkpointing = enable
def random_masking(self, x, mask_ratio):
N, L, D = x.shape # batch, length, dim
len_keep = int((L - 1) * (1 - mask_ratio))
noise = torch.rand(N, L - 1, device=x.device)
ids_shuffle = torch.argsort(noise, dim=1) + torch.ones(
N, L - 1, device=x.device, dtype=int
)
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
x0 = x[:, 0, :]
x0 = x0.reshape(N, 1, D)
x_masked_add = torch.cat([x0, x_masked], axis=1)
return x_masked_add
def forward(self, x: torch.Tensor, mask_ratio: float = 0.0):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat(
[
self.class_embedding.to(x.dtype)
+ torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
),
x,
],
dim=1,
) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
if mask_ratio != 0:
x = self.random_masking(x, mask_ratio)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_post(x[:, 0, :])
if self.proj is not None:
x = x @ self.proj
return x
class CLIP(nn.Module):
def __init__(
self,
embed_dim: int,
# vision
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
# text
vocab_size: int,
text_attention_probs_dropout_prob: float,
text_hidden_act: str,
text_hidden_dropout_prob: float,
text_hidden_size: int,
text_initializer_range: float,
text_intermediate_size: int,
text_max_position_embeddings: int,
text_num_attention_heads: int,
text_num_hidden_layers: int,
text_type_vocab_size: int,
# tokenizer = _tokenizer,
# vision head width, added this param for ViT-H
vision_head_width: int = 64,
use_flash_attention: bool = False,
args={},
):
super().__init__()
print("use_flash_attention", use_flash_attention)
if isinstance(vision_layers, (tuple, list)):
vision_heads = vision_width * 32 // vision_head_width
self.visual = ModifiedResNet(
layers=vision_layers,
output_dim=embed_dim,
heads=vision_heads,
input_resolution=image_resolution,
width=vision_width,
)
else:
vision_heads = vision_width // vision_head_width
self.visual = VisualTransformer(
input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim,
use_flash_attention=use_flash_attention,
)
# load isogclr loss tempNetwork
if (
"contrastive_loss_type" in args
and args.contrastive_loss_type == "isogclr_loss"
):
feat_dim = args.isogclr_loss_feat_dim
squeeze_dim = args.isogclr_loss_squeeze_dim
tau_min, tau_max = args.isogclr_loss_tau_min, args.isogclr_loss_tau_max
if args.isogclr_loss_temp_input == "unimodal":
self.image_temp_gen = TempGenerator(
feature_dim=feat_dim,
M=squeeze_dim,
tau_min=tau_min,
tau_max=tau_max,
)
self.text_temp_gen = TempGenerator(
feature_dim=feat_dim,
M=squeeze_dim,
tau_min=tau_min,
tau_max=tau_max,
)
else:
self.temp_gen = TempGenerator(
feature_dim=feat_dim,
M=squeeze_dim,
tau_min=tau_min,
tau_max=tau_max,
)
else:
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.args = args
self.initialize_parameters()
def initialize_parameters(self):
# self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
if isinstance(self.visual, ModifiedResNet):
if self.visual.attnpool is not None:
std = self.visual.attnpool.c_proj.in_features**-0.5
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
for resnet_block in [
self.visual.layer1,
self.visual.layer2,
self.visual.layer3,
self.visual.layer4,
]:
for name, param in resnet_block.named_parameters():
if name.endswith("bn3.weight"):
nn.init.zeros_(param)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.bert.set_grad_checkpointing(enable)
@property
def dtype(self):
return self.visual.conv1.weight.dtype
def encode_image(self, image, mask_ratio=0):
if isinstance(self.visual, ModifiedResNet):
# mask_ratio > 0 (FLIP strategy) is currently only implemented for VisualTransformer.
return self.visual(image.type(self.dtype))
return self.visual(image.type(self.dtype), mask_ratio)
def convert_models_to_fp32(model):
for p in model.parameters():
p.data = p.data.float()
if p.grad:
p.grad.data = p.grad.data.float()
def convert_weights(model: nn.Module):
"""Convert applicable model parameters to fp16"""
def _convert_weights_to_fp16(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.half()
if l.bias is not None:
l.bias.data = l.bias.data.half()
if isinstance(l, nn.MultiheadAttention):
for attr in [
*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
"in_proj_bias",
"bias_k",
"bias_v",
]:
tensor = getattr(l, attr)
if tensor is not None:
tensor.data = tensor.data.half()
if isinstance(l, BertModel):
l.to(torch.half)
for name in ["text_projection", "proj"]:
if hasattr(l, name):
attr = getattr(l, name)
if attr is not None:
attr.data = attr.data.half()
model.apply(_convert_weights_to_fp16)
def convert_state_dict(state_dict):
"""Adapt to Flash Attention"""
if not state_dict:
return state_dict
prefix = "module." if list(state_dict.keys())[0].startswith("module") else ""
if f"{prefix}visual.transformer.resblocks.0.attn.in_proj_weight" in state_dict:
for k in list(state_dict.keys()):
if "attn.in_proj_weight" in k:
state_dict[k.replace("attn.in_proj_weight", "attn.Wqkv.weight")] = (
state_dict.pop(k)
)
elif "attn.in_proj_bias" in k:
state_dict[k.replace("attn.in_proj_bias", "attn.Wqkv.bias")] = (
state_dict.pop(k)
)
elif f"{prefix}visual.transformer.resblocks.0.attn.Wqkv.weight" in state_dict:
for k in list(state_dict.keys()):
if "attn.Wqkv.weight" in k:
state_dict[k.replace("attn.Wqkv.weight", "attn.in_proj_weight")] = (
state_dict.pop(k)
)
elif "attn.Wqkv.bias" in k:
state_dict[k.replace("attn.Wqkv.bias", "attn.in_proj_bias")] = (
state_dict.pop(k)
)
if f"{prefix}bert.encoder.layer.0.attention.self.query.weight" in state_dict:
i = 0
while (
f"{prefix}bert.encoder.layer.{i}.attention.self.query.weight" in state_dict
):
state_dict[f"{prefix}bert.encoder.layer.{i}.attention.self.Wqkv.weight"] = (
torch.cat(
(
state_dict.pop(
f"{prefix}bert.encoder.layer.{i}.attention.self.query.weight"
),
state_dict.pop(
f"{prefix}bert.encoder.layer.{i}.attention.self.key.weight"
),
state_dict.pop(
f"{prefix}bert.encoder.layer.{i}.attention.self.value.weight"
),
)
)
)
state_dict[f"{prefix}bert.encoder.layer.{i}.attention.self.Wqkv.bias"] = (
torch.cat(
(
state_dict.pop(
f"{prefix}bert.encoder.layer.{i}.attention.self.query.bias"
),
state_dict.pop(
f"{prefix}bert.encoder.layer.{i}.attention.self.key.bias"
),
state_dict.pop(
f"{prefix}bert.encoder.layer.{i}.attention.self.value.bias"
),
)
)
)
state_dict[
f"{prefix}bert.encoder.layer.{i}.attention.self.out_proj.weight"
] = state_dict.pop(
f"{prefix}bert.encoder.layer.{i}.attention.output.dense.weight"
)
state_dict[
f"{prefix}bert.encoder.layer.{i}.attention.self.out_proj.bias"
] = state_dict.pop(
f"{prefix}bert.encoder.layer.{i}.attention.output.dense.bias"
)
i += 1
elif f"{prefix}bert.encoder.layer.0.attention.self.Wqkv.weight" in state_dict:
i = 0
while (
f"{prefix}bert.encoder.layer.{i}.attention.self.Wqkv.weight" in state_dict
):
(
state_dict[
f"{prefix}bert.encoder.layer.{i}.attention.self.query.weight"
],
state_dict[f"{prefix}bert.encoder.layer.{i}.attention.self.key.weight"],
state_dict[
f"{prefix}bert.encoder.layer.{i}.attention.self.value.weight"
],
) = torch.chunk(
state_dict.pop(
f"{prefix}bert.encoder.layer.{i}.attention.self.Wqkv.weight"
),
chunks=3,
)
(
state_dict[f"{prefix}bert.encoder.layer.{i}.attention.self.query.bias"],
state_dict[f"{prefix}bert.encoder.layer.{i}.attention.self.key.bias"],
state_dict[f"{prefix}bert.encoder.layer.{i}.attention.self.value.bias"],
) = torch.chunk(
state_dict.pop(
f"{prefix}bert.encoder.layer.{i}.attention.self.Wqkv.bias"
),
chunks=3,
)
state_dict[
f"{prefix}bert.encoder.layer.{i}.attention.output.dense.weight"
] = state_dict.pop(
f"{prefix}bert.encoder.layer.{i}.attention.self.out_proj.weight"
)
state_dict[
f"{prefix}bert.encoder.layer.{i}.attention.output.dense.bias"
] = state_dict.pop(
f"module.bert.encoder.layer.{i}.attention.self.out_proj.bias"
)
i += 1
return state_dict
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = lambda n, x: _ntuple(n)(x)
def _convert_to_rgb(image):
return image.convert("RGB")
def image_transform(image_size=224):
transform = Compose(
[
Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
_convert_to_rgb,
ToTensor(),
Normalize(
(0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711),
),
]
)
return transform
image_preprocess = image_transform(224)
def convert_models_to_fp32(model):
for p in model.parameters():
p.data = p.data.float()
if p.grad:
p.grad.data = p.grad.data.float()
class ImgClipEmbDetector:
def __init__(self):
# 用于预定于的多分类
self.image_preprocess = image_preprocess
config = dict()
config["vision_config"] = "ipadapter/model_configs/ViT-H-14.json"
config["gpu"] = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config["text_config"] = (
"ipadapter/model_configs/RoBERTa-wwm-ext-large-cn-en.json"
)
config["precision"] = "amp"
config["context_length"] = 77
config["resume"] = "ckpts/t2i/clip_img_encoder/clip_img_encoder.pt"
self.cfg = config
self.model = self.build_model()
def build_model(self):
with open(self.cfg["vision_config"], "r") as fv, open(
self.cfg["text_config"], "r"
) as ft:
model_info = json.load(fv)
if isinstance(model_info["vision_layers"], str):
model_info["vision_layers"] = eval(model_info["vision_layers"])
for k, v in json.load(ft).items():
model_info[k] = v
model = CLIP(**model_info)
if self.cfg["precision"] == "fp16":
convert_weights(model)
# See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
if self.cfg["precision"] == "amp" or self.cfg["precision"] == "fp32":
convert_models_to_fp32(model)
if self.cfg["precision"] == "fp16":
convert_weights(model)
checkpoint = torch.load(self.cfg["resume"], map_location="cpu")
sd = checkpoint["state_dict"]
if next(iter(sd.items()))[0].startswith("module"):
sd = {
k[len("module.") :]: v for k, v in sd.items() if "bert.pooler" not in k
}
model.load_state_dict(sd, strict=False)
print(
f"=> loaded checkpoint {self.cfg['resume']} (epoch {checkpoint['epoch']} @ {checkpoint['step']} steps)"
)
model.eval()
loc = self.cfg["gpu"]
model.to(loc)
self.cfg["device"] = loc
return model
def __call__(self, images):
# texts: must be list-of-str
with torch.no_grad():
image_features = self.model.encode_image(images)
image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features
import random
import logging
from pathlib import Path
import shutil
import numpy as np
from PIL import Image
import torch
import torch.distributed as dist
from tqdm.auto import tqdm
import math
import torch.nn.functional as F
import os
def get_trainable_params(model):
params = model.parameters()
params = [p for p in params if p.requires_grad]
return params
def get_trainable_params_ipa(model, args, freeze_others=False):
if args.training_parts == "all":
params = model.parameters()
elif args.training_parts == "time_embedding":
params = [p for n, p in model.named_parameters() if "t_embedder" in n]
if freeze_others:
for n, p in model.named_parameters():
if "t_embedder" not in n:
p.requires_grad_(False)
elif (
args.training_parts == "adapt_concat_to_text_concat"
): # adapt concat to text_concat
def valid_name(n):
if (
"default_modulation" in n
or "image_meta_size_embedder" in n
or "t_embedder" in n
):
return True
return False
params = []
for n, p in model.named_parameters():
if valid_name(n):
params.append(p)
elif freeze_others:
p.requires_grad_(False)
elif args.training_parts == "ipadapter":
params = [p for n, p in model.named_parameters() if "ip_adapter" in n]
# print('params', params)
if freeze_others:
for n, p in model.named_parameters():
if "ip_adapter" not in n:
p.requires_grad_(False)
else:
pass
else:
raise ValueError(f"Unknown training_parts {args.training_parts}")
return params
def set_seeds(seed_list, device=None):
if isinstance(seed_list, (tuple, list)):
seed = sum(seed_list)
else:
seed = seed_list
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
return torch.Generator(device).manual_seed(seed)
def get_start_epoch(resume_path, ckpt, steps_per_epoch):
if "epoch" in ckpt:
start_epoch = ckpt["epoch"]
else:
start_epoch = 0
if "steps" in ckpt:
train_steps = ckpt["steps"]
else:
try:
train_steps = int(Path(resume_path).stem)
except:
train_steps = start_epoch * steps_per_epoch
start_epoch_step = train_steps % steps_per_epoch + 1
return start_epoch, start_epoch_step, train_steps
def assert_shape(*args):
if len(args) < 2:
return
cond = True
fail_str = f"{args[0] if isinstance(args[0], (list, tuple)) else args[0].shape}"
for i in range(1, len(args)):
shape1 = args[i] if isinstance(args[i], (list, tuple)) else args[i].shape
shape2 = (
args[i - 1] if isinstance(args[i - 1], (list, tuple)) else args[i - 1].shape
)
cond = cond and (shape1 == shape2)
fail_str += (
f" vs {args[i] if isinstance(args[i], (list, tuple)) else args[i].shape}"
)
assert cond, fail_str
def create_logger(logging_dir=None, logging_file=None, ddp=True):
"""
Create a logger that writes to a log file and stdout.
"""
if not ddp or (ddp and dist.get_rank() == 0): # real logger
if logging_file is not None:
file_handler = [logging.FileHandler(logging_file)]
elif logging_dir is not None:
file_handler = [logging.FileHandler(f"{logging_dir}/log.txt")]
else:
file_handler = []
logging.basicConfig(
level=logging.INFO,
format="[\033[34m%(asctime)s\033[0m] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler()] + file_handler,
)
logger = logging.getLogger(__name__)
else:
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
return logger
def create_exp_folder(args, rank):
if rank == 0:
os.makedirs(args.results_dir, exist_ok=True)
existed_experiments = list(Path(args.results_dir).glob("*dit*"))
if len(existed_experiments) == 0:
experiment_index = 1
else:
existed_experiments.sort()
print("existed_experiments", existed_experiments)
experiment_index = (
max([int(x.stem.split("-")[0]) for x in existed_experiments]) + 1
)
dist.barrier()
model_string_name = (
args.task_flag if args.task_flag else args.model.replace("/", "-")
)
experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}" # Create an experiment folder
checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints
if rank == 0:
os.makedirs(checkpoint_dir, exist_ok=True)
logger = create_logger(experiment_dir)
logger.info(f"Experiment directory created at {experiment_dir}")
else:
logger = create_logger()
experiment_dir = ""
return experiment_dir, checkpoint_dir, logger
def model_resume(args, model, ema, logger, len_loader):
"""
Load pretrained weights.
"""
start_epoch = 0
start_epoch_step = 0
train_steps = 0
# Resume model
if args.resume:
resume_path = args.resume_module_root
if not Path(resume_path).exists():
raise FileNotFoundError(
f" Cannot find model checkpoint from {resume_path}"
)
logger.info(f" Resume from checkpoint {resume_path}")
resume_ckpt = torch.load(resume_path, map_location=lambda storage, loc: storage)
if "module" in resume_ckpt.keys():
model.load_state_dict(resume_ckpt["module"], strict=args.strict)
else:
model.load_state_dict(resume_ckpt, strict=args.strict)
# Resume EMA model
if args.use_ema:
resume_ema_path = args.resume_ema_root
if not Path(resume_ema_path).exists():
raise FileNotFoundError(
f" Cannot find ema checkpoint from {resume_ema_path}"
)
logger.info(f" Resume from ema checkpoint {resume_path}")
resume_ema_ckpt = torch.load(
resume_ema_path, map_location=lambda storage, loc: storage
)
if "ema" in resume_ema_ckpt.keys():
ema.load_state_dict(resume_ema_ckpt["ema"], strict=args.strict)
elif "module" in resume_ema_ckpt.keys():
ema.load_state_dict(resume_ema_ckpt["module"], strict=args.strict)
else:
ema.load_state_dict(resume_ema_ckpt, strict=args.strict)
if not args.reset_loader:
start_epoch, start_epoch_step, train_steps = get_start_epoch(
args.resume, resume_ckpt, len_loader
)
if args.resume_ipa:
if Path(args.resume_ipa_root).exists():
logger.info(f" Resume from ipa checkpoint {args.resume_ipa_root}")
ipa_state_dict = torch.load(
args.resume_ipa_root, map_location=lambda storage, loc: storage
)
model.load_state_dict(ipa_state_dict, strict=False)
else:
raise FileNotFoundError(
f" Cannot find ipa-checkpoint from {args.resume_ipa_root}"
)
return model, ema, start_epoch, start_epoch_step, train_steps
## Using HunyuanDiT IP-Adapter
### Instructions
The dependencies and installation are basically the same as the base model, and we use the module weights for training.
Download the model using the following commands:
```bash
cd HunyuanDiT
# Use the huggingface-cli tool to download the model.
# We recommend using module weights as the base model for IP-Adapter inference, as our provided pretrained weights are trained on them.
huggingface-cli download Tencent-Hunyuan/IP-Adapter ipa.pt --local-dir ./ckpts/t2i/model
huggingface-cli download Tencent-Hunyuan/IP-Adapter clip_img_encoder.pt --local-dir ./ckpts/t2i/model/clip_img_encoder
# Quick start
python3 sample_ipadapter.py --infer-mode fa --ref-image-path ipadapter/asset/input/tiger.png --i-scale 1.0 --prompt 一只老虎在海洋中游泳,背景是海洋。构图方式是居中构图,呈现了动漫风格和文化,营造了平静的氛围。 --infer-steps 100 --is-ipa True --load-key distill
```
Examples of ref input and IP-Adapter results are as follows:
<table>
<tr>
<td colspan="3" align="center">Ref Input</td>
</tr>
q
<tr>
<td align="center"><img src="asset/input/tiger.png" alt="Image 0" width="200"/></td>
<td align="center"><img src="asset/input/beauty.png" alt="Image 1" width="200"/></td>
<td align="center"><img src="asset/input/xunyicao.png" alt="Image 2" width="200"/></td>
</tr>
<tr>
<td colspan="3" align="center">IP-Adapter Output</td>
</tr>
<tr>
<td align="center">一只老虎在奔跑。<br>(A tiger running.) </td>
<td align="center">一个卡通美女,抱着一只小猪。<br>(A cartoon beauty holding a little pig.) </td>
<td align="center">一片紫色薰衣草地。<br>(A purple lavender field.) </td>
</tr>
<tr>
<td align="center"><img src="asset/output/tiger_run.png" alt="Image 3" width="200"/></td>
<td align="center"><img src="asset/output/beauty_pig.png" alt="Image 4" width="200"/></td>
<td align="center"><img src="asset/output/xunyicao_res.png" alt="Image 5" width="200"/></td>
</tr>
<tr>
<td align="center">一只老虎在看书。<br>(A tiger is reading a book.) </td>
<td align="center">一个卡通美女,穿着绿色衣服。<br>(A cartoon beauty wearing green clothes.) </td>
<td align="center">一片紫色薰衣草地,有一只可爱的小狗。<br>(A purple lavender field with a cute puppy.) </td>
</tr>
<tr>
<td align="center"><img src="asset/output/tiger_book.png" alt="Image 3" width="200"/></td>
<td align="center"><img src="asset/output/beauty_green_cloth.png" alt="Image 4" width="200"/></td>
<td align="center"><img src="asset/output/xunyicao_dog.png" alt="Image 5" width="200"/></td>
</tr>
<tr>
<td align="center">一只老虎在咆哮。<br>(A tiger is roaring.) </td>
<td align="center">一个卡通美女,戴着墨镜。<br>(A cartoon beauty wearing sunglasses.) </td>
<td align="center">水墨风格,一片紫色薰衣草地。<br>(Ink style. A purple lavender field.) </td>
</tr>
<tr>
<td align="center"><img src="asset/output/tiger_roar.png" alt="Image 3" width="200"/></td>
<td align="center"><img src="asset/output/beauty_glass.png" alt="Image 4" width="200"/></td>
<td align="center"><img src="asset/output/xunyicao_style.png" alt="Image 5" width="200"/></td>
</tr>
</table>
### Training
We provide base model weights for IP-Adapter training, you can use `module` weights for IP-Adapter training.
Here is an example, we load the `module` weights into the main model and conduct IP-Adapter training.
If apply multiple resolution training, you need to add the `--multireso` and `--reso-step 64` parameter.
```bash
task_flag="IP_Adapter" # the task flag is used to identify folders. # checkpoint root for resume
index_file=path/to/your/index_file
results_dir=./log_EXP # save root for results
batch_size=1 # training batch size
image_size=1024 # training image resolution
grad_accu_steps=1 # gradient accumulation
warmup_num_steps=0 # warm-up steps
lr=0.0001 # learning rate
ckpt_every=10 # create a ckpt every a few steps.
ckpt_latest_every=10000 # create a ckpt named `latest.pt` every a few steps.
ckpt_every_n_epoch=2 # create a ckpt every a few epochs.
epochs=8 # total training epochs
PYTHONPATH=. \
sh $(dirname "$0")/run_g_ipadapter.sh \
--task-flag ${task_flag} \
--noise-schedule scaled_linear --beta-start 0.00085 --beta-end 0.018 \
--predict-type v_prediction \
--multireso \
--reso-step 64 \
--uncond-p 0.22 \
--uncond-p-t5 0.22\
--uncond-p-img 0.05\
--index-file ${index_file} \
--random-flip \
--lr ${lr} \
--batch-size ${batch_size} \
--image-size ${image_size} \
--global-seed 999 \
--grad-accu-steps ${grad_accu_steps} \
--warmup-num-steps ${warmup_num_steps} \
--use-flash-attn \
--use-fp16 \
--extra-fp16 \
--results-dir ${results_dir} \
--resume\
--resume-module-root ckpts/t2i/model/pytorch_model_module.pt \
--epochs ${epochs} \
--ckpt-every ${ckpt_every} \
--ckpt-latest-every ${ckpt_latest_every} \
--ckpt-every-n-epoch ${ckpt_every_n_epoch} \
--log-every 10 \
--deepspeed \
--use-zero-stage 2 \
--gradient-checkpointing \
--no-strict \
--training-parts ipadapter \
--is-ipa True \
--resume-ipa True \
--resume-ipa-root ckpts/t2i/model/ipa.pt \
"$@"
```
Recommended parameter settings
| Parameter | Description | Recommended Parameter Value | Note|
|:---------------:|:---------:|:---------------------------------------------------:|:--:|
| `--batch-size` | Training batch size | 1 | Depends on GPU memory|
| `--grad-accu-steps` | Size of gradient accumulation | 2 | - |
| `--lr` | Learning rate | 0.0001 | - |
| `--training-parts` | be trained parameters when training IP-Adapter | ipadapter | - |
| `--is-ipa` | training IP-Adapter or not | True | - |
| `--resume-ipa-root` | resume ipa model or not when training | ipa model path | - |
### Inference
Use the following command line for inference.
a. Use the parameter float i-scale to specify the weight of IP-Adapter reference image. The bigger parameter indicates more relativity to reference image.
```bash
python3 sample_ipadapter.py --infer-mode fa --ref-image-path ipadapter/input/beach.png --i-scale 1.0 --prompt 一只老虎在海洋中游泳,背景是海洋。构图方式是居中构图,呈现了动漫风格和文化,营造了平静的氛围。 --infer-steps 100 --is-ipa True --load-key module
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment