main.py 2.15 KB
Newer Older
mashun1's avatar
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from data import VITONDataModule
from model import OOTDiffusion

from argparse import ArgumentParser
from lightning.pytorch.callbacks import Callback, ModelCheckpoint

import lightning as L


def get_args():
    
    parser = ArgumentParser()
    
    # 数据
    parser.add_argument("--data_root", type=str, default="/home/modelzoo/OOTDiffusion/datasets/VITON-HD")
    
    # 模型相关
    parser.add_argument("--vae_path", type=str, default="/home/modelzoo/OOTDiffusion/checkpoints/ootd")
    
    parser.add_argument("--unet_path", type=str, default="/home/modelzoo/OOTDiffusion/checkpoints/ootd/ootd_dc/checkpoint-36000")

    parser.add_argument("--model_path", type=str, default="/home/modelzoo/OOTDiffusion/checkpoints/ootd")
    
    parser.add_argument("--vit_path", type=str, default="/home/modelzoo/OOTDiffusion/checkpoints/clip-vit-large-patch14")

    parser.add_argument("--scheduler_path", type=str, default="/home/modelzoo/OOTDiffusion/checkpoints/ootd/scheduler")
    
    parser.add_argument("--mtype", type=str, default="hd")
    
    # 训练
    parser.add_argument("--batch_size", type=int, default=1)

    parser.add_argument("--max_length", type=int, default=128)
    
    parser.add_argument("--lr", type=float, default=5e-5)
    
    parser.add_argument("--lr_scheduler", type=str, default="constant")
    
    args = parser.parse_args()
    
    return args


def main():
    
    args = get_args()
    
    dm = VITONDataModule(args.data_root)
    
    model = OOTDiffusion(args.vae_path,
                         args.unet_path,
                         args.model_path,
                         args.vit_path,
                         args.scheduler_path,
                         args.mtype,
                         args.batch_size,
                         args.max_length,
                         args.lr,
                         args.lr_scheduler)
    
    trainer = L.Trainer(
        max_epochs=10,
        accelerator='auto',
        log_every_n_steps=1,
        callbacks=[ModelCheckpoint(every_n_train_steps=6000, save_top_k=-1, save_last=True)],
        precision="16-mixed"
    )
    
    trainer.fit(model, dm)


if __name__ == "__main__":
    main()