test.py 4.32 KB
Newer Older
pangjm's avatar
pangjm committed
1
2
3
4
import argparse

import torch
import mmcv
Kai Chen's avatar
Kai Chen committed
5
from mmcv.runner import load_checkpoint, parallel_test, obj_from_dict
pangjm's avatar
pangjm committed
6
from mmcv.parallel import scatter, collate, MMDataParallel
Kai Chen's avatar
Kai Chen committed
7
8

from mmdet import datasets
Kai Chen's avatar
Kai Chen committed
9
from mmdet.core import results2json, coco_eval
pangjm's avatar
pangjm committed
10
from mmdet.datasets import build_dataloader
Kai Chen's avatar
Kai Chen committed
11
12
13
14
15
16
from mmdet.models import build_detector, detectors


def single_test(model, data_loader, show=False):
    model.eval()
    results = []
17
18
    dataset = data_loader.dataset
    prog_bar = mmcv.ProgressBar(len(dataset))
Kai Chen's avatar
Kai Chen committed
19
20
    for i, data in enumerate(data_loader):
        with torch.no_grad():
Kai Chen's avatar
Kai Chen committed
21
            result = model(return_loss=False, rescale=not show, **data)
Kai Chen's avatar
Kai Chen committed
22
23
24
        results.append(result)

        if show:
25
26
            model.module.show_result(data, result, dataset.img_norm_cfg,
                                     dataset.CLASSES)
Kai Chen's avatar
Kai Chen committed
27
28
29
30
31
32
33
34
35

        batch_size = data['img'][0].size(0)
        for _ in range(batch_size):
            prog_bar.update()
    return results


def _data_func(data, device_id):
    data = scatter(collate([data], samples_per_gpu=1), [device_id])[0]
Kai Chen's avatar
Kai Chen committed
36
    return dict(return_loss=False, rescale=True, **data)
pangjm's avatar
pangjm committed
37
38
39
40
41
42


def parse_args():
    parser = argparse.ArgumentParser(description='MMDet test detector')
    parser.add_argument('config', help='test config file path')
    parser.add_argument('checkpoint', help='checkpoint file')
43
44
45
46
47
48
49
    parser.add_argument(
        '--gpus', default=1, type=int, help='GPU number used for testing')
    parser.add_argument(
        '--proc_per_gpu',
        default=1,
        type=int,
        help='Number of processes per GPU')
pangjm's avatar
pangjm committed
50
51
    parser.add_argument('--out', help='output result file')
    parser.add_argument(
Kai Chen's avatar
Kai Chen committed
52
53
54
        '--eval',
        type=str,
        nargs='+',
55
        choices=['proposal', 'proposal_fast', 'bbox', 'segm', 'keypoints'],
Kai Chen's avatar
Kai Chen committed
56
57
        help='eval types')
    parser.add_argument('--show', action='store_true', help='show results')
pangjm's avatar
pangjm committed
58
59
60
61
62
    args = parser.parse_args()
    return args


def main():
63
64
    args = parse_args()

Kai Chen's avatar
Kai Chen committed
65
66
67
    if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
        raise ValueError('The output file must be a pkl file.')

Kai Chen's avatar
Kai Chen committed
68
    cfg = mmcv.Config.fromfile(args.config)
yhcao6's avatar
yhcao6 committed
69
70
71
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
Kai Chen's avatar
Kai Chen committed
72
73
74
    cfg.model.pretrained = None
    cfg.data.test.test_mode = True

Kai Chen's avatar
Kai Chen committed
75
    dataset = obj_from_dict(cfg.data.test, datasets, dict(test_mode=True))
Kai Chen's avatar
Kai Chen committed
76
77
78
    if args.gpus == 1:
        model = build_detector(
            cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
pangjm's avatar
pangjm committed
79
        load_checkpoint(model, args.checkpoint)
Kai Chen's avatar
Kai Chen committed
80
81
82
83
84
85
86
87
88
89
        model = MMDataParallel(model, device_ids=[0])

        data_loader = build_dataloader(
            dataset,
            imgs_per_gpu=1,
            workers_per_gpu=cfg.data.workers_per_gpu,
            num_gpus=1,
            dist=False,
            shuffle=False)
        outputs = single_test(model, data_loader, args.show)
pangjm's avatar
pangjm committed
90
    else:
Kai Chen's avatar
Kai Chen committed
91
92
93
        model_args = cfg.model.copy()
        model_args.update(train_cfg=None, test_cfg=cfg.test_cfg)
        model_type = getattr(detectors, model_args.pop('type'))
94
95
96
97
98
99
100
101
        outputs = parallel_test(
            model_type,
            model_args,
            args.checkpoint,
            dataset,
            _data_func,
            range(args.gpus),
            workers_per_gpu=args.proc_per_gpu)
pangjm's avatar
pangjm committed
102
103

    if args.out:
Kai Chen's avatar
Kai Chen committed
104
        print('writing results to {}'.format(args.out))
Kai Chen's avatar
Kai Chen committed
105
        mmcv.dump(outputs, args.out)
Kai Chen's avatar
Kai Chen committed
106
107
108
109
110
        eval_types = args.eval
        if eval_types:
            print('Starting evaluate {}'.format(' and '.join(eval_types)))
            if eval_types == ['proposal_fast']:
                result_file = args.out
Kai Chen's avatar
Kai Chen committed
111
                coco_eval(result_file, eval_types, dataset.coco)
Kai Chen's avatar
Kai Chen committed
112
            else:
Kai Chen's avatar
Kai Chen committed
113
114
115
116
117
118
119
120
121
122
123
                if not isinstance(outputs[0], dict):
                    result_file = args.out + '.json'
                    results2json(dataset, outputs, result_file)
                    coco_eval(result_file, eval_types, dataset.coco)
                else:
                    for name in outputs[0]:
                        print('\nEvaluating {}'.format(name))
                        outputs_ = [out[name] for out in outputs]
                        result_file = args.out + '.{}.json'.format(name)
                        results2json(dataset, outputs_, result_file)
                        coco_eval(result_file, eval_types, dataset.coco)
pangjm's avatar
pangjm committed
124
125
126
127


if __name__ == '__main__':
    main()