Main_Test.py 3.61 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from Diffusion.Train import train, eval, eval_tmp
import os
import argparse
import torch
def main(model_config = None):

    if model_config is not None:
        modelConfig = model_config
    if modelConfig["state"] == "train":
        train(modelConfig)
        modelConfig['batch_size'] = 64
        modelConfig['test_load_weight'] = 'ckpt_{}_.pt'.format(modelConfig['epoch'])
        for i in range(32):
            modelConfig["sampledImgName"] = "sampledImgName{}.png".format(i)
            eval(modelConfig)
    else:
        for i in range(1):
            modelConfig["sampledImgName"] = "sampledImgName{}.png".format(i)
            eval_tmp(modelConfig,1000) # for grid visualization
            # eval(modelConfig) # for metric evaluation

def seed_all(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    import numpy as np
    np.random.seed(args.seed)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--state', type=str, default='eval')
    parser.add_argument('--dataset', type=str, default='cvc') # busi, glas, cvc
    parser.add_argument('--epoch', type=int, default=1000)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--T', type=int, default=1000)
    parser.add_argument('--channel', type=int, default=64)
    parser.add_argument('--test_load_weight', type=str, default='ckpt_1000_.pt')
    parser.add_argument('--num_res_blocks', type=int, default=2)
    parser.add_argument('--dropout', type=float, default=0.15)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--img_size', type=float, default=64) # 64 or 128
    parser.add_argument('--dataset_repeat', type=int, default=1) # didnot use
    parser.add_argument('--seed', type=int, default=0) 
    parser.add_argument('--model', type=str, default='UKan_Hybrid')
    parser.add_argument('--exp_nme', type=str, default='./')

    parser.add_argument('--save_root', type=str, default='released_models/ukan_cvc') 
    # parser.add_argument('--save_root', type=str, default='released_models/ukan_glas') 
    # parser.add_argument('--save_root', type=str, default='released_models/ukan_busi') 
    args = parser.parse_args()

    save_root = args.save_root
    if args.seed != 0:
        seed_all(args)

    modelConfig = {
        "dataset": args.dataset, 
        "state": args.state, # or eval
        "epoch": args.epoch,
        "batch_size": args.batch_size,
        "T": args.T,
        "channel": args.channel,
        "channel_mult": [1, 2, 3, 4],
        "attn": [2],
        "num_res_blocks": args.num_res_blocks,
        "dropout": args.dropout,
        "lr": args.lr,
        "multiplier": 2.,
        "beta_1": 1e-4,
        "beta_T": 0.02,
        "img_size": 64,
        "grad_clip": 1.,
        "device": "cuda", ### MAKE SURE YOU HAVE A GPU !!!
        "training_load_weight": None,
        "save_weight_dir": os.path.join(save_root, args.exp_nme, "Weights"),
        "sampled_dir": os.path.join(save_root, args.exp_nme, "FinalCheck"),
        "test_load_weight": args.test_load_weight,
        "sampledNoisyImgName": "NoisyNoGuidenceImgs.png",
        "sampledImgName": "SampledNoGuidenceImgs.png",
        "nrow": 8,
        "model":args.model, 
        "version": 1,
        "dataset_repeat": args.dataset_repeat,
        "seed": args.seed,
        "save_root": args.save_root,
        }

    os.makedirs(modelConfig["save_weight_dir"], exist_ok=True)
    os.makedirs(modelConfig["sampled_dir"], exist_ok=True)

    main(modelConfig)