embedding.py 1.23 KB
Newer Older
1
import numpy as np
Zihao Ye's avatar
Zihao Ye committed
2
3
import torch as th
import torch.nn as nn
4

Zihao Ye's avatar
Zihao Ye committed
5
6
7

class PositionalEncoding(nn.Module):
    "Position Encoding module"
8

Zihao Ye's avatar
Zihao Ye committed
9
10
11
12
13
14
    def __init__(self, dim_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        # Compute the positional encodings once in log space.
        pe = th.zeros(max_len, dim_model, dtype=th.float)
        position = th.arange(0, max_len, dtype=th.float).unsqueeze(1)
15
16
17
18
        div_term = th.exp(
            th.arange(0, dim_model, 2, dtype=th.float)
            * -(np.log(10000.0) / dim_model)
        )
Zihao Ye's avatar
Zihao Ye committed
19
20
21
        pe[:, 0::2] = th.sin(position * div_term)
        pe[:, 1::2] = th.cos(position * div_term)
        pe = pe.unsqueeze(0)
22
23
24
        self.register_buffer(
            "pe", pe
        )  # Not a parameter but should be in state_dict
Zihao Ye's avatar
Zihao Ye committed
25
26
27
28
29
30
31

    def forward(self, pos):
        return th.index_select(self.pe, 1, pos).squeeze(0)


class Embeddings(nn.Module):
    "Word Embedding module"
32

Zihao Ye's avatar
Zihao Ye committed
33
34
35
36
37
38
39
    def __init__(self, vocab_size, dim_model):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab_size, dim_model)
        self.dim_model = dim_model

    def forward(self, x):
        return self.lut(x) * np.sqrt(self.dim_model)