mvx_faster_rcnn.py 1.99 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
zhangwenwei's avatar
zhangwenwei committed
2
import torch
3
from mmcv.runner import force_fp32
zhangwenwei's avatar
zhangwenwei committed
4
from torch.nn import functional as F
zhangwenwei's avatar
zhangwenwei committed
5

6
from ..builder import DETECTORS
zhangwenwei's avatar
zhangwenwei committed
7
8
9
from .mvx_two_stage import MVXTwoStageDetector


zhangwenwei's avatar
zhangwenwei committed
10
11
@DETECTORS.register_module()
class MVXFasterRCNN(MVXTwoStageDetector):
zhangwenwei's avatar
zhangwenwei committed
12
    """Multi-modality VoxelNet using Faster R-CNN."""
zhangwenwei's avatar
zhangwenwei committed
13
14
15
16
17

    def __init__(self, **kwargs):
        super(MVXFasterRCNN, self).__init__(**kwargs)


18
@DETECTORS.register_module()
zhangwenwei's avatar
zhangwenwei committed
19
class DynamicMVXFasterRCNN(MVXTwoStageDetector):
zhangwenwei's avatar
zhangwenwei committed
20
    """Multi-modality VoxelNet using Faster R-CNN and dynamic voxelization."""
zhangwenwei's avatar
zhangwenwei committed
21
22
23
24
25

    def __init__(self, **kwargs):
        super(DynamicMVXFasterRCNN, self).__init__(**kwargs)

    @torch.no_grad()
26
    @force_fp32()
zhangwenwei's avatar
zhangwenwei committed
27
    def voxelize(self, points):
zhangwenwei's avatar
zhangwenwei committed
28
29
30
31
32
33
34
35
        """Apply dynamic voxelization to points.

        Args:
            points (list[torch.Tensor]): Points of each sample.

        Returns:
            tuple[torch.Tensor]: Concatenated points and coordinates.
        """
zhangwenwei's avatar
zhangwenwei committed
36
37
38
39
40
41
42
43
44
45
46
47
48
        coors = []
        # dynamic voxelization only provide a coors mapping
        for res in points:
            res_coors = self.pts_voxel_layer(res)
            coors.append(res_coors)
        points = torch.cat(points, dim=0)
        coors_batch = []
        for i, coor in enumerate(coors):
            coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
            coors_batch.append(coor_pad)
        coors_batch = torch.cat(coors_batch, dim=0)
        return points, coors_batch

zhangwenwei's avatar
zhangwenwei committed
49
    def extract_pts_feat(self, points, img_feats, img_metas):
zhangwenwei's avatar
zhangwenwei committed
50
        """Extract point features."""
zhangwenwei's avatar
zhangwenwei committed
51
52
53
54
        if not self.with_pts_bbox:
            return None
        voxels, coors = self.voxelize(points)
        voxel_features, feature_coors = self.pts_voxel_encoder(
zhangwenwei's avatar
zhangwenwei committed
55
            voxels, coors, points, img_feats, img_metas)
zhangwenwei's avatar
zhangwenwei committed
56
57
58
59
60
61
        batch_size = coors[-1, 0] + 1
        x = self.pts_middle_encoder(voxel_features, feature_coors, batch_size)
        x = self.pts_backbone(x)
        if self.with_pts_neck:
            x = self.pts_neck(x)
        return x