main.py 5.74 KB
Newer Older
1
import os
2
3
import warnings

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
4
5
import dgl

6
import numpy as np
7
import torch
8
9
10
11
12
import torch.nn as nn
from model import PGNN
from sklearn.metrics import roc_auc_score
from utils import get_dataset, preselect_anchor

13
warnings.filterwarnings("ignore")
14
15


16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def get_loss(p, data, out, loss_func, device, get_auc=True):
    edge_mask = np.concatenate(
        (
            data["positive_edges_{}".format(p)],
            data["negative_edges_{}".format(p)],
        ),
        axis=-1,
    )

    nodes_first = torch.index_select(
        out, 0, torch.from_numpy(edge_mask[0, :]).long().to(out.device)
    )
    nodes_second = torch.index_select(
        out, 0, torch.from_numpy(edge_mask[1, :]).long().to(out.device)
    )
31
32
33

    pred = torch.sum(nodes_first * nodes_second, dim=-1)

34
35
36
37
38
39
40
41
42
43
44
45
    label_positive = torch.ones(
        [
            data["positive_edges_{}".format(p)].shape[1],
        ],
        dtype=pred.dtype,
    )
    label_negative = torch.zeros(
        [
            data["negative_edges_{}".format(p)].shape[1],
        ],
        dtype=pred.dtype,
    )
46
47
48
49
    label = torch.cat((label_positive, label_negative)).to(device)
    loss = loss_func(pred, label)

    if get_auc:
50
51
52
53
        auc = roc_auc_score(
            label.flatten().cpu().numpy(),
            torch.sigmoid(pred).flatten().data.cpu().numpy(),
        )
54
55
56
57
        return loss, auc
    else:
        return loss

58

59
60
61
62
def train_model(data, model, loss_func, optimizer, device, g_data):
    model.train()
    out = model(g_data)

63
    loss = get_loss("train", data, out, loss_func, device, get_auc=False)
64
65
66
67
68
69
70
71

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    return g_data

72

73
74
75
76
77
def eval_model(data, g_data, model, loss_func, device):
    model.eval()
    out = model(g_data)

    # train loss and auc
78
    tmp_loss, auc_train = get_loss("train", data, out, loss_func, device)
79
80
81
    loss_train = tmp_loss.cpu().data.numpy()

    # val loss and auc
82
    _, auc_val = get_loss("val", data, out, loss_func, device)
83
84

    # test loss and auc
85
    _, auc_test = get_loss("test", data, out, loss_func, device)
86
87
88

    return loss_train, auc_train, auc_val, auc_test

89

90
91
92
def main(args):
    # The mean and standard deviation of the experiment results
    # are stored in the 'results' folder
93
94
    if not os.path.isdir("results"):
        os.mkdir("results")
95
96

    if torch.cuda.is_available():
97
        device = "cuda:0"
98
    else:
99
        device = "cpu"
100

101
102
103
104
105
106
    print(
        "Learning Type: {}".format(
            ["Transductive", "Inductive"][args.inductive]
        ),
        "Task: {}".format(args.task),
    )
107
108
109
110
111
112
113

    results = []

    for repeat in range(args.repeat_num):
        data = get_dataset(args)

        # pre-sample anchor nodes and compute shortest distance values for all epochs
114
115
116
117
118
119
        (
            g_list,
            anchor_eid_list,
            dist_max_list,
            edge_weight_list,
        ) = preselect_anchor(data, args)
120
121

        # model
122
        model = PGNN(input_dim=data["feature"].shape[1]).to(device)
123
124

        # loss
125
126
127
        optimizer = torch.optim.Adam(
            model.parameters(), lr=1e-2, weight_decay=5e-4
        )
128
129
130
131
132
133
134
135
        loss_func = nn.BCEWithLogitsLoss()

        best_auc_val = -1
        best_auc_test = -1

        for epoch in range(args.epoch_num):
            if epoch == 200:
                for param_group in optimizer.param_groups:
136
                    param_group["lr"] /= 10
137
138

            g = dgl.graph(g_list[epoch])
139
140
            g.ndata["feat"] = torch.FloatTensor(data["feature"])
            g.edata["sp_dist"] = torch.FloatTensor(edge_weight_list[epoch])
141
            g_data = {
142
143
144
                "graph": g.to(device),
                "anchor_eid": anchor_eid_list[epoch],
                "dists_max": dist_max_list[epoch],
145
146
147
148
149
            }

            train_model(data, model, loss_func, optimizer, device, g_data)

            loss_train, auc_train, auc_val, auc_test = eval_model(
150
151
                data, g_data, model, loss_func, device
            )
152
153
154
155
156
            if auc_val > best_auc_val:
                best_auc_val = auc_val
                best_auc_test = auc_test

            if epoch % args.epoch_log == 0:
157
158
159
160
161
162
163
164
165
166
                print(
                    repeat,
                    epoch,
                    "Loss {:.4f}".format(loss_train),
                    "Train AUC: {:.4f}".format(auc_train),
                    "Val AUC: {:.4f}".format(auc_val),
                    "Test AUC: {:.4f}".format(auc_test),
                    "Best Val AUC: {:.4f}".format(best_auc_val),
                    "Best Test AUC: {:.4f}".format(best_auc_test),
                )
167
168
169
170
171
172

        results.append(best_auc_test)

    results = np.array(results)
    results_mean = np.mean(results).round(6)
    results_std = np.std(results).round(6)
173
    print("-----------------Final-------------------")
174
175
    print(results_mean, results_std)

176
177
178
179
180
181
182
183
184
185
    with open(
        "results/{}_{}_{}.txt".format(
            ["Transductive", "Inductive"][args.inductive],
            args.task,
            args.k_hop_dist,
        ),
        "w",
    ) as f:
        f.write("{}, {}\n".format(results_mean, results_std))

186

187
if __name__ == "__main__":
188
189
190
    from argparse import ArgumentParser

    parser = ArgumentParser()
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    parser.add_argument(
        "--task", type=str, default="link", choices=["link", "link_pair"]
    )
    parser.add_argument(
        "--inductive",
        action="store_true",
        help="Inductive learning or transductive learning",
    )
    parser.add_argument(
        "--k_hop_dist",
        default=-1,
        type=int,
        help="K-hop shortest path distance, -1 means exact shortest path.",
    )

    parser.add_argument("--epoch_num", type=int, default=2000)
    parser.add_argument("--repeat_num", type=int, default=10)
    parser.add_argument("--epoch_log", type=int, default=100)
209
210
211

    args = parser.parse_args()
    main(args)