ladies_sampler.py 8.16 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
# referenced the following implementation: https://github.com/BarclayII/dgl/blob/ladies/examples/pytorch/ladies/ladies2.py

import dgl
import dgl.function as fn
import torch


def find_indices_in(a, b):
    b_sorted, indices = torch.sort(b)
    sorted_indices = torch.searchsorted(b_sorted, a)
    sorted_indices[sorted_indices >= indices.shape[0]] = 0
    return indices[sorted_indices]


def union(*arrays):
    return torch.unique(torch.cat(arrays))


def normalized_edata(g, weight=None):
    with g.local_scope():
        if weight is None:
            weight = "W"
            g.edata[weight] = torch.ones(g.number_of_edges(), device=g.device)
        g.update_all(fn.copy_e(weight, weight), fn.sum(weight, "v"))
        g.apply_edges(lambda edges: {"w": 1 / edges.dst["v"]})
        return g.edata["w"]


class LadiesSampler(dgl.dataloading.BlockSampler):
    def __init__(
        self,
        nodes_per_layer,
        importance_sampling=True,
        weight="w",
        out_weight="edge_weights",
        replace=False,
    ):
        super().__init__()
        self.nodes_per_layer = nodes_per_layer
        self.importance_sampling = importance_sampling
        self.edge_weight = weight
        self.output_weight = out_weight
        self.replace = replace

    def compute_prob(self, g, seed_nodes, weight, num):
        """
        g : the whole graph
        seed_nodes : the output nodes for the current layer
        weight : the weight of the edges
        return : the unnormalized probability of the candidate nodes, as well as the subgraph
                 containing all the edges from the candidate nodes to the output nodes.
        """
        insg = dgl.in_subgraph(g, seed_nodes)
        insg = dgl.compact_graphs(insg, seed_nodes)
        if self.importance_sampling:
            out_frontier = dgl.reverse(insg, copy_edata=True)
            weight = weight[out_frontier.edata[dgl.EID].long()]
            prob = dgl.ops.copy_e_sum(out_frontier, weight**2)
            # prob = torch.sqrt(prob)
        else:
            prob = torch.ones(insg.num_nodes())
            prob[insg.out_degrees() == 0] = 0
        return prob, insg

    def select_neighbors(self, prob, num):
        """
        seed_nodes : output nodes
        cand_nodes : candidate nodes.  Must contain all output nodes in @seed_nodes
        prob : unnormalized probability of each candidate node
        num : number of neighbors to sample
        return : the set of input nodes in terms of their indices in @cand_nodes, and also the indices of
                 seed nodes in the selected nodes.
        """
        # The returned nodes should be a union of seed_nodes plus @num nodes from cand_nodes.
        # Because compute_prob returns a compacted subgraph and a list of probabilities,
        # we need to find the corresponding local IDs of the resulting union in the subgraph
        # so that we can compute the edge weights of the block.
        # This is why we need a find_indices_in() function.
        neighbor_nodes_idx = torch.multinomial(
            prob, min(num, prob.shape[0]), replacement=self.replace
        )
        return neighbor_nodes_idx

    def generate_block(self, insg, neighbor_nodes_idx, seed_nodes, P_sg, W_sg):
        """
        insg : the subgraph yielded by compute_prob()
        neighbor_nodes_idx : the sampled nodes from the subgraph @insg, yielded by select_neighbors()
        seed_nodes_local_idx : the indices of seed nodes in the selected neighbor nodes, also yielded
                               by select_neighbors()
        P_sg : unnormalized probability of each node being sampled, yielded by compute_prob()
        W_sg : edge weights of @insg
        return : the block.
        """
        seed_nodes_idx = find_indices_in(seed_nodes, insg.ndata[dgl.NID])
        u_nodes = union(neighbor_nodes_idx, seed_nodes_idx)
        sg = insg.subgraph(u_nodes.type(insg.idtype))
        u, v = sg.edges()
        lu = sg.ndata[dgl.NID][u.long()]
        s = find_indices_in(lu, neighbor_nodes_idx)
        eg = dgl.edge_subgraph(
            sg, lu == neighbor_nodes_idx[s], relabel_nodes=False
        )
        eg.ndata[dgl.NID] = sg.ndata[dgl.NID][: eg.num_nodes()]
        eg.edata[dgl.EID] = sg.edata[dgl.EID][eg.edata[dgl.EID].long()]
        sg = eg
        nids = insg.ndata[dgl.NID][sg.ndata[dgl.NID].long()]
        P = P_sg[u_nodes.long()]
        W = W_sg[sg.edata[dgl.EID].long()]
        W_tilde = dgl.ops.e_div_u(sg, W, P)
        W_tilde_sum = dgl.ops.copy_e_sum(sg, W_tilde)
        d = sg.in_degrees()
        W_tilde = dgl.ops.e_mul_v(sg, W_tilde, d / W_tilde_sum)

        block = dgl.to_block(sg, seed_nodes_idx.type(sg.idtype))
        block.edata[self.output_weight] = W_tilde
        # correct node ID mapping
        block.srcdata[dgl.NID] = nids[block.srcdata[dgl.NID].long()]
        block.dstdata[dgl.NID] = nids[block.dstdata[dgl.NID].long()]

        sg_eids = insg.edata[dgl.EID][sg.edata[dgl.EID].long()]
        block.edata[dgl.EID] = sg_eids[block.edata[dgl.EID].long()]
        return block

    def sample_blocks(self, g, seed_nodes, exclude_eids=None):
        output_nodes = seed_nodes
        blocks = []
        for block_id in reversed(range(len(self.nodes_per_layer))):
            num_nodes_to_sample = self.nodes_per_layer[block_id]
            W = g.edata[self.edge_weight]
            prob, insg = self.compute_prob(
                g, seed_nodes, W, num_nodes_to_sample
            )
            neighbor_nodes_idx = self.select_neighbors(
                prob, num_nodes_to_sample
            )
            block = self.generate_block(
                insg,
                neighbor_nodes_idx.type(g.idtype),
                seed_nodes.type(g.idtype),
                prob,
                W[insg.edata[dgl.EID].long()],
            )
            seed_nodes = block.srcdata[dgl.NID]
            blocks.insert(0, block)
        return seed_nodes, output_nodes, blocks


class PoissonLadiesSampler(LadiesSampler):
    def __init__(
        self,
        nodes_per_layer,
        importance_sampling=True,
        weight="w",
        out_weight="edge_weights",
        skip=False,
    ):
        super().__init__(
            nodes_per_layer, importance_sampling, weight, out_weight
        )
        self.eps = 0.9999
        self.skip = skip

    def compute_prob(self, g, seed_nodes, weight, num):
        """
        g : the whole graph
        seed_nodes : the output nodes for the current layer
        weight : the weight of the edges
        return : the unnormalized probability of the candidate nodes, as well as the subgraph
                 containing all the edges from the candidate nodes to the output nodes.
        """
        prob, insg = super().compute_prob(g, seed_nodes, weight, num)

        one = torch.ones_like(prob)
        if prob.shape[0] <= num:
            return one, insg

        c = 1.0
        for i in range(50):
            S = torch.sum(torch.minimum(prob * c, one).to(torch.float64)).item()
            if min(S, num) / max(S, num) >= self.eps:
                break
            else:
                c *= num / S

        if self.skip:
            skip_nodes = find_indices_in(seed_nodes, insg.ndata[dgl.NID])
            prob[skip_nodes] = float("inf")

        return torch.minimum(prob * c, one), insg

    def select_neighbors(self, prob, num):
        """
        seed_nodes : output nodes
        cand_nodes : candidate nodes.  Must contain all output nodes in @seed_nodes
        prob : unnormalized probability of each candidate node
        num : number of neighbors to sample
        return : the set of input nodes in terms of their indices in @cand_nodes, and also the indices of
                 seed nodes in the selected nodes.
        """
        # The returned nodes should be a union of seed_nodes plus @num nodes from cand_nodes.
        # Because compute_prob returns a compacted subgraph and a list of probabilities,
        # we need to find the corresponding local IDs of the resulting union in the subgraph
        # so that we can compute the edge weights of the block.
        # This is why we need a find_indices_in() function.
        neighbor_nodes_idx = torch.arange(prob.shape[0], device=prob.device)[
            torch.bernoulli(prob) == 1
        ]
        return neighbor_nodes_idx