dataloader.py 3.27 KB
Newer Older
1
2
3
4
5
"""
MxNet compatible dataloader
"""

import math
6

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
7
8
import dgl

9
10
import numpy as np
from mxnet import nd
11
from mxnet.gluon.data import DataLoader, Sampler
12
from sklearn.model_selection import StratifiedKFold
13
14


15
16
17
18
19
class SubsetRandomSampler(Sampler):
    def __init__(self, indices):
        self.indices = indices

    def __iter__(self):
20
21
22
        return iter(
            [self.indices[i] for i in np.random.permutation(len(self.indices))]
        )
23
24
25
26

    def __len__(self):
        return len(self.indices)

27

28
29
30
31
32
33
34
35
36
37
# default collate function
def collate(samples):
    # The input `samples` is a list of pairs (graph, label).
    graphs, labels = map(list, zip(*samples))
    for g in graphs:
        # deal with node feats
        for key in g.node_attr_schemes().keys():
            g.ndata[key] = nd.array(g.ndata[key])
        # no edge feats
    batched_graph = dgl.batch(graphs)
38
39
    labels = [nd.reshape(label, (1,)) for label in labels]
    labels = nd.concat(*labels, dim=0)
40
41
    return batched_graph, labels

42
43
44
45
46
47
48
49
50
51
52
53
54

class GraphDataLoader:
    def __init__(
        self,
        dataset,
        batch_size,
        collate_fn=collate,
        seed=0,
        shuffle=True,
        split_name="fold10",
        fold_idx=0,
        split_ratio=0.7,
    ):
55
56
57
58
59
        self.shuffle = shuffle
        self.seed = seed

        labels = [l for _, l in dataset]

60
        if split_name == "fold10":
61
            train_idx, valid_idx = self._split_fold10(
62
63
64
                labels, fold_idx, seed, shuffle
            )
        elif split_name == "rand":
65
            train_idx, valid_idx = self._split_rand(
66
67
                labels, split_ratio, seed, shuffle
            )
68
69
70
71
72
73
74
        else:
            raise NotImplementedError()

        train_sampler = SubsetRandomSampler(train_idx)
        valid_sampler = SubsetRandomSampler(valid_idx)

        self.train_loader = DataLoader(
75
76
77
78
79
            dataset,
            sampler=train_sampler,
            batch_size=batch_size,
            batchify_fn=collate_fn,
        )
80
        self.valid_loader = DataLoader(
81
82
83
84
85
            dataset,
            sampler=valid_sampler,
            batch_size=batch_size,
            batchify_fn=collate_fn,
        )
86
87
88
89
90

    def train_valid_loader(self):
        return self.train_loader, self.valid_loader

    def _split_fold10(self, labels, fold_idx=0, seed=0, shuffle=True):
91
        """10 flod"""
92
        assert 0 <= fold_idx and fold_idx < 10, print(
93
94
            "fold_idx must be from 0 to 9."
        )
95
96
97

        skf = StratifiedKFold(n_splits=10, shuffle=shuffle, random_state=seed)
        idx_list = []
98
99
100
        for idx in skf.split(
            np.zeros(len(labels)), [label.asnumpy() for label in labels]
        ):  # split(x, y)
101
102
103
            idx_list.append(idx)
        train_idx, valid_idx = idx_list[fold_idx]

104
        print("train_set : test_set = %d : %d", len(train_idx), len(valid_idx))
105
106
107
108
109
110
111
112
113
114
115

        return train_idx, valid_idx

    def _split_rand(self, labels, split_ratio=0.7, seed=0, shuffle=True):
        num_entries = len(labels)
        indices = list(range(num_entries))
        np.random.seed(seed)
        np.random.shuffle(indices)
        split = int(math.floor(split_ratio * num_entries))
        train_idx, valid_idx = indices[:split], indices[split:]

116
        print("train_set : test_set = %d : %d", len(train_idx), len(valid_idx))
117
118

        return train_idx, valid_idx