edge_fusion_module.py 3.04 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
from typing import List

ChaimZhu's avatar
ChaimZhu committed
4
from mmcv.cnn import ConvModule
5
from mmengine.model import BaseModule
6
from torch import Tensor
ChaimZhu's avatar
ChaimZhu committed
7
8
9
from torch import nn as nn
from torch.nn import functional as F

10
11
from mmdet3d.utils import ConfigType

ChaimZhu's avatar
ChaimZhu committed
12
13
14
15
16
17
18
19

class EdgeFusionModule(BaseModule):
    """Edge Fusion Module for feature map.

    Args:
        out_channels (int): The number of output channels.
        feat_channels (int): The number of channels in feature map
            during edge feature fusion.
20
21
22
23
24
        kernel_size (int): Kernel size of convolution. Defaults to 3.
        act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
            Defaults to dict(type='ReLU').
        norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
            layer. Defaults to dict(type='BN1d').
ChaimZhu's avatar
ChaimZhu committed
25
26
    """

27
28
29
30
31
32
33
34
35
    def __init__(
        self,
        out_channels: int,
        feat_channels: int,
        kernel_size: int = 3,
        act_cfg: ConfigType = dict(type='ReLU'),
        norm_cfg: ConfigType = dict(type='BN1d')
    ) -> None:
        super(EdgeFusionModule, self).__init__()
ChaimZhu's avatar
ChaimZhu committed
36
37
38
39
40
41
42
43
44
45
46
47
        self.edge_convs = nn.Sequential(
            ConvModule(
                feat_channels,
                feat_channels,
                kernel_size=kernel_size,
                padding=kernel_size // 2,
                conv_cfg=dict(type='Conv1d'),
                norm_cfg=norm_cfg,
                act_cfg=act_cfg),
            nn.Conv1d(feat_channels, out_channels, kernel_size=1))
        self.feat_channels = feat_channels

48
49
50
    def forward(self, features: Tensor, fused_features: Tensor,
                edge_indices: Tensor, edge_lens: List[int], output_h: int,
                output_w: int) -> Tensor:
ChaimZhu's avatar
ChaimZhu committed
51
52
53
        """Forward pass.

        Args:
54
55
56
57
58
            features (Tensor): Different representative features for fusion.
            fused_features (Tensor): Different representative features
                to be fused.
            edge_indices (Tensor): Batch image edge indices.
            edge_lens (List[int]): List of edge length of each image.
ChaimZhu's avatar
ChaimZhu committed
59
60
61
62
            output_h (int): Height of output feature map.
            output_w (int): Width of output feature map.

        Returns:
63
            Tensor: Fused feature maps.
ChaimZhu's avatar
ChaimZhu committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
        """
        batch_size = features.shape[0]
        # normalize
        grid_edge_indices = edge_indices.view(batch_size, -1, 1, 2).float()
        grid_edge_indices[..., 0] = \
            grid_edge_indices[..., 0] / (output_w - 1) * 2 - 1
        grid_edge_indices[..., 1] = \
            grid_edge_indices[..., 1] / (output_h - 1) * 2 - 1

        # apply edge fusion
        edge_features = F.grid_sample(
            features, grid_edge_indices, align_corners=True).squeeze(-1)
        edge_output = self.edge_convs(edge_features)

        for k in range(batch_size):
            edge_indice_k = edge_indices[k, :edge_lens[k]]
            fused_features[k, :, edge_indice_k[:, 1],
                           edge_indice_k[:, 0]] += edge_output[
                               k, :, :edge_lens[k]]

        return fused_features