model_sampling.py 7.28 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
2
import dgl
import dgl.function as fn
3
import numpy as np
4
import torch as th
5
import torch.nn as nn
6

7
8
9

def _l1_dist(edges):
    # formula 2
10
11
    ed = th.norm(edges.src["nd"] - edges.dst["nd"], 1, 1)
    return {"ed": ed}
12
13
14
15


class CARESampler(dgl.dataloading.BlockSampler):
    def __init__(self, p, dists, num_layers):
Xin Yao's avatar
Xin Yao committed
16
        super().__init__()
17
18
        self.p = p
        self.dists = dists
Xin Yao's avatar
Xin Yao committed
19
        self.num_layers = num_layers
20
21
22
23
24
25
26
27

    def sample_frontier(self, block_id, g, seed_nodes, *args, **kwargs):
        with g.local_scope():
            new_edges_masks = {}
            for etype in g.canonical_etypes:
                edge_mask = th.zeros(g.number_of_edges(etype))
                # extract each node from dict because of single node type
                for node in seed_nodes:
28
29
30
31
32
33
34
35
36
                    edges = g.in_edges(node, form="eid", etype=etype)
                    num_neigh = (
                        th.ceil(
                            g.in_degrees(node, etype=etype)
                            * self.p[block_id][etype]
                        )
                        .int()
                        .item()
                    )
37
                    neigh_dist = self.dists[block_id][etype][edges]
38
                    if neigh_dist.shape[0] > num_neigh:
39
40
41
                        neigh_index = np.argpartition(neigh_dist, num_neigh)[
                            :num_neigh
                        ]
42
43
                    else:
                        neigh_index = np.arange(num_neigh)
44
45
46
47
48
                    edge_mask[edges[neigh_index]] = 1
                new_edges_masks[etype] = edge_mask.bool()

            return dgl.edge_subgraph(g, new_edges_masks, relabel_nodes=False)

Xin Yao's avatar
Xin Yao committed
49
50
51
52
53
54
55
56
57
58
59
60
61
    def sample_blocks(self, g, seed_nodes, exclude_eids=None):
        output_nodes = seed_nodes
        blocks = []
        for block_id in reversed(range(self.num_layers)):
            frontier = self.sample_frontier(block_id, g, seed_nodes)
            eid = frontier.edata[dgl.EID]
            block = dgl.to_block(frontier, seed_nodes)
            block.edata[dgl.EID] = eid
            seed_nodes = block.srcdata[dgl.NID]
            blocks.insert(0, block)

        return seed_nodes, output_nodes, blocks

62
63
64
65
66
67
68
    def __len__(self):
        return self.num_layers


class CAREConv(nn.Module):
    """One layer of CARE-GNN."""

69
70
71
72
73
74
75
76
77
    def __init__(
        self,
        in_dim,
        out_dim,
        num_classes,
        edges,
        activation=None,
        step_size=0.02,
    ):
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        super(CAREConv, self).__init__()

        self.activation = activation
        self.step_size = step_size
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.num_classes = num_classes
        self.edges = edges

        self.linear = nn.Linear(self.in_dim, self.out_dim)
        self.MLP = nn.Linear(self.in_dim, self.num_classes)

        self.p = {}
        self.last_avg_dist = {}
        self.f = {}
93
        # indicate whether the RL converges
94
95
96
97
98
99
100
101
        self.cvg = {}
        for etype in edges:
            self.p[etype] = 0.5
            self.last_avg_dist[etype] = 0
            self.f[etype] = []
            self.cvg[etype] = False

    def forward(self, g, feat):
102
        g.srcdata["h"] = feat
103
104
105
106

        # formula 8
        hr = {}
        for etype in g.canonical_etypes:
107
108
            g.update_all(fn.copy_u("h", "m"), fn.mean("m", "hr"), etype=etype)
            hr[etype] = g.dstdata["hr"]
109
110
111
112
            if self.activation is not None:
                hr[etype] = self.activation(hr[etype])

        # formula 9 using mean as inter-relation aggregator
113
114
115
        p_tensor = (
            th.Tensor(list(self.p.values())).view(-1, 1, 1).to(feat.device)
        )
116
        h_homo = th.sum(th.stack(list(hr.values())) * p_tensor, dim=0)
117
        h_homo += feat[: g.number_of_dst_nodes()]
118
119
120
121
122
123
124
        if self.activation is not None:
            h_homo = self.activation(h_homo)

        return self.linear(h_homo)


class CAREGNN(nn.Module):
125
126
127
128
129
130
131
132
133
134
    def __init__(
        self,
        in_dim,
        num_classes,
        hid_dim=64,
        edges=None,
        num_layers=2,
        activation=None,
        step_size=0.02,
    ):
135
136
137
138
139
140
141
142
143
144
145
146
147
        super(CAREGNN, self).__init__()
        self.in_dim = in_dim
        self.hid_dim = hid_dim
        self.num_classes = num_classes
        self.edges = edges
        self.num_layers = num_layers
        self.activation = activation
        self.step_size = step_size

        self.layers = nn.ModuleList()

        if self.num_layers == 1:
            # Single layer
148
149
150
151
152
153
154
155
156
157
            self.layers.append(
                CAREConv(
                    self.in_dim,
                    self.num_classes,
                    self.num_classes,
                    self.edges,
                    activation=self.activation,
                    step_size=self.step_size,
                )
            )
158
159
160

        else:
            # Input layer
161
162
163
164
165
166
167
168
169
170
            self.layers.append(
                CAREConv(
                    self.in_dim,
                    self.hid_dim,
                    self.num_classes,
                    self.edges,
                    activation=self.activation,
                    step_size=self.step_size,
                )
            )
171
172
173

            # Hidden layers with n - 2 layers
            for i in range(self.num_layers - 2):
174
175
176
177
178
179
180
181
182
183
                self.layers.append(
                    CAREConv(
                        self.hid_dim,
                        self.hid_dim,
                        self.num_classes,
                        self.edges,
                        activation=self.activation,
                        step_size=self.step_size,
                    )
                )
184
185

            # Output layer
186
187
188
189
190
191
192
193
194
195
            self.layers.append(
                CAREConv(
                    self.hid_dim,
                    self.num_classes,
                    self.num_classes,
                    self.edges,
                    activation=self.activation,
                    step_size=self.step_size,
                )
            )
196
197
198

    def forward(self, blocks, feat):
        # formula 4
199
        sim = th.tanh(self.layers[0].MLP(blocks[-1].dstdata["feature"].float()))
200
201
202
203
204
205
206
207
208

        # Forward of n layers of CARE-GNN
        for block, layer in zip(blocks, self.layers):
            feat = layer(block, feat)
        return feat, sim

    def RLModule(self, graph, epoch, idx, dists):
        for i, layer in enumerate(self.layers):
            for etype in self.edges:
209
                if not layer.cvg[etype]:
210
                    # formula 5
211
                    eid = graph.in_edges(idx, form="eid", etype=etype)
212
213
214
215
216
                    avg_dist = th.mean(dists[i][etype][eid])

                    # formula 6
                    if layer.last_avg_dist[etype] < avg_dist:
                        layer.p[etype] -= self.step_size
217
218
219
220
                        layer.f[etype].append(-1)
                        # avoid overflow, follow the author's implement
                        if layer.p[etype] < 0:
                            layer.p[etype] = 0.001
221
222
                    else:
                        layer.p[etype] += self.step_size
223
224
225
226
                        layer.f[etype].append(+1)
                        if layer.p[etype] > 1:
                            layer.p[etype] = 0.999
                    layer.last_avg_dist[etype] = avg_dist
227
228

                    # formula 7
229
230
                    if epoch >= 9 and abs(sum(layer.f[etype][-10:])) <= 2:
                        layer.cvg[etype] = True