sampler.py 1.89 KB
Newer Older
1
2
3
4
import os
import random
import time

5
import torch
6
7
from partition_utils import *

8
9
import dgl.function as fn

10
11

class ClusterIter(object):
12
    """The partition sampler given a DGLGraph and partition number.
13
    The metis is used as the graph partition backend.
14
15
    """

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
    def __init__(self, dn, g, psize, batch_size, seed_nid):
        """Initialize the sampler.

        Paramters
        ---------
        dn : str
            The dataset name.
        g  : DGLGraph
            The full graph of dataset
        psize: int
            The partition number
        batch_size: int
            The number of partitions in one batch
        seed_nid: np.ndarray
            The training nodes ids, used to extract the training graph
        """
        self.psize = psize
        self.batch_size = batch_size
        # cache the partitions of known datasets&partition number
        if dn:
36
            fn = os.path.join("./datasets/", dn + "_{}.npy".format(psize))
37
38
39
            if os.path.exists(fn):
                self.par_li = np.load(fn, allow_pickle=True)
            else:
40
                os.makedirs("./datasets/", exist_ok=True)
41
42
43
44
45
46
47
48
49
50
51
52
                self.par_li = get_partition_list(g, psize)
                np.save(fn, self.par_li)
        else:
            self.par_li = get_partition_list(g, psize)
        par_list = []
        for p in self.par_li:
            par = torch.Tensor(p)
            par_list.append(par)
        self.par_list = par_list

    # use one side normalization
    def get_norm(self, g):
53
        norm = 1.0 / g.in_degrees().float().unsqueeze(1)
54
        norm[torch.isinf(norm)] = 0
55
        norm = norm.to(self.g.ndata["feat"].device)
56
57
58
59
60
61
62
63
        return norm

    def __len__(self):
        return self.psize

    def __getitem__(self, idx):
        return self.par_li[idx]

64

65
66
67
68
def subgraph_collate_fn(g, batch):
    nids = np.concatenate(batch).reshape(-1).astype(np.int64)
    g1 = g.subgraph(nids)
    return g1