graph.py 8.53 KB
Newer Older
Zihao Ye's avatar
Zihao Ye committed
1
2
3
4
import itertools
import time
from collections import *

5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import numpy as np
import torch as th

import dgl

Graph = namedtuple(
    "Graph",
    [
        "g",
        "src",
        "tgt",
        "tgt_y",
        "nids",
        "eids",
        "nid_arr",
        "n_nodes",
        "n_edges",
        "n_tokens",
    ],
)

Zihao Ye's avatar
Zihao Ye committed
26
27

class GraphPool:
brett koonce's avatar
brett koonce committed
28
    "Create a graph pool in advance to accelerate graph building phase in Transformer."
29

Zihao Ye's avatar
Zihao Ye committed
30
    def __init__(self, n=50, m=50):
31
        """
Zihao Ye's avatar
Zihao Ye committed
32
33
34
        args:
            n: maximum length of input sequence.
            m: maximum length of output sequence.
35
36
        """
        print("start creating graph pool...")
Zihao Ye's avatar
Zihao Ye committed
37
38
        tic = time.time()
        self.n, self.m = n, m
39
        g_pool = [[dgl.graph(([], [])) for _ in range(m)] for _ in range(n)]
Zihao Ye's avatar
Zihao Ye committed
40
        num_edges = {
41
42
43
            "ee": np.zeros((n, n)).astype(int),
            "ed": np.zeros((n, m)).astype(int),
            "dd": np.zeros((m, m)).astype(int),
Zihao Ye's avatar
Zihao Ye committed
44
45
46
47
48
49
50
51
52
53
54
55
56
        }
        for i, j in itertools.product(range(n), range(m)):
            src_length = i + 1
            tgt_length = j + 1

            g_pool[i][j].add_nodes(src_length + tgt_length)
            enc_nodes = th.arange(src_length, dtype=th.long)
            dec_nodes = th.arange(tgt_length, dtype=th.long) + src_length

            # enc -> enc
            us = enc_nodes.unsqueeze(-1).repeat(1, src_length).view(-1)
            vs = enc_nodes.repeat(src_length)
            g_pool[i][j].add_edges(us, vs)
57
            num_edges["ee"][i][j] = len(us)
Zihao Ye's avatar
Zihao Ye committed
58
59
60
61
            # enc -> dec
            us = enc_nodes.unsqueeze(-1).repeat(1, tgt_length).view(-1)
            vs = dec_nodes.repeat(src_length)
            g_pool[i][j].add_edges(us, vs)
62
            num_edges["ed"][i][j] = len(us)
Zihao Ye's avatar
Zihao Ye committed
63
64
65
66
67
            # dec -> dec
            indices = th.triu(th.ones(tgt_length, tgt_length)) == 1
            us = dec_nodes.unsqueeze(-1).repeat(1, tgt_length)[indices]
            vs = dec_nodes.unsqueeze(0).repeat(tgt_length, 1)[indices]
            g_pool[i][j].add_edges(us, vs)
68
            num_edges["dd"][i][j] = len(us)
Zihao Ye's avatar
Zihao Ye committed
69

70
71
72
73
74
        print(
            "successfully created graph pool, time: {0:0.3f}s".format(
                time.time() - tic
            )
        )
Zihao Ye's avatar
Zihao Ye committed
75
76
77
        self.g_pool = g_pool
        self.num_edges = num_edges

78
79
    def beam(self, src_buf, start_sym, max_len, k, device="cpu"):
        """
Zihao Ye's avatar
Zihao Ye committed
80
81
82
83
84
85
        Return a batched graph for beam search during inference of Transformer.
        args:
            src_buf: a list of input sequence
            start_sym: the index of start-of-sequence symbol
            max_len: maximum length for decoding
            k: beam size
86
87
            device: 'cpu' or 'cuda:*'
        """
Zihao Ye's avatar
Zihao Ye committed
88
89
90
        g_list = []
        src_lens = [len(_) for _ in src_buf]
        tgt_lens = [max_len] * len(src_buf)
91
        num_edges = {"ee": [], "ed": [], "dd": []}
Zihao Ye's avatar
Zihao Ye committed
92
93
94
95
        for src_len, tgt_len in zip(src_lens, tgt_lens):
            i, j = src_len - 1, tgt_len - 1
            for _ in range(k):
                g_list.append(self.g_pool[i][j])
96
            for key in ["ee", "ed", "dd"]:
Zihao Ye's avatar
Zihao Ye committed
97
98
99
100
101
102
103
104
                num_edges[key].append(int(self.num_edges[key][i][j]))

        g = dgl.batch(g_list)
        src, tgt = [], []
        src_pos, tgt_pos = [], []
        enc_ids, dec_ids = [], []
        e2e_eids, e2d_eids, d2d_eids = [], [], []
        n_nodes, n_edges, n_tokens = 0, 0, 0
105
106
107
        for src_sample, n, n_ee, n_ed, n_dd in zip(
            src_buf, src_lens, num_edges["ee"], num_edges["ed"], num_edges["dd"]
        ):
Zihao Ye's avatar
Zihao Ye committed
108
109
110
            for _ in range(k):
                src.append(th.tensor(src_sample, dtype=th.long, device=device))
                src_pos.append(th.arange(n, dtype=th.long, device=device))
111
112
113
114
115
                enc_ids.append(
                    th.arange(
                        n_nodes, n_nodes + n, dtype=th.long, device=device
                    )
                )
Zihao Ye's avatar
Zihao Ye committed
116
                n_nodes += n
117
118
119
120
121
                e2e_eids.append(
                    th.arange(
                        n_edges, n_edges + n_ee, dtype=th.long, device=device
                    )
                )
Zihao Ye's avatar
Zihao Ye committed
122
123
124
125
126
127
                n_edges += n_ee
                tgt_seq = th.zeros(max_len, dtype=th.long, device=device)
                tgt_seq[0] = start_sym
                tgt.append(tgt_seq)
                tgt_pos.append(th.arange(max_len, dtype=th.long, device=device))

128
129
130
131
132
                dec_ids.append(
                    th.arange(
                        n_nodes, n_nodes + max_len, dtype=th.long, device=device
                    )
                )
Zihao Ye's avatar
Zihao Ye committed
133
                n_nodes += max_len
134
135
136
137
138
                e2d_eids.append(
                    th.arange(
                        n_edges, n_edges + n_ed, dtype=th.long, device=device
                    )
                )
Zihao Ye's avatar
Zihao Ye committed
139
                n_edges += n_ed
140
141
142
143
144
                d2d_eids.append(
                    th.arange(
                        n_edges, n_edges + n_dd, dtype=th.long, device=device
                    )
                )
Zihao Ye's avatar
Zihao Ye committed
145
146
147
148
                n_edges += n_dd

        g.set_n_initializer(dgl.init.zero_initializer)
        g.set_e_initializer(dgl.init.zero_initializer)
149
        g = g.to(device).long()
Zihao Ye's avatar
Zihao Ye committed
150

151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
        return Graph(
            g=g,
            src=(th.cat(src), th.cat(src_pos)),
            tgt=(th.cat(tgt), th.cat(tgt_pos)),
            tgt_y=None,
            nids={"enc": th.cat(enc_ids), "dec": th.cat(dec_ids)},
            eids={
                "ee": th.cat(e2e_eids),
                "ed": th.cat(e2d_eids),
                "dd": th.cat(d2d_eids),
            },
            nid_arr={"enc": enc_ids, "dec": dec_ids},
            n_nodes=n_nodes,
            n_edges=n_edges,
            n_tokens=n_tokens,
        )

    def __call__(self, src_buf, tgt_buf, device="cpu"):
        """
Zihao Ye's avatar
Zihao Ye committed
170
171
172
173
174
        Return a batched graph for the training phase of Transformer.
        args:
            src_buf: a set of input sequence arrays.
            tgt_buf: a set of output sequence arrays.
            device: 'cpu' or 'cuda:*'
175
        """
Zihao Ye's avatar
Zihao Ye committed
176
177
178
        g_list = []
        src_lens = [len(_) for _ in src_buf]
        tgt_lens = [len(_) - 1 for _ in tgt_buf]
179
        num_edges = {"ee": [], "ed": [], "dd": []}
Zihao Ye's avatar
Zihao Ye committed
180
181
182
        for src_len, tgt_len in zip(src_lens, tgt_lens):
            i, j = src_len - 1, tgt_len - 1
            g_list.append(self.g_pool[i][j])
183
            for key in ["ee", "ed", "dd"]:
Zihao Ye's avatar
Zihao Ye committed
184
185
186
187
188
189
190
191
                num_edges[key].append(int(self.num_edges[key][i][j]))

        g = dgl.batch(g_list)
        src, tgt, tgt_y = [], [], []
        src_pos, tgt_pos = [], []
        enc_ids, dec_ids = [], []
        e2e_eids, d2d_eids, e2d_eids = [], [], []
        n_nodes, n_edges, n_tokens = 0, 0, 0
192
193
194
195
196
197
198
199
200
        for src_sample, tgt_sample, n, m, n_ee, n_ed, n_dd in zip(
            src_buf,
            tgt_buf,
            src_lens,
            tgt_lens,
            num_edges["ee"],
            num_edges["ed"],
            num_edges["dd"],
        ):
Zihao Ye's avatar
Zihao Ye committed
201
202
            src.append(th.tensor(src_sample, dtype=th.long, device=device))
            tgt.append(th.tensor(tgt_sample[:-1], dtype=th.long, device=device))
203
204
205
            tgt_y.append(
                th.tensor(tgt_sample[1:], dtype=th.long, device=device)
            )
Zihao Ye's avatar
Zihao Ye committed
206
207
            src_pos.append(th.arange(n, dtype=th.long, device=device))
            tgt_pos.append(th.arange(m, dtype=th.long, device=device))
208
209
210
            enc_ids.append(
                th.arange(n_nodes, n_nodes + n, dtype=th.long, device=device)
            )
Zihao Ye's avatar
Zihao Ye committed
211
            n_nodes += n
212
213
214
            dec_ids.append(
                th.arange(n_nodes, n_nodes + m, dtype=th.long, device=device)
            )
Zihao Ye's avatar
Zihao Ye committed
215
            n_nodes += m
216
217
218
            e2e_eids.append(
                th.arange(n_edges, n_edges + n_ee, dtype=th.long, device=device)
            )
Zihao Ye's avatar
Zihao Ye committed
219
            n_edges += n_ee
220
221
222
            e2d_eids.append(
                th.arange(n_edges, n_edges + n_ed, dtype=th.long, device=device)
            )
Zihao Ye's avatar
Zihao Ye committed
223
            n_edges += n_ed
224
225
226
            d2d_eids.append(
                th.arange(n_edges, n_edges + n_dd, dtype=th.long, device=device)
            )
Zihao Ye's avatar
Zihao Ye committed
227
228
229
230
231
            n_edges += n_dd
            n_tokens += m

        g.set_n_initializer(dgl.init.zero_initializer)
        g.set_e_initializer(dgl.init.zero_initializer)
232
        g = g.to(device).long()
Zihao Ye's avatar
Zihao Ye committed
233

234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
        return Graph(
            g=g,
            src=(th.cat(src), th.cat(src_pos)),
            tgt=(th.cat(tgt), th.cat(tgt_pos)),
            tgt_y=th.cat(tgt_y),
            nids={"enc": th.cat(enc_ids), "dec": th.cat(dec_ids)},
            eids={
                "ee": th.cat(e2e_eids),
                "ed": th.cat(e2d_eids),
                "dd": th.cat(d2d_eids),
            },
            nid_arr={"enc": enc_ids, "dec": dec_ids},
            n_nodes=n_nodes,
            n_edges=n_edges,
            n_tokens=n_tokens,
        )