main.py 14.1 KB
Newer Older
Linfang He's avatar
Linfang He committed
1
2
3
4
import math
import os
import sys
import time
5
from collections import defaultdict
Linfang He's avatar
Linfang He committed
6
7
8
9
10
11
12

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from numpy import random
from torch.nn.parameter import Parameter
13
14
15
from tqdm.auto import tqdm
from utils import *

Linfang He's avatar
Linfang He committed
16
17
18
19
20
import dgl
import dgl.function as fn


def get_graph(network_data, vocab):
21
    """Build graph, treat all nodes as the same type
22

Linfang He's avatar
Linfang He committed
23
24
25
26
27
28
29
30
    Parameters
    ----------
    network_data: a dict
        keys describing the edge types, values representing edges
    vocab: a dict
        mapping node IDs to node indices
    Output
    ------
peizhou001's avatar
peizhou001 committed
31
    DGLGraph
Linfang He's avatar
Linfang He committed
32
        a heterogenous graph, with one node type and different edge types
33
    """
Linfang He's avatar
Linfang He committed
34
    graphs = []
35

36
    node_type = "_N"  # '_N' can be replaced by an arbitrary name
37
38
    data_dict = dict()
    num_nodes_dict = {node_type: len(vocab)}
Linfang He's avatar
Linfang He committed
39
40
41

    for edge_type in network_data:
        tmp_data = network_data[edge_type]
Mufei Li's avatar
Mufei Li committed
42
43
        src = []
        dst = []
Linfang He's avatar
Linfang He committed
44
        for edge in tmp_data:
Mufei Li's avatar
Mufei Li committed
45
46
47
            src.extend([vocab[edge[0]], vocab[edge[1]]])
            dst.extend([vocab[edge[1]], vocab[edge[0]]])
        data_dict[(node_type, edge_type, node_type)] = (src, dst)
48
    graph = dgl.heterograph(data_dict, num_nodes_dict)
49

Linfang He's avatar
Linfang He committed
50
51
52
53
54
55
56
    return graph


class NeighborSampler(object):
    def __init__(self, g, num_fanouts):
        self.g = g
        self.num_fanouts = num_fanouts
57

Linfang He's avatar
Linfang He committed
58
59
    def sample(self, pairs):
        heads, tails, types = zip(*pairs)
60
61
62
        seeds, head_invmap = torch.unique(
            torch.LongTensor(heads), return_inverse=True
        )
Linfang He's avatar
Linfang He committed
63
64
65
66
67
68
        blocks = []
        for fanout in reversed(self.num_fanouts):
            sampled_graph = dgl.sampling.sample_neighbors(self.g, seeds, fanout)
            sampled_block = dgl.to_block(sampled_graph, seeds)
            seeds = sampled_block.srcdata[dgl.NID]
            blocks.insert(0, sampled_block)
69
70
71
72
73
74
        return (
            blocks,
            torch.LongTensor(head_invmap),
            torch.LongTensor(tails),
            torch.LongTensor(types),
        )
Linfang He's avatar
Linfang He committed
75
76
77


class DGLGATNE(nn.Module):
78
79
80
81
82
83
84
85
86
    def __init__(
        self,
        num_nodes,
        embedding_size,
        embedding_u_size,
        edge_types,
        edge_type_count,
        dim_a,
    ):
Linfang He's avatar
Linfang He committed
87
88
89
90
91
92
93
94
        super(DGLGATNE, self).__init__()
        self.num_nodes = num_nodes
        self.embedding_size = embedding_size
        self.embedding_u_size = embedding_u_size
        self.edge_types = edge_types
        self.edge_type_count = edge_type_count
        self.dim_a = dim_a

95
96
97
        self.node_embeddings = Parameter(
            torch.FloatTensor(num_nodes, embedding_size)
        )
98
99
100
101
102
103
104
105
106
        self.node_type_embeddings = Parameter(
            torch.FloatTensor(num_nodes, edge_type_count, embedding_u_size)
        )
        self.trans_weights = Parameter(
            torch.FloatTensor(edge_type_count, embedding_u_size, embedding_size)
        )
        self.trans_weights_s1 = Parameter(
            torch.FloatTensor(edge_type_count, embedding_u_size, dim_a)
        )
107
108
109
        self.trans_weights_s2 = Parameter(
            torch.FloatTensor(edge_type_count, dim_a, 1)
        )
Linfang He's avatar
Linfang He committed
110
111
112
113
114
115

        self.reset_parameters()

    def reset_parameters(self):
        self.node_embeddings.data.uniform_(-1.0, 1.0)
        self.node_type_embeddings.data.uniform_(-1.0, 1.0)
116
117
118
119
120
121
122
123
124
        self.trans_weights.data.normal_(
            std=1.0 / math.sqrt(self.embedding_size)
        )
        self.trans_weights_s1.data.normal_(
            std=1.0 / math.sqrt(self.embedding_size)
        )
        self.trans_weights_s2.data.normal_(
            std=1.0 / math.sqrt(self.embedding_size)
        )
Linfang He's avatar
Linfang He committed
125
126
127
128
129
130
131
132
133
134
135
136

    # embs: [batch_size, embedding_size]
    def forward(self, block):
        input_nodes = block.srcdata[dgl.NID]
        output_nodes = block.dstdata[dgl.NID]
        batch_size = block.number_of_dst_nodes()
        node_embed = self.node_embeddings
        node_type_embed = []

        with block.local_scope():
            for i in range(self.edge_type_count):
                edge_type = self.edge_types[i]
137
138
139
140
141
142
                block.srcdata[edge_type] = self.node_type_embeddings[
                    input_nodes, i
                ]
                block.dstdata[edge_type] = self.node_type_embeddings[
                    output_nodes, i
                ]
143
                block.update_all(
144
145
146
                    fn.copy_u(edge_type, "m"),
                    fn.sum("m", edge_type),
                    etype=edge_type,
147
                )
Linfang He's avatar
Linfang He committed
148
                node_type_embed.append(block.dstdata[edge_type])
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172

            node_type_embed = torch.stack(node_type_embed, 1)
            tmp_node_type_embed = node_type_embed.unsqueeze(2).view(
                -1, 1, self.embedding_u_size
            )
            trans_w = (
                self.trans_weights.unsqueeze(0)
                .repeat(batch_size, 1, 1, 1)
                .view(-1, self.embedding_u_size, self.embedding_size)
            )
            trans_w_s1 = (
                self.trans_weights_s1.unsqueeze(0)
                .repeat(batch_size, 1, 1, 1)
                .view(-1, self.embedding_u_size, self.dim_a)
            )
            trans_w_s2 = (
                self.trans_weights_s2.unsqueeze(0)
                .repeat(batch_size, 1, 1, 1)
                .view(-1, self.dim_a, 1)
            )

            attention = (
                F.softmax(
                    torch.matmul(
173
174
175
                        torch.tanh(
                            torch.matmul(tmp_node_type_embed, trans_w_s1)
                        ),
176
177
178
179
180
181
182
183
184
185
186
187
                        trans_w_s2,
                    )
                    .squeeze(2)
                    .view(-1, self.edge_type_count),
                    dim=1,
                )
                .unsqueeze(1)
                .repeat(1, self.edge_type_count, 1)
            )

            node_type_embed = torch.matmul(attention, node_type_embed).view(
                -1, 1, self.embedding_u_size
Linfang He's avatar
Linfang He committed
188
            )
189
190
191
192
            node_embed = node_embed[output_nodes].unsqueeze(1).repeat(
                1, self.edge_type_count, 1
            ) + torch.matmul(node_type_embed, trans_w).view(
                -1, self.edge_type_count, self.embedding_size
Linfang He's avatar
Linfang He committed
193
194
            )
            last_node_embed = F.normalize(node_embed, dim=2)
195

196
197
198
            return (
                last_node_embed  # [batch_size, edge_type_count, embedding_size]
            )
Linfang He's avatar
Linfang He committed
199
200
201
202
203


class NSLoss(nn.Module):
    def __init__(self, num_nodes, num_sampled, embedding_size):
        super(NSLoss, self).__init__()
204
205
206
207
        self.num_nodes = num_nodes
        self.num_sampled = num_sampled
        self.embedding_size = embedding_size
        self.weights = Parameter(torch.FloatTensor(num_nodes, embedding_size))
Linfang He's avatar
Linfang He committed
208
209
210
211
        # [ (log(i+2) - log(i+1)) / log(num_nodes + 1)]
        self.sample_weights = F.normalize(
            torch.Tensor(
                [
212
213
                    (math.log(k + 2) - math.log(k + 1))
                    / math.log(num_nodes + 1)
Linfang He's avatar
Linfang He committed
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
                    for k in range(num_nodes)
                ]
            ),
            dim=0,
        )

        self.reset_parameters()

    def reset_parameters(self):
        self.weights.data.normal_(std=1.0 / math.sqrt(self.embedding_size))

    def forward(self, input, embs, label):
        n = input.shape[0]
        log_target = torch.log(
            torch.sigmoid(torch.sum(torch.mul(embs, self.weights[label]), 1))
        )
        negs = torch.multinomial(
            self.sample_weights, self.num_sampled * n, replacement=True
        ).view(n, self.num_sampled)
        noise = torch.neg(self.weights[negs])
        sum_log_sampled = torch.sum(
            torch.log(torch.sigmoid(torch.bmm(noise, embs.unsqueeze(2)))), 1
        ).squeeze()

        loss = log_target + sum_log_sampled
        return -loss.sum() / n


def train_model(network_data):
    index2word, vocab, type_nodes = generate_vocab(network_data)

245
246
247
248
249
250
251
252
253
254
    edge_types = list(network_data.keys())
    num_nodes = len(index2word)
    edge_type_count = len(edge_types)
    epochs = args.epoch
    batch_size = args.batch_size
    embedding_size = args.dimensions
    embedding_u_size = args.edge_dim
    u_num = edge_type_count
    num_sampled = args.negative_samples
    dim_a = args.att_dim
Linfang He's avatar
Linfang He committed
255
256
    att_head = 1
    neighbor_samples = args.neighbor_samples
257
    num_workers = args.workers
Linfang He's avatar
Linfang He committed
258

259
260
261
    device = torch.device(
        "cuda" if args.gpu is not None and torch.cuda.is_available() else "cpu"
    )
Linfang He's avatar
Linfang He committed
262
263
264
265
266

    g = get_graph(network_data, vocab)
    all_walks = []
    for i in range(edge_type_count):
        nodes = torch.LongTensor(type_nodes[i] * args.num_walks)
267
268
269
        traces, types = dgl.sampling.random_walk(
            g, nodes, metapath=[edge_types[i]] * (neighbor_samples - 1)
        )
Linfang He's avatar
Linfang He committed
270
271
        all_walks.append(traces)

272
    train_pairs = generate_pairs(all_walks, args.window_size, num_workers)
Linfang He's avatar
Linfang He committed
273
274
    neighbor_sampler = NeighborSampler(g, [neighbor_samples])
    train_dataloader = torch.utils.data.DataLoader(
275
276
277
278
279
280
281
282
        train_pairs,
        batch_size=batch_size,
        collate_fn=neighbor_sampler.sample,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
    )
    model = DGLGATNE(
283
284
285
286
287
288
        num_nodes,
        embedding_size,
        embedding_u_size,
        edge_types,
        edge_type_count,
        dim_a,
Linfang He's avatar
Linfang He committed
289
290
291
292
293
    )
    nsloss = NSLoss(num_nodes, num_sampled, embedding_size)
    model.to(device)
    nsloss.to(device)

294
    optimizer = torch.optim.Adam(
295
296
        [{"params": model.parameters()}, {"params": nsloss.parameters()}],
        lr=1e-3,
297
    )
Linfang He's avatar
Linfang He committed
298
299
300
301
302
303
304

    best_score = 0
    patience = 0
    for epoch in range(epochs):
        model.train()
        random.shuffle(train_pairs)

305
        data_iter = tqdm(
Linfang He's avatar
Linfang He committed
306
307
308
309
310
311
312
313
314
315
316
            train_dataloader,
            desc="epoch %d" % (epoch),
            total=(len(train_pairs) + (batch_size - 1)) // batch_size,
        )
        avg_loss = 0.0

        for i, (block, head_invmap, tails, block_types) in enumerate(data_iter):
            optimizer.zero_grad()
            # embs: [batch_size, edge_type_count, embedding_size]
            block_types = block_types.to(device)
            embs = model(block[0].to(device))[head_invmap]
317
            embs = embs.gather(
318
319
320
321
                1,
                block_types.view(-1, 1, 1).expand(
                    embs.shape[0], 1, embs.shape[2]
                ),
322
323
324
325
326
327
            )[:, 0]
            loss = nsloss(
                block[0].dstdata[dgl.NID][head_invmap].to(device),
                embs,
                tails.to(device),
            )
Linfang He's avatar
Linfang He committed
328
329
330
331
            loss.backward()
            optimizer.step()
            avg_loss += loss.item()

332
333
334
335
336
337
338
            post_fix = {
                "epoch": epoch,
                "iter": i,
                "avg_loss": avg_loss / (i + 1),
                "loss": loss.item(),
            }
            data_iter.set_postfix(post_fix)
Linfang He's avatar
Linfang He committed
339
340
341

        model.eval()
        # {'1': {}, '2': {}}
342
343
344
        final_model = dict(
            zip(edge_types, [dict() for _ in range(edge_type_count)])
        )
Linfang He's avatar
Linfang He committed
345
        for i in range(num_nodes):
346
347
348
349
350
351
            train_inputs = (
                torch.tensor([i for _ in range(edge_type_count)])
                .unsqueeze(1)
                .to(device)
            )  # [i, i]
            train_types = (
352
353
354
                torch.tensor(list(range(edge_type_count)))
                .unsqueeze(1)
                .to(device)
355
356
357
358
359
360
361
362
363
364
            )  # [0, 1]
            pairs = torch.cat(
                (train_inputs, train_inputs, train_types), dim=1
            )  # (2, 3)
            (
                train_blocks,
                train_invmap,
                fake_tails,
                train_types,
            ) = neighbor_sampler.sample(pairs)
Linfang He's avatar
Linfang He committed
365
366

            node_emb = model(train_blocks[0].to(device))[train_invmap]
367
368
369
370
371
            node_emb = node_emb.gather(
                1,
                train_types.to(device)
                .view(-1, 1, 1)
                .expand(node_emb.shape[0], 1, node_emb.shape[2]),
Linfang He's avatar
Linfang He committed
372
373
374
375
376
377
378
379
380
381
            )[:, 0]

            for j in range(edge_type_count):
                final_model[edge_types[j]][index2word[i]] = (
                    node_emb[j].cpu().detach().numpy()
                )

        valid_aucs, valid_f1s, valid_prs = [], [], []
        test_aucs, test_f1s, test_prs = [], [], []
        for i in range(edge_type_count):
382
383
384
            if args.eval_type == "all" or edge_types[i] in args.eval_type.split(
                ","
            ):
Linfang He's avatar
Linfang He committed
385
386
387
388
                tmp_auc, tmp_f1, tmp_pr = evaluate(
                    final_model[edge_types[i]],
                    valid_true_data_by_edge[edge_types[i]],
                    valid_false_data_by_edge[edge_types[i]],
389
                    num_workers,
Linfang He's avatar
Linfang He committed
390
391
392
393
394
395
396
397
398
                )
                valid_aucs.append(tmp_auc)
                valid_f1s.append(tmp_f1)
                valid_prs.append(tmp_pr)

                tmp_auc, tmp_f1, tmp_pr = evaluate(
                    final_model[edge_types[i]],
                    testing_true_data_by_edge[edge_types[i]],
                    testing_false_data_by_edge[edge_types[i]],
399
                    num_workers,
Linfang He's avatar
Linfang He committed
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
                )
                test_aucs.append(tmp_auc)
                test_f1s.append(tmp_f1)
                test_prs.append(tmp_pr)
        print("valid auc:", np.mean(valid_aucs))
        print("valid pr:", np.mean(valid_prs))
        print("valid f1:", np.mean(valid_f1s))

        average_auc = np.mean(test_aucs)
        average_f1 = np.mean(test_f1s)
        average_pr = np.mean(test_prs)

        cur_score = np.mean(valid_aucs)
        if cur_score > best_score:
            best_score = cur_score
            patience = 0
        else:
            patience += 1
            if patience > args.patience:
                print("Early Stopping")
                break
    return average_auc, average_f1, average_pr


if __name__ == "__main__":
    args = parse_args()
    file_name = args.input
    print(args)

    training_data_by_type = load_training_data(file_name + "/train.txt")
    valid_true_data_by_edge, valid_false_data_by_edge = load_testing_data(
        file_name + "/valid.txt"
    )
    testing_true_data_by_edge, testing_false_data_by_edge = load_testing_data(
        file_name + "/test.txt"
    )
    start = time.time()
    average_auc, average_f1, average_pr = train_model(training_data_by_type)
    end = time.time()

    print("Overall ROC-AUC:", average_auc)
    print("Overall PR-AUC", average_pr)
    print("Overall F1:", average_f1)
443
    print("Training Time", end - start)