dataset.py 1.41 KB
Newer Older
1
import numpy as np
2
3
import torch

4
5
6
7
8
9
10
import dgl


def load_dataset(name):
    dataset = name.lower()
    if dataset == "amazon":
        from ogb.nodeproppred.dataset_dgl import DglNodePropPredDataset
11

12
13
14
15
16
17
18
        dataset = DglNodePropPredDataset(name="ogbn-products")
        splitted_idx = dataset.get_idx_split()
        train_nid = splitted_idx["train"]
        val_nid = splitted_idx["valid"]
        test_nid = splitted_idx["test"]
        g, labels = dataset[0]
        n_classes = int(labels.max() - labels.min() + 1)
19
20
        g.ndata["label"] = labels.squeeze()
        g.ndata["feat"] = g.ndata["feat"].float()
21
22
23
    elif dataset in ["reddit", "cora"]:
        if dataset == "reddit":
            from dgl.data import RedditDataset
24

25
            data = RedditDataset(self_loop=True)
26
            g = data[0]
27
        else:
Mufei Li's avatar
Mufei Li committed
28
            from dgl.data import CitationGraphDataset
29
30

            data = CitationGraphDataset("cora")
31
            g = data[0]
32
        n_classes = data.num_labels
33
34
35
        train_mask = g.ndata["train_mask"]
        val_mask = g.ndata["val_mask"]
        test_mask = g.ndata["test_mask"]
36
37
38
        train_nid = torch.LongTensor(train_mask.nonzero().squeeze())
        val_nid = torch.LongTensor(val_mask.nonzero().squeeze())
        test_nid = torch.LongTensor(test_mask.nonzero().squeeze())
39
40
    else:
        print("Dataset {} is not supported".format(name))
41
        assert 0
42

43
    return g, n_classes, train_nid, val_nid, test_nid