test.py 6.72 KB
Newer Older
pangjm's avatar
pangjm committed
1
import argparse
2
3
4
import os.path as osp
import shutil
import tempfile
pangjm's avatar
pangjm committed
5
6

import mmcv
7
8
9
10
import torch
import torch.distributed as dist
from mmcv.runner import load_checkpoint, get_dist_info
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
Kai Chen's avatar
Kai Chen committed
11

12
from mmdet.apis import init_dist
Kai Chen's avatar
Kai Chen committed
13
from mmdet.core import results2json, coco_eval
14
15
from mmdet.datasets import build_dataloader, get_dataset
from mmdet.models import build_detector
Kai Chen's avatar
Kai Chen committed
16
17


18
def single_gpu_test(model, data_loader, show=False):
Kai Chen's avatar
Kai Chen committed
19
20
    model.eval()
    results = []
21
22
    dataset = data_loader.dataset
    prog_bar = mmcv.ProgressBar(len(dataset))
Kai Chen's avatar
Kai Chen committed
23
24
    for i, data in enumerate(data_loader):
        with torch.no_grad():
Kai Chen's avatar
Kai Chen committed
25
            result = model(return_loss=False, rescale=not show, **data)
Kai Chen's avatar
Kai Chen committed
26
27
28
        results.append(result)

        if show:
29
            model.module.show_result(data, result, dataset.img_norm_cfg)
Kai Chen's avatar
Kai Chen committed
30
31
32
33
34
35
36

        batch_size = data['img'][0].size(0)
        for _ in range(batch_size):
            prog_bar.update()
    return results


37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def multi_gpu_test(model, data_loader, tmpdir=None):
    model.eval()
    results = []
    dataset = data_loader.dataset
    rank, world_size = get_dist_info()
    if rank == 0:
        prog_bar = mmcv.ProgressBar(len(dataset))
    for i, data in enumerate(data_loader):
        with torch.no_grad():
            result = model(return_loss=False, rescale=True, **data)
        results.append(result)

        if rank == 0:
            batch_size = data['img'][0].size(0)
            for _ in range(batch_size * world_size):
                prog_bar.update()

    # collect results from all ranks
    results = collect_results(results, len(dataset), tmpdir)

    return results


def collect_results(result_part, size, tmpdir=None):
    rank, world_size = get_dist_info()
    # create a tmp dir if it is not specified
    if tmpdir is None:
        MAX_LEN = 512
        # 32 is whitespace
66
67
68
69
        dir_tensor = torch.full((MAX_LEN, ),
                                32,
                                dtype=torch.uint8,
                                device='cuda')
70
71
        if rank == 0:
            tmpdir = tempfile.mkdtemp()
72
73
            tmpdir = torch.tensor(
                bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
            dir_tensor[:len(tmpdir)] = tmpdir
        dist.broadcast(dir_tensor, 0)
        tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
    else:
        mmcv.mkdir_or_exist(tmpdir)
    # dump the part result to the dir
    mmcv.dump(result_part, osp.join(tmpdir, 'part_{}.pkl'.format(rank)))
    dist.barrier()
    # collect all parts
    if rank != 0:
        return None
    else:
        # load results of all parts from tmp dir
        part_list = []
        for i in range(world_size):
            part_file = osp.join(tmpdir, 'part_{}.pkl'.format(i))
            part_list.append(mmcv.load(part_file))
        # sort the results
        ordered_results = []
        for res in zip(*part_list):
            ordered_results.extend(list(res))
        # the dataloader may pad some samples
        ordered_results = ordered_results[:size]
        # remove tmp dir
        shutil.rmtree(tmpdir)
        return ordered_results
pangjm's avatar
pangjm committed
100
101
102
103
104
105
106
107


def parse_args():
    parser = argparse.ArgumentParser(description='MMDet test detector')
    parser.add_argument('config', help='test config file path')
    parser.add_argument('checkpoint', help='checkpoint file')
    parser.add_argument('--out', help='output result file')
    parser.add_argument(
Kai Chen's avatar
Kai Chen committed
108
109
110
        '--eval',
        type=str,
        nargs='+',
111
        choices=['proposal', 'proposal_fast', 'bbox', 'segm', 'keypoints'],
Kai Chen's avatar
Kai Chen committed
112
113
        help='eval types')
    parser.add_argument('--show', action='store_true', help='show results')
114
    parser.add_argument('--tmpdir', help='tmp dir for writing some results')
115
116
117
118
119
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
120
    parser.add_argument('--local_rank', type=int, default=0)
pangjm's avatar
pangjm committed
121
122
123
124
125
    args = parser.parse_args()
    return args


def main():
126
127
    args = parse_args()

Kai Chen's avatar
Kai Chen committed
128
129
130
    if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
        raise ValueError('The output file must be a pkl file.')

Kai Chen's avatar
Kai Chen committed
131
    cfg = mmcv.Config.fromfile(args.config)
yhcao6's avatar
yhcao6 committed
132
133
134
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
Kai Chen's avatar
Kai Chen committed
135
136
137
    cfg.model.pretrained = None
    cfg.data.test.test_mode = True

138
139
140
141
142
143
144
145
146
147
    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    # build the dataloader
    # TODO: support multiple images per gpu (only minor changes are needed)
    dataset = get_dataset(cfg.data.test)
148
149
150
151
152
153
    data_loader = build_dataloader(
        dataset,
        imgs_per_gpu=1,
        workers_per_gpu=cfg.data.workers_per_gpu,
        dist=distributed,
        shuffle=False)
154
155
156

    # build the model and load checkpoint
    model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
157
158
159
160
161
162
163
    checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
    # old versions did not save class info in checkpoints, this walkaround is
    # for backward compatibility
    if 'CLASSES' in checkpoint['meta']:
        model.CLASSES = checkpoint['meta']['CLASSES']
    else:
        model.CLASSES = dataset.CLASSES
164
165

    if not distributed:
Kai Chen's avatar
Kai Chen committed
166
        model = MMDataParallel(model, device_ids=[0])
167
        outputs = single_gpu_test(model, data_loader, args.show)
pangjm's avatar
pangjm committed
168
    else:
169
170
171
172
173
174
        model = MMDistributedDataParallel(model.cuda())
        outputs = multi_gpu_test(model, data_loader, args.tmpdir)

    rank, _ = get_dist_info()
    if args.out and rank == 0:
        print('\nwriting results to {}'.format(args.out))
Kai Chen's avatar
Kai Chen committed
175
        mmcv.dump(outputs, args.out)
Kai Chen's avatar
Kai Chen committed
176
177
178
179
180
        eval_types = args.eval
        if eval_types:
            print('Starting evaluate {}'.format(' and '.join(eval_types)))
            if eval_types == ['proposal_fast']:
                result_file = args.out
Kai Chen's avatar
Kai Chen committed
181
                coco_eval(result_file, eval_types, dataset.coco)
Kai Chen's avatar
Kai Chen committed
182
            else:
Kai Chen's avatar
Kai Chen committed
183
184
185
186
187
188
189
190
191
192
193
                if not isinstance(outputs[0], dict):
                    result_file = args.out + '.json'
                    results2json(dataset, outputs, result_file)
                    coco_eval(result_file, eval_types, dataset.coco)
                else:
                    for name in outputs[0]:
                        print('\nEvaluating {}'.format(name))
                        outputs_ = [out[name] for out in outputs]
                        result_file = args.out + '.{}.json'.format(name)
                        results2json(dataset, outputs_, result_file)
                        coco_eval(result_file, eval_types, dataset.coco)
pangjm's avatar
pangjm committed
194
195
196
197


if __name__ == '__main__':
    main()