base.py 5.51 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import mmcv
import numpy as np
import torch
from mmcv.parallel import DataContainer as DC
from mmcv.runner import auto_fp16
from os import path as osp

from mmdet3d.core import show_seg_result
from mmseg.models.segmentors import BaseSegmentor


class Base3DSegmentor(BaseSegmentor):
    """Base class for 3D segmentors.

    The main difference with `BaseSegmentor` is that we modify the keys in
    data_dict and use a 3D seg specific visualization function.
    """

20
21
22
23
24
25
    @property
    def with_regularization_loss(self):
        """bool: whether the segmentor has regularization loss for weight"""
        return hasattr(self, 'loss_regularization') and \
            self.loss_regularization is not None

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    def forward_test(self, points, img_metas, **kwargs):
        """Calls either simple_test or aug_test depending on the length of
        outer list of points. If len(points) == 1, call simple_test. Otherwise
        call aug_test to aggregate the test results by e.g. voting.

        Args:
            points (list[list[torch.Tensor]]): the outer list indicates
                test-time augmentations and inner torch.Tensor should have a
                shape BXNxC, which contains all points in the batch.
            img_metas (list[list[dict]]): the outer list indicates test-time
                augs (multiscale, flip, etc.) and the inner list indicates
                images in a batch.
        """
        for var, name in [(points, 'points'), (img_metas, 'img_metas')]:
            if not isinstance(var, list):
                raise TypeError(f'{name} must be a list, but got {type(var)}')

        num_augs = len(points)
        if num_augs != len(img_metas):
            raise ValueError(f'num of augmentations ({len(points)}) != '
                             f'num of image meta ({len(img_metas)})')

        if num_augs == 1:
            return self.simple_test(points[0], img_metas[0], **kwargs)
        else:
            return self.aug_test(points, img_metas, **kwargs)

    @auto_fp16(apply_to=('points'))
    def forward(self, return_loss=True, **kwargs):
        """Calls either forward_train or forward_test depending on whether
        return_loss=True.

        Note this setting will change the expected inputs. When
        `return_loss=True`, point and img_metas are single-nested (i.e.
        torch.Tensor and list[dict]), and when `resturn_loss=False`, point and
        img_metas should be double nested (i.e.  list[torch.Tensor],
        list[list[dict]]), with the outer list indicating test time
        augmentations.
        """
        if return_loss:
            return self.forward_train(**kwargs)
        else:
            return self.forward_test(**kwargs)

    def show_results(self,
                     data,
                     result,
                     palette=None,
                     out_dir=None,
MilkClouds's avatar
MilkClouds committed
75
76
77
                     ignore_index=None,
                     show=False,
                     score_thr=None):
78
79
80
81
82
        """Results visualization.

        Args:
            data (list[dict]): Input points and the information of the sample.
            result (list[dict]): Prediction results.
83
            palette (list[list[int]]] | np.ndarray): The palette of
84
85
86
87
88
89
                segmentation map. If None is given, random palette will be
                generated. Default: None
            out_dir (str): Output directory of visualization result.
            ignore_index (int, optional): The label index to be ignored, e.g.
                unannotated points. If None is given, set to len(self.CLASSES).
                Defaults to None.
MilkClouds's avatar
MilkClouds committed
90
91
92
93
94
95
96
            show (bool, optional): Determines whether you are
                going to show result by open3d.
                Defaults to False.
            TODO: implement score_thr of Base3DSegmentor.
            score_thr (float, optional): Score threshold of bounding boxes.
                Default to None.
                Not implemented yet, but it is here for unification.
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
        """
        assert out_dir is not None, 'Expect out_dir, got none.'
        if palette is None:
            if self.PALETTE is None:
                palette = np.random.randint(
                    0, 255, size=(len(self.CLASSES), 3))
            else:
                palette = self.PALETTE
        palette = np.array(palette)
        for batch_id in range(len(result)):
            if isinstance(data['points'][0], DC):
                points = data['points'][0]._data[0][batch_id].numpy()
            elif mmcv.is_list_of(data['points'][0], torch.Tensor):
                points = data['points'][0][batch_id]
            else:
                ValueError(f"Unsupported data type {type(data['points'][0])} "
                           f'for visualization!')
            if isinstance(data['img_metas'][0], DC):
                pts_filename = data['img_metas'][0]._data[0][batch_id][
                    'pts_filename']
            elif mmcv.is_list_of(data['img_metas'][0], dict):
                pts_filename = data['img_metas'][0][batch_id]['pts_filename']
            else:
                ValueError(
                    f"Unsupported data type {type(data['img_metas'][0])} "
                    f'for visualization!')
            file_name = osp.split(pts_filename)[-1].split('.')[0]

            pred_sem_mask = result[batch_id]['semantic_mask'].cpu().numpy()

127
128
129
130
131
132
133
134
            show_seg_result(
                points,
                None,
                pred_sem_mask,
                out_dir,
                file_name,
                palette,
                ignore_index,
MilkClouds's avatar
MilkClouds committed
135
                show=show)