dataset.py 540 Bytes
Newer Older
1
2
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset

3
4

def load(name):
5
    if name == "cora":
6
        dataset = CoraGraphDataset()
7
    elif name == "citeseer":
8
        dataset = CiteseerGraphDataset()
9
    elif name == "pubmed":
10
11
12
13
        dataset = PubmedGraphDataset()

    graph = dataset[0]

14
15
    train_mask = graph.ndata.pop("train_mask")
    test_mask = graph.ndata.pop("test_mask")
16

17
18
    feat = graph.ndata.pop("feat")
    labels = graph.ndata.pop("label")
19

20
    return graph, feat, labels, train_mask, test_mask