"references/vscode:/vscode.git/clone" did not exist on "22a23e8d229daaa241e2a1c3fe123f078cca1645"
dataset.py 1.41 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
import dgl
2
import numpy as np
3
4
import torch

5
6
7
8
9

def load_dataset(name):
    dataset = name.lower()
    if dataset == "amazon":
        from ogb.nodeproppred.dataset_dgl import DglNodePropPredDataset
10

11
12
13
14
15
16
17
        dataset = DglNodePropPredDataset(name="ogbn-products")
        splitted_idx = dataset.get_idx_split()
        train_nid = splitted_idx["train"]
        val_nid = splitted_idx["valid"]
        test_nid = splitted_idx["test"]
        g, labels = dataset[0]
        n_classes = int(labels.max() - labels.min() + 1)
18
19
        g.ndata["label"] = labels.squeeze()
        g.ndata["feat"] = g.ndata["feat"].float()
20
21
22
    elif dataset in ["reddit", "cora"]:
        if dataset == "reddit":
            from dgl.data import RedditDataset
23

24
            data = RedditDataset(self_loop=True)
25
            g = data[0]
26
        else:
Mufei Li's avatar
Mufei Li committed
27
            from dgl.data import CitationGraphDataset
28
29

            data = CitationGraphDataset("cora")
30
            g = data[0]
31
        n_classes = data.num_labels
32
33
34
        train_mask = g.ndata["train_mask"]
        val_mask = g.ndata["val_mask"]
        test_mask = g.ndata["test_mask"]
35
36
37
        train_nid = torch.LongTensor(train_mask.nonzero().squeeze())
        val_nid = torch.LongTensor(val_mask.nonzero().squeeze())
        test_nid = torch.LongTensor(test_mask.nonzero().squeeze())
38
39
    else:
        print("Dataset {} is not supported".format(name))
40
        assert 0
41

42
    return g, n_classes, train_nid, val_nid, test_nid