transformer.py 8.16 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
# modify from https://github.com/TuSimple/centerformer/blob/master/det3d/models/utils/transformer.py # noqa

import torch
from einops import rearrange
from mmcv.cnn.bricks.activation import GELU
from torch import einsum, nn

from .multi_scale_deform_attn import MSDeformAttn


class PreNorm(nn.Module):

    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, y=None, **kwargs):
        if y is not None:
            return self.fn(self.norm(x), self.norm(y), **kwargs)
        else:
            return self.fn(self.norm(x), **kwargs)


class FFN(nn.Module):

    def __init__(self, dim, hidden_dim, dropout=0.0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)


class SelfAttention(nn.Module):

    def __init__(self,
                 dim,
                 n_heads=8,
                 dim_single_head=64,
                 dropout=0.0,
                 out_attention=False):
        super().__init__()
        inner_dim = dim_single_head * n_heads
        project_out = not (n_heads == 1 and dim_single_head == dim)

        self.n_heads = n_heads
        self.scale = dim_single_head**-0.5
        self.out_attention = out_attention

        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = (
            nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
            if project_out else nn.Identity())

    def forward(self, x):
        _, _, _, h = *x.shape, self.n_heads
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(dots)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')

        if self.out_attention:
            return self.to_out(out), attn
        else:
            return self.to_out(out)


class DeformableCrossAttention(nn.Module):

    def __init__(
        self,
        dim_model=256,
        dim_single_head=64,
        dropout=0.3,
        n_levels=3,
        n_heads=6,
        n_points=9,
        out_sample_loc=False,
    ):
        super().__init__()

        # cross attention
        self.cross_attn = MSDeformAttn(
            dim_model,
            dim_single_head,
            n_levels,
            n_heads,
            n_points,
            out_sample_loc=out_sample_loc)
        self.dropout = nn.Dropout(dropout)
        self.out_sample_loc = out_sample_loc

    @staticmethod
    def with_pos_embed(tensor, pos):
        return tensor if pos is None else tensor + pos

    def forward(
        self,
        tgt,
        src,
        query_pos=None,
        reference_points=None,
        src_spatial_shapes=None,
        level_start_index=None,
        src_padding_mask=None,
    ):
        # cross attention
        tgt2, sampling_locations = self.cross_attn(
            self.with_pos_embed(tgt, query_pos),
            reference_points,
            src,
            src_spatial_shapes,
            level_start_index,
            src_padding_mask,
        )
        tgt = self.dropout(tgt2)

        if self.out_sample_loc:
            return tgt, sampling_locations
        else:
            return tgt


class DeformableTransformerDecoder(nn.Module):
    """Deformable transformer decoder.

    Note that the ``DeformableDetrTransformerDecoder`` in MMDet has different
    interfaces in multi-head-attention which is customized here. For example,
    'embed_dims' is not a position argument in our customized multi-head-self-
    attention, but is required in MMDet. Thus, we can not directly use the
    ``DeformableDetrTransformerDecoder`` in MMDET.
    """

    def __init__(
        self,
        dim,
        n_levels=3,
        depth=2,
        n_heads=4,
        dim_single_head=32,
        dim_ffn=256,
        dropout=0.0,
        out_attention=False,
        n_points=9,
    ):
        super().__init__()
        self.out_attention = out_attention
        self.layers = nn.ModuleList([])
        self.depth = depth
        self.n_levels = n_levels
        self.n_points = n_points

        for _ in range(depth):
            self.layers.append(
                nn.ModuleList([
                    PreNorm(
                        dim,
                        SelfAttention(
                            dim,
                            n_heads=n_heads,
                            dim_single_head=dim_single_head,
                            dropout=dropout,
                            out_attention=self.out_attention,
                        ),
                    ),
                    PreNorm(
                        dim,
                        DeformableCrossAttention(
                            dim,
                            dim_single_head,
                            n_levels=n_levels,
                            n_heads=n_heads,
                            dropout=dropout,
                            n_points=n_points,
                            out_sample_loc=self.out_attention,
                        ),
                    ),
                    PreNorm(dim, FFN(dim, dim_ffn, dropout=dropout)),
                ]))

    def forward(self, x, pos_embedding, src, src_spatial_shapes,
                level_start_index, center_pos):
        if self.out_attention:
            out_cross_attention_list = []
        if pos_embedding is not None:
            center_pos_embedding = pos_embedding(center_pos)
        reference_points = center_pos[:, :,
                                      None, :].repeat(1, 1, self.n_levels, 1)
        for i, (self_attn, cross_attn, ff) in enumerate(self.layers):
            if self.out_attention:
                if center_pos_embedding is not None:
                    x_att, self_att = self_attn(x + center_pos_embedding)
                    x = x_att + x
                    x_att, cross_att = cross_attn(
                        x,
                        src,
                        query_pos=center_pos_embedding,
                        reference_points=reference_points,
                        src_spatial_shapes=src_spatial_shapes,
                        level_start_index=level_start_index,
                    )
                else:
                    x_att, self_att = self_attn(x)
                    x = x_att + x
                    x_att, cross_att = cross_attn(
                        x,
                        src,
                        query_pos=None,
                        reference_points=reference_points,
                        src_spatial_shapes=src_spatial_shapes,
                        level_start_index=level_start_index,
                    )
                out_cross_attention_list.append(cross_att)
            else:
                if center_pos_embedding is not None:
                    x_att = self_attn(x + center_pos_embedding)
                    x = x_att + x
                    x_att = cross_attn(
                        x,
                        src,
                        query_pos=center_pos_embedding,
                        reference_points=reference_points,
                        src_spatial_shapes=src_spatial_shapes,
                        level_start_index=level_start_index,
                    )
                else:
                    x_att = self_attn(x)
                    x = x_att + x
                    x_att = cross_attn(
                        x,
                        src,
                        query_pos=None,
                        reference_points=reference_points,
                        src_spatial_shapes=src_spatial_shapes,
                        level_start_index=level_start_index,
                    )

            x = x_att + x
            x = ff(x) + x

        out_dict = {'ct_feat': x}
        if self.out_attention:
            out_dict.update({
                'out_attention':
                torch.stack(out_cross_attention_list, dim=2)
            })
        return out_dict