layers.py 2.3 KB
Newer Older
Zihao Ye's avatar
Zihao Ye committed
1
2
3
4
import torch as th
import torch.nn as nn
from torch.nn import LayerNorm

5

Zihao Ye's avatar
Zihao Ye committed
6
class Generator(nn.Module):
7
    """
Zihao Ye's avatar
Zihao Ye committed
8
9
    Generate next token from the representation. This part is separated from the decoder, mostly for the convenience of sharing weight between embedding and generator.
    log(softmax(Wx + b))
10
11
    """

Zihao Ye's avatar
Zihao Ye committed
12
13
14
15
16
    def __init__(self, dim_model, vocab_size):
        super(Generator, self).__init__()
        self.proj = nn.Linear(dim_model, vocab_size)

    def forward(self, x):
17
        return th.log_softmax(self.proj(x), dim=-1)
Zihao Ye's avatar
Zihao Ye committed
18
19
20


class SubLayerWrapper(nn.Module):
21
    """
Zihao Ye's avatar
Zihao Ye committed
22
23
    The module wraps normalization, dropout, residual connection into one equation:
    sublayerwrapper(sublayer)(x) = x + dropout(sublayer(norm(x)))
24
25
    """

Zihao Ye's avatar
Zihao Ye committed
26
27
28
29
30
31
32
33
34
35
    def __init__(self, size, dropout):
        super(SubLayerWrapper, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))


class PositionwiseFeedForward(nn.Module):
36
    """
Zihao Ye's avatar
Zihao Ye committed
37
38
    This module implements feed-forward network(after the Multi-Head Network) equation:
    FFN(x) = max(0, x @ W_1 + b_1) @ W_2 + b_2
39
40
    """

Zihao Ye's avatar
Zihao Ye committed
41
42
43
44
45
46
47
48
49
50
51
    def __init__(self, dim_model, dim_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(dim_model, dim_ff)
        self.w_2 = nn.Linear(dim_ff, dim_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(th.relu(self.w_1(x))))


import copy
52
53


Zihao Ye's avatar
Zihao Ye committed
54
def clones(module, k):
55
56
    return nn.ModuleList(copy.deepcopy(module) for _ in range(k))

Zihao Ye's avatar
Zihao Ye committed
57
58
59
60
61

class EncoderLayer(nn.Module):
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.size = size
62
        self.self_attn = self_attn  # (key, query, value, mask)
Zihao Ye's avatar
Zihao Ye committed
63
64
65
66
67
68
69
70
        self.feed_forward = feed_forward
        self.sublayer = clones(SubLayerWrapper(size, dropout), 2)


class DecoderLayer(nn.Module):
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
71
        self.self_attn = self_attn  # (key, query, value, mask)
Zihao Ye's avatar
Zihao Ye committed
72
73
74
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SubLayerWrapper(size, dropout), 3)