act.py 9.33 KB
Newer Older
Zihao Ye's avatar
Zihao Ye committed
1
2
3
4
5
from .attention import *
from .layers import *
from .functions import *
from .embedding import *
import dgl.function as fn
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
6
import torch as th
Zihao Ye's avatar
Zihao Ye committed
7
8
import torch.nn.init as INIT

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
9

Zihao Ye's avatar
Zihao Ye committed
10
11
12
13
14
15
class UEncoder(nn.Module):
    def __init__(self, layer):
        super(UEncoder, self).__init__()
        self.layer = layer
        self.norm = LayerNorm(layer.size)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
16
    def pre_func(self, fields="qkv"):
Zihao Ye's avatar
Zihao Ye committed
17
        layer = self.layer
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
18

Zihao Ye's avatar
Zihao Ye committed
19
        def func(nodes):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
20
            x = nodes.data["x"]
Zihao Ye's avatar
Zihao Ye committed
21
22
            norm_x = layer.sublayer[0].norm(x)
            return layer.self_attn.get(norm_x, fields=fields)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
23

Zihao Ye's avatar
Zihao Ye committed
24
25
26
27
        return func

    def post_func(self):
        layer = self.layer
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
28

Zihao Ye's avatar
Zihao Ye committed
29
        def func(nodes):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
30
            x, wv, z = nodes.data["x"], nodes.data["wv"], nodes.data["z"]
Zihao Ye's avatar
Zihao Ye committed
31
32
33
            o = layer.self_attn.get_o(wv / z)
            x = x + layer.sublayer[0].dropout(o)
            x = layer.sublayer[1](x, layer.feed_forward)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
34
35
            return {"x": x}

Zihao Ye's avatar
Zihao Ye committed
36
37
38
39
40
41
42
43
44
        return func


class UDecoder(nn.Module):
    def __init__(self, layer):
        super(UDecoder, self).__init__()
        self.layer = layer
        self.norm = LayerNorm(layer.size)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
45
    def pre_func(self, fields="qkv", l=0):
Zihao Ye's avatar
Zihao Ye committed
46
        layer = self.layer
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
47

Zihao Ye's avatar
Zihao Ye committed
48
        def func(nodes):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
49
50
            x = nodes.data["x"]
            if fields == "kv":
Zihao Ye's avatar
Zihao Ye committed
51
52
53
54
                norm_x = x
            else:
                norm_x = layer.sublayer[l].norm(x)
            return layer.self_attn.get(norm_x, fields)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
55

Zihao Ye's avatar
Zihao Ye committed
56
57
58
59
        return func

    def post_func(self, l=0):
        layer = self.layer
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
60

Zihao Ye's avatar
Zihao Ye committed
61
        def func(nodes):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
62
            x, wv, z = nodes.data["x"], nodes.data["wv"], nodes.data["z"]
Zihao Ye's avatar
Zihao Ye committed
63
64
65
66
            o = layer.self_attn.get_o(wv / z)
            x = x + layer.sublayer[l].dropout(o)
            if l == 1:
                x = layer.sublayer[2](x, layer.feed_forward)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
67
68
            return {"x": x}

Zihao Ye's avatar
Zihao Ye committed
69
70
71
72
73
        return func


class HaltingUnit(nn.Module):
    halting_bias_init = 1.0
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
74

Zihao Ye's avatar
Zihao Ye committed
75
76
77
78
79
80
81
82
83
    def __init__(self, dim_model):
        super(HaltingUnit, self).__init__()
        self.linear = nn.Linear(dim_model, 1)
        self.norm = LayerNorm(dim_model)
        INIT.constant_(self.linear.bias, self.halting_bias_init)

    def forward(self, x):
        return th.sigmoid(self.linear(self.norm(x)))

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
84

Zihao Ye's avatar
Zihao Ye committed
85
86
87
88
class UTransformer(nn.Module):
    "Universal Transformer(https://arxiv.org/pdf/1807.03819.pdf) with ACT(https://arxiv.org/pdf/1603.08983.pdf)."
    MAX_DEPTH = 8
    thres = 0.99
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    act_loss_weight = 0.01

    def __init__(
        self,
        encoder,
        decoder,
        src_embed,
        tgt_embed,
        pos_enc,
        time_enc,
        generator,
        h,
        d_k,
    ):
Zihao Ye's avatar
Zihao Ye committed
103
        super(UTransformer, self).__init__()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
104
        self.encoder, self.decoder = encoder, decoder
Zihao Ye's avatar
Zihao Ye committed
105
106
107
108
109
110
111
112
113
114
115
116
        self.src_embed, self.tgt_embed = src_embed, tgt_embed
        self.pos_enc, self.time_enc = pos_enc, time_enc
        self.halt_enc = HaltingUnit(h * d_k)
        self.halt_dec = HaltingUnit(h * d_k)
        self.generator = generator
        self.h, self.d_k = h, d_k
        self.reset_stat()

    def reset_stat(self):
        self.stat = [0] * (self.MAX_DEPTH + 1)

    def step_forward(self, nodes):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
117
118
119
120
121
122
123
124
125
        x = nodes.data["x"]
        step = nodes.data["step"]
        pos = nodes.data["pos"]
        return {
            "x": self.pos_enc.dropout(
                x + self.pos_enc(pos.view(-1)) + self.time_enc(step.view(-1))
            ),
            "step": step + 1,
        }
Zihao Ye's avatar
Zihao Ye committed
126
127
128

    def halt_and_accum(self, name, end=False):
        "field: 'enc' or 'dec'"
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
129
        halt = self.halt_enc if name == "enc" else self.halt_dec
Zihao Ye's avatar
Zihao Ye committed
130
        thres = self.thres
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
131

Zihao Ye's avatar
Zihao Ye committed
132
        def func(nodes):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
133
134
            p = halt(nodes.data["x"])
            sum_p = nodes.data["sum_p"] + p
Zihao Ye's avatar
Zihao Ye committed
135
136
            active = (sum_p < thres) & (1 - end)
            _continue = active.float()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
137
138
139
140
141
142
143
            r = nodes.data["r"] * (1 - _continue) + (1 - sum_p) * _continue
            s = (
                nodes.data["s"]
                + ((1 - _continue) * r + _continue * p) * nodes.data["x"]
            )
            return {"p": p, "sum_p": sum_p, "r": r, "s": s, "active": active}

Zihao Ye's avatar
Zihao Ye committed
144
145
146
147
        return func

    def propagate_attention(self, g, eids):
        # Compute attention score
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
148
149
        g.apply_edges(src_dot_dst("k", "q", "score"), eids)
        g.apply_edges(scaled_exp("score", np.sqrt(self.d_k)), eids)
Zihao Ye's avatar
Zihao Ye committed
150
        # Send weighted values to target nodes
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
151
152
153
154
155
        g.send_and_recv(
            eids,
            [fn.u_mul_e("v", "score", "v"), fn.copy_e("score", "score")],
            [fn.sum("v", "wv"), fn.sum("score", "z")],
        )
Zihao Ye's avatar
Zihao Ye committed
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172

    def update_graph(self, g, eids, pre_pairs, post_pairs):
        "Update the node states and edge states of the graph."
        # Pre-compute queries and key-value pairs.
        for pre_func, nids in pre_pairs:
            g.apply_nodes(pre_func, nids)
        self.propagate_attention(g, eids)
        # Further calculation after attention mechanism
        for post_func, nids in post_pairs:
            g.apply_nodes(post_func, nids)

    def forward(self, graph):
        g = graph.g
        N, E = graph.n_nodes, graph.n_edges
        nids, eids = graph.nids, graph.eids

        # embed & pos
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
173
174
175
176
        g.nodes[nids["enc"]].data["x"] = self.src_embed(graph.src[0])
        g.nodes[nids["dec"]].data["x"] = self.tgt_embed(graph.tgt[0])
        g.nodes[nids["enc"]].data["pos"] = graph.src[1]
        g.nodes[nids["dec"]].data["pos"] = graph.tgt[1]
Zihao Ye's avatar
Zihao Ye committed
177
178
179

        # init step
        device = next(self.parameters()).device
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        g.ndata["s"] = th.zeros(
            N, self.h * self.d_k, dtype=th.float, device=device
        )  # accumulated state
        g.ndata["p"] = th.zeros(
            N, 1, dtype=th.float, device=device
        )  # halting prob
        g.ndata["r"] = th.ones(N, 1, dtype=th.float, device=device)  # remainder
        g.ndata["sum_p"] = th.zeros(
            N, 1, dtype=th.float, device=device
        )  # sum of pondering values
        g.ndata["step"] = th.zeros(N, 1, dtype=th.long, device=device)  # step
        g.ndata["active"] = th.ones(
            N, 1, dtype=th.uint8, device=device
        )  # active
Zihao Ye's avatar
Zihao Ye committed
194
195

        for step in range(self.MAX_DEPTH):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
196
            pre_func = self.encoder.pre_func("qkv")
Zihao Ye's avatar
Zihao Ye committed
197
            post_func = self.encoder.post_func()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
198
199
200
201
202
203
204
205
            nodes = g.filter_nodes(
                lambda v: v.data["active"].view(-1), nids["enc"]
            )
            if len(nodes) == 0:
                break
            edges = g.filter_edges(
                lambda e: e.dst["active"].view(-1), eids["ee"]
            )
Zihao Ye's avatar
Zihao Ye committed
206
            end = step == self.MAX_DEPTH - 1
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
207
208
209
210
211
212
            self.update_graph(
                g,
                edges,
                [(self.step_forward, nodes), (pre_func, nodes)],
                [(post_func, nodes), (self.halt_and_accum("enc", end), nodes)],
            )
Zihao Ye's avatar
Zihao Ye committed
213

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
214
215
216
        g.nodes[nids["enc"]].data["x"] = self.encoder.norm(
            g.nodes[nids["enc"]].data["s"]
        )
Zihao Ye's avatar
Zihao Ye committed
217
218

        for step in range(self.MAX_DEPTH):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
219
            pre_func = self.decoder.pre_func("qkv")
Zihao Ye's avatar
Zihao Ye committed
220
            post_func = self.decoder.post_func()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
            nodes = g.filter_nodes(
                lambda v: v.data["active"].view(-1), nids["dec"]
            )
            if len(nodes) == 0:
                break
            edges = g.filter_edges(
                lambda e: e.dst["active"].view(-1), eids["dd"]
            )
            self.update_graph(
                g,
                edges,
                [(self.step_forward, nodes), (pre_func, nodes)],
                [(post_func, nodes)],
            )

            pre_q = self.decoder.pre_func("q", 1)
            pre_kv = self.decoder.pre_func("kv", 1)
Zihao Ye's avatar
Zihao Ye committed
238
            post_func = self.decoder.post_func(1)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
239
240
241
242
            nodes_e = nids["enc"]
            edges = g.filter_edges(
                lambda e: e.dst["active"].view(-1), eids["ed"]
            )
Zihao Ye's avatar
Zihao Ye committed
243
            end = step == self.MAX_DEPTH - 1
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
244
245
246
247
248
249
            self.update_graph(
                g,
                edges,
                [(pre_q, nodes), (pre_kv, nodes_e)],
                [(post_func, nodes), (self.halt_and_accum("dec", end), nodes)],
            )
Zihao Ye's avatar
Zihao Ye committed
250

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
251
252
253
254
        g.nodes[nids["dec"]].data["x"] = self.decoder.norm(
            g.nodes[nids["dec"]].data["s"]
        )
        act_loss = th.mean(g.ndata["r"])  # ACT loss
Zihao Ye's avatar
Zihao Ye committed
255
256
257

        self.stat[0] += N
        for step in range(1, self.MAX_DEPTH + 1):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
258
            self.stat[step] += th.sum(g.ndata["step"] >= step).item()
Zihao Ye's avatar
Zihao Ye committed
259

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
260
261
262
263
        return (
            self.generator(g.ndata["x"][nids["dec"]]),
            act_loss * self.act_loss_weight,
        )
Zihao Ye's avatar
Zihao Ye committed
264
265
266
267
268

    def infer(self, *args, **kwargs):
        raise NotImplementedError


Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
269
270
271
def make_universal_model(
    src_vocab, tgt_vocab, dim_model=512, dim_ff=2048, h=8, dropout=0.1
):
Zihao Ye's avatar
Zihao Ye committed
272
273
274
275
276
277
    c = copy.deepcopy
    attn = MultiHeadAttention(h, dim_model)
    ff = PositionwiseFeedForward(dim_model, dim_ff)
    pos_enc = PositionalEncoding(dim_model, dropout)
    time_enc = PositionalEncoding(dim_model, dropout)
    encoder = UEncoder(EncoderLayer((dim_model), c(attn), c(ff), dropout))
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
278
279
280
    decoder = UDecoder(
        DecoderLayer((dim_model), c(attn), c(attn), c(ff), dropout)
    )
Zihao Ye's avatar
Zihao Ye committed
281
282
283
284
    src_embed = Embeddings(src_vocab, dim_model)
    tgt_embed = Embeddings(tgt_vocab, dim_model)
    generator = Generator(dim_model, tgt_vocab)
    model = UTransformer(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
285
286
287
288
289
290
291
292
293
294
        encoder,
        decoder,
        src_embed,
        tgt_embed,
        pos_enc,
        time_enc,
        generator,
        h,
        dim_model // h,
    )
Zihao Ye's avatar
Zihao Ye committed
295
296
297
298
299
    # xavier init
    for p in model.parameters():
        if p.dim() > 1:
            INIT.xavier_uniform_(p)
    return model