test.py 4.18 KB
Newer Older
pangjm's avatar
pangjm committed
1
2
3
4
import argparse

import torch
import mmcv
Kai Chen's avatar
Kai Chen committed
5
from mmcv.runner import load_checkpoint, parallel_test, obj_from_dict
pangjm's avatar
pangjm committed
6
from mmcv.parallel import scatter, collate, MMDataParallel
Kai Chen's avatar
Kai Chen committed
7
8

from mmdet import datasets
Kai Chen's avatar
Kai Chen committed
9
from mmdet.core import results2json, coco_eval
pangjm's avatar
pangjm committed
10
from mmdet.datasets import build_dataloader
Kai Chen's avatar
Kai Chen committed
11
12
13
14
15
16
17
18
19
from mmdet.models import build_detector, detectors


def single_test(model, data_loader, show=False):
    model.eval()
    results = []
    prog_bar = mmcv.ProgressBar(len(data_loader.dataset))
    for i, data in enumerate(data_loader):
        with torch.no_grad():
Kai Chen's avatar
Kai Chen committed
20
            result = model(return_loss=False, rescale=not show, **data)
Kai Chen's avatar
Kai Chen committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
        results.append(result)

        if show:
            model.module.show_result(data, result,
                                     data_loader.dataset.img_norm_cfg)

        batch_size = data['img'][0].size(0)
        for _ in range(batch_size):
            prog_bar.update()
    return results


def _data_func(data, device_id):
    data = scatter(collate([data], samples_per_gpu=1), [device_id])[0]
Kai Chen's avatar
Kai Chen committed
35
    return dict(return_loss=False, rescale=True, **data)
pangjm's avatar
pangjm committed
36
37
38
39
40
41


def parse_args():
    parser = argparse.ArgumentParser(description='MMDet test detector')
    parser.add_argument('config', help='test config file path')
    parser.add_argument('checkpoint', help='checkpoint file')
42
43
44
45
46
47
48
    parser.add_argument(
        '--gpus', default=1, type=int, help='GPU number used for testing')
    parser.add_argument(
        '--proc_per_gpu',
        default=1,
        type=int,
        help='Number of processes per GPU')
pangjm's avatar
pangjm committed
49
50
    parser.add_argument('--out', help='output result file')
    parser.add_argument(
Kai Chen's avatar
Kai Chen committed
51
52
53
        '--eval',
        type=str,
        nargs='+',
54
        choices=['proposal', 'proposal_fast', 'bbox', 'segm', 'keypoints'],
Kai Chen's avatar
Kai Chen committed
55
56
        help='eval types')
    parser.add_argument('--show', action='store_true', help='show results')
pangjm's avatar
pangjm committed
57
58
59
60
61
    args = parser.parse_args()
    return args


def main():
62
63
    args = parse_args()

Kai Chen's avatar
Kai Chen committed
64
65
66
    if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
        raise ValueError('The output file must be a pkl file.')

Kai Chen's avatar
Kai Chen committed
67
68
69
70
    cfg = mmcv.Config.fromfile(args.config)
    cfg.model.pretrained = None
    cfg.data.test.test_mode = True

Kai Chen's avatar
Kai Chen committed
71
    dataset = obj_from_dict(cfg.data.test, datasets, dict(test_mode=True))
Kai Chen's avatar
Kai Chen committed
72
73
74
    if args.gpus == 1:
        model = build_detector(
            cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
pangjm's avatar
pangjm committed
75
        load_checkpoint(model, args.checkpoint)
Kai Chen's avatar
Kai Chen committed
76
77
78
79
80
81
82
83
84
85
        model = MMDataParallel(model, device_ids=[0])

        data_loader = build_dataloader(
            dataset,
            imgs_per_gpu=1,
            workers_per_gpu=cfg.data.workers_per_gpu,
            num_gpus=1,
            dist=False,
            shuffle=False)
        outputs = single_test(model, data_loader, args.show)
pangjm's avatar
pangjm committed
86
    else:
Kai Chen's avatar
Kai Chen committed
87
88
89
        model_args = cfg.model.copy()
        model_args.update(train_cfg=None, test_cfg=cfg.test_cfg)
        model_type = getattr(detectors, model_args.pop('type'))
90
91
92
93
94
95
96
97
        outputs = parallel_test(
            model_type,
            model_args,
            args.checkpoint,
            dataset,
            _data_func,
            range(args.gpus),
            workers_per_gpu=args.proc_per_gpu)
pangjm's avatar
pangjm committed
98
99

    if args.out:
Kai Chen's avatar
Kai Chen committed
100
        print('writing results to {}'.format(args.out))
Kai Chen's avatar
Kai Chen committed
101
        mmcv.dump(outputs, args.out)
Kai Chen's avatar
Kai Chen committed
102
103
104
105
106
        eval_types = args.eval
        if eval_types:
            print('Starting evaluate {}'.format(' and '.join(eval_types)))
            if eval_types == ['proposal_fast']:
                result_file = args.out
Kai Chen's avatar
Kai Chen committed
107
                coco_eval(result_file, eval_types, dataset.coco)
Kai Chen's avatar
Kai Chen committed
108
            else:
Kai Chen's avatar
Kai Chen committed
109
110
111
112
113
114
115
116
117
118
119
                if not isinstance(outputs[0], dict):
                    result_file = args.out + '.json'
                    results2json(dataset, outputs, result_file)
                    coco_eval(result_file, eval_types, dataset.coco)
                else:
                    for name in outputs[0]:
                        print('\nEvaluating {}'.format(name))
                        outputs_ = [out[name] for out in outputs]
                        result_file = args.out + '.{}.json'.format(name)
                        results2json(dataset, outputs_, result_file)
                        coco_eval(result_file, eval_types, dataset.coco)
pangjm's avatar
pangjm committed
120
121
122
123


if __name__ == '__main__':
    main()