import argparse from scipy.sparse.csgraph import shortest_path import numpy as np import pandas as pd import torch import dgl from ogb.linkproppred import DglLinkPropPredDataset, Evaluator def parse_arguments(): """ Parse arguments """ parser = argparse.ArgumentParser(description='SEAL') parser.add_argument('--dataset', type=str, default='ogbl-collab') parser.add_argument('--gpu_id', type=int, default=0) parser.add_argument('--hop', type=int, default=1) parser.add_argument('--model', type=str, default='dgcnn') parser.add_argument('--gcn_type', type=str, default='gcn') parser.add_argument('--num_layers', type=int, default=3) parser.add_argument('--hidden_units', type=int, default=32) parser.add_argument('--sort_k', type=int, default=30) parser.add_argument('--pooling', type=str, default='sum') parser.add_argument('--dropout', type=str, default=0.5) parser.add_argument('--hits_k', type=int, default=50) parser.add_argument('--lr', type=float, default=0.0001) parser.add_argument('--neg_samples', type=int, default=1) parser.add_argument('--subsample_ratio', type=float, default=0.1) parser.add_argument('--epochs', type=int, default=60) parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('--eval_steps', type=int, default=5) parser.add_argument('--num_workers', type=int, default=32) parser.add_argument('--random_seed', type=int, default=2021) parser.add_argument('--save_dir', type=str, default='./processed') args = parser.parse_args() return args def load_ogb_dataset(dataset): """ Load OGB dataset Args: dataset(str): name of dataset (ogbl-collab, ogbl-ddi, ogbl-citation) Returns: graph(DGLGraph): graph split_edge(dict): split edge """ dataset = DglLinkPropPredDataset(name=dataset) split_edge = dataset.get_edge_split() graph = dataset[0] return graph, split_edge def drnl_node_labeling(subgraph, src, dst): """ Double Radius Node labeling d = r(i,u)+r(i,v) label = 1+ min(r(i,u),r(i,v))+ (d//2)*(d//2+d%2-1) Isolated nodes in subgraph will be set as zero. Extreme large graph may cause memory error. Args: subgraph(DGLGraph): The graph src(int): node id of one of src node in new subgraph dst(int): node id of one of dst node in new subgraph Returns: z(Tensor): node labeling tensor """ adj = subgraph.adj().to_dense().numpy() src, dst = (dst, src) if src > dst else (src, dst) idx = list(range(src)) + list(range(src + 1, adj.shape[0])) adj_wo_src = adj[idx, :][:, idx] idx = list(range(dst)) + list(range(dst + 1, adj.shape[0])) adj_wo_dst = adj[idx, :][:, idx] dist2src = shortest_path(adj_wo_dst, directed=False, unweighted=True, indices=src) dist2src = np.insert(dist2src, dst, 0, axis=0) dist2src = torch.from_numpy(dist2src) dist2dst = shortest_path(adj_wo_src, directed=False, unweighted=True, indices=dst - 1) dist2dst = np.insert(dist2dst, src, 0, axis=0) dist2dst = torch.from_numpy(dist2dst) dist = dist2src + dist2dst dist_over_2, dist_mod_2 = dist // 2, dist % 2 z = 1 + torch.min(dist2src, dist2dst) z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1) z[src] = 1. z[dst] = 1. z[torch.isnan(z)] = 0. return z.to(torch.long) def evaluate_hits(name, pos_pred, neg_pred, K): """ Compute hits Args: name(str): name of dataset pos_pred(Tensor): predict value of positive edges neg_pred(Tensor): predict value of negative edges K(int): num of hits Returns: hits(float): score of hits """ evaluator = Evaluator(name) evaluator.K = K hits = evaluator.eval({ 'y_pred_pos': pos_pred, 'y_pred_neg': neg_pred, })[f'hits@{K}'] return hits