transformer.py 5.49 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
from mmcv.cnn.bricks.transformer import MultiheadAttention
from mmengine.registry import MODELS
hjin2902's avatar
hjin2902 committed
4
5
6
from torch import nn as nn


7
@MODELS.register_module()
hjin2902's avatar
hjin2902 committed
8
9
10
11
12
13
14
15
16
17
class GroupFree3DMHA(MultiheadAttention):
    """A warpper for torch.nn.MultiheadAttention for GroupFree3D.

    This module implements MultiheadAttention with identity connection,
    and positional encoding used in DETR is also passed as input.

    Args:
        embed_dims (int): The embedding dimension.
        num_heads (int): Parallel attention heads. Same as
            `nn.MultiheadAttention`.
18
19
20
21
        attn_drop (float, optional): A Dropout layer on attn_output_weights.
            Defaults to 0.0.
        proj_drop (float, optional): A Dropout layer. Defaults to 0.0.
        dropout_layer (obj:`ConfigDict`, optional): The dropout_layer used
hjin2902's avatar
hjin2902 committed
22
            when adding the shortcut.
23
24
25
        init_cfg (obj:`mmcv.ConfigDict`, optional): The Config for
            initialization. Default: None.
        batch_first (bool, optional): Key, Query and Value are shape of
hjin2902's avatar
hjin2902 committed
26
            (batch, n, embed_dim)
27
            or (n, batch, embed_dim). Defaults to False.
hjin2902's avatar
hjin2902 committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    """

    def __init__(self,
                 embed_dims,
                 num_heads,
                 attn_drop=0.,
                 proj_drop=0.,
                 dropout_layer=dict(type='DropOut', drop_prob=0.),
                 init_cfg=None,
                 batch_first=False,
                 **kwargs):
        super().__init__(embed_dims, num_heads, attn_drop, proj_drop,
                         dropout_layer, init_cfg, batch_first, **kwargs)

    def forward(self,
                query,
                key,
                value,
                identity,
                query_pos=None,
                key_pos=None,
                attn_mask=None,
                key_padding_mask=None,
                **kwargs):
        """Forward function for `GroupFree3DMHA`.

        **kwargs allow passing a more general data flow when combining
        with other operations in `transformerlayer`.

        Args:
            query (Tensor): The input query with shape [num_queries, bs,
                embed_dims]. Same in `nn.MultiheadAttention.forward`.
            key (Tensor): The key tensor with shape [num_keys, bs,
                embed_dims]. Same in `nn.MultiheadAttention.forward`.
62
                If None, the ``query`` will be used.
hjin2902's avatar
hjin2902 committed
63
            value (Tensor): The value tensor with same shape as `key`.
64
                Same in `nn.MultiheadAttention.forward`.
hjin2902's avatar
hjin2902 committed
65
66
                If None, the `key` will be used.
            identity (Tensor): This tensor, with the same shape as x,
67
68
69
70
71
72
73
74
                will be used for the identity link. If None, `x` will be used.
            query_pos (Tensor, optional): The positional encoding for query,
                with the same shape as `x`. Defaults to None.
                If not None, it will be added to `x` before forward function.
            key_pos (Tensor, optional): The positional encoding for `key`,
                with the same shape as `key`. Defaults to None. If not None,
                it will be added to `key` before forward function. If None,
                and `query_pos` has the same shape as `key`, then `query_pos`
hjin2902's avatar
hjin2902 committed
75
                will be used for `key_pos`. Defaults to None.
76
77
            attn_mask (Tensor, optional): ByteTensor mask with shape
                [num_queries, num_keys].
hjin2902's avatar
hjin2902 committed
78
                Same in `nn.MultiheadAttention.forward`. Defaults to None.
79
80
81
            key_padding_mask (Tensor, optional): ByteTensor with shape
                [bs, num_keys]. Same in `nn.MultiheadAttention.forward`.
                Defaults to None.
hjin2902's avatar
hjin2902 committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110

        Returns:
            Tensor: forwarded results with shape [num_queries, bs, embed_dims].
        """

        if hasattr(self, 'operation_name'):
            if self.operation_name == 'self_attn':
                value = value + query_pos
            elif self.operation_name == 'cross_attn':
                value = value + key_pos
            else:
                raise NotImplementedError(
                    f'{self.__class__.name} '
                    f"can't be used as {self.operation_name}")
        else:
            value = value + query_pos

        return super(GroupFree3DMHA, self).forward(
            query=query,
            key=key,
            value=value,
            identity=identity,
            query_pos=query_pos,
            key_pos=key_pos,
            attn_mask=attn_mask,
            key_padding_mask=key_padding_mask,
            **kwargs)


111
@MODELS.register_module()
hjin2902's avatar
hjin2902 committed
112
113
114
115
116
class ConvBNPositionalEncoding(nn.Module):
    """Absolute position embedding with Conv learning.

    Args:
        input_channel (int): input features dim.
117
        num_pos_feats (int, optional): output position features dim.
hjin2902's avatar
hjin2902 committed
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
            Defaults to 288 to be consistent with seed features dim.
    """

    def __init__(self, input_channel, num_pos_feats=288):
        super().__init__()
        self.position_embedding_head = nn.Sequential(
            nn.Conv1d(input_channel, num_pos_feats, kernel_size=1),
            nn.BatchNorm1d(num_pos_feats), nn.ReLU(inplace=True),
            nn.Conv1d(num_pos_feats, num_pos_feats, kernel_size=1))

    def forward(self, xyz):
        """Forward pass.

        Args:
            xyz (Tensor): (B, N, 3) the coordinates to embed.

        Returns:
135
            Tensor: (B, num_pos_feats, N) the embedded position features.
hjin2902's avatar
hjin2902 committed
136
137
138
139
        """
        xyz = xyz.permute(0, 2, 1)
        position_embedding = self.position_embedding_head(xyz)
        return position_embedding