singan_inference.py 3.25 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import sys

import mmcv
import torch
from mmcv import Config
from mmcv.parallel import MMDataParallel
from mmcv.runner import load_checkpoint

# yapf: disable
sys.path.append(os.path.abspath(os.path.join(__file__, '../../..')))  # isort:skip  # noqa

from mmgen.apis import set_random_seed  # isort:skip  # noqa
from mmgen.models import build_model  # isort:skip  # noqa

# yapf: enable


def parse_args():
    parser = argparse.ArgumentParser(description='Evaluate a GAN model')
    parser.add_argument('config', help='evaluation config file path')
    parser.add_argument('checkpoint', help='checkpoint file')
    parser.add_argument('--seed', type=int, default=2021, help='random seed')
    parser.add_argument(
        '--deterministic',
        action='store_true',
        help='whether to set deterministic options for CUDNN backend.')
    parser.add_argument(
        '--samples-path',
        type=str,
        default=None,
        help='path to store images. If not given, remove it after evaluation\
             finished')
    parser.add_argument(
        '--save-prev-res',
        action='store_true',
        help='whether to store the results from previous stages')
    parser.add_argument(
        '--num-samples',
        type=int,
        default=10,
        help='the number of synthesized samples')
    args = parser.parse_args()
    return args


def _tensor2img(img):
    img = img[0].permute(1, 2, 0)
    img = ((img + 1) / 2 * 255).to(torch.uint8)

    return img.cpu().numpy()


@torch.no_grad()
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # set random seeds
    if args.seed is not None:
        set_random_seed(args.seed, deterministic=args.deterministic)

    # build the model and load checkpoint
    model = build_model(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)

    model.eval()

    # load ckpt
    mmcv.print_log(f'Loading ckpt from {args.checkpoint}', 'mmgen')
    _ = load_checkpoint(model, args.checkpoint, map_location='cpu')

    # add dp wrapper
    model = MMDataParallel(model, device_ids=[0])

    pbar = mmcv.ProgressBar(args.num_samples)
    for sample_iter in range(args.num_samples):
        outputs = model(None, num_batches=1, get_prev_res=args.save_prev_res)

        # store results from previous stages
        if args.save_prev_res:
            fake_img = outputs['fake_img']
            prev_res_list = outputs['prev_res_list']
            prev_res_list.append(fake_img)
            for i, img in enumerate(prev_res_list):
                img = _tensor2img(img)
                mmcv.imwrite(
                    img,
                    os.path.join(args.samples_path, f'stage{i}',
                                 f'rand_sample_{sample_iter}.png'))
        # just store the final result
        else:
            img = _tensor2img(outputs)
            mmcv.imwrite(
                img,
                os.path.join(args.samples_path,
                             f'rand_sample_{sample_iter}.png'))

        pbar.update()

    # change the line after pbar
    sys.stdout.write('\n')


if __name__ == '__main__':
    main()