test.py 7.79 KB
Newer Older
pangjm's avatar
pangjm committed
1
import argparse
lizz's avatar
lizz committed
2
import os
3
4
5
import os.path as osp
import shutil
import tempfile
pangjm's avatar
pangjm committed
6
7

import mmcv
8
9
10
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
11
from mmcv.runner import get_dist_info, load_checkpoint
Kai Chen's avatar
Kai Chen committed
12

13
from mmdet.apis import init_dist
14
from mmdet.core import coco_eval, results2json, wrap_fp16_model
15
from mmdet.datasets import build_dataloader, build_dataset
16
from mmdet.models import build_detector
Kai Chen's avatar
Kai Chen committed
17
18


19
def single_gpu_test(model, data_loader, show=False):
Kai Chen's avatar
Kai Chen committed
20
21
    model.eval()
    results = []
22
23
    dataset = data_loader.dataset
    prog_bar = mmcv.ProgressBar(len(dataset))
Kai Chen's avatar
Kai Chen committed
24
25
    for i, data in enumerate(data_loader):
        with torch.no_grad():
Kai Chen's avatar
Kai Chen committed
26
            result = model(return_loss=False, rescale=not show, **data)
Kai Chen's avatar
Kai Chen committed
27
28
29
        results.append(result)

        if show:
30
            model.module.show_result(data, result, dataset.img_norm_cfg)
Kai Chen's avatar
Kai Chen committed
31
32
33
34
35
36
37

        batch_size = data['img'][0].size(0)
        for _ in range(batch_size):
            prog_bar.update()
    return results


38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def multi_gpu_test(model, data_loader, tmpdir=None):
    model.eval()
    results = []
    dataset = data_loader.dataset
    rank, world_size = get_dist_info()
    if rank == 0:
        prog_bar = mmcv.ProgressBar(len(dataset))
    for i, data in enumerate(data_loader):
        with torch.no_grad():
            result = model(return_loss=False, rescale=True, **data)
        results.append(result)

        if rank == 0:
            batch_size = data['img'][0].size(0)
            for _ in range(batch_size * world_size):
                prog_bar.update()

    # collect results from all ranks
    results = collect_results(results, len(dataset), tmpdir)

    return results


def collect_results(result_part, size, tmpdir=None):
    rank, world_size = get_dist_info()
    # create a tmp dir if it is not specified
    if tmpdir is None:
        MAX_LEN = 512
        # 32 is whitespace
67
68
69
70
        dir_tensor = torch.full((MAX_LEN, ),
                                32,
                                dtype=torch.uint8,
                                device='cuda')
71
72
        if rank == 0:
            tmpdir = tempfile.mkdtemp()
73
74
            tmpdir = torch.tensor(
                bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
            dir_tensor[:len(tmpdir)] = tmpdir
        dist.broadcast(dir_tensor, 0)
        tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
    else:
        mmcv.mkdir_or_exist(tmpdir)
    # dump the part result to the dir
    mmcv.dump(result_part, osp.join(tmpdir, 'part_{}.pkl'.format(rank)))
    dist.barrier()
    # collect all parts
    if rank != 0:
        return None
    else:
        # load results of all parts from tmp dir
        part_list = []
        for i in range(world_size):
            part_file = osp.join(tmpdir, 'part_{}.pkl'.format(i))
            part_list.append(mmcv.load(part_file))
        # sort the results
        ordered_results = []
        for res in zip(*part_list):
            ordered_results.extend(list(res))
        # the dataloader may pad some samples
        ordered_results = ordered_results[:size]
        # remove tmp dir
        shutil.rmtree(tmpdir)
        return ordered_results
pangjm's avatar
pangjm committed
101
102
103
104
105
106
107


def parse_args():
    parser = argparse.ArgumentParser(description='MMDet test detector')
    parser.add_argument('config', help='test config file path')
    parser.add_argument('checkpoint', help='checkpoint file')
    parser.add_argument('--out', help='output result file')
108
109
110
111
    parser.add_argument(
        '--json_out',
        help='output result file name without extension',
        type=str)
pangjm's avatar
pangjm committed
112
    parser.add_argument(
Kai Chen's avatar
Kai Chen committed
113
114
115
        '--eval',
        type=str,
        nargs='+',
116
        choices=['proposal', 'proposal_fast', 'bbox', 'segm', 'keypoints'],
Kai Chen's avatar
Kai Chen committed
117
118
        help='eval types')
    parser.add_argument('--show', action='store_true', help='show results')
119
    parser.add_argument('--tmpdir', help='tmp dir for writing some results')
120
121
122
123
124
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
125
    parser.add_argument('--local_rank', type=int, default=0)
pangjm's avatar
pangjm committed
126
    args = parser.parse_args()
lizz's avatar
lizz committed
127
128
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)
pangjm's avatar
pangjm committed
129
130
131
132
    return args


def main():
133
134
    args = parse_args()

135
    assert args.out or args.show or args.json_out, \
136
        ('Please specify at least one operation (save or show the results) '
137
         'with the argument "--out" or "--show" or "--json_out"')
138

Kai Chen's avatar
Kai Chen committed
139
140
141
    if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
        raise ValueError('The output file must be a pkl file.')

142
143
144
    if args.json_out is not None and args.json_out.endswith('.json'):
        args.json_out = args.json_out[:-5]

Kai Chen's avatar
Kai Chen committed
145
    cfg = mmcv.Config.fromfile(args.config)
yhcao6's avatar
yhcao6 committed
146
147
148
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
Kai Chen's avatar
Kai Chen committed
149
150
151
    cfg.model.pretrained = None
    cfg.data.test.test_mode = True

152
153
154
155
156
157
158
159
160
    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    # build the dataloader
    # TODO: support multiple images per gpu (only minor changes are needed)
161
    dataset = build_dataset(cfg.data.test)
162
163
164
165
166
167
    data_loader = build_dataloader(
        dataset,
        imgs_per_gpu=1,
        workers_per_gpu=cfg.data.workers_per_gpu,
        dist=distributed,
        shuffle=False)
168
169
170

    # build the model and load checkpoint
    model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
Cao Yuhang's avatar
Cao Yuhang committed
171
172
173
    fp16_cfg = cfg.get('fp16', None)
    if fp16_cfg is not None:
        wrap_fp16_model(model)
174
175
176
177
178
179
180
    checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
    # old versions did not save class info in checkpoints, this walkaround is
    # for backward compatibility
    if 'CLASSES' in checkpoint['meta']:
        model.CLASSES = checkpoint['meta']['CLASSES']
    else:
        model.CLASSES = dataset.CLASSES
181
182

    if not distributed:
Kai Chen's avatar
Kai Chen committed
183
        model = MMDataParallel(model, device_ids=[0])
184
        outputs = single_gpu_test(model, data_loader, args.show)
pangjm's avatar
pangjm committed
185
    else:
186
187
188
189
190
191
        model = MMDistributedDataParallel(model.cuda())
        outputs = multi_gpu_test(model, data_loader, args.tmpdir)

    rank, _ = get_dist_info()
    if args.out and rank == 0:
        print('\nwriting results to {}'.format(args.out))
Kai Chen's avatar
Kai Chen committed
192
        mmcv.dump(outputs, args.out)
Kai Chen's avatar
Kai Chen committed
193
194
195
196
197
        eval_types = args.eval
        if eval_types:
            print('Starting evaluate {}'.format(' and '.join(eval_types)))
            if eval_types == ['proposal_fast']:
                result_file = args.out
Kai Chen's avatar
Kai Chen committed
198
                coco_eval(result_file, eval_types, dataset.coco)
Kai Chen's avatar
Kai Chen committed
199
            else:
Kai Chen's avatar
Kai Chen committed
200
                if not isinstance(outputs[0], dict):
201
202
                    result_files = results2json(dataset, outputs, args.out)
                    coco_eval(result_files, eval_types, dataset.coco)
Kai Chen's avatar
Kai Chen committed
203
204
205
206
                else:
                    for name in outputs[0]:
                        print('\nEvaluating {}'.format(name))
                        outputs_ = [out[name] for out in outputs]
207
208
209
210
                        result_file = args.out + '.{}'.format(name)
                        result_files = results2json(dataset, outputs_,
                                                    result_file)
                        coco_eval(result_files, eval_types, dataset.coco)
pangjm's avatar
pangjm committed
211

212
213
214
215
216
217
218
219
220
221
    # Save predictions in the COCO json format
    if args.json_out and rank == 0:
        if not isinstance(outputs[0], dict):
            results2json(dataset, outputs, args.json_out)
        else:
            for name in outputs[0]:
                outputs_ = [out[name] for out in outputs]
                result_file = args.json_out + '.{}'.format(name)
                results2json(dataset, outputs_, result_file)

pangjm's avatar
pangjm committed
222
223
224

if __name__ == '__main__':
    main()