load_graph.py 1.8 KB
Newer Older
1
import dgl
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
2
import torch as th
3
4


5
def load_reddit(self_loop=True):
6
7
8
    from dgl.data import RedditDataset

    # load reddit data
9
    data = RedditDataset(self_loop=self_loop)
Xiangkun Hu's avatar
Xiangkun Hu committed
10
    g = data[0]
11
12
    g.ndata["features"] = g.ndata.pop("feat")
    g.ndata["labels"] = g.ndata.pop("label")
13
    return g, data.num_classes
14

15
16

def load_ogb(name, root="dataset"):
17
18
    from ogb.nodeproppred import DglNodePropPredDataset

19
    print("load", name)
20
    data = DglNodePropPredDataset(name=name, root=root)
21
    print("finish loading", name)
22
23
24
25
    splitted_idx = data.get_idx_split()
    graph, labels = data[0]
    labels = labels[:, 0]

26
27
28
    graph.ndata["features"] = graph.ndata.pop("feat")
    graph.ndata["labels"] = labels
    in_feats = graph.ndata["features"].shape[1]
Da Zheng's avatar
Da Zheng committed
29
    num_labels = len(th.unique(labels[th.logical_not(th.isnan(labels))]))
30
31

    # Find the node IDs in the training, validation, and test set.
32
33
34
35
36
    train_nid, val_nid, test_nid = (
        splitted_idx["train"],
        splitted_idx["valid"],
        splitted_idx["test"],
    )
37
    train_mask = th.zeros((graph.num_nodes(),), dtype=th.bool)
38
    train_mask[train_nid] = True
39
    val_mask = th.zeros((graph.num_nodes(),), dtype=th.bool)
40
    val_mask[val_nid] = True
41
    test_mask = th.zeros((graph.num_nodes(),), dtype=th.bool)
42
    test_mask[test_nid] = True
43
44
45
46
    graph.ndata["train_mask"] = train_mask
    graph.ndata["val_mask"] = val_mask
    graph.ndata["test_mask"] = test_mask
    print("finish constructing", name)
Da Zheng's avatar
Da Zheng committed
47
    return graph, num_labels
48

49

50
51
52
def inductive_split(g):
    """Split the graph into training graph, validation graph, and test graph by training
    and validation masks.  Suitable for inductive models."""
53
54
    train_g = g.subgraph(g.ndata["train_mask"])
    val_g = g.subgraph(g.ndata["train_mask"] | g.ndata["val_mask"])
55
56
    test_g = g
    return train_g, val_g, test_g