test.py 9.6 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
zhangwenwei's avatar
zhangwenwei committed
2
import argparse
zhangwenwei's avatar
zhangwenwei committed
3
import os
Wenhao Wu's avatar
Wenhao Wu committed
4
import warnings
5
6
7

import mmcv
import torch
zhangwenwei's avatar
zhangwenwei committed
8
from mmcv import Config, DictAction
Wenhao Wu's avatar
Wenhao Wu committed
9
from mmcv.cnn import fuse_conv_bn
zhangwenwei's avatar
zhangwenwei committed
10
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
Wenhao Wu's avatar
Wenhao Wu committed
11
12
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
                         wrap_fp16_model)
zhangwenwei's avatar
zhangwenwei committed
13

VVsssssk's avatar
VVsssssk committed
14
import mmdet
liyinhao's avatar
liyinhao committed
15
from mmdet3d.apis import single_gpu_test
zhangwenwei's avatar
zhangwenwei committed
16
from mmdet3d.datasets import build_dataloader, build_dataset
17
from mmdet3d.models import build_model
liyinhao's avatar
liyinhao committed
18
from mmdet.apis import multi_gpu_test, set_random_seed
Wenhao Wu's avatar
Wenhao Wu committed
19
from mmdet.datasets import replace_ImageToTensor
zhangwenwei's avatar
zhangwenwei committed
20

VVsssssk's avatar
VVsssssk committed
21
22
if mmdet.__version__ > '2.23.0':
    # If mmdet version > 2.23.0, setup_multi_processes would be imported and
23
24
    # used from mmdet instead of mmdet3d.
    from mmdet.utils import setup_multi_processes
VVsssssk's avatar
VVsssssk committed
25
else:
26
27
    from mmdet3d.utils import setup_multi_processes

28
29
30
31
32
33
34
try:
    # If mmdet version > 2.23.0, compat_cfg would be imported and
    # used from mmdet instead of mmdet3d.
    from mmdet.utils import compat_cfg
except ImportError:
    from mmdet3d.utils import compat_cfg

zhangwenwei's avatar
zhangwenwei committed
35
36
37
38
39
40
41
42

def parse_args():
    parser = argparse.ArgumentParser(
        description='MMDet test (and eval) a model')
    parser.add_argument('config', help='test config file path')
    parser.add_argument('checkpoint', help='checkpoint file')
    parser.add_argument('--out', help='output result file in pickle format')
    parser.add_argument(
zhangwenwei's avatar
zhangwenwei committed
43
        '--fuse-conv-bn',
zhangwenwei's avatar
zhangwenwei committed
44
45
46
        action='store_true',
        help='Whether to fuse conv and bn, this will slightly increase'
        'the inference speed')
47
48
49
50
51
52
53
54
55
56
57
58
    parser.add_argument(
        '--gpu-ids',
        type=int,
        nargs='+',
        help='(Deprecated, please use --gpu-id) ids of gpus to use '
        '(only applicable to non-distributed training)')
    parser.add_argument(
        '--gpu-id',
        type=int,
        default=0,
        help='id of gpu to use '
        '(only applicable to non-distributed testing)')
zhangwenwei's avatar
zhangwenwei committed
59
    parser.add_argument(
zhangwenwei's avatar
zhangwenwei committed
60
        '--format-only',
zhangwenwei's avatar
zhangwenwei committed
61
62
63
64
65
66
67
68
69
70
71
        action='store_true',
        help='Format the output results without perform evaluation. It is'
        'useful when you want to format the result to a specific format and '
        'submit it to the test server')
    parser.add_argument(
        '--eval',
        type=str,
        nargs='+',
        help='evaluation metrics, which depends on the dataset, e.g., "bbox",'
        ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC')
    parser.add_argument('--show', action='store_true', help='show results')
liyinhao's avatar
liyinhao committed
72
73
    parser.add_argument(
        '--show-dir', help='directory where results will be saved')
zhangwenwei's avatar
zhangwenwei committed
74
    parser.add_argument(
zhangwenwei's avatar
zhangwenwei committed
75
        '--gpu-collect',
zhangwenwei's avatar
zhangwenwei committed
76
77
78
79
80
        action='store_true',
        help='whether to use gpu to collect results.')
    parser.add_argument(
        '--tmpdir',
        help='tmp directory used for collecting results from multiple '
Wenhao Wu's avatar
Wenhao Wu committed
81
        'workers, available when gpu-collect is not specified')
wuyuefeng's avatar
wuyuefeng committed
82
83
84
85
86
    parser.add_argument('--seed', type=int, default=0, help='random seed')
    parser.add_argument(
        '--deterministic',
        action='store_true',
        help='whether to set deterministic options for CUDNN backend.')
zhangwenwei's avatar
zhangwenwei committed
87
    parser.add_argument(
Wenhao Wu's avatar
Wenhao Wu committed
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    parser.add_argument(
        '--options',
        nargs='+',
        action=DictAction,
        help='custom options for evaluation, the key-value pair in xxx=yyy '
        'format will be kwargs for dataset.evaluate() function (deprecate), '
        'change to --eval-options instead.')
    parser.add_argument(
        '--eval-options',
        nargs='+',
        action=DictAction,
        help='custom options for evaluation, the key-value pair in xxx=yyy '
        'format will be kwargs for dataset.evaluate() function')
zhangwenwei's avatar
zhangwenwei committed
110
111
112
113
114
115
116
117
118
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)
Wenhao Wu's avatar
Wenhao Wu committed
119
120
121
122
123
124
125
126

    if args.options and args.eval_options:
        raise ValueError(
            '--options and --eval-options cannot be both specified, '
            '--options is deprecated in favor of --eval-options')
    if args.options:
        warnings.warn('--options is deprecated in favor of --eval-options')
        args.eval_options = args.options
zhangwenwei's avatar
zhangwenwei committed
127
128
129
130
131
132
    return args


def main():
    args = parse_args()

Wenhao Wu's avatar
Wenhao Wu committed
133
134
    assert args.out or args.eval or args.format_only or args.show \
        or args.show_dir, \
zhangwenwei's avatar
zhangwenwei committed
135
        ('Please specify at least one operation (save/eval/format/show the '
Wenhao Wu's avatar
Wenhao Wu committed
136
137
         'results / save the results) with the argument "--out", "--eval"'
         ', "--format-only", "--show" or "--show-dir"')
zhangwenwei's avatar
zhangwenwei committed
138
139
140
141
142
143
144

    if args.eval and args.format_only:
        raise ValueError('--eval and --format_only cannot be both specified')

    if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
        raise ValueError('The output file must be a pkl file.')

zhangwenwei's avatar
zhangwenwei committed
145
    cfg = Config.fromfile(args.config)
Wenhao Wu's avatar
Wenhao Wu committed
146
147
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)
148

149
150
    cfg = compat_cfg(cfg)

151
152
153
    # set multi-process settings
    setup_multi_processes(cfg)

zhangwenwei's avatar
zhangwenwei committed
154
155
156
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
wuyuefeng's avatar
wuyuefeng committed
157

zhangwenwei's avatar
zhangwenwei committed
158
159
    cfg.model.pretrained = None

160
161
162
163
164
165
166
167
168
    if args.gpu_ids is not None:
        cfg.gpu_ids = args.gpu_ids[0:1]
        warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. '
                      'Because we only support single GPU mode in '
                      'non-distributed testing. Use the first GPU '
                      'in `gpu_ids` now.')
    else:
        cfg.gpu_ids = [args.gpu_id]

zhangwenwei's avatar
zhangwenwei committed
169
170
171
172
173
174
175
    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
    test_dataloader_default_args = dict(
        samples_per_gpu=1, workers_per_gpu=2, dist=distributed, shuffle=False)

    # in case the test dataset is concatenated
    if isinstance(cfg.data.test, dict):
        cfg.data.test.test_mode = True
        if cfg.data.test_dataloader.get('samples_per_gpu', 1) > 1:
            # Replace 'ImageToTensor' to 'DefaultFormatBundle'
            cfg.data.test.pipeline = replace_ImageToTensor(
                cfg.data.test.pipeline)
    elif isinstance(cfg.data.test, list):
        for ds_cfg in cfg.data.test:
            ds_cfg.test_mode = True
        if cfg.data.test_dataloader.get('samples_per_gpu', 1) > 1:
            for ds_cfg in cfg.data.test:
                ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)

    test_loader_cfg = {
        **test_dataloader_default_args,
        **cfg.data.get('test_dataloader', {})
    }

wuyuefeng's avatar
wuyuefeng committed
198
199
200
201
    # set random seeds
    if args.seed is not None:
        set_random_seed(args.seed, deterministic=args.deterministic)

zhangwenwei's avatar
zhangwenwei committed
202
203
    # build the dataloader
    dataset = build_dataset(cfg.data.test)
204
    data_loader = build_dataloader(dataset, **test_loader_cfg)
zhangwenwei's avatar
zhangwenwei committed
205
206

    # build the model and load checkpoint
207
    cfg.model.train_cfg = None
208
    model = build_model(cfg.model, test_cfg=cfg.get('test_cfg'))
zhangwenwei's avatar
zhangwenwei committed
209
210
211
212
213
    fp16_cfg = cfg.get('fp16', None)
    if fp16_cfg is not None:
        wrap_fp16_model(model)
    checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
    if args.fuse_conv_bn:
Wenhao Wu's avatar
Wenhao Wu committed
214
        model = fuse_conv_bn(model)
zhangwenwei's avatar
zhangwenwei committed
215
216
    # old versions did not save class info in checkpoints, this walkaround is
    # for backward compatibility
Wenhao Wu's avatar
Wenhao Wu committed
217
    if 'CLASSES' in checkpoint.get('meta', {}):
zhangwenwei's avatar
zhangwenwei committed
218
219
220
        model.CLASSES = checkpoint['meta']['CLASSES']
    else:
        model.CLASSES = dataset.CLASSES
221
222
223
    # palette for visualization in segmentation tasks
    if 'PALETTE' in checkpoint.get('meta', {}):
        model.PALETTE = checkpoint['meta']['PALETTE']
Ziyi Wu's avatar
Ziyi Wu committed
224
225
    elif hasattr(dataset, 'PALETTE'):
        # segmentation dataset has `PALETTE` attribute
226
        model.PALETTE = dataset.PALETTE
zhangwenwei's avatar
zhangwenwei committed
227
228

    if not distributed:
229
        model = MMDataParallel(model, device_ids=cfg.gpu_ids)
liyinhao's avatar
liyinhao committed
230
        outputs = single_gpu_test(model, data_loader, args.show, args.show_dir)
zhangwenwei's avatar
zhangwenwei committed
231
232
233
234
235
236
237
238
239
240
241
    else:
        model = MMDistributedDataParallel(
            model.cuda(),
            device_ids=[torch.cuda.current_device()],
            broadcast_buffers=False)
        outputs = multi_gpu_test(model, data_loader, args.tmpdir,
                                 args.gpu_collect)

    rank, _ = get_dist_info()
    if rank == 0:
        if args.out:
zhangwenwei's avatar
zhangwenwei committed
242
            print(f'\nwriting results to {args.out}')
zhangwenwei's avatar
zhangwenwei committed
243
            mmcv.dump(outputs, args.out)
Wenhao Wu's avatar
Wenhao Wu committed
244
        kwargs = {} if args.eval_options is None else args.eval_options
zhangwenwei's avatar
zhangwenwei committed
245
246
247
        if args.format_only:
            dataset.format_results(outputs, **kwargs)
        if args.eval:
Wenhao Wu's avatar
Wenhao Wu committed
248
249
250
251
252
253
254
255
256
            eval_kwargs = cfg.get('evaluation', {}).copy()
            # hard-code way to remove EvalHook args
            for key in [
                    'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
                    'rule'
            ]:
                eval_kwargs.pop(key, None)
            eval_kwargs.update(dict(metric=args.eval, **kwargs))
            print(dataset.evaluate(outputs, **eval_kwargs))
zhangwenwei's avatar
zhangwenwei committed
257
258
259
260


if __name__ == '__main__':
    main()