load_graph.py 2.04 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import dgl
import torch as th

def load_reddit():
    from dgl.data import RedditDataset

    # load reddit data
    data = RedditDataset(self_loop=True)
    train_mask = data.train_mask
    val_mask = data.val_mask
    features = th.Tensor(data.features)
    labels = th.LongTensor(data.labels)

    # Construct graph
    g = data.graph
    g.ndata['features'] = features
    g.ndata['labels'] = labels
18
19
20
    g.ndata['train_mask'] = th.BoolTensor(data.train_mask)
    g.ndata['val_mask'] = th.BoolTensor(data.val_mask)
    g.ndata['test_mask'] = th.BoolTensor(data.test_mask)
21
22
23
24
25
    return g, data.num_labels

def load_ogb(name):
    from ogb.nodeproppred import DglNodePropPredDataset

Da Zheng's avatar
Da Zheng committed
26
    print('load', name)
27
    data = DglNodePropPredDataset(name=name)
Da Zheng's avatar
Da Zheng committed
28
    print('finish loading', name)
29
30
31
32
33
34
35
    splitted_idx = data.get_idx_split()
    graph, labels = data[0]
    labels = labels[:, 0]

    graph.ndata['features'] = graph.ndata['feat']
    graph.ndata['labels'] = labels
    in_feats = graph.ndata['features'].shape[1]
Da Zheng's avatar
Da Zheng committed
36
    num_labels = len(th.unique(labels[th.logical_not(th.isnan(labels))]))
37
38
39

    # Find the node IDs in the training, validation, and test set.
    train_nid, val_nid, test_nid = splitted_idx['train'], splitted_idx['valid'], splitted_idx['test']
40
41
42
43
44
45
    train_mask = th.zeros((graph.number_of_nodes(),), dtype=th.bool)
    train_mask[train_nid] = True
    val_mask = th.zeros((graph.number_of_nodes(),), dtype=th.bool)
    val_mask[val_nid] = True
    test_mask = th.zeros((graph.number_of_nodes(),), dtype=th.bool)
    test_mask[test_nid] = True
46
47
48
    graph.ndata['train_mask'] = train_mask
    graph.ndata['val_mask'] = val_mask
    graph.ndata['test_mask'] = test_mask
Da Zheng's avatar
Da Zheng committed
49
50
    print('finish constructing', name)
    return graph, num_labels
51
52
53
54
55
56
57
58

def inductive_split(g):
    """Split the graph into training graph, validation graph, and test graph by training
    and validation masks.  Suitable for inductive models."""
    train_g = g.subgraph(g.ndata['train_mask'])
    val_g = g.subgraph(g.ndata['train_mask'] | g.ndata['val_mask'])
    test_g = g
    return train_g, val_g, test_g