import os import torch import random import torch.nn.functional as F from torch.optim import AdamW from accelerate import Accelerator from diffusers import AutoencoderKL, DDIMScheduler, UNet2DConditionModel from tqdm import tqdm from typing import Optional from torch.utils.data import DataLoader from torchvision.utils import save_image from model.attn_processor import SkipAttnProcessor from model.utils import init_adapter from args import get_args from data.vitonhd import VITHONHD from utils import prepare_image, prepare_mask_image, compute_vae_encodings, compute_dream_and_update_latents_for_inpaint from model.pipeline import CatVTONPipeline def init_models(model_root: str, weight_dtype: str = "no", vae_subfolder: str = "vae", device = "cpu"): if weight_dtype == "no": weight_dtype = torch.float32 elif weight_dtype == "fp16": weight_dtype = torch.float16 elif weight_dtype == "bf16": weight_dtype = torch.bfloat16 else: raise NotImplemented print(f"load vae from {vae_subfolder}") vae = AutoencoderKL.from_pretrained(model_root, subfolder=vae_subfolder) unet = UNet2DConditionModel.from_pretrained(model_root, subfolder="unet") try: noise_scheduler = DDIMScheduler.from_pretrained(model_root, subfolder="scheduler") except Exception as e: noise_scheduler = DDIMScheduler.from_pretrained(model_root, subfolder="noise_scheduler") init_adapter(unet, cross_attn_cls=SkipAttnProcessor) vae.to(device) unet.to(device) vae.requires_grad_(False) unet.requires_grad_(False) for name, param in unet.named_modules(): if "attn1" in name: param.requires_grad_(True) unet.train() # unet.enable_gradient_checkpointing() optimizer_path = os.path.join(model_root, "optim.pth") if os.path.exists(optimizer_path): optimizer_state_dict = torch.load(optimizer_path) else: optimizer_state_dict = None return noise_scheduler, vae, unet, optimizer_state_dict def train_one_step(batch, noise_scheduler, vae, unet, device, extra_condition_key): person = prepare_image(batch['person']) cloth = prepare_image(batch['cloth']) mask = prepare_mask_image(batch['mask']) masked_person = person * (mask < 0.5) person_latent = compute_vae_encodings(person, vae) # 加噪 masked_person_latent = compute_vae_encodings(masked_person, vae) if random.random() < 0.15: # for cfg cloth_latent = torch.zeros_like(masked_person_latent).to(device).to(masked_person_latent.dtype) else: cloth_latent = compute_vae_encodings(cloth, vae) mask_latent = F.interpolate(mask, size=masked_person_latent.shape[-2:], mode="nearest") bsz = masked_person_latent.shape[0] timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz, )).to(device).long() first_input_latent = torch.concat([person_latent, cloth_latent], dim=-2) noise = torch.randn_like(first_input_latent) noisy_first_latent = noise_scheduler.add_noise(first_input_latent, noise, timesteps) masked_latent_concat = torch.cat([masked_person_latent, cloth_latent], dim=-2) extra_condition = batch.get(extra_condition_key, None) extra_condition = F.interpolate(extra_condition, size=mask_latent.shape[-2:], mode="nearest") mask_latent_concat = torch.cat([mask_latent, extra_condition], dim=-2) inpainting_latent_model_input = torch.cat([noisy_first_latent, mask_latent_concat, masked_latent_concat], dim=1) noise_pred = unet( inpainting_latent_model_input, timesteps, encoder_hidden_states=None, return_dict=False )[0] loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") return loss def main(): args = get_args() accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.weight_dtype ) device = accelerator.device train_dataset = VITHONHD(args.train_data_record_path, args.height, args.width, extra_condition_key=args.extra_condition_key) train_dataloader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True) noise_scheduler, vae, unet, optimizer_state_dict = init_models(args.model_root, device=device, vae_subfolder=args.vae_subfolder) optimizer = AdamW(unet.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), eps=args.eps, weight_decay=args.weight_decay) if optimizer_state_dict: print("加载优化器状态") optimizer.load_state_dict(optimizer_state_dict) if accelerator.is_main_process: eval_dataset = VITHONHD(args.eval_data_record_path, args.height, args.width, is_train=False, extra_condition_key=args.extra_condition_key) eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1, num_workers=args.num_workers) else: eval_dataloader = None ( unet, optimizer, train_dataloader ) = accelerator.prepare( unet, optimizer, train_dataloader ) global_step = args.global_steps reach_max_steps = False progress_bar = tqdm(initial=global_step, total=args.max_steps, disable=not accelerator.is_main_process) progress_bar.set_description("train catvton") while True: if reach_max_steps: print("到达最大训练步数,停止训练") break avg_loss = 0. for batch in train_dataloader: with accelerator.accumulate(unet): with accelerator.autocast(): loss = train_one_step( batch, noise_scheduler, vae, unet, device, args.extra_condition_key ) avg_loss += loss.item() accelerator.backward(loss) # TODO: 需要关注 if accelerator.sync_gradients: accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) optimizer.step() optimizer.zero_grad() if accelerator.sync_gradients: avg_loss = torch.tensor(avg_loss).to(device) avg_loss = accelerator.gather(avg_loss).mean().item() / accelerator.gradient_accumulation_steps progress_bar.update(1) logs = {"step_loss": avg_loss, "global_steps": global_step} progress_bar.set_postfix(**logs) global_step += 1 avg_loss = 0. # 验证并保存模型 if global_step % args.logging_steps == 0 or global_step >= args.max_steps: if accelerator.is_main_process: unwrap_unet = accelerator.unwrap_model(unet) unwrap_unet.eval() pipeline = CatVTONPipeline(noise_scheduler, vae, unwrap_unet) os.makedirs(f"../eval_outputs/{args.eval_output_dir}/{global_step}", exist_ok=True) with torch.no_grad(): for idx, batch in enumerate(eval_dataloader): if args.extra_condition_key: sample = pipeline( image=batch['person'], condition_image=batch['cloth'], mask=batch['mask'], extra_condition=batch[args.extra_condition_key] )[0] else: sample = pipeline( image=batch['person'], condition_image=batch['cloth'], mask=batch['mask'] )[0] sample.save(f"../eval_outputs/{args.eval_output_dir}/{global_step}/{idx}.png") save_path = os.path.join(args.output_dir, args.checkpoint_dir) pipeline.save_pretrained(save_path) torch.save(optimizer.state_dict(), f"{save_path}/optim.pth") del pipeline del unwrap_unet torch.cuda.empty_cache() if global_step >= args.max_steps: reach_max_steps = True break if __name__ == "__main__": main()