test.py 3.19 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
from os import path as osp

liyinhao's avatar
liyinhao committed
4
5
import mmcv
import torch
6
from mmcv.image import tensor2imgs
7

8
9
from mmdet3d.models import (Base3DDetector, Base3DSegmentor,
                            SingleStageMono3DDetector)
liyinhao's avatar
liyinhao committed
10
11


12
13
14
15
16
def single_gpu_test(model,
                    data_loader,
                    show=False,
                    out_dir=None,
                    show_score_thr=0.3):
17
18
19
20
21
22
23
24
25
    """Test model with single gpu.

    This method tests model with single gpu and gives the 'show' option.
    By setting ``show=True``, it saves the visualization results under
    ``out_dir``.

    Args:
        model (nn.Module): Model to be tested.
        data_loader (nn.Dataloader): Pytorch data loader.
26
        show (bool, optional): Whether to save viualization results.
27
            Default: True.
28
        out_dir (str, optional): The path to save visualization results.
29
30
31
32
33
            Default: None.

    Returns:
        list[dict]: The prediction results.
    """
liyinhao's avatar
liyinhao committed
34
35
36
37
38
39
40
41
42
    model.eval()
    results = []
    dataset = data_loader.dataset
    prog_bar = mmcv.ProgressBar(len(dataset))
    for i, data in enumerate(data_loader):
        with torch.no_grad():
            result = model(return_loss=False, rescale=True, **data)

        if show:
43
            # Visualize the results of MMDetection3D model
44
            # 'show_results' is MMdetection3D visualization API
45
46
47
            models_3d = (Base3DDetector, Base3DSegmentor,
                         SingleStageMono3DDetector)
            if isinstance(model.module, models_3d):
MilkClouds's avatar
MilkClouds committed
48
49
50
51
52
53
                model.module.show_results(
                    data,
                    result,
                    out_dir=out_dir,
                    show=show,
                    score_thr=show_score_thr)
54
            # Visualize the results of MMDetection model
55
56
57
58
59
60
61
62
63
64
65
            # 'show_result' is MMdetection visualization API
            else:
                batch_size = len(result)
                if batch_size == 1 and isinstance(data['img'][0],
                                                  torch.Tensor):
                    img_tensor = data['img'][0]
                else:
                    img_tensor = data['img'][0].data[0]
                img_metas = data['img_metas'][0].data[0]
                imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
                assert len(imgs) == len(img_metas)
liyinhao's avatar
liyinhao committed
66

67
68
69
70
71
72
73
74
                for i, (img, img_meta) in enumerate(zip(imgs, img_metas)):
                    h, w, _ = img_meta['img_shape']
                    img_show = img[:h, :w, :]

                    ori_h, ori_w = img_meta['ori_shape'][:-1]
                    img_show = mmcv.imresize(img_show, (ori_w, ori_h))

                    if out_dir:
75
                        out_file = osp.join(out_dir, img_meta['ori_filename'])
76
77
78
79
80
81
82
83
84
                    else:
                        out_file = None

                    model.module.show_result(
                        img_show,
                        result[i],
                        show=show,
                        out_file=out_file,
                        score_thr=show_score_thr)
85
        results.extend(result)
liyinhao's avatar
liyinhao committed
86

87
        batch_size = len(result)
liyinhao's avatar
liyinhao committed
88
89
90
        for _ in range(batch_size):
            prog_bar.update()
    return results