transformer.py 5.94 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
from typing import Optional

4
5
from mmcv.cnn.bricks.transformer import MultiheadAttention
from mmengine.registry import MODELS
6
from torch import Tensor
hjin2902's avatar
hjin2902 committed
7
8
from torch import nn as nn

9
10
from mmdet3d.utils import ConfigType, OptMultiConfig

hjin2902's avatar
hjin2902 committed
11

12
@MODELS.register_module()
hjin2902's avatar
hjin2902 committed
13
class GroupFree3DMHA(MultiheadAttention):
14
    """A wrapper for torch.nn.MultiheadAttention for GroupFree3D.
hjin2902's avatar
hjin2902 committed
15
16
17
18
19
20
21
22

    This module implements MultiheadAttention with identity connection,
    and positional encoding used in DETR is also passed as input.

    Args:
        embed_dims (int): The embedding dimension.
        num_heads (int): Parallel attention heads. Same as
            `nn.MultiheadAttention`.
23
        attn_drop (float): A Dropout layer on attn_output_weights.
24
            Defaults to 0.0.
25
26
27
28
29
30
31
32
        proj_drop (float): A Dropout layer. Defaults to 0.0.
        dropout_layer (ConfigType): The dropout_layer used when adding
            the shortcut. Defaults to dict(type='DropOut', drop_prob=0.).
        init_cfg (:obj:`ConfigDict` or dict or List[:obj:`Contigdict` or dict],
            optional): Initialization config dict. Defaults to None.
        batch_first (bool): Key, Query and Value are shape of
            (batch, n, embed_dim) or (n, batch, embed_dim).
            Defaults to False.
hjin2902's avatar
hjin2902 committed
33
34
35
    """

    def __init__(self,
36
37
38
39
40
41
42
43
44
45
46
47
                 embed_dims: int,
                 num_heads: int,
                 attn_drop: float = 0.,
                 proj_drop: float = 0.,
                 dropout_layer: ConfigType = dict(
                     type='DropOut', drop_prob=0.),
                 init_cfg: OptMultiConfig = None,
                 batch_first: bool = False,
                 **kwargs) -> None:
        super(GroupFree3DMHA,
              self).__init__(embed_dims, num_heads, attn_drop, proj_drop,
                             dropout_layer, init_cfg, batch_first, **kwargs)
hjin2902's avatar
hjin2902 committed
48
49

    def forward(self,
50
51
52
53
54
55
56
57
58
                query: Tensor,
                key: Tensor,
                value: Tensor,
                identity: Tensor,
                query_pos: Optional[Tensor] = None,
                key_pos: Optional[Tensor] = None,
                attn_mask: Optional[Tensor] = None,
                key_padding_mask: Optional[Tensor] = None,
                **kwargs) -> Tensor:
hjin2902's avatar
hjin2902 committed
59
60
61
62
63
64
65
66
67
68
        """Forward function for `GroupFree3DMHA`.

        **kwargs allow passing a more general data flow when combining
        with other operations in `transformerlayer`.

        Args:
            query (Tensor): The input query with shape [num_queries, bs,
                embed_dims]. Same in `nn.MultiheadAttention.forward`.
            key (Tensor): The key tensor with shape [num_keys, bs,
                embed_dims]. Same in `nn.MultiheadAttention.forward`.
69
                If None, the ``query`` will be used.
hjin2902's avatar
hjin2902 committed
70
            value (Tensor): The value tensor with same shape as `key`.
71
                Same in `nn.MultiheadAttention.forward`.
hjin2902's avatar
hjin2902 committed
72
73
                If None, the `key` will be used.
            identity (Tensor): This tensor, with the same shape as x,
74
75
76
77
78
79
80
81
                will be used for the identity link. If None, `x` will be used.
            query_pos (Tensor, optional): The positional encoding for query,
                with the same shape as `x`. Defaults to None.
                If not None, it will be added to `x` before forward function.
            key_pos (Tensor, optional): The positional encoding for `key`,
                with the same shape as `key`. Defaults to None. If not None,
                it will be added to `key` before forward function. If None,
                and `query_pos` has the same shape as `key`, then `query_pos`
hjin2902's avatar
hjin2902 committed
82
                will be used for `key_pos`. Defaults to None.
83
84
            attn_mask (Tensor, optional): ByteTensor mask with shape
                [num_queries, num_keys].
hjin2902's avatar
hjin2902 committed
85
                Same in `nn.MultiheadAttention.forward`. Defaults to None.
86
87
88
            key_padding_mask (Tensor, optional): ByteTensor with shape
                [bs, num_keys]. Same in `nn.MultiheadAttention.forward`.
                Defaults to None.
hjin2902's avatar
hjin2902 committed
89
90

        Returns:
91
            Tensor: Forwarded results with shape [num_queries, bs, embed_dims].
hjin2902's avatar
hjin2902 committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        """

        if hasattr(self, 'operation_name'):
            if self.operation_name == 'self_attn':
                value = value + query_pos
            elif self.operation_name == 'cross_attn':
                value = value + key_pos
            else:
                raise NotImplementedError(
                    f'{self.__class__.name} '
                    f"can't be used as {self.operation_name}")
        else:
            value = value + query_pos

        return super(GroupFree3DMHA, self).forward(
            query=query,
            key=key,
            value=value,
            identity=identity,
            query_pos=query_pos,
            key_pos=key_pos,
            attn_mask=attn_mask,
            key_padding_mask=key_padding_mask,
            **kwargs)


118
@MODELS.register_module()
hjin2902's avatar
hjin2902 committed
119
120
121
122
class ConvBNPositionalEncoding(nn.Module):
    """Absolute position embedding with Conv learning.

    Args:
123
124
        input_channel (int): Input features dim.
        num_pos_feats (int): Output position features dim.
hjin2902's avatar
hjin2902 committed
125
126
127
            Defaults to 288 to be consistent with seed features dim.
    """

128
129
    def __init__(self, input_channel: int, num_pos_feats: int = 288) -> None:
        super(ConvBNPositionalEncoding, self).__init__()
hjin2902's avatar
hjin2902 committed
130
131
132
133
134
        self.position_embedding_head = nn.Sequential(
            nn.Conv1d(input_channel, num_pos_feats, kernel_size=1),
            nn.BatchNorm1d(num_pos_feats), nn.ReLU(inplace=True),
            nn.Conv1d(num_pos_feats, num_pos_feats, kernel_size=1))

135
    def forward(self, xyz: Tensor) -> Tensor:
hjin2902's avatar
hjin2902 committed
136
137
138
        """Forward pass.

        Args:
139
            xyz (Tensor): (B, N, 3) The coordinates to embed.
hjin2902's avatar
hjin2902 committed
140
141

        Returns:
142
            Tensor: (B, num_pos_feats, N) The embedded position features.
hjin2902's avatar
hjin2902 committed
143
144
145
146
        """
        xyz = xyz.permute(0, 2, 1)
        position_embedding = self.position_embedding_head(xyz)
        return position_embedding