embedding.py 3.36 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init


class VocabEmbedding(torch.nn.Module):

    def __init__(self, num_embeddings, embedding_dim):
        super(VocabEmbedding, self).__init__()
        # Keep the input dimensions.
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.padding_idx = None
        self.max_norm = None
        self.norm_type = 2.
        self.scale_grad_by_freq = False
        self.sparse = False
        self._weight = None

        # Allocate weights and initialize.
        self.weight = nn.Parameter(torch.empty(
            self.num_embeddings, self.embedding_dim))
        init.xavier_uniform_(self.weight)

    def forward(self, hidden_state):
        output = F.embedding(hidden_state, self.weight,
                             self.padding_idx, self.max_norm,
                             self.norm_type, self.scale_grad_by_freq,
                             self.sparse)
        return output

    def __repr__(self):
        return f'VocabEmbedding(num_embeddings={self.num_embeddings}, ' \
            f'embedding_dim={self.embedding_dim})'


class Embedding(nn.Module):
    """Language model embeddings.
    Arguments:
        hidden_size: hidden size
        vocab_size: vocabulary size
        max_sequence_length: maximum size of sequence. This
                             is used for positional embedding
        embedding_dropout_prob: dropout probability for embeddings
        init_method: weight initialization method
        num_tokentypes: size of the token-type embeddings. 0 value
                        will ignore this embedding
    """

    def __init__(self,
                 hidden_size,
                 vocab_size,
                 max_sequence_length,
                 embedding_dropout_prob,
                 num_tokentypes):
        super(Embedding, self).__init__()

        self.hidden_size = hidden_size
        self.num_tokentypes = num_tokentypes

        self.word_embeddings = VocabEmbedding(vocab_size, self.hidden_size)

        # Position embedding (serial).
        self.position_embeddings = torch.nn.Embedding(
            max_sequence_length, self.hidden_size)

        # Token type embedding.
        # Add this as an optional field that can be added through
        # method call so we can load a pretrain model without
        # token types and add them as needed.
        if self.num_tokentypes > 0:
            self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
                                                           self.hidden_size)
        else:
            self.tokentype_embeddings = None

        # Embeddings dropout
        self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)

    @property
    def word_embedding_weight(self):
        return self.word_embeddings.weight

    def forward(self, input_ids, position_ids, tokentype_ids=None):
        # Embeddings.
        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        embeddings = words_embeddings + position_embeddings
        if tokentype_ids is not None and self.tokentype_embeddings is not None:
            embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)

        # Dropout.
        embeddings = self.embedding_dropout(embeddings)

        return embeddings