import os import warnings import numpy as np import torch import torch.nn as nn from model import PGNN from sklearn.metrics import roc_auc_score from utils import get_dataset, preselect_anchor import dgl warnings.filterwarnings("ignore") def get_loss(p, data, out, loss_func, device, get_auc=True): edge_mask = np.concatenate( ( data["positive_edges_{}".format(p)], data["negative_edges_{}".format(p)], ), axis=-1, ) nodes_first = torch.index_select( out, 0, torch.from_numpy(edge_mask[0, :]).long().to(out.device) ) nodes_second = torch.index_select( out, 0, torch.from_numpy(edge_mask[1, :]).long().to(out.device) ) pred = torch.sum(nodes_first * nodes_second, dim=-1) label_positive = torch.ones( [ data["positive_edges_{}".format(p)].shape[1], ], dtype=pred.dtype, ) label_negative = torch.zeros( [ data["negative_edges_{}".format(p)].shape[1], ], dtype=pred.dtype, ) label = torch.cat((label_positive, label_negative)).to(device) loss = loss_func(pred, label) if get_auc: auc = roc_auc_score( label.flatten().cpu().numpy(), torch.sigmoid(pred).flatten().data.cpu().numpy(), ) return loss, auc else: return loss def train_model(data, model, loss_func, optimizer, device, g_data): model.train() out = model(g_data) loss = get_loss("train", data, out, loss_func, device, get_auc=False) optimizer.zero_grad() loss.backward() optimizer.step() optimizer.zero_grad() return g_data def eval_model(data, g_data, model, loss_func, device): model.eval() out = model(g_data) # train loss and auc tmp_loss, auc_train = get_loss("train", data, out, loss_func, device) loss_train = tmp_loss.cpu().data.numpy() # val loss and auc _, auc_val = get_loss("val", data, out, loss_func, device) # test loss and auc _, auc_test = get_loss("test", data, out, loss_func, device) return loss_train, auc_train, auc_val, auc_test def main(args): # The mean and standard deviation of the experiment results # are stored in the 'results' folder if not os.path.isdir("results"): os.mkdir("results") if torch.cuda.is_available(): device = "cuda:0" else: device = "cpu" print( "Learning Type: {}".format( ["Transductive", "Inductive"][args.inductive] ), "Task: {}".format(args.task), ) results = [] for repeat in range(args.repeat_num): data = get_dataset(args) # pre-sample anchor nodes and compute shortest distance values for all epochs ( g_list, anchor_eid_list, dist_max_list, edge_weight_list, ) = preselect_anchor(data, args) # model model = PGNN(input_dim=data["feature"].shape[1]).to(device) # loss optimizer = torch.optim.Adam( model.parameters(), lr=1e-2, weight_decay=5e-4 ) loss_func = nn.BCEWithLogitsLoss() best_auc_val = -1 best_auc_test = -1 for epoch in range(args.epoch_num): if epoch == 200: for param_group in optimizer.param_groups: param_group["lr"] /= 10 g = dgl.graph(g_list[epoch]) g.ndata["feat"] = torch.FloatTensor(data["feature"]) g.edata["sp_dist"] = torch.FloatTensor(edge_weight_list[epoch]) g_data = { "graph": g.to(device), "anchor_eid": anchor_eid_list[epoch], "dists_max": dist_max_list[epoch], } train_model(data, model, loss_func, optimizer, device, g_data) loss_train, auc_train, auc_val, auc_test = eval_model( data, g_data, model, loss_func, device ) if auc_val > best_auc_val: best_auc_val = auc_val best_auc_test = auc_test if epoch % args.epoch_log == 0: print( repeat, epoch, "Loss {:.4f}".format(loss_train), "Train AUC: {:.4f}".format(auc_train), "Val AUC: {:.4f}".format(auc_val), "Test AUC: {:.4f}".format(auc_test), "Best Val AUC: {:.4f}".format(best_auc_val), "Best Test AUC: {:.4f}".format(best_auc_test), ) results.append(best_auc_test) results = np.array(results) results_mean = np.mean(results).round(6) results_std = np.std(results).round(6) print("-----------------Final-------------------") print(results_mean, results_std) with open( "results/{}_{}_{}.txt".format( ["Transductive", "Inductive"][args.inductive], args.task, args.k_hop_dist, ), "w", ) as f: f.write("{}, {}\n".format(results_mean, results_std)) if __name__ == "__main__": from argparse import ArgumentParser parser = ArgumentParser() parser.add_argument( "--task", type=str, default="link", choices=["link", "link_pair"] ) parser.add_argument( "--inductive", action="store_true", help="Inductive learning or transductive learning", ) parser.add_argument( "--k_hop_dist", default=-1, type=int, help="K-hop shortest path distance, -1 means exact shortest path.", ) parser.add_argument("--epoch_num", type=int, default=2000) parser.add_argument("--repeat_num", type=int, default=10) parser.add_argument("--epoch_log", type=int, default=100) args = parser.parse_args() main(args)