sampler.py 2.83 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
import os
import random

import dgl.function as fn
import torch

from partition_utils import *


class ClusterIter(object):
11
12
    '''The partition sampler given a DGLGraph and partition number.
    The metis is used as the graph partition backend.
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    '''
    def __init__(self, dn, g, psize, batch_size, seed_nid, use_pp=True):
        """Initialize the sampler.

        Paramters
        ---------
        dn : str
            The dataset name.
        g  : DGLGraph
            The full graph of dataset
        psize: int
            The partition number
        batch_size: int
            The number of partitions in one batch
        seed_nid: np.ndarray
            The training nodes ids, used to extract the training graph
        use_pp: bool
            Whether to use precompute of AX
        """
        self.use_pp = use_pp
        self.g = g.subgraph(seed_nid)

        # precalc the aggregated features from training graph only
        if use_pp:
            self.precalc(self.g)
            print('precalculating')

        self.psize = psize
        self.batch_size = batch_size
        # cache the partitions of known datasets&partition number
        if dn:
            fn = os.path.join('./datasets/', dn + '_{}.npy'.format(psize))
            if os.path.exists(fn):
                self.par_li = np.load(fn, allow_pickle=True)
            else:
                os.makedirs('./datasets/', exist_ok=True)
                self.par_li = get_partition_list(self.g, psize)
                np.save(fn, self.par_li)
        else:
            self.par_li = get_partition_list(self.g, psize)
        self.max = int((psize) // batch_size)
        random.shuffle(self.par_li)
        self.get_fn = get_subgraph

    def precalc(self, g):
        norm = self.get_norm(g)
        g.ndata['norm'] = norm
Xiangkun Hu's avatar
Xiangkun Hu committed
60
        features = g.ndata['feat']
61
62
        print("features shape, ", features.shape)
        with torch.no_grad():
Xiangkun Hu's avatar
Xiangkun Hu committed
63
64
            g.update_all(fn.copy_src(src='feat', out='m'),
                         fn.sum(msg='m', out='feat'),
65
                         None)
Xiangkun Hu's avatar
Xiangkun Hu committed
66
            pre_feats = g.ndata['feat'] * norm
67
            # use graphsage embedding aggregation style
Xiangkun Hu's avatar
Xiangkun Hu committed
68
            g.ndata['feat'] = torch.cat([features, pre_feats], dim=1)
69
70
71
72
73

    # use one side normalization
    def get_norm(self, g):
        norm = 1. / g.in_degrees().float().unsqueeze(1)
        norm[torch.isinf(norm)] = 0
Xiangkun Hu's avatar
Xiangkun Hu committed
74
        norm = norm.to(self.g.ndata['feat'].device)
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        return norm

    def __len__(self):
        return self.max

    def __iter__(self):
        self.n = 0
        return self

    def __next__(self):
        if self.n < self.max:
            result = self.get_fn(self.g, self.par_li, self.n,
                                 self.psize, self.batch_size)
            self.n += 1
            return result
        else:
            random.shuffle(self.par_li)
            raise StopIteration