utils.py 3.88 KB
Newer Older
Smile's avatar
Smile committed
1
import argparse
2

Smile's avatar
Smile committed
3
4
5
6
import numpy as np
import pandas as pd
import torch
from ogb.linkproppred import DglLinkPropPredDataset, Evaluator
7
8
9
from scipy.sparse.csgraph import shortest_path

import dgl
Smile's avatar
Smile committed
10
11
12
13
14
15


def parse_arguments():
    """
    Parse arguments
    """
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
    parser = argparse.ArgumentParser(description="SEAL")
    parser.add_argument("--dataset", type=str, default="ogbl-collab")
    parser.add_argument("--gpu_id", type=int, default=0)
    parser.add_argument("--hop", type=int, default=1)
    parser.add_argument("--model", type=str, default="dgcnn")
    parser.add_argument("--gcn_type", type=str, default="gcn")
    parser.add_argument("--num_layers", type=int, default=3)
    parser.add_argument("--hidden_units", type=int, default=32)
    parser.add_argument("--sort_k", type=int, default=30)
    parser.add_argument("--pooling", type=str, default="sum")
    parser.add_argument("--dropout", type=str, default=0.5)
    parser.add_argument("--hits_k", type=int, default=50)
    parser.add_argument("--lr", type=float, default=0.0001)
    parser.add_argument("--neg_samples", type=int, default=1)
    parser.add_argument("--subsample_ratio", type=float, default=0.1)
    parser.add_argument("--epochs", type=int, default=60)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--eval_steps", type=int, default=5)
    parser.add_argument("--num_workers", type=int, default=32)
    parser.add_argument("--random_seed", type=int, default=2021)
    parser.add_argument("--save_dir", type=str, default="./processed")
Smile's avatar
Smile committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    args = parser.parse_args()

    return args


def load_ogb_dataset(dataset):
    """
    Load OGB dataset
    Args:
        dataset(str): name of dataset (ogbl-collab, ogbl-ddi, ogbl-citation)

    Returns:
        graph(DGLGraph): graph
        split_edge(dict): split edge

    """
    dataset = DglLinkPropPredDataset(name=dataset)
    split_edge = dataset.get_edge_split()
    graph = dataset[0]

    return graph, split_edge


def drnl_node_labeling(subgraph, src, dst):
    """
    Double Radius Node labeling
    d = r(i,u)+r(i,v)
    label = 1+ min(r(i,u),r(i,v))+ (d//2)*(d//2+d%2-1)
    Isolated nodes in subgraph will be set as zero.
    Extreme large graph may cause memory error.

    Args:
        subgraph(DGLGraph): The graph
        src(int): node id of one of src node in new subgraph
        dst(int): node id of one of dst node in new subgraph
    Returns:
        z(Tensor): node labeling tensor
    """
    adj = subgraph.adj().to_dense().numpy()
    src, dst = (dst, src) if src > dst else (src, dst)

    idx = list(range(src)) + list(range(src + 1, adj.shape[0]))
    adj_wo_src = adj[idx, :][:, idx]

    idx = list(range(dst)) + list(range(dst + 1, adj.shape[0]))
    adj_wo_dst = adj[idx, :][:, idx]

84
85
86
    dist2src = shortest_path(
        adj_wo_dst, directed=False, unweighted=True, indices=src
    )
Smile's avatar
Smile committed
87
88
89
    dist2src = np.insert(dist2src, dst, 0, axis=0)
    dist2src = torch.from_numpy(dist2src)

90
91
92
    dist2dst = shortest_path(
        adj_wo_src, directed=False, unweighted=True, indices=dst - 1
    )
Smile's avatar
Smile committed
93
94
95
96
97
98
99
100
    dist2dst = np.insert(dist2dst, src, 0, axis=0)
    dist2dst = torch.from_numpy(dist2dst)

    dist = dist2src + dist2dst
    dist_over_2, dist_mod_2 = dist // 2, dist % 2

    z = 1 + torch.min(dist2src, dist2dst)
    z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1)
101
102
103
    z[src] = 1.0
    z[dst] = 1.0
    z[torch.isnan(z)] = 0.0
Smile's avatar
Smile committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123

    return z.to(torch.long)


def evaluate_hits(name, pos_pred, neg_pred, K):
    """
    Compute hits
    Args:
        name(str): name of dataset
        pos_pred(Tensor): predict value of positive edges
        neg_pred(Tensor): predict value of negative edges
        K(int): num of hits

    Returns:
        hits(float): score of hits


    """
    evaluator = Evaluator(name)
    evaluator.K = K
124
125
126
127
128
129
    hits = evaluator.eval(
        {
            "y_pred_pos": pos_pred,
            "y_pred_neg": neg_pred,
        }
    )[f"hits@{K}"]
Smile's avatar
Smile committed
130
131

    return hits