models.py 11.7 KB
Newer Older
Zihao Ye's avatar
Zihao Ye committed
1
2
3
4
5
6
7
8
from .config import *
from .act import *
from .attention import *
from .viz import *
from .layers import *
from .functions import *
from .embedding import *
import threading
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
9

Zihao Ye's avatar
Zihao Ye committed
10
import dgl.function as fn
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
11
import torch as th
Zihao Ye's avatar
Zihao Ye committed
12
13
import torch.nn.init as INIT

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
14

Zihao Ye's avatar
Zihao Ye committed
15
16
17
18
19
20
21
class Encoder(nn.Module):
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.N = N
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
22
    def pre_func(self, i, fields="qkv"):
Zihao Ye's avatar
Zihao Ye committed
23
        layer = self.layers[i]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
24

Zihao Ye's avatar
Zihao Ye committed
25
        def func(nodes):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
26
            x = nodes.data["x"]
Zihao Ye's avatar
Zihao Ye committed
27
28
            norm_x = layer.sublayer[0].norm(x)
            return layer.self_attn.get(norm_x, fields=fields)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
29

Zihao Ye's avatar
Zihao Ye committed
30
31
32
33
        return func

    def post_func(self, i):
        layer = self.layers[i]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
34

Zihao Ye's avatar
Zihao Ye committed
35
        def func(nodes):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
36
            x, wv, z = nodes.data["x"], nodes.data["wv"], nodes.data["z"]
Zihao Ye's avatar
Zihao Ye committed
37
38
39
            o = layer.self_attn.get_o(wv / z)
            x = x + layer.sublayer[0].dropout(o)
            x = layer.sublayer[1](x, layer.feed_forward)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
40
41
            return {"x": x if i < self.N - 1 else self.norm(x)}

Zihao Ye's avatar
Zihao Ye committed
42
43
        return func

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
44

Zihao Ye's avatar
Zihao Ye committed
45
46
47
48
49
50
51
class Decoder(nn.Module):
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.N = N
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
52
    def pre_func(self, i, fields="qkv", l=0):
Zihao Ye's avatar
Zihao Ye committed
53
        layer = self.layers[i]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
54

Zihao Ye's avatar
Zihao Ye committed
55
        def func(nodes):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
56
57
58
            x = nodes.data["x"]
            norm_x = layer.sublayer[l].norm(x) if fields.startswith("q") else x
            if fields != "qkv":
59
                return layer.src_attn.get(norm_x, fields)
Zihao Ye's avatar
Zihao Ye committed
60
            else:
61
                return layer.self_attn.get(norm_x, fields)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
62

Zihao Ye's avatar
Zihao Ye committed
63
64
65
66
        return func

    def post_func(self, i, l=0):
        layer = self.layers[i]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
67

Zihao Ye's avatar
Zihao Ye committed
68
        def func(nodes):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
69
            x, wv, z = nodes.data["x"], nodes.data["wv"], nodes.data["z"]
Zihao Ye's avatar
Zihao Ye committed
70
71
72
73
            o = layer.self_attn.get_o(wv / z)
            x = x + layer.sublayer[l].dropout(o)
            if l == 1:
                x = layer.sublayer[2](x, layer.feed_forward)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
74
75
            return {"x": x if i < self.N - 1 else self.norm(x)}

Zihao Ye's avatar
Zihao Ye committed
76
77
        return func

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
78

Zihao Ye's avatar
Zihao Ye committed
79
class Transformer(nn.Module):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
80
81
82
    def __init__(
        self, encoder, decoder, src_embed, tgt_embed, pos_enc, generator, h, d_k
    ):
Zihao Ye's avatar
Zihao Ye committed
83
        super(Transformer, self).__init__()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
84
        self.encoder, self.decoder = encoder, decoder
Zihao Ye's avatar
Zihao Ye committed
85
86
87
88
89
90
91
92
        self.src_embed, self.tgt_embed = src_embed, tgt_embed
        self.pos_enc = pos_enc
        self.generator = generator
        self.h, self.d_k = h, d_k
        self.att_weight_map = None

    def propagate_attention(self, g, eids):
        # Compute attention score
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
93
94
        g.apply_edges(src_dot_dst("k", "q", "score"), eids)
        g.apply_edges(scaled_exp("score", np.sqrt(self.d_k)), eids)
Zihao Ye's avatar
Zihao Ye committed
95
        # Send weighted values to target nodes
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
96
97
        g.send_and_recv(eids, fn.u_mul_e("v", "score", "v"), fn.sum("v", "wv"))
        g.send_and_recv(eids, fn.copy_e("score", "score"), fn.sum("score", "z"))
Zihao Ye's avatar
Zihao Ye committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114

    def update_graph(self, g, eids, pre_pairs, post_pairs):
        "Update the node states and edge states of the graph."

        # Pre-compute queries and key-value pairs.
        for pre_func, nids in pre_pairs:
            g.apply_nodes(pre_func, nids)
        self.propagate_attention(g, eids)
        # Further calculation after attention mechanism
        for post_func, nids in post_pairs:
            g.apply_nodes(post_func, nids)

    def forward(self, graph):
        g = graph.g
        nids, eids = graph.nids, graph.eids

        # embed
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
115
116
117
118
119
120
121
122
123
124
125
126
        src_embed, src_pos = self.src_embed(graph.src[0]), self.pos_enc(
            graph.src[1]
        )
        tgt_embed, tgt_pos = self.tgt_embed(graph.tgt[0]), self.pos_enc(
            graph.tgt[1]
        )
        g.nodes[nids["enc"]].data["x"] = self.pos_enc.dropout(
            src_embed + src_pos
        )
        g.nodes[nids["dec"]].data["x"] = self.pos_enc.dropout(
            tgt_embed + tgt_pos
        )
Zihao Ye's avatar
Zihao Ye committed
127
128

        for i in range(self.encoder.N):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
129
            pre_func = self.encoder.pre_func(i, "qkv")
Zihao Ye's avatar
Zihao Ye committed
130
            post_func = self.encoder.post_func(i)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
131
132
133
134
            nodes, edges = nids["enc"], eids["ee"]
            self.update_graph(
                g, edges, [(pre_func, nodes)], [(post_func, nodes)]
            )
Zihao Ye's avatar
Zihao Ye committed
135
136

        for i in range(self.decoder.N):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
137
            pre_func = self.decoder.pre_func(i, "qkv")
Zihao Ye's avatar
Zihao Ye committed
138
            post_func = self.decoder.post_func(i)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
139
140
141
142
143
144
            nodes, edges = nids["dec"], eids["dd"]
            self.update_graph(
                g, edges, [(pre_func, nodes)], [(post_func, nodes)]
            )
            pre_q = self.decoder.pre_func(i, "q", 1)
            pre_kv = self.decoder.pre_func(i, "kv", 1)
Zihao Ye's avatar
Zihao Ye committed
145
            post_func = self.decoder.post_func(i, 1)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
146
147
148
149
150
151
152
            nodes_e, edges = nids["enc"], eids["ed"]
            self.update_graph(
                g,
                edges,
                [(pre_q, nodes), (pre_kv, nodes_e)],
                [(post_func, nodes)],
            )
Zihao Ye's avatar
Zihao Ye committed
153
154

        # visualize attention
155
        """
Zihao Ye's avatar
Zihao Ye committed
156
157
            if self.att_weight_map is None:
                self._register_att_map(g, graph.nid_arr['enc'][VIZ_IDX], graph.nid_arr['dec'][VIZ_IDX])
158
        """
Zihao Ye's avatar
Zihao Ye committed
159

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
160
161
        return self.generator(g.ndata["x"][nids["dec"]])

162
    def infer(self, graph, max_len, eos_id, k, alpha=1.0):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
163
        """
Zihao Ye's avatar
Zihao Ye committed
164
        This function implements Beam Search in DGL, which is required in inference phase.
165
        Length normalization is given by (5 + len) ^ alpha / 6 ^ alpha. Please refer to https://arxiv.org/pdf/1609.08144.pdf.
Zihao Ye's avatar
Zihao Ye committed
166
167
168
169
170
171
172
        args:
            graph: a `Graph` object defined in `dgl.contrib.transformer.graph`.
            max_len: the maximum length of decoding.
            eos_id: the index of end-of-sequence symbol.
            k: beam size
        return:
            ret: a list of index array correspond to the input sequence specified by `graph``.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
173
        """
Zihao Ye's avatar
Zihao Ye committed
174
175
176
177
178
179
180
        g = graph.g
        N, E = graph.n_nodes, graph.n_edges
        nids, eids = graph.nids, graph.eids

        # embed & pos
        src_embed = self.src_embed(graph.src[0])
        src_pos = self.pos_enc(graph.src[1])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
181
182
183
184
        g.nodes[nids["enc"]].data["pos"] = graph.src[1]
        g.nodes[nids["enc"]].data["x"] = self.pos_enc.dropout(
            src_embed + src_pos
        )
Zihao Ye's avatar
Zihao Ye committed
185
        tgt_pos = self.pos_enc(graph.tgt[1])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
186
        g.nodes[nids["dec"]].data["pos"] = graph.tgt[1]
Zihao Ye's avatar
Zihao Ye committed
187
188
189

        # init mask
        device = next(self.parameters()).device
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
190
        g.ndata["mask"] = th.zeros(N, dtype=th.uint8, device=device)
Zihao Ye's avatar
Zihao Ye committed
191
192
193

        # encode
        for i in range(self.encoder.N):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
194
            pre_func = self.encoder.pre_func(i, "qkv")
Zihao Ye's avatar
Zihao Ye committed
195
            post_func = self.encoder.post_func(i)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
196
197
198
199
            nodes, edges = nids["enc"], eids["ee"]
            self.update_graph(
                g, edges, [(pre_func, nodes)], [(post_func, nodes)]
            )
Zihao Ye's avatar
Zihao Ye committed
200
201
202
203
204
205
206

        # decode
        log_prob = None
        y = graph.tgt[0]
        for step in range(1, max_len):
            y = y.view(-1)
            tgt_embed = self.tgt_embed(y)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
            g.ndata["x"][nids["dec"]] = self.pos_enc.dropout(
                tgt_embed + tgt_pos
            )
            edges_ed = g.filter_edges(
                lambda e: (e.dst["pos"] < step) & ~e.dst["mask"].bool(),
                eids["ed"],
            )
            edges_dd = g.filter_edges(
                lambda e: (e.dst["pos"] < step) & ~e.dst["mask"].bool(),
                eids["dd"],
            )
            nodes_d = g.filter_nodes(
                lambda v: (v.data["pos"] < step) & ~v.data["mask"].bool(),
                nids["dec"],
            )
Zihao Ye's avatar
Zihao Ye committed
222
            for i in range(self.decoder.N):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
223
224
225
                pre_func, post_func = self.decoder.pre_func(
                    i, "qkv"
                ), self.decoder.post_func(i)
Zihao Ye's avatar
Zihao Ye committed
226
                nodes, edges = nodes_d, edges_dd
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
227
228
229
230
231
232
                self.update_graph(
                    g, edges, [(pre_func, nodes)], [(post_func, nodes)]
                )
                pre_q, pre_kv = self.decoder.pre_func(
                    i, "q", 1
                ), self.decoder.pre_func(i, "kv", 1)
Zihao Ye's avatar
Zihao Ye committed
233
                post_func = self.decoder.post_func(i, 1)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
234
235
236
237
238
239
240
                nodes_e, nodes_d, edges = nids["enc"], nodes_d, edges_ed
                self.update_graph(
                    g,
                    edges,
                    [(pre_q, nodes_d), (pre_kv, nodes_e)],
                    [(post_func, nodes_d)],
                )
Zihao Ye's avatar
Zihao Ye committed
241

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
242
243
244
245
            frontiers = g.filter_nodes(
                lambda v: v.data["pos"] == step - 1, nids["dec"]
            )
            out = self.generator(g.ndata["x"][frontiers])
Zihao Ye's avatar
Zihao Ye committed
246
247
            batch_size = frontiers.shape[0] // k
            vocab_size = out.shape[-1]
248
249
250
            # Mask output for complete sequence
            one_hot = th.zeros(vocab_size).fill_(-1e9).to(device)
            one_hot[eos_id] = 0
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
251
            mask = g.ndata["mask"][frontiers].unsqueeze(-1).float()
252
253
            out = out * (1 - mask) + one_hot.unsqueeze(0) * mask

Zihao Ye's avatar
Zihao Ye committed
254
            if log_prob is None:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
255
256
257
                log_prob, pos = out.view(batch_size, k, -1)[:, 0, :].topk(
                    k, dim=-1
                )
258
                eos = th.zeros(batch_size, k).byte()
Zihao Ye's avatar
Zihao Ye committed
259
            else:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
                norm_old = eos.float().to(device) + (
                    1 - eos.float().to(device)
                ) * np.power((4.0 + step) / 6, alpha)
                norm_new = eos.float().to(device) + (
                    1 - eos.float().to(device)
                ) * np.power((5.0 + step) / 6, alpha)
                log_prob, pos = (
                    (
                        (
                            out.view(batch_size, k, -1)
                            + (log_prob * norm_old).unsqueeze(-1)
                        )
                        / norm_new.unsqueeze(-1)
                    )
                    .view(batch_size, -1)
                    .topk(k, dim=-1)
                )
Zihao Ye's avatar
Zihao Ye committed
277
278
279

            _y = y.view(batch_size * k, -1)
            y = th.zeros_like(_y)
280
            _eos = eos.clone()
Zihao Ye's avatar
Zihao Ye committed
281
            for i in range(batch_size):
282
283
284
                for j in range(k):
                    _j = pos[i, j].item() // vocab_size
                    token = pos[i, j].item() % vocab_size
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
285
286
                    y[i * k + j, :] = _y[i * k + _j, :]
                    y[i * k + j, step] = token
287
                    eos[i, j] = _eos[i, _j] | (token == eos_id)
Zihao Ye's avatar
Zihao Ye committed
288
289
290
291

            if eos.all():
                break
            else:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
292
293
294
                g.ndata["mask"][nids["dec"]] = (
                    eos.unsqueeze(-1).repeat(1, 1, max_len).view(-1).to(device)
                )
Zihao Ye's avatar
Zihao Ye committed
295
296
297
298
299
300
301
302
303
304
        return y.view(batch_size, k, -1)[:, 0, :].tolist()

    def _register_att_map(self, g, enc_ids, dec_ids):
        self.att_weight_map = [
            get_attention_map(g, enc_ids, enc_ids, self.h),
            get_attention_map(g, enc_ids, dec_ids, self.h),
            get_attention_map(g, dec_ids, dec_ids, self.h),
        ]


Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
305
306
307
308
309
310
311
312
313
314
def make_model(
    src_vocab,
    tgt_vocab,
    N=6,
    dim_model=512,
    dim_ff=2048,
    h=8,
    dropout=0.1,
    universal=False,
):
Zihao Ye's avatar
Zihao Ye committed
315
    if universal:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
316
317
318
        return make_universal_model(
            src_vocab, tgt_vocab, dim_model, dim_ff, h, dropout
        )
Zihao Ye's avatar
Zihao Ye committed
319
320
321
322
323
324
    c = copy.deepcopy
    attn = MultiHeadAttention(h, dim_model)
    ff = PositionwiseFeedForward(dim_model, dim_ff)
    pos_enc = PositionalEncoding(dim_model, dropout)

    encoder = Encoder(EncoderLayer(dim_model, c(attn), c(ff), dropout), N)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
325
326
327
    decoder = Decoder(
        DecoderLayer(dim_model, c(attn), c(attn), c(ff), dropout), N
    )
Zihao Ye's avatar
Zihao Ye committed
328
329
330
331
    src_embed = Embeddings(src_vocab, dim_model)
    tgt_embed = Embeddings(tgt_vocab, dim_model)
    generator = Generator(dim_model, tgt_vocab)
    model = Transformer(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
332
333
334
335
336
337
338
339
340
        encoder,
        decoder,
        src_embed,
        tgt_embed,
        pos_enc,
        generator,
        h,
        dim_model // h,
    )
Zihao Ye's avatar
Zihao Ye committed
341
342
343
344
345
    # xavier init
    for p in model.parameters():
        if p.dim() > 1:
            INIT.xavier_uniform_(p)
    return model