import os import math import torch import numpy as np from train import init_models from data.vitonhd import VITHONHD from model.pipeline import CatVTONPipeline from torch.utils.data import DataLoader from argparse import ArgumentParser from tqdm import tqdm from accelerate import Accelerator from PIL import ImageFilter, Image weight_dtype_maps = { "no": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16 } def repaint(person, mask, result): _, h = result.size kernal_size = h // 50 if kernal_size % 2 == 0: kernal_size += 1 # mask = mask.filter(ImageFilter.GaussianBlur(kernal_size)) person_np = np.array(person) result_np = np.array(result) mask_np = np.array(mask) / 255 repaint_result = person_np * (1 - mask_np) + result_np * mask_np repaint_result = Image.fromarray(repaint_result.astype(np.uint8)) return repaint_result def get_args(): parser = ArgumentParser() parser.add_argument("--model_root", type=str) parser.add_argument("--data_record_path", type=str) parser.add_argument("--vae_subfolder", type=str, default="vae") parser.add_argument("--output_dir", type=str) parser.add_argument("--height", type=int, default=512) parser.add_argument("--width", type=int, default=384) parser.add_argument("--extra_condition_key", type=str) parser.add_argument("--batch_size", type=int, default=16) parser.add_argument("--weight_dtype", type=str, default="bf16") parser.add_argument("--num_inference_steps", type=int, default=50) parser.add_argument("--guidance_scale", type=float, default=0.) parser.add_argument("--repaint", action="store_true") parser.add_argument("--data_nums", type=int, default=None) args = parser.parse_args() return args def main(): args = get_args() accelerator = Accelerator() device = accelerator.device noise_scheduler, vae, unet, optimizer_state_dict = init_models(args.model_root, device=device, vae_subfolder=args.vae_subfolder) del optimizer_state_dict unet.eval() vae.eval() # train is better? vae.to(weight_dtype_maps[args.weight_dtype]) unet.to(weight_dtype_maps[args.weight_dtype]) pipeline = CatVTONPipeline(noise_scheduler, vae, unet) datasets = VITHONHD(args.data_record_path, 512, 384, is_train=False, extra_condition_key=args.extra_condition_key, data_nums=args.data_nums) dataloader = DataLoader(datasets, batch_size=args.batch_size, shuffle=False, num_workers=8) dataloader = accelerator.prepare(dataloader) progress_bar = tqdm(total=math.ceil(len(dataloader)), iterable=dataloader, disable=not accelerator.is_main_process) output_dir = os.path.join(args.output_dir, os.path.join(args.model_root.split("/")[-1], f"cfg_{args.guidance_scale}")) os.makedirs(output_dir, exist_ok=True) with torch.no_grad(): for batch in progress_bar: names = batch['name'] sample = pipeline( image=batch['person'], condition_image=batch['cloth'], mask=batch['mask'], extra_condition=batch[args.extra_condition_key], guidance_scale=args.guidance_scale ) for idx, name in enumerate(names): save_path = os.path.join(output_dir, name.replace(".jpg", '.png')) person = Image.fromarray(batch['person_ori'][idx].cpu().numpy()) mask = Image.fromarray(batch['mask_ori'][idx].cpu().numpy()) result = sample[idx] if args.repaint: result = repaint(person, mask, result) result.save(save_path) if __name__ == "__main__": main()