transformer.py 13 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
DETR Transformer class.

Copy-paste from torch.nn.Transformer with modifications:
    * positional encodings are passed in MHattention
    * extra LN at the end of encoder is removed
    * decoder returns a stack of activations from all decoding layers
"""
import copy
from typing import Optional

import torch
import torch.nn.functional as F
from torch import nn, Tensor


class Transformer(nn.Module):
Yanghan Wang's avatar
Yanghan Wang committed
21
22
23
24
25
26
27
28
29
30
31
32
    def __init__(
        self,
        d_model=512,
        nhead=8,
        num_encoder_layers=6,
        num_decoder_layers=6,
        dim_feedforward=2048,
        dropout=0.1,
        activation="relu",
        normalize_before=False,
        return_intermediate_dec=False,
    ):
facebook-github-bot's avatar
facebook-github-bot committed
33
34
        super().__init__()

Yanghan Wang's avatar
Yanghan Wang committed
35
36
37
        encoder_layer = TransformerEncoderLayer(
            d_model, nhead, dim_feedforward, dropout, activation, normalize_before
        )
facebook-github-bot's avatar
facebook-github-bot committed
38
        encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
Yanghan Wang's avatar
Yanghan Wang committed
39
40
41
        self.encoder = TransformerEncoder(
            encoder_layer, num_encoder_layers, encoder_norm
        )
facebook-github-bot's avatar
facebook-github-bot committed
42

Yanghan Wang's avatar
Yanghan Wang committed
43
44
45
        decoder_layer = TransformerDecoderLayer(
            d_model, nhead, dim_feedforward, dropout, activation, normalize_before
        )
facebook-github-bot's avatar
facebook-github-bot committed
46
        decoder_norm = nn.LayerNorm(d_model)
Yanghan Wang's avatar
Yanghan Wang committed
47
48
49
50
51
52
        self.decoder = TransformerDecoder(
            decoder_layer,
            num_decoder_layers,
            decoder_norm,
            return_intermediate=return_intermediate_dec,
        )
facebook-github-bot's avatar
facebook-github-bot committed
53
54
55
56
57
58
59
60
61
62
63
64

        self._reset_parameters()

        self.d_model = d_model
        self.nhead = nhead

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src, mask, query_embed, pos_embed):
65
66
67
68
69
        # src shape (B, C, H, W)
        # mask shape (B, H, W)
        # query_embed shape (M, C)
        # pos_embed shape (B, C, H, W)

facebook-github-bot's avatar
facebook-github-bot committed
70
71
        # flatten NxCxHxW to HWxNxC
        bs, c, h, w = src.shape
72
73
74
75
        src = src.flatten(2).permute(2, 0, 1)  # shape (L, B, C)
        pos_embed = pos_embed.flatten(2).permute(2, 0, 1)  # shape (L, B, C)
        query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)  # shape (M, B, C)
        mask = mask.flatten(1)  # shape (B, HxW)
facebook-github-bot's avatar
facebook-github-bot committed
76
77

        tgt = torch.zeros_like(query_embed)
78
        # memory shape (L, B, C)
facebook-github-bot's avatar
facebook-github-bot committed
79
        memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
80
        # hs shape (NUM_LEVEL, S, B, C)
Yanghan Wang's avatar
Yanghan Wang committed
81
82
83
84
85
86
87
        hs = self.decoder(
            tgt,
            memory,
            memory_key_padding_mask=mask,
            pos=pos_embed,
            query_pos=query_embed,
        )
88
        # return shape (NUM_LEVEL, B, S, C) and (B, C, H, W)
facebook-github-bot's avatar
facebook-github-bot committed
89
90
91
92
93
94
95
96
97
98
        return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)


class TransformerEncoder(nn.Module):
    def __init__(self, encoder_layer, num_layers, norm=None):
        super().__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

Yanghan Wang's avatar
Yanghan Wang committed
99
100
101
102
103
104
105
    def forward(
        self,
        src,
        mask: Optional[Tensor] = None,
        src_key_padding_mask: Optional[Tensor] = None,
        pos: Optional[Tensor] = None,
    ):
facebook-github-bot's avatar
facebook-github-bot committed
106
        output = src
107
108
        # mask, shape (L, L)
        # src_key_padding_mask, shape (B, L)
facebook-github-bot's avatar
facebook-github-bot committed
109
        for layer in self.layers:
Yanghan Wang's avatar
Yanghan Wang committed
110
111
112
113
114
115
            output = layer(
                output,
                src_mask=mask,
                src_key_padding_mask=src_key_padding_mask,
                pos=pos,
            )
facebook-github-bot's avatar
facebook-github-bot committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130

        if self.norm is not None:
            output = self.norm(output)

        return output


class TransformerDecoder(nn.Module):
    def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
        super().__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm
        self.return_intermediate = return_intermediate

Yanghan Wang's avatar
Yanghan Wang committed
131
132
133
134
135
136
137
138
139
140
141
    def forward(
        self,
        tgt,
        memory,
        tgt_mask: Optional[Tensor] = None,
        memory_mask: Optional[Tensor] = None,
        tgt_key_padding_mask: Optional[Tensor] = None,
        memory_key_padding_mask: Optional[Tensor] = None,
        pos: Optional[Tensor] = None,
        query_pos: Optional[Tensor] = None,
    ):
facebook-github-bot's avatar
facebook-github-bot committed
142
143
144
        output = tgt

        intermediate = []
145
146
147
148
149
        # tgt shape (L, B, C)
        # tgt_mask shape (L, L)
        # tgt_key_padding_mask shape (B, L)
        # memory_mask shape (L, S)
        # memory_key_padding_mask shape (B, S)
facebook-github-bot's avatar
facebook-github-bot committed
150
        for layer in self.layers:
Yanghan Wang's avatar
Yanghan Wang committed
151
152
153
154
155
156
157
158
159
160
            output = layer(
                output,
                memory,
                tgt_mask=tgt_mask,
                memory_mask=memory_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=memory_key_padding_mask,
                pos=pos,
                query_pos=query_pos,
            )
facebook-github-bot's avatar
facebook-github-bot committed
161
162
163
164
165
166
167
168
169
170
171
            if self.return_intermediate:
                intermediate.append(self.norm(output))

        if self.norm is not None:
            output = self.norm(output)
            if self.return_intermediate:
                intermediate.pop()
                intermediate.append(output)

        if self.return_intermediate:
            return torch.stack(intermediate)
172
        # return shape (NUM_LAYER, L, B, C)
facebook-github-bot's avatar
facebook-github-bot committed
173
174
175
176
        return output.unsqueeze(0)


class TransformerEncoderLayer(nn.Module):
Yanghan Wang's avatar
Yanghan Wang committed
177
178
179
180
181
182
183
184
185
    def __init__(
        self,
        d_model,
        nhead,
        dim_feedforward=2048,
        dropout=0.1,
        activation="relu",
        normalize_before=False,
    ):
facebook-github-bot's avatar
facebook-github-bot committed
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)
        self.normalize_before = normalize_before

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

Yanghan Wang's avatar
Yanghan Wang committed
204
205
206
207
208
209
210
    def forward_post(
        self,
        src,
        src_mask: Optional[Tensor] = None,
        src_key_padding_mask: Optional[Tensor] = None,
        pos: Optional[Tensor] = None,
    ):
211
212
213
        q = k = self.with_pos_embed(src, pos)  # shape (L, B, D)
        # src mask, shape (L, L)
        # src_key_padding_mask: shape (B, L)
Yanghan Wang's avatar
Yanghan Wang committed
214
215
216
        src2 = self.self_attn(
            q, k, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
        )[0]
facebook-github-bot's avatar
facebook-github-bot committed
217
218
219
220
221
222
223
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

Yanghan Wang's avatar
Yanghan Wang committed
224
225
226
227
228
229
230
    def forward_pre(
        self,
        src,
        src_mask: Optional[Tensor] = None,
        src_key_padding_mask: Optional[Tensor] = None,
        pos: Optional[Tensor] = None,
    ):
facebook-github-bot's avatar
facebook-github-bot committed
231
232
        src2 = self.norm1(src)
        q = k = self.with_pos_embed(src2, pos)
Yanghan Wang's avatar
Yanghan Wang committed
233
234
235
        src2 = self.self_attn(
            q, k, src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
        )[0]
facebook-github-bot's avatar
facebook-github-bot committed
236
237
238
239
240
241
        src = src + self.dropout1(src2)
        src2 = self.norm2(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
        src = src + self.dropout2(src2)
        return src

Yanghan Wang's avatar
Yanghan Wang committed
242
243
244
245
246
247
248
    def forward(
        self,
        src,
        src_mask: Optional[Tensor] = None,
        src_key_padding_mask: Optional[Tensor] = None,
        pos: Optional[Tensor] = None,
    ):
facebook-github-bot's avatar
facebook-github-bot committed
249
250
251
252
253
254
        if self.normalize_before:
            return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
        return self.forward_post(src, src_mask, src_key_padding_mask, pos)


class TransformerDecoderLayer(nn.Module):
Yanghan Wang's avatar
Yanghan Wang committed
255
256
257
258
259
260
261
262
263
    def __init__(
        self,
        d_model,
        nhead,
        dim_feedforward=2048,
        dropout=0.1,
        activation="relu",
        normalize_before=False,
    ):
facebook-github-bot's avatar
facebook-github-bot committed
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)
        self.normalize_before = normalize_before

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

Yanghan Wang's avatar
Yanghan Wang committed
285
286
287
288
289
290
291
292
293
294
295
    def forward_post(
        self,
        tgt,
        memory,
        tgt_mask: Optional[Tensor] = None,
        memory_mask: Optional[Tensor] = None,
        tgt_key_padding_mask: Optional[Tensor] = None,
        memory_key_padding_mask: Optional[Tensor] = None,
        pos: Optional[Tensor] = None,
        query_pos: Optional[Tensor] = None,
    ):
296
297
298
        # tgt shape (L, B, C)
        # tgt_mask shape (L, L)
        # tgt_key_padding_mask shape (B, L)
facebook-github-bot's avatar
facebook-github-bot committed
299
        q = k = self.with_pos_embed(tgt, query_pos)
Yanghan Wang's avatar
Yanghan Wang committed
300
301
302
        tgt2 = self.self_attn(
            q, k, tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
        )[0]
facebook-github-bot's avatar
facebook-github-bot committed
303
304
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
305
306
307
        # memory_mask shape (L, S)
        # memory_key_padding_mask shape (B, S)
        # query_pos shape (L, B, C)
Yanghan Wang's avatar
Yanghan Wang committed
308
309
310
311
312
313
314
        tgt2 = self.multihead_attn(
            self.with_pos_embed(tgt, query_pos),
            self.with_pos_embed(memory, pos),
            memory,
            attn_mask=memory_mask,
            key_padding_mask=memory_key_padding_mask,
        )[0]
facebook-github-bot's avatar
facebook-github-bot committed
315
316
317
318
319
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
320
        # return tgt shape (L, B, C)
facebook-github-bot's avatar
facebook-github-bot committed
321
322
        return tgt

Yanghan Wang's avatar
Yanghan Wang committed
323
324
325
326
327
328
329
330
331
332
333
    def forward_pre(
        self,
        tgt,
        memory,
        tgt_mask: Optional[Tensor] = None,
        memory_mask: Optional[Tensor] = None,
        tgt_key_padding_mask: Optional[Tensor] = None,
        memory_key_padding_mask: Optional[Tensor] = None,
        pos: Optional[Tensor] = None,
        query_pos: Optional[Tensor] = None,
    ):
facebook-github-bot's avatar
facebook-github-bot committed
334
335
        tgt2 = self.norm1(tgt)
        q = k = self.with_pos_embed(tgt2, query_pos)
Yanghan Wang's avatar
Yanghan Wang committed
336
337
338
        tgt2 = self.self_attn(
            q, k, tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
        )[0]
facebook-github-bot's avatar
facebook-github-bot committed
339
340
        tgt = tgt + self.dropout1(tgt2)
        tgt2 = self.norm2(tgt)
Yanghan Wang's avatar
Yanghan Wang committed
341
342
343
344
345
346
347
        tgt2 = self.multihead_attn(
            self.with_pos_embed(tgt2, query_pos),
            self.with_pos_embed(memory, pos),
            memory,
            attn_mask=memory_mask,
            key_padding_mask=memory_key_padding_mask,
        )[0]
facebook-github-bot's avatar
facebook-github-bot committed
348
349
350
351
352
353
        tgt = tgt + self.dropout2(tgt2)
        tgt2 = self.norm3(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
        tgt = tgt + self.dropout3(tgt2)
        return tgt

Yanghan Wang's avatar
Yanghan Wang committed
354
355
356
357
358
359
360
361
362
363
364
    def forward(
        self,
        tgt,
        memory,
        tgt_mask: Optional[Tensor] = None,
        memory_mask: Optional[Tensor] = None,
        tgt_key_padding_mask: Optional[Tensor] = None,
        memory_key_padding_mask: Optional[Tensor] = None,
        pos: Optional[Tensor] = None,
        query_pos: Optional[Tensor] = None,
    ):
facebook-github-bot's avatar
facebook-github-bot committed
365
        if self.normalize_before:
Yanghan Wang's avatar
Yanghan Wang committed
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
            return self.forward_pre(
                tgt,
                memory,
                tgt_mask,
                memory_mask,
                tgt_key_padding_mask,
                memory_key_padding_mask,
                pos,
                query_pos,
            )
        return self.forward_post(
            tgt,
            memory,
            tgt_mask,
            memory_mask,
            tgt_key_padding_mask,
            memory_key_padding_mask,
            pos,
            query_pos,
        )
facebook-github-bot's avatar
facebook-github-bot committed
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412


def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])


def build_transformer(args):
    return Transformer(
        d_model=args.hidden_dim,
        dropout=args.dropout,
        nhead=args.nheads,
        dim_feedforward=args.dim_feedforward,
        num_encoder_layers=args.enc_layers,
        num_decoder_layers=args.dec_layers,
        normalize_before=args.pre_norm,
        return_intermediate_dec=True,
    )


def _get_activation_fn(activation):
    """Return an activation function given a string"""
    if activation == "relu":
        return F.relu
    if activation == "gelu":
        return F.gelu
    if activation == "glu":
        return F.glu
Yanghan Wang's avatar
Yanghan Wang committed
413
    raise RuntimeError(f"activation should be relu/gelu, not {activation}.")