main.py 2.22 KB
Newer Older
mashun1's avatar
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from data import VITONDataModule
from model import OOTDiffusion

from argparse import ArgumentParser
from lightning.pytorch.callbacks import Callback, ModelCheckpoint

import lightning as L


def get_args():
    
    parser = ArgumentParser()
    
    # 数据
    parser.add_argument("--data_root", type=str, default="/home/modelzoo/OOTDiffusion/datasets/VITON-HD")
    
    # 模型相关
    parser.add_argument("--vae_path", type=str, default="/home/modelzoo/OOTDiffusion/checkpoints/ootd")
    
mashun1's avatar
ootd  
mashun1 committed
20
    parser.add_argument("--unet_path", type=str, default="/home/modelzoo/OOTDiffusion/checkpoints/sd15")
mashun1's avatar
mashun1 committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61

    parser.add_argument("--model_path", type=str, default="/home/modelzoo/OOTDiffusion/checkpoints/ootd")
    
    parser.add_argument("--vit_path", type=str, default="/home/modelzoo/OOTDiffusion/checkpoints/clip-vit-large-patch14")

    parser.add_argument("--scheduler_path", type=str, default="/home/modelzoo/OOTDiffusion/checkpoints/ootd/scheduler")
    
    parser.add_argument("--mtype", type=str, default="hd")
    
    # 训练
    parser.add_argument("--batch_size", type=int, default=1)

    parser.add_argument("--max_length", type=int, default=128)
    
    parser.add_argument("--lr", type=float, default=5e-5)
    
    parser.add_argument("--lr_scheduler", type=str, default="constant")
    
    args = parser.parse_args()
    
    return args


def main():
    
    args = get_args()
    
    dm = VITONDataModule(args.data_root)
    
    model = OOTDiffusion(args.vae_path,
                         args.unet_path,
                         args.model_path,
                         args.vit_path,
                         args.scheduler_path,
                         args.mtype,
                         args.batch_size,
                         args.max_length,
                         args.lr,
                         args.lr_scheduler)
    
    trainer = L.Trainer(
mashun1's avatar
ootd  
mashun1 committed
62
        max_epochs=50,
mashun1's avatar
mashun1 committed
63
64
65
        accelerator='auto',
        log_every_n_steps=1,
        callbacks=[ModelCheckpoint(every_n_train_steps=6000, save_top_k=-1, save_last=True)],
mashun1's avatar
ootd  
mashun1 committed
66
67
        precision="16-mixed",
        accumulate_grad_batches=32,
mashun1's avatar
mashun1 committed
68
69
    )
    
mashun1's avatar
ootd  
mashun1 committed
70
    trainer.fit(model, dm, ckpt_path="lightning_logs/version_6/checkpoints/last.ckpt")
mashun1's avatar
mashun1 committed
71
72
73
74


if __name__ == "__main__":
    main()