linkpred.jinja-py 5.17 KB
Newer Older
Jinjing Zhou's avatar
Jinjing Zhou committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import dgl
from dgl.data import AsLinkPredDataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np

{{ data_import_code }}


{{ edge_model_code }}

{{ node_model_code}}

class Model(nn.Module):
    def __init__(self, node_model, edge_model, neg_sampler, eval_batch_size):
        super().__init__()
        self.node_model = node_model
        self.edge_model = edge_model
        self.neg_sampler = neg_sampler
        self.eval_batch_size = eval_batch_size

    def inference(self, g, x, edges):
        src, dst = edges
        h = self.node_model(g, x)
        eid_dataloader = DataLoader(
            range(
                src.shape[-1]),
            batch_size=self.eval_batch_size) 
        score_list = []
        for eids in eid_dataloader:
            score = self.edge_model(h[src[eids]], h[dst[eids]])
        score_list.append(score)
        return torch.cat(score_list, dim=0)

def calc_hitsk(y_pred_pos, y_pred_neg, k):
    kth_score_in_negative_edges = torch.topk(y_pred_neg.flatten(), k)[0][-1]
    hitsK = (y_pred_pos > kth_score_in_negative_edges).float().mean()
    return hitsK.item()

def train(cfg, pipeline_cfg, device, dataset, model, optimizer, loss_fcn):
    train_g = dataset.train_graph
    train_g = train_g.to(device)
    node_feat = train_g.ndata['feat']
    train_src, train_dst = train_g.edges()
    for epoch in range(pipeline_cfg['num_epochs']):
        model.train()
        eid_dataloader = DataLoader(range(train_g.num_edges()), batch_size = pipeline_cfg["train_batch_size"], shuffle=True)
        for eids in eid_dataloader:
            h = model.node_model(train_g, node_feat)
            eids = eids.to(device)
            src, dst = train_src[eids], train_dst[eids]
            pos_score = model.edge_model(h[src], h[dst])
            neg_src, neg_dst = model.neg_sampler(train_g, eids)
            neg_score = model.edge_model(h[neg_src], h[neg_dst])
            loss = loss_fcn(torch.cat([pos_score, neg_score]),  torch.cat(
                [torch.ones_like(pos_score), torch.zeros_like(neg_score)]))
        
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
        with torch.no_grad():
            model.eval()
            val_neg_edges = dataset.val_edges[1]
            val_neg_score = model.inference(train_g, node_feat, val_neg_edges)
        train_hits = calc_hitsk(pos_score, val_neg_score, k=50)
        print("Epoch {:05d} | Loss {:.4f} | Train Hits@50 {:.4f}".format(epoch, loss, train_hits))

        if epoch != 0 and epoch % pipeline_cfg['eval_period'] == 0:
            with torch.no_grad():
                model.eval()
                val_pos_edge, val_neg_edges = dataset.val_edges
                pos_result = model.inference(train_g, node_feat, val_pos_edge)
                neg_result = model.inference(train_g, node_feat, val_neg_edges)
                val_hits = calc_hitsk(pos_result, neg_result, k=50)
            print("Epoch {:05d} | Val Hits@50 {:.4f}".format(epoch, val_hits))

    with torch.no_grad():
        model.eval()        
        test_pos_edge, test_neg_edges = dataset.test_edges
        pos_result = model.inference(train_g, node_feat, test_pos_edge)
        neg_result = model.inference(train_g, node_feat, test_neg_edges)
        test_hits = calc_hitsk(pos_result, neg_result, k=50)
        print("Test Hits@50 {:.4f}".format(test_hits))
    return test_hits


def main():
    {{user_cfg_str}}
    device = cfg['device']
    pipeline_cfg = cfg['general_pipeline']
    dataset = AsLinkPredDataset({{ data_initialize_code }})
    if 'feat' not in dataset[0].ndata:
        assert cfg["node_model"]["embed_size"] > 0, "Need to specify embed size if graph doesn't have feat in ndata"
    cfg["node_model"]["data_info"] = {
        "in_size": cfg["node_model"]['embed_size'] if cfg["node_model"]['embed_size'] > 0 else dataset[0].ndata['feat'].shape[1],
        "out_size": pipeline_cfg['hidden_size'],
        "num_nodes": dataset[0].num_nodes()
        }
    cfg["edge_model"]["data_info"] = {
        "in_size": pipeline_cfg['hidden_size'],
        "out_size": 1 # output each edge score
    }
    node_model = {{node_model_class_name}}(**cfg["node_model"])
    edge_model = {{edge_model_class_name}}(**cfg["edge_model"])
    neg_sampler = dgl.dataloading.negative_sampler.{{ neg_sampler_name }}(**cfg["neg_sampler"])
    model = Model(node_model, edge_model, neg_sampler, pipeline_cfg["eval_batch_size"])
    model = model.to(device)
    params = model.parameters()
    loss = torch.nn.{{ loss }}()
    optimizer = torch.optim.Adam(params, **pipeline_cfg["optimizer"])
    test_hits = train(cfg, pipeline_cfg, device, dataset, model, optimizer, loss)
    return test_hits

if __name__ == '__main__':
    all_acc = []
    num_runs = {{ user_cfg.general_pipeline.num_runs }}
    for run in range(num_runs):
        print(f'Run experiment #{run}')
        test_acc = main()
        print("Test Hits@50 {:.4f}".format(test_acc))
        all_acc.append(test_acc)
    avg_acc = np.round(np.mean(all_acc), 6)
    std_acc = np.round(np.std(all_acc), 6)
    print(f'Test Hits@50 across {num_runs} runs: {avg_acc} ± {std_acc}')