sampler.py 14.1 KB
Newer Older
1
import math
lt610's avatar
lt610 committed
2
import os
3
import random
lt610's avatar
lt610 committed
4
import time
5
6
7

import numpy as np
import scipy
lt610's avatar
lt610 committed
8
import torch as th
K's avatar
K committed
9
from torch.utils.data import DataLoader
10

lt610's avatar
lt610 committed
11
import dgl
12
13
import dgl.function as fn
from dgl.sampling import pack_traces, random_walk
lt610's avatar
lt610 committed
14
15
16


# The base class of sampler
K's avatar
K committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
class SAINTSampler:
    """
    Description
    -----------
    SAINTSampler implements the sampler described in GraphSAINT. This sampler implements offline sampling in
    pre-sampling phase as well as fully offline sampling, fully online sampling in training phase.
    Users can conveniently set param 'online' of the sampler to choose different modes.

    Parameters
    ----------
    node_budget : int
        the expected number of nodes in each subgraph, which is specifically explained in the paper. Actually this
        param specifies the times of sampling nodes from the original graph with replacement. The meaning of edge_budget
        is similar to the node_budget.
    dn : str
        name of dataset.
    g : DGLGraph
        the full graph.
    train_nid : list
        ids of training nodes.
    num_workers_sampler : int
        number of processes to sample subgraphs in pre-sampling procedure using torch.dataloader.
    num_subg_sampler : int, optional
        the max number of subgraphs sampled in pre-sampling phase for computing normalization coefficients in the beginning.
        Actually this param is used as ``__len__`` of sampler in pre-sampling phase.
        Please make sure that num_subg_sampler is greater than batch_size_sampler so that we can sample enough subgraphs.
        Defaults: 10000
    batch_size_sampler : int, optional
        the number of subgraphs sampled by each process concurrently in pre-sampling phase.
        Defaults: 200
    online : bool, optional
        If `True`, we employ online sampling in training phase. Otherwise employing offline sampling.
        Defaults: True
    num_subg : int, optional
        the expected number of sampled subgraphs in pre-sampling phase.
        It is actually the 'N' in the original paper. Note that this param is different from the num_subg_sampler.
        This param is just used to control the number of pre-sampled subgraphs.
        Defaults: 50
    full : bool, optional
        True if the number of subgraphs used in the training phase equals to that of pre-sampled subgraphs, or
        ``math.ceil(self.train_g.num_nodes() / self.node_budget)``. This formula takes the result of A divided by B as
        the number of subgraphs used in the training phase, where A is the number of training nodes in the original
        graph, B is the expected number of nodes in each pre-sampled subgraph. Please refer to the paper to check the
        details.
        Defaults: True

    Notes
    -----
    For parallelism of pre-sampling, we utilize `torch.DataLoader` to concurrently speed up sampling.
    The `num_subg_sampler` is the return value of `__len__` in pre-sampling phase. Moreover, the param `batch_size_sampler`
    determines the batch_size of `torch.DataLoader` in internal pre-sampling part. But note that if we wanna pass the
    SAINTSampler to `torch.DataLoader` for concurrently sampling subgraphs in training phase, we need to specify
    `batch_size` of `DataLoader`, that is, `batch_size_sampler` is not related to how sampler works in training procedure.
    """

72
73
74
75
76
77
78
79
80
81
82
83
84
    def __init__(
        self,
        node_budget,
        dn,
        g,
        train_nid,
        num_workers_sampler,
        num_subg_sampler=10000,
        batch_size_sampler=200,
        online=True,
        num_subg=50,
        full=True,
    ):
K's avatar
K committed
85
86
        self.g = g.cpu()
        self.node_budget = node_budget
lt610's avatar
lt610 committed
87
        self.train_g: dgl.graph = g.subgraph(train_nid)
K's avatar
K committed
88
        self.dn, self.num_subg = dn, num_subg
lt610's avatar
lt610 committed
89
90
91
        self.node_counter = th.zeros((self.train_g.num_nodes(),))
        self.edge_counter = th.zeros((self.train_g.num_edges(),))
        self.prob = None
K's avatar
K committed
92
93
94
95
96
97
98
        self.num_subg_sampler = num_subg_sampler
        self.batch_size_sampler = batch_size_sampler
        self.num_workers_sampler = num_workers_sampler
        self.train = False
        self.online = online
        self.full = full

99
100
101
        assert (
            self.num_subg_sampler >= self.batch_size_sampler
        ), "num_subg_sampler should be greater than batch_size_sampler"
lt610's avatar
lt610 committed
102
103
104
105
106
107
        graph_fn, norm_fn = self.__generate_fn__()

        if os.path.exists(graph_fn):
            self.subgraphs = np.load(graph_fn, allow_pickle=True)
            aggr_norm, loss_norm = np.load(norm_fn, allow_pickle=True)
        else:
108
            os.makedirs("./subgraphs/", exist_ok=True)
lt610's avatar
lt610 committed
109
110
111

            self.subgraphs = []
            self.N, sampled_nodes = 0, 0
K's avatar
K committed
112
113
114
            # N: the number of pre-sampled subgraphs

            # Employ parallelism to speed up the sampling procedure
115
116
117
118
119
120
121
122
            loader = DataLoader(
                self,
                batch_size=self.batch_size_sampler,
                shuffle=True,
                num_workers=self.num_workers_sampler,
                collate_fn=self.__collate_fn__,
                drop_last=False,
            )
lt610's avatar
lt610 committed
123
124

            t = time.perf_counter()
K's avatar
K committed
125
126
127
128
129
            for num_nodes, subgraphs_nids, subgraphs_eids in loader:

                self.subgraphs.extend(subgraphs_nids)
                sampled_nodes += num_nodes

130
131
132
                _subgraphs, _node_counts = np.unique(
                    np.concatenate(subgraphs_nids), return_counts=True
                )
K's avatar
K committed
133
134
135
136
                sampled_nodes_idx = th.from_numpy(_subgraphs)
                _node_counts = th.from_numpy(_node_counts)
                self.node_counter[sampled_nodes_idx] += _node_counts

137
138
139
                _subgraphs_eids, _edge_counts = np.unique(
                    np.concatenate(subgraphs_eids), return_counts=True
                )
K's avatar
K committed
140
141
142
143
144
145
146
147
                sampled_edges_idx = th.from_numpy(_subgraphs_eids)
                _edge_counts = th.from_numpy(_edge_counts)
                self.edge_counter[sampled_edges_idx] += _edge_counts

                self.N += len(subgraphs_nids)  # number of subgraphs
                if sampled_nodes > self.train_g.num_nodes() * num_subg:
                    break

148
            print(f"Sampling time: [{time.perf_counter() - t:.2f}s]")
lt610's avatar
lt610 committed
149
150
151
152
            np.save(graph_fn, self.subgraphs)

            t = time.perf_counter()
            aggr_norm, loss_norm = self.__compute_norm__()
153
            print(f"Normalization time: [{time.perf_counter() - t:.2f}s]")
lt610's avatar
lt610 committed
154
155
            np.save(norm_fn, (aggr_norm, loss_norm))

156
157
        self.train_g.ndata["l_n"] = th.Tensor(loss_norm)
        self.train_g.edata["w"] = th.Tensor(aggr_norm)
K's avatar
K committed
158
        self.__compute_degree_norm()  # basically normalizing adjacent matrix
lt610's avatar
lt610 committed
159
160
161
162

        random.shuffle(self.subgraphs)
        self.__clear__()
        print("The number of subgraphs is: ", len(self.subgraphs))
K's avatar
K committed
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186

        self.train = True

    def __len__(self):
        if self.train is False:
            return self.num_subg_sampler
        else:
            if self.full:
                return len(self.subgraphs)
            else:
                return math.ceil(self.train_g.num_nodes() / self.node_budget)

    def __getitem__(self, idx):
        # Only when sampling subgraphs in training procedure and need to utilize sampled subgraphs and we still
        # have sampled subgraphs we can fetch a subgraph from sampled subgraphs
        if self.train:
            if self.online:
                subgraph = self.__sample__()
                return dgl.node_subgraph(self.train_g, subgraph)
            else:
                return dgl.node_subgraph(self.train_g, self.subgraphs[idx])
        else:
            subgraph_nids = self.__sample__()
            num_nodes = len(subgraph_nids)
187
188
189
            subgraph_eids = dgl.node_subgraph(
                self.train_g, subgraph_nids
            ).edata[dgl.EID]
K's avatar
K committed
190
191
192
            return num_nodes, subgraph_nids, subgraph_eids

    def __collate_fn__(self, batch):
193
194
195
        if (
            self.train
        ):  # sample only one graph each epoch, batch_size in training phase in 1
K's avatar
K committed
196
197
198
199
200
201
202
203
204
205
            return batch[0]
        else:
            sum_num_nodes = 0
            subgraphs_nids_list = []
            subgraphs_eids_list = []
            for num_nodes, subgraph_nids, subgraph_eids in batch:
                sum_num_nodes += num_nodes
                subgraphs_nids_list.append(subgraph_nids)
                subgraphs_eids_list.append(subgraph_eids)
            return sum_num_nodes, subgraphs_nids_list, subgraphs_eids_list
lt610's avatar
lt610 committed
206
207
208
209
210
211
212
213
214
215
216

    def __clear__(self):
        self.prob = None
        self.node_counter = None
        self.edge_counter = None
        self.g = None

    def __generate_fn__(self):
        raise NotImplementedError

    def __compute_norm__(self):
K's avatar
K committed
217

lt610's avatar
lt610 committed
218
219
220
221
222
        self.node_counter[self.node_counter == 0] = 1
        self.edge_counter[self.edge_counter == 0] = 1

        loss_norm = self.N / self.node_counter / self.train_g.num_nodes()

223
224
225
226
        self.train_g.ndata["n_c"] = self.node_counter
        self.train_g.edata["e_c"] = self.edge_counter
        self.train_g.apply_edges(fn.v_div_e("n_c", "e_c", "a_n"))
        aggr_norm = self.train_g.edata.pop("a_n")
lt610's avatar
lt610 committed
227

228
229
        self.train_g.ndata.pop("n_c")
        self.train_g.edata.pop("e_c")
lt610's avatar
lt610 committed
230
231
232
233
234

        return aggr_norm.numpy(), loss_norm.numpy()

    def __compute_degree_norm(self):

235
236
237
238
239
240
        self.train_g.ndata[
            "train_D_norm"
        ] = 1.0 / self.train_g.in_degrees().float().clamp(min=1).unsqueeze(1)
        self.g.ndata["full_D_norm"] = 1.0 / self.g.in_degrees().float().clamp(
            min=1
        ).unsqueeze(1)
lt610's avatar
lt610 committed
241
242
243
244
245
246

    def __sample__(self):
        raise NotImplementedError


class SAINTNodeSampler(SAINTSampler):
K's avatar
K committed
247
248
249
250
251
252
253
254
255
256
257
258
    """
    Description
    -----------
    GraphSAINT with node sampler.

    Parameters
    ----------
    node_budget : int
        the expected number of nodes in each subgraph, which is specifically explained in the paper.
    """

    def __init__(self, node_budget, **kwargs):
lt610's avatar
lt610 committed
259
        self.node_budget = node_budget
260
261
262
        super(SAINTNodeSampler, self).__init__(
            node_budget=node_budget, **kwargs
        )
lt610's avatar
lt610 committed
263
264

    def __generate_fn__(self):
265
266
267
268
269
270
271
272
273
274
        graph_fn = os.path.join(
            "./subgraphs/{}_Node_{}_{}.npy".format(
                self.dn, self.node_budget, self.num_subg
            )
        )
        norm_fn = os.path.join(
            "./subgraphs/{}_Node_{}_{}_norm.npy".format(
                self.dn, self.node_budget, self.num_subg
            )
        )
lt610's avatar
lt610 committed
275
276
277
278
279
280
        return graph_fn, norm_fn

    def __sample__(self):
        if self.prob is None:
            self.prob = self.train_g.in_degrees().float().clamp(min=1)

281
282
283
        sampled_nodes = th.multinomial(
            self.prob, num_samples=self.node_budget, replacement=True
        ).unique()
lt610's avatar
lt610 committed
284
285
286
287
        return sampled_nodes.numpy()


class SAINTEdgeSampler(SAINTSampler):
K's avatar
K committed
288
289
290
291
292
293
294
295
296
297
298
299
    """
    Description
    -----------
    GraphSAINT with edge sampler.

    Parameters
    ----------
    edge_budget : int
        the expected number of edges in each subgraph, which is specifically explained in the paper.
    """

    def __init__(self, edge_budget, **kwargs):
lt610's avatar
lt610 committed
300
        self.edge_budget = edge_budget
K's avatar
K committed
301
302
        self.rng = np.random.default_rng()

303
304
305
        super(SAINTEdgeSampler, self).__init__(
            node_budget=edge_budget * 2, **kwargs
        )
lt610's avatar
lt610 committed
306
307

    def __generate_fn__(self):
308
309
310
311
312
313
314
315
316
317
        graph_fn = os.path.join(
            "./subgraphs/{}_Edge_{}_{}.npy".format(
                self.dn, self.edge_budget, self.num_subg
            )
        )
        norm_fn = os.path.join(
            "./subgraphs/{}_Edge_{}_{}_norm.npy".format(
                self.dn, self.edge_budget, self.num_subg
            )
        )
lt610's avatar
lt610 committed
318
319
        return graph_fn, norm_fn

K's avatar
K committed
320
321
    # TODO: only sample half edges, then add another half edges
    # TODO: use numpy to implement cython sampling method
lt610's avatar
lt610 committed
322
323
324
    def __sample__(self):
        if self.prob is None:
            src, dst = self.train_g.edges()
325
326
327
328
329
330
331
332
333
            src_degrees, dst_degrees = self.train_g.in_degrees(
                src
            ).float().clamp(min=1), self.train_g.in_degrees(dst).float().clamp(
                min=1
            )
            prob_mat = 1.0 / src_degrees + 1.0 / dst_degrees
            prob_mat = scipy.sparse.csr_matrix(
                (prob_mat.numpy(), (src.numpy(), dst.numpy()))
            )
K's avatar
K committed
334
335
336
337
338
339
340
341
            # The edge probability here only contains that of edges in upper triangle adjacency matrix
            # Because we assume the graph is undirected, that is, the adjacency matrix is symmetric. We only need
            # to consider half of edges in the graph.
            self.prob = th.tensor(scipy.sparse.triu(prob_mat).data)
            self.prob /= self.prob.sum()
            self.adj_nodes = np.stack(prob_mat.nonzero(), axis=1)

        sampled_edges = np.unique(
342
343
344
345
346
347
348
349
350
351
            dgl.random.choice(
                len(self.prob),
                size=self.edge_budget,
                prob=self.prob,
                replace=False,
            )
        )
        sampled_nodes = np.unique(
            self.adj_nodes[sampled_edges].flatten()
        ).astype("long")
K's avatar
K committed
352
        return sampled_nodes
lt610's avatar
lt610 committed
353
354


K's avatar
K committed
355
356
357
358
359
class SAINTRandomWalkSampler(SAINTSampler):
    """
    Description
    -----------
    GraphSAINT with random walk sampler
lt610's avatar
lt610 committed
360

K's avatar
K committed
361
362
363
364
365
366
    Parameters
    ----------
    num_roots : int
        the number of roots to generate random walks.
    length : int
        the length of each random walk.
lt610's avatar
lt610 committed
367

K's avatar
K committed
368
369
370
    """

    def __init__(self, num_roots, length, **kwargs):
lt610's avatar
lt610 committed
371
        self.num_roots, self.length = num_roots, length
372
373
374
        super(SAINTRandomWalkSampler, self).__init__(
            node_budget=num_roots * length, **kwargs
        )
lt610's avatar
lt610 committed
375
376

    def __generate_fn__(self):
377
378
379
380
381
382
383
384
385
386
        graph_fn = os.path.join(
            "./subgraphs/{}_RW_{}_{}_{}.npy".format(
                self.dn, self.num_roots, self.length, self.num_subg
            )
        )
        norm_fn = os.path.join(
            "./subgraphs/{}_RW_{}_{}_{}_norm.npy".format(
                self.dn, self.num_roots, self.length, self.num_subg
            )
        )
lt610's avatar
lt610 committed
387
388
389
        return graph_fn, norm_fn

    def __sample__(self):
390
391
392
393
394
395
        sampled_roots = th.randint(
            0, self.train_g.num_nodes(), (self.num_roots,)
        )
        traces, types = random_walk(
            self.train_g, nodes=sampled_roots, length=self.length
        )
lt610's avatar
lt610 committed
396
397
398
        sampled_nodes, _, _, _ = pack_traces(traces, types)
        sampled_nodes = sampled_nodes.unique()
        return sampled_nodes.numpy()