test.py 10.1 KB
Newer Older
pangjm's avatar
pangjm committed
1
import argparse
lizz's avatar
lizz committed
2
import os
3
import os.path as osp
4
import pickle
5
6
import shutil
import tempfile
pangjm's avatar
pangjm committed
7
8

import mmcv
9
10
11
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
12
from mmcv.runner import get_dist_info, load_checkpoint
Kai Chen's avatar
Kai Chen committed
13

14
from mmdet.apis import init_dist
15
from mmdet.core import coco_eval, results2json, wrap_fp16_model
16
from mmdet.datasets import build_dataloader, build_dataset
17
from mmdet.models import build_detector
Kai Chen's avatar
Kai Chen committed
18
19


20
def single_gpu_test(model, data_loader, show=False):
Kai Chen's avatar
Kai Chen committed
21
22
    model.eval()
    results = []
23
24
    dataset = data_loader.dataset
    prog_bar = mmcv.ProgressBar(len(dataset))
Kai Chen's avatar
Kai Chen committed
25
26
    for i, data in enumerate(data_loader):
        with torch.no_grad():
Kai Chen's avatar
Kai Chen committed
27
            result = model(return_loss=False, rescale=not show, **data)
Kai Chen's avatar
Kai Chen committed
28
29
30
        results.append(result)

        if show:
31
            model.module.show_result(data, result)
Kai Chen's avatar
Kai Chen committed
32
33
34
35
36
37
38

        batch_size = data['img'][0].size(0)
        for _ in range(batch_size):
            prog_bar.update()
    return results


39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
    """Test model with multiple gpus.

    This method tests model with multiple gpus and collects the results
    under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
    it encodes results to gpu tensors and use gpu communication for results
    collection. On cpu mode it saves the results on different gpus to 'tmpdir'
    and collects them by the rank 0 worker.

    Args:
        model (nn.Module): Model to be tested.
        data_loader (nn.Dataloader): Pytorch data loader.
        tmpdir (str): Path of directory to save the temporary results from
            different gpus under cpu mode.
        gpu_collect (bool): Option to use either gpu or cpu to collect results.

    Returns:
        list: The prediction results.
    """
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    model.eval()
    results = []
    dataset = data_loader.dataset
    rank, world_size = get_dist_info()
    if rank == 0:
        prog_bar = mmcv.ProgressBar(len(dataset))
    for i, data in enumerate(data_loader):
        with torch.no_grad():
            result = model(return_loss=False, rescale=True, **data)
        results.append(result)

        if rank == 0:
            batch_size = data['img'][0].size(0)
            for _ in range(batch_size * world_size):
                prog_bar.update()

    # collect results from all ranks
75
76
77
78
    if gpu_collect:
        results = collect_results_gpu(results, len(dataset))
    else:
        results = collect_results_cpu(results, len(dataset), tmpdir)
79
80
81
    return results


82
def collect_results_cpu(result_part, size, tmpdir=None):
83
84
85
86
87
    rank, world_size = get_dist_info()
    # create a tmp dir if it is not specified
    if tmpdir is None:
        MAX_LEN = 512
        # 32 is whitespace
88
89
90
91
        dir_tensor = torch.full((MAX_LEN, ),
                                32,
                                dtype=torch.uint8,
                                device='cuda')
92
93
        if rank == 0:
            tmpdir = tempfile.mkdtemp()
94
95
            tmpdir = torch.tensor(
                bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
            dir_tensor[:len(tmpdir)] = tmpdir
        dist.broadcast(dir_tensor, 0)
        tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
    else:
        mmcv.mkdir_or_exist(tmpdir)
    # dump the part result to the dir
    mmcv.dump(result_part, osp.join(tmpdir, 'part_{}.pkl'.format(rank)))
    dist.barrier()
    # collect all parts
    if rank != 0:
        return None
    else:
        # load results of all parts from tmp dir
        part_list = []
        for i in range(world_size):
            part_file = osp.join(tmpdir, 'part_{}.pkl'.format(i))
            part_list.append(mmcv.load(part_file))
        # sort the results
        ordered_results = []
        for res in zip(*part_list):
            ordered_results.extend(list(res))
        # the dataloader may pad some samples
        ordered_results = ordered_results[:size]
        # remove tmp dir
        shutil.rmtree(tmpdir)
        return ordered_results
pangjm's avatar
pangjm committed
122
123


124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def collect_results_gpu(result_part, size):
    rank, world_size = get_dist_info()
    # dump result part to tensor with pickle
    part_tensor = torch.tensor(
        bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
    # gather all result part tensor shape
    shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
    shape_list = [shape_tensor.clone() for _ in range(world_size)]
    dist.all_gather(shape_list, shape_tensor)
    # padding result part tensor to max length
    shape_max = torch.tensor(shape_list).max()
    part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
    part_send[:shape_tensor[0]] = part_tensor
    part_recv_list = [
        part_tensor.new_zeros(shape_max) for _ in range(world_size)
    ]
    # gather all result part
    dist.all_gather(part_recv_list, part_send)

    if rank == 0:
        part_list = []
        for recv, shape in zip(part_recv_list, shape_list):
            part_list.append(
                pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
        # sort the results
        ordered_results = []
        for res in zip(*part_list):
            ordered_results.extend(list(res))
        # the dataloader may pad some samples
        ordered_results = ordered_results[:size]
        return ordered_results


pangjm's avatar
pangjm committed
157
158
159
160
161
def parse_args():
    parser = argparse.ArgumentParser(description='MMDet test detector')
    parser.add_argument('config', help='test config file path')
    parser.add_argument('checkpoint', help='checkpoint file')
    parser.add_argument('--out', help='output result file')
162
163
164
165
    parser.add_argument(
        '--json_out',
        help='output result file name without extension',
        type=str)
pangjm's avatar
pangjm committed
166
    parser.add_argument(
Kai Chen's avatar
Kai Chen committed
167
168
169
        '--eval',
        type=str,
        nargs='+',
170
        choices=['proposal', 'proposal_fast', 'bbox', 'segm', 'keypoints'],
Kai Chen's avatar
Kai Chen committed
171
172
        help='eval types')
    parser.add_argument('--show', action='store_true', help='show results')
173
174
175
176
    parser.add_argument(
        '--gpu_collect',
        action='store_true',
        help='whether to use gpu to collect results')
177
    parser.add_argument('--tmpdir', help='tmp dir for writing some results')
178
179
180
181
182
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
183
    parser.add_argument('--local_rank', type=int, default=0)
pangjm's avatar
pangjm committed
184
    args = parser.parse_args()
lizz's avatar
lizz committed
185
186
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)
pangjm's avatar
pangjm committed
187
188
189
190
    return args


def main():
191
192
    args = parse_args()

193
    assert args.out or args.show or args.json_out, \
194
        ('Please specify at least one operation (save or show the results) '
195
         'with the argument "--out" or "--show" or "--json_out"')
196

Kai Chen's avatar
Kai Chen committed
197
198
199
    if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
        raise ValueError('The output file must be a pkl file.')

200
201
202
    if args.json_out is not None and args.json_out.endswith('.json'):
        args.json_out = args.json_out[:-5]

Kai Chen's avatar
Kai Chen committed
203
    cfg = mmcv.Config.fromfile(args.config)
yhcao6's avatar
yhcao6 committed
204
205
206
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
Kai Chen's avatar
Kai Chen committed
207
208
209
    cfg.model.pretrained = None
    cfg.data.test.test_mode = True

210
211
212
213
214
215
216
217
218
    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    # build the dataloader
    # TODO: support multiple images per gpu (only minor changes are needed)
219
    dataset = build_dataset(cfg.data.test)
220
221
222
223
224
225
    data_loader = build_dataloader(
        dataset,
        imgs_per_gpu=1,
        workers_per_gpu=cfg.data.workers_per_gpu,
        dist=distributed,
        shuffle=False)
226
227
228

    # build the model and load checkpoint
    model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
Cao Yuhang's avatar
Cao Yuhang committed
229
230
231
    fp16_cfg = cfg.get('fp16', None)
    if fp16_cfg is not None:
        wrap_fp16_model(model)
232
233
234
235
236
237
238
    checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
    # old versions did not save class info in checkpoints, this walkaround is
    # for backward compatibility
    if 'CLASSES' in checkpoint['meta']:
        model.CLASSES = checkpoint['meta']['CLASSES']
    else:
        model.CLASSES = dataset.CLASSES
239
240

    if not distributed:
Kai Chen's avatar
Kai Chen committed
241
        model = MMDataParallel(model, device_ids=[0])
242
        outputs = single_gpu_test(model, data_loader, args.show)
pangjm's avatar
pangjm committed
243
    else:
244
        model = MMDistributedDataParallel(model.cuda())
245
246
        outputs = multi_gpu_test(model, data_loader, args.tmpdir,
                                 args.gpu_collect)
247
248
249
250

    rank, _ = get_dist_info()
    if args.out and rank == 0:
        print('\nwriting results to {}'.format(args.out))
Kai Chen's avatar
Kai Chen committed
251
        mmcv.dump(outputs, args.out)
Kai Chen's avatar
Kai Chen committed
252
253
254
255
256
        eval_types = args.eval
        if eval_types:
            print('Starting evaluate {}'.format(' and '.join(eval_types)))
            if eval_types == ['proposal_fast']:
                result_file = args.out
Kai Chen's avatar
Kai Chen committed
257
                coco_eval(result_file, eval_types, dataset.coco)
Kai Chen's avatar
Kai Chen committed
258
            else:
Kai Chen's avatar
Kai Chen committed
259
                if not isinstance(outputs[0], dict):
260
261
                    result_files = results2json(dataset, outputs, args.out)
                    coco_eval(result_files, eval_types, dataset.coco)
Kai Chen's avatar
Kai Chen committed
262
263
264
265
                else:
                    for name in outputs[0]:
                        print('\nEvaluating {}'.format(name))
                        outputs_ = [out[name] for out in outputs]
266
267
268
269
                        result_file = args.out + '.{}'.format(name)
                        result_files = results2json(dataset, outputs_,
                                                    result_file)
                        coco_eval(result_files, eval_types, dataset.coco)
pangjm's avatar
pangjm committed
270

271
272
273
274
275
276
277
278
279
280
    # Save predictions in the COCO json format
    if args.json_out and rank == 0:
        if not isinstance(outputs[0], dict):
            results2json(dataset, outputs, args.json_out)
        else:
            for name in outputs[0]:
                outputs_ = [out[name] for out in outputs]
                result_file = args.json_out + '.{}'.format(name)
                results2json(dataset, outputs_, result_file)

pangjm's avatar
pangjm committed
281
282
283

if __name__ == '__main__':
    main()