test.py 5.47 KB
Newer Older
zhangwenwei's avatar
zhangwenwei committed
1
2
import argparse
import mmcv
zhangwenwei's avatar
zhangwenwei committed
3
import os
zhangwenwei's avatar
zhangwenwei committed
4
import torch
zhangwenwei's avatar
zhangwenwei committed
5
from mmcv import Config, DictAction
zhangwenwei's avatar
zhangwenwei committed
6
7
8
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import get_dist_info, init_dist, load_checkpoint

liyinhao's avatar
liyinhao committed
9
from mmdet3d.apis import single_gpu_test
zhangwenwei's avatar
zhangwenwei committed
10
11
from mmdet3d.datasets import build_dataloader, build_dataset
from mmdet3d.models import build_detector
liyinhao's avatar
liyinhao committed
12
from mmdet.apis import multi_gpu_test, set_random_seed
zhangwenwei's avatar
zhangwenwei committed
13
from mmdet.core import wrap_fp16_model
zhangwenwei's avatar
zhangwenwei committed
14
from tools.fuse_conv_bn import fuse_module
zhangwenwei's avatar
zhangwenwei committed
15
16
17
18
19
20
21
22
23


def parse_args():
    parser = argparse.ArgumentParser(
        description='MMDet test (and eval) a model')
    parser.add_argument('config', help='test config file path')
    parser.add_argument('checkpoint', help='checkpoint file')
    parser.add_argument('--out', help='output result file in pickle format')
    parser.add_argument(
zhangwenwei's avatar
zhangwenwei committed
24
        '--fuse-conv-bn',
zhangwenwei's avatar
zhangwenwei committed
25
26
27
28
        action='store_true',
        help='Whether to fuse conv and bn, this will slightly increase'
        'the inference speed')
    parser.add_argument(
zhangwenwei's avatar
zhangwenwei committed
29
        '--format-only',
zhangwenwei's avatar
zhangwenwei committed
30
31
32
33
34
35
36
37
38
39
40
        action='store_true',
        help='Format the output results without perform evaluation. It is'
        'useful when you want to format the result to a specific format and '
        'submit it to the test server')
    parser.add_argument(
        '--eval',
        type=str,
        nargs='+',
        help='evaluation metrics, which depends on the dataset, e.g., "bbox",'
        ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC')
    parser.add_argument('--show', action='store_true', help='show results')
liyinhao's avatar
liyinhao committed
41
42
    parser.add_argument(
        '--show-dir', help='directory where results will be saved')
zhangwenwei's avatar
zhangwenwei committed
43
    parser.add_argument(
zhangwenwei's avatar
zhangwenwei committed
44
        '--gpu-collect',
zhangwenwei's avatar
zhangwenwei committed
45
46
47
48
49
50
        action='store_true',
        help='whether to use gpu to collect results.')
    parser.add_argument(
        '--tmpdir',
        help='tmp directory used for collecting results from multiple '
        'workers, available when gpu_collect is not specified')
wuyuefeng's avatar
wuyuefeng committed
51
52
53
54
55
    parser.add_argument('--seed', type=int, default=0, help='random seed')
    parser.add_argument(
        '--deterministic',
        action='store_true',
        help='whether to set deterministic options for CUDNN backend.')
zhangwenwei's avatar
zhangwenwei committed
56
    parser.add_argument(
zhangwenwei's avatar
zhangwenwei committed
57
        '--options', nargs='+', action=DictAction, help='custom options')
zhangwenwei's avatar
zhangwenwei committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)
    return args


def main():
    args = parse_args()

    assert args.out or args.eval or args.format_only or args.show, \
        ('Please specify at least one operation (save/eval/format/show the '
         'results) with the argument "--out", "--eval", "--format_only" '
         'or "--show"')

    if args.eval and args.format_only:
        raise ValueError('--eval and --format_only cannot be both specified')

    if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
        raise ValueError('The output file must be a pkl file.')

zhangwenwei's avatar
zhangwenwei committed
84
    cfg = Config.fromfile(args.config)
zhangwenwei's avatar
zhangwenwei committed
85
86
87
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
wuyuefeng's avatar
wuyuefeng committed
88

zhangwenwei's avatar
zhangwenwei committed
89
90
91
92
93
94
95
96
97
98
    cfg.model.pretrained = None
    cfg.data.test.test_mode = True

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

wuyuefeng's avatar
wuyuefeng committed
99
100
101
102
    # set random seeds
    if args.seed is not None:
        set_random_seed(args.seed, deterministic=args.deterministic)

zhangwenwei's avatar
zhangwenwei committed
103
    # build the dataloader
104
    samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)
zhangwenwei's avatar
zhangwenwei committed
105
106
107
    dataset = build_dataset(cfg.data.test)
    data_loader = build_dataloader(
        dataset,
108
        samples_per_gpu=samples_per_gpu,
zhangwenwei's avatar
zhangwenwei committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        workers_per_gpu=cfg.data.workers_per_gpu,
        dist=distributed,
        shuffle=False)

    # build the model and load checkpoint
    model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
    fp16_cfg = cfg.get('fp16', None)
    if fp16_cfg is not None:
        wrap_fp16_model(model)
    checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
    if args.fuse_conv_bn:
        model = fuse_module(model)
    # old versions did not save class info in checkpoints, this walkaround is
    # for backward compatibility
    if 'CLASSES' in checkpoint['meta']:
        model.CLASSES = checkpoint['meta']['CLASSES']
    else:
        model.CLASSES = dataset.CLASSES

    if not distributed:
        model = MMDataParallel(model, device_ids=[0])
liyinhao's avatar
liyinhao committed
130
        outputs = single_gpu_test(model, data_loader, args.show, args.show_dir)
zhangwenwei's avatar
zhangwenwei committed
131
132
133
134
135
136
137
138
139
140
141
    else:
        model = MMDistributedDataParallel(
            model.cuda(),
            device_ids=[torch.cuda.current_device()],
            broadcast_buffers=False)
        outputs = multi_gpu_test(model, data_loader, args.tmpdir,
                                 args.gpu_collect)

    rank, _ = get_dist_info()
    if rank == 0:
        if args.out:
zhangwenwei's avatar
zhangwenwei committed
142
            print(f'\nwriting results to {args.out}')
zhangwenwei's avatar
zhangwenwei committed
143
144
145
146
147
148
149
150
151
152
            mmcv.dump(outputs, args.out)
        kwargs = {} if args.options is None else args.options
        if args.format_only:
            dataset.format_results(outputs, **kwargs)
        if args.eval:
            dataset.evaluate(outputs, args.eval, **kwargs)


if __name__ == '__main__':
    main()