utils.py 2.04 KB
Newer Older
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1
import argparse
2

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
3
from ogb.linkproppred import *
4
5
6
from ogb.nodeproppred import *

from dgl.data import CitationGraphDataset
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
7
8
9


def load_graph(name):
10
    cite_graphs = ["cora", "citeseer", "pubmed"]
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
11
12
13
14
15
16

    if name in cite_graphs:
        dataset = CitationGraphDataset(name)
        graph = dataset[0]

        nodes = graph.nodes()
17
18
19
        y = graph.ndata["label"]
        train_mask = graph.ndata["train_mask"]
        val_mask = graph.ndata["test_mask"]
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
20
21
22
23
24

        nodes_train, y_train = nodes[train_mask], y[train_mask]
        nodes_val, y_val = nodes[val_mask], y[val_mask]
        eval_set = [(nodes_train, y_train), (nodes_val, y_val)]

25
    elif name.startswith("ogbn"):
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
26
27
28
29
30
31

        dataset = DglNodePropPredDataset(name)
        graph, y = dataset[0]
        split_nodes = dataset.get_idx_split()
        nodes = graph.nodes()

32
33
        train_idx = split_nodes["train"]
        val_idx = split_nodes["valid"]
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

        nodes_train, y_train = nodes[train_idx], y[train_idx]
        nodes_val, y_val = nodes[val_idx], y[val_idx]
        eval_set = [(nodes_train, y_train), (nodes_val, y_val)]

    else:
        raise ValueError("Dataset name error!")

    return graph, eval_set


def parse_arguments():
    """
    Parse arguments
    """
49
50
    parser = argparse.ArgumentParser(description="Node2vec")
    parser.add_argument("--dataset", type=str, default="cora")
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
51
    # 'train' for training node2vec model, 'time' for testing speed of random walk
52
53
54
55
56
57
58
59
60
61
    parser.add_argument("--task", type=str, default="train")
    parser.add_argument("--runs", type=int, default=10)
    parser.add_argument("--device", type=str, default="cpu")
    parser.add_argument("--embedding_dim", type=int, default=128)
    parser.add_argument("--walk_length", type=int, default=50)
    parser.add_argument("--p", type=float, default=0.25)
    parser.add_argument("--q", type=float, default=4.0)
    parser.add_argument("--num_walks", type=int, default=10)
    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--batch_size", type=int, default=128)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
62
63
64
65

    args = parser.parse_args()

    return args