test.py 2.95 KB
Newer Older
pangjm's avatar
pangjm committed
1
2
3
4
import argparse

import torch
import mmcv
Kai Chen's avatar
Kai Chen committed
5
from mmcv.runner import load_checkpoint, parallel_test, obj_from_dict
Kai Chen's avatar
Kai Chen committed
6
7

from mmdet import datasets
8
from mmdet.core import scatter, MMDataParallel, results2json, coco_eval
Kai Chen's avatar
Kai Chen committed
9
from mmdet.datasets import collate, build_dataloader
Kai Chen's avatar
Kai Chen committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from mmdet.models import build_detector, detectors


def single_test(model, data_loader, show=False):
    model.eval()
    results = []
    prog_bar = mmcv.ProgressBar(len(data_loader.dataset))
    for i, data in enumerate(data_loader):
        with torch.no_grad():
            result = model(**data, return_loss=False, rescale=not show)
        results.append(result)

        if show:
            model.module.show_result(data, result,
                                     data_loader.dataset.img_norm_cfg)

        batch_size = data['img'][0].size(0)
        for _ in range(batch_size):
            prog_bar.update()
    return results


def _data_func(data, device_id):
    data = scatter(collate([data], samples_per_gpu=1), [device_id])[0]
    return dict(**data, return_loss=False, rescale=True)
pangjm's avatar
pangjm committed
35
36
37
38
39
40


def parse_args():
    parser = argparse.ArgumentParser(description='MMDet test detector')
    parser.add_argument('config', help='test config file path')
    parser.add_argument('checkpoint', help='checkpoint file')
Kai Chen's avatar
Kai Chen committed
41
    parser.add_argument('--gpus', default=1, type=int)
pangjm's avatar
pangjm committed
42
43
    parser.add_argument('--out', help='output result file')
    parser.add_argument(
Kai Chen's avatar
Kai Chen committed
44
45
46
        '--eval',
        type=str,
        nargs='+',
47
        choices=['proposal', 'proposal_fast', 'bbox', 'segm', 'keypoints'],
Kai Chen's avatar
Kai Chen committed
48
49
        help='eval types')
    parser.add_argument('--show', action='store_true', help='show results')
pangjm's avatar
pangjm committed
50
51
52
53
54
    args = parser.parse_args()
    return args


def main():
55
56
    args = parse_args()

Kai Chen's avatar
Kai Chen committed
57
58
59
60
    cfg = mmcv.Config.fromfile(args.config)
    cfg.model.pretrained = None
    cfg.data.test.test_mode = True

Kai Chen's avatar
Kai Chen committed
61
    dataset = obj_from_dict(cfg.data.test, datasets, dict(test_mode=True))
Kai Chen's avatar
Kai Chen committed
62
63
64
    if args.gpus == 1:
        model = build_detector(
            cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
pangjm's avatar
pangjm committed
65
        load_checkpoint(model, args.checkpoint)
Kai Chen's avatar
Kai Chen committed
66
67
68
69
70
71
72
73
74
75
        model = MMDataParallel(model, device_ids=[0])

        data_loader = build_dataloader(
            dataset,
            imgs_per_gpu=1,
            workers_per_gpu=cfg.data.workers_per_gpu,
            num_gpus=1,
            dist=False,
            shuffle=False)
        outputs = single_test(model, data_loader, args.show)
pangjm's avatar
pangjm committed
76
    else:
Kai Chen's avatar
Kai Chen committed
77
78
79
80
81
        model_args = cfg.model.copy()
        model_args.update(train_cfg=None, test_cfg=cfg.test_cfg)
        model_type = getattr(detectors, model_args.pop('type'))
        outputs = parallel_test(model_type, model_args, args.checkpoint,
                                dataset, _data_func, range(args.gpus))
pangjm's avatar
pangjm committed
82
83

    if args.out:
Kai Chen's avatar
Kai Chen committed
84
85
86
87
88
        mmcv.dump(outputs, args.out)
        if args.eval:
            json_file = args.out + '.json'
            results2json(dataset, outputs, json_file)
            coco_eval(json_file, args.eval, dataset.coco)
pangjm's avatar
pangjm committed
89
90
91
92


if __name__ == '__main__':
    main()