test.py 2.94 KB
Newer Older
pangjm's avatar
pangjm committed
1
2
3
4
import argparse

import torch
import mmcv
Kai Chen's avatar
Kai Chen committed
5
6
7
from mmcv.torchpack import load_checkpoint, parallel_test, obj_from_dict

from mmdet import datasets
8
from mmdet.core import scatter, MMDataParallel, results2json, coco_eval
Kai Chen's avatar
Kai Chen committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from mmdet.datasets.loader import collate, build_dataloader
from mmdet.models import build_detector, detectors


def single_test(model, data_loader, show=False):
    model.eval()
    results = []
    prog_bar = mmcv.ProgressBar(len(data_loader.dataset))
    for i, data in enumerate(data_loader):
        with torch.no_grad():
            result = model(**data, return_loss=False, rescale=not show)
        results.append(result)

        if show:
            model.module.show_result(data, result,
                                     data_loader.dataset.img_norm_cfg)

        batch_size = data['img'][0].size(0)
        for _ in range(batch_size):
            prog_bar.update()
    return results


def _data_func(data, device_id):
    data = scatter(collate([data], samples_per_gpu=1), [device_id])[0]
    return dict(**data, return_loss=False, rescale=True)
pangjm's avatar
pangjm committed
35
36
37
38
39
40


def parse_args():
    parser = argparse.ArgumentParser(description='MMDet test detector')
    parser.add_argument('config', help='test config file path')
    parser.add_argument('checkpoint', help='checkpoint file')
Kai Chen's avatar
Kai Chen committed
41
    parser.add_argument('--gpus', default=1, type=int)
pangjm's avatar
pangjm committed
42
43
    parser.add_argument('--out', help='output result file')
    parser.add_argument(
Kai Chen's avatar
Kai Chen committed
44
45
46
47
48
49
        '--eval',
        type=str,
        nargs='+',
        choices=['proposal', 'bbox', 'segm', 'keypoints'],
        help='eval types')
    parser.add_argument('--show', action='store_true', help='show results')
pangjm's avatar
pangjm committed
50
51
52
53
54
55
56
57
    args = parser.parse_args()
    return args


args = parse_args()


def main():
Kai Chen's avatar
Kai Chen committed
58
59
60
61
    cfg = mmcv.Config.fromfile(args.config)
    cfg.model.pretrained = None
    cfg.data.test.test_mode = True

Kai Chen's avatar
Kai Chen committed
62
    dataset = obj_from_dict(cfg.data.test, datasets, dict(test_mode=True))
Kai Chen's avatar
Kai Chen committed
63
64
65
    if args.gpus == 1:
        model = build_detector(
            cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
pangjm's avatar
pangjm committed
66
        load_checkpoint(model, args.checkpoint)
Kai Chen's avatar
Kai Chen committed
67
68
69
70
71
72
73
74
75
76
        model = MMDataParallel(model, device_ids=[0])

        data_loader = build_dataloader(
            dataset,
            imgs_per_gpu=1,
            workers_per_gpu=cfg.data.workers_per_gpu,
            num_gpus=1,
            dist=False,
            shuffle=False)
        outputs = single_test(model, data_loader, args.show)
pangjm's avatar
pangjm committed
77
    else:
Kai Chen's avatar
Kai Chen committed
78
79
80
81
82
        model_args = cfg.model.copy()
        model_args.update(train_cfg=None, test_cfg=cfg.test_cfg)
        model_type = getattr(detectors, model_args.pop('type'))
        outputs = parallel_test(model_type, model_args, args.checkpoint,
                                dataset, _data_func, range(args.gpus))
pangjm's avatar
pangjm committed
83
84

    if args.out:
Kai Chen's avatar
Kai Chen committed
85
86
87
88
89
        mmcv.dump(outputs, args.out)
        if args.eval:
            json_file = args.out + '.json'
            results2json(dataset, outputs, json_file)
            coco_eval(json_file, args.eval, dataset.coco)
pangjm's avatar
pangjm committed
90
91
92
93


if __name__ == '__main__':
    main()