from data import VITONDataModule from model import OOTDiffusion from argparse import ArgumentParser from lightning.pytorch.callbacks import Callback, ModelCheckpoint import lightning as L def get_args(): parser = ArgumentParser() # 数据 parser.add_argument("--data_root", type=str, default="/home/modelzoo/OOTDiffusion/datasets/VITON-HD") # 模型相关 parser.add_argument("--vae_path", type=str, default="/home/modelzoo/OOTDiffusion/checkpoints/ootd") parser.add_argument("--unet_path", type=str, default="/home/modelzoo/OOTDiffusion/checkpoints/sd15") parser.add_argument("--model_path", type=str, default="/home/modelzoo/OOTDiffusion/checkpoints/ootd") parser.add_argument("--vit_path", type=str, default="/home/modelzoo/OOTDiffusion/checkpoints/clip-vit-large-patch14") parser.add_argument("--scheduler_path", type=str, default="/home/modelzoo/OOTDiffusion/checkpoints/ootd/scheduler") parser.add_argument("--mtype", type=str, default="hd") # 训练 parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--max_length", type=int, default=128) parser.add_argument("--lr", type=float, default=5e-5) parser.add_argument("--lr_scheduler", type=str, default="constant") args = parser.parse_args() return args def main(): args = get_args() dm = VITONDataModule(args.data_root) model = OOTDiffusion(args.vae_path, args.unet_path, args.model_path, args.vit_path, args.scheduler_path, args.mtype, args.batch_size, args.max_length, args.lr, args.lr_scheduler) trainer = L.Trainer( max_epochs=50, accelerator='auto', log_every_n_steps=1, callbacks=[ModelCheckpoint(every_n_train_steps=6000, save_top_k=-1, save_last=True)], precision="16-mixed", accumulate_grad_batches=32, ) trainer.fit(model, dm, ckpt_path="lightning_logs/version_6/checkpoints/last.ckpt") if __name__ == "__main__": main()