graphwriter.py 12.6 KB
Newer Older
1
import torch
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
2
from modules import BiLSTM, GraphTrans, MSA
3
from torch import nn
4
5
from utlis import *

6
import dgl
7

8
9
10
11
12
13

class GraphWriter(nn.Module):
    def __init__(self, args):
        super(GraphWriter, self).__init__()
        self.args = args
        if args.title:
14
15
16
17
            self.title_emb = nn.Embedding(
                len(args.title_vocab), args.nhid, padding_idx=0
            )
            self.title_enc = BiLSTM(args, enc_type="title")
18
            self.title_attn = MSA(args)
19
20
21
22
23
24
        self.ent_emb = nn.Embedding(
            len(args.ent_text_vocab), args.nhid, padding_idx=0
        )
        self.tar_emb = nn.Embedding(
            len(args.text_vocab), args.nhid, padding_idx=0
        )
25
26
27
        if args.title:
            nn.init.xavier_normal_(self.title_emb.weight)
        nn.init.xavier_normal_(self.ent_emb.weight)
28
29
30
        self.rel_emb = nn.Embedding(
            len(args.rel_vocab), args.nhid, padding_idx=0
        )
31
32
        nn.init.xavier_normal_(self.rel_emb.weight)
        self.decode_lstm = nn.LSTMCell(args.dec_ninp, args.nhid)
33
        self.ent_enc = BiLSTM(args, enc_type="entity")
34
35
        self.graph_enc = GraphTrans(args)
        self.ent_attn = MSA(args)
36
        self.copy_attn = MSA(args, mode="copy")
37
38
39
        self.copy_fc = nn.Linear(args.dec_ninp, 1)
        self.pred_v_fc = nn.Linear(args.dec_ninp, len(args.text_vocab))

40
41
42
    def enc_forward(
        self, batch, ent_mask, ent_text_mask, ent_len, rel_mask, title_mask
    ):
43
44
        title_enc = None
        if self.args.title:
45
46
47
48
49
50
51
52
53
54
55
56
57
            title_enc = self.title_enc(
                self.title_emb(batch["title"]), title_mask
            )
        ent_enc = self.ent_enc(
            self.ent_emb(batch["ent_text"]),
            ent_text_mask,
            ent_len=batch["ent_len"],
        )
        rel_emb = self.rel_emb(batch["rel"])
        g_ent, g_root = self.graph_enc(
            ent_enc, ent_mask, ent_len, rel_emb, rel_mask, batch["graph"]
        )
        return g_ent, g_root, title_enc, ent_enc
58
59

    def forward(self, batch, beam_size=-1):
60
61
62
63
64
65
66
67
68
69
70
71
        ent_mask = len2mask(batch["ent_len"], self.args.device)
        ent_text_mask = batch["ent_text"] == 0
        rel_mask = batch["rel"] == 0  # 0 means the <PAD>
        title_mask = batch["title"] == 0
        g_ent, g_root, title_enc, ent_enc = self.enc_forward(
            batch,
            ent_mask,
            ent_text_mask,
            batch["ent_len"],
            rel_mask,
            title_mask,
        )
72
73
74
75
76
77

        _h, _c = g_root, g_root.clone().detach()
        ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask)
        if self.args.title:
            attn = _h + self.title_attn(_h, title_enc, mask=title_mask)
            ctx = torch.cat([ctx, attn], 1)
78
        if beam_size < 1:
79
80
            # training
            outs = []
81
            tar_inp = self.tar_emb(batch["text"].transpose(0, 1))
82
83
84
85
86
87
88
            for t, xt in enumerate(tar_inp):
                _xt = torch.cat([ctx, xt], 1)
                _h, _c = self.decode_lstm(_xt, (_h, _c))
                ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask)
                if self.args.title:
                    attn = _h + self.title_attn(_h, title_enc, mask=title_mask)
                    ctx = torch.cat([ctx, attn], 1)
89
                outs.append(torch.cat([_h, ctx], 1))
90
91
92
93
            outs = torch.stack(outs, 1)
            copy_gate = torch.sigmoid(self.copy_fc(outs))
            EPSI = 1e-6
            # copy
94
95
96
97
98
99
            pred_v = torch.log(copy_gate + EPSI) + torch.log_softmax(
                self.pred_v_fc(outs), -1
            )
            pred_c = torch.log((1.0 - copy_gate) + EPSI) + torch.log_softmax(
                self.copy_attn(outs, ent_enc, mask=ent_mask), -1
            )
100
101
102
            pred = torch.cat([pred_v, pred_c], -1)
            return pred
        else:
103
            if beam_size == 1:
104
105
106
                # greedy
                device = g_ent.device
                B = g_ent.shape[0]
107
108
109
110
111
112
113
114
115
                ent_type = batch["ent_type"].view(B, -1)
                seq = (
                    torch.ones(
                        B,
                    )
                    .long()
                    .to(device)
                    * self.args.text_vocab("<BOS>")
                ).unsqueeze(1)
116
                for t in range(self.args.beam_max_len):
117
118
119
                    _inp = replace_ent(
                        seq[:, -1], ent_type, len(self.args.text_vocab)
                    )
120
121
122
123
124
                    xt = self.tar_emb(_inp)
                    _xt = torch.cat([ctx, xt], 1)
                    _h, _c = self.decode_lstm(_xt, (_h, _c))
                    ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask)
                    if self.args.title:
125
126
127
                        attn = _h + self.title_attn(
                            _h, title_enc, mask=title_mask
                        )
128
129
130
                        ctx = torch.cat([ctx, attn], 1)
                    _y = torch.cat([_h, ctx], 1)
                    copy_gate = torch.sigmoid(self.copy_fc(_y))
131
132
133
134
135
136
137
138
139
140
141
                    pred_v = torch.log(copy_gate) + torch.log_softmax(
                        self.pred_v_fc(_y), -1
                    )
                    pred_c = torch.log((1.0 - copy_gate)) + torch.log_softmax(
                        self.copy_attn(
                            _y.unsqueeze(1), ent_enc, mask=ent_mask
                        ).squeeze(1),
                        -1,
                    )
                    pred = torch.cat([pred_v, pred_c], -1).view(B, -1)
                    for ban_item in ["<BOS>", "<PAD>", "<UNK>"]:
142
143
144
145
146
147
148
149
150
151
152
                        pred[:, self.args.text_vocab(ban_item)] = -1e8
                    _, word = pred.max(-1)
                    seq = torch.cat([seq, word.unsqueeze(1)], 1)
                return seq
            else:
                # beam search
                device = g_ent.device
                B = g_ent.shape[0]
                BSZ = B * beam_size
                _h = _h.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1)
                _c = _c.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1)
153
154
155
156
157
                ent_mask = (
                    ent_mask.view(B, 1, -1)
                    .repeat(1, beam_size, 1)
                    .view(BSZ, -1)
                )
158
                if self.args.title:
159
160
161
162
163
164
165
166
167
168
                    title_mask = (
                        title_mask.view(B, 1, -1)
                        .repeat(1, beam_size, 1)
                        .view(BSZ, -1)
                    )
                    title_enc = (
                        title_enc.view(B, 1, title_enc.size(1), -1)
                        .repeat(1, beam_size, 1, 1)
                        .view(BSZ, title_enc.size(1), -1)
                    )
169
                ctx = ctx.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1)
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
                ent_type = (
                    batch["ent_type"]
                    .view(B, 1, -1)
                    .repeat(1, beam_size, 1)
                    .view(BSZ, -1)
                )
                g_ent = (
                    g_ent.view(B, 1, g_ent.size(1), -1)
                    .repeat(1, beam_size, 1, 1)
                    .view(BSZ, g_ent.size(1), -1)
                )
                ent_enc = (
                    ent_enc.view(B, 1, ent_enc.size(1), -1)
                    .repeat(1, beam_size, 1, 1)
                    .view(BSZ, ent_enc.size(1), -1)
                )
186
187

                beam_best = torch.zeros(B).to(device) - 1e9
188
189
190
191
192
                beam_best_seq = [None] * B
                beam_seq = (
                    torch.ones(B, beam_size).long().to(device)
                    * self.args.text_vocab("<BOS>")
                ).unsqueeze(-1)
193
194
195
                beam_score = torch.zeros(B, beam_size).to(device)
                done_flag = torch.zeros(B, beam_size)
                for t in range(self.args.beam_max_len):
196
197
198
199
200
                    _inp = replace_ent(
                        beam_seq[:, :, -1].view(-1),
                        ent_type,
                        len(self.args.text_vocab),
                    )
201
202
203
204
205
                    xt = self.tar_emb(_inp)
                    _xt = torch.cat([ctx, xt], 1)
                    _h, _c = self.decode_lstm(_xt, (_h, _c))
                    ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask)
                    if self.args.title:
206
207
208
                        attn = _h + self.title_attn(
                            _h, title_enc, mask=title_mask
                        )
209
210
211
                        ctx = torch.cat([ctx, attn], 1)
                    _y = torch.cat([_h, ctx], 1)
                    copy_gate = torch.sigmoid(self.copy_fc(_y))
212
213
214
215
216
217
218
219
220
221
222
223
224
                    pred_v = torch.log(copy_gate) + torch.log_softmax(
                        self.pred_v_fc(_y), -1
                    )
                    pred_c = torch.log((1.0 - copy_gate)) + torch.log_softmax(
                        self.copy_attn(
                            _y.unsqueeze(1), ent_enc, mask=ent_mask
                        ).squeeze(1),
                        -1,
                    )
                    pred = torch.cat([pred_v, pred_c], -1).view(
                        B, beam_size, -1
                    )
                    for ban_item in ["<BOS>", "<PAD>", "<UNK>"]:
225
                        pred[:, :, self.args.text_vocab(ban_item)] = -1e8
226
227
228
229
230
231
232
233
234
235
236
                    if t == self.args.beam_max_len - 1:  # force ending
                        tt = pred[:, :, self.args.text_vocab("<EOS>")]
                        pred = pred * 0 - 1e8
                        pred[:, :, self.args.text_vocab("<EOS>")] = tt
                    cum_score = beam_score.view(B, beam_size, 1) + pred
                    score, word = cum_score.topk(
                        dim=-1, k=beam_size
                    )  # B, beam_size, beam_size
                    score, word = score.view(B, -1), word.view(B, -1)
                    eos_idx = self.args.text_vocab("<EOS>")
                    if beam_seq.size(2) == 1:
237
                        new_idx = torch.arange(beam_size).to(word)
238
                        new_idx = new_idx[None, :].repeat(B, 1)
239
240
241
242
243
244
245
246
                    else:
                        _, new_idx = score.topk(dim=-1, k=beam_size)
                    new_src, new_score, new_word, new_done = [], [], [], []
                    LP = beam_seq.size(2) ** self.args.lp
                    for i in range(B):
                        for j in range(beam_size):
                            tmp_score = score[i][new_idx[i][j]]
                            tmp_word = word[i][new_idx[i][j]]
247
                            src_idx = new_idx[i][j] // beam_size
248
249
250
251
252
253
254
                            new_src.append(src_idx)
                            if tmp_word == eos_idx:
                                new_score.append(-1e8)
                            else:
                                new_score.append(tmp_score)
                            new_word.append(tmp_word)

255
256
257
258
259
260
                            if (
                                tmp_word == eos_idx
                                and done_flag[i][src_idx] == 0
                                and tmp_score / LP > beam_best[i]
                            ):
                                beam_best[i] = tmp_score / LP
261
262
263
264
265
                                beam_best_seq[i] = beam_seq[i][src_idx]
                            if tmp_word == eos_idx:
                                new_done.append(1)
                            else:
                                new_done.append(done_flag[i][src_idx])
266
267
268
269
270
271
272
273
274
275
276
277
278
279
                    new_score = (
                        torch.Tensor(new_score)
                        .view(B, beam_size)
                        .to(beam_score)
                    )
                    new_word = (
                        torch.Tensor(new_word).view(B, beam_size).to(beam_seq)
                    )
                    new_src = (
                        torch.LongTensor(new_src).view(B, beam_size).to(device)
                    )
                    new_done = (
                        torch.Tensor(new_done).view(B, beam_size).to(done_flag)
                    )
280
281
                    beam_score = new_score
                    done_flag = new_done
282
283
284
                    beam_seq = beam_seq.view(B, beam_size, -1)[
                        torch.arange(B)[:, None].to(device), new_src
                    ]
285
                    beam_seq = torch.cat([beam_seq, new_word.unsqueeze(2)], 2)
286
287
288
289
290
291
292
293
294
                    _h = _h.view(B, beam_size, -1)[
                        torch.arange(B)[:, None].to(device), new_src
                    ].view(BSZ, -1)
                    _c = _c.view(B, beam_size, -1)[
                        torch.arange(B)[:, None].to(device), new_src
                    ].view(BSZ, -1)
                    ctx = ctx.view(B, beam_size, -1)[
                        torch.arange(B)[:, None].to(device), new_src
                    ].view(BSZ, -1)
295
296

                return beam_best_seq