sampler.py 1.58 KB
Newer Older
1
2
3
4
5
import os

import torch
from partition_utils import *

6

7
class ClusterIter(object):
8
    """The partition sampler given a DGLGraph and partition number.
9
    The metis is used as the graph partition backend.
10
11
    """

12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
    def __init__(self, dn, g, psize, batch_size):
        """Initialize the sampler.

        Paramters
        ---------
        dn : str
            The dataset name.
        g  : DGLGraph
            The full graph of dataset
        psize: int
            The partition number
        batch_size: int
            The number of partitions in one batch
        """
        self.psize = psize
        self.batch_size = batch_size
        # cache the partitions of known datasets&partition number
        if dn:
30
            fn = os.path.join("./datasets/", dn + "_{}.npy".format(psize))
31
32
33
            if os.path.exists(fn):
                self.par_li = np.load(fn, allow_pickle=True)
            else:
34
                os.makedirs("./datasets/", exist_ok=True)
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
                self.par_li = get_partition_list(g, psize)
                np.save(fn, self.par_li)
        else:
            self.par_li = get_partition_list(g, psize)
        par_list = []
        for p in self.par_li:
            par = torch.Tensor(p)
            par_list.append(par)
        self.par_list = par_list

    def __len__(self):
        return self.psize

    def __getitem__(self, idx):
        return self.par_li[idx]

51

52
53
54
55
56
57
def subgraph_collate_fn(g, batch):
    nids = np.concatenate(batch).reshape(-1).astype(np.int64)
    g1 = g.subgraph(nids)
    g1 = dgl.remove_self_loop(g1)
    g1 = dgl.add_self_loop(g1)
    return g1