model.py 6.38 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
import dgl.function as fn
2
import numpy as np
3
import torch as th
4
import torch.nn as nn
5

6
7
8
9

class CAREConv(nn.Module):
    """One layer of CARE-GNN."""

10
11
12
13
14
15
16
17
18
    def __init__(
        self,
        in_dim,
        out_dim,
        num_classes,
        edges,
        activation=None,
        step_size=0.02,
    ):
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
        super(CAREConv, self).__init__()

        self.activation = activation
        self.step_size = step_size
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.num_classes = num_classes
        self.edges = edges
        self.dist = {}

        self.linear = nn.Linear(self.in_dim, self.out_dim)
        self.MLP = nn.Linear(self.in_dim, self.num_classes)

        self.p = {}
        self.last_avg_dist = {}
        self.f = {}
        self.cvg = {}
        for etype in edges:
            self.p[etype] = 0.5
            self.last_avg_dist[etype] = 0
            self.f[etype] = []
            self.cvg[etype] = False

    def _calc_distance(self, edges):
        # formula 2
44
45
46
47
48
49
50
        d = th.norm(
            th.tanh(self.MLP(edges.src["h"]))
            - th.tanh(self.MLP(edges.dst["h"])),
            1,
            1,
        )
        return {"d": d}
51
52
53
54

    def _top_p_sampling(self, g, p):
        # this implementation is low efficient
        # optimization requires dgl.sampling.select_top_p requested in issue #3100
55
        dist = g.edata["d"]
56
57
        neigh_list = []
        for node in g.nodes():
58
            edges = g.in_edges(node, form="eid")
59
            num_neigh = th.ceil(g.in_degrees(node) * p).int().item()
60
            neigh_dist = dist[edges]
61
            if neigh_dist.shape[0] > num_neigh:
62
63
64
                neigh_index = np.argpartition(
                    neigh_dist.cpu().detach(), num_neigh
                )[:num_neigh]
65
66
            else:
                neigh_index = np.arange(num_neigh)
67
68
69
70
71
            neigh_list.append(edges[neigh_index])
        return th.cat(neigh_list)

    def forward(self, g, feat):
        with g.local_scope():
72
            g.ndata["h"] = feat
73
74
75
76

            hr = {}
            for i, etype in enumerate(g.canonical_etypes):
                g.apply_edges(self._calc_distance, etype=etype)
77
                self.dist[etype] = g.edges[etype].data["d"]
78
79
80
                sampled_edges = self._top_p_sampling(g[etype], self.p[etype])

                # formula 8
81
82
83
84
85
86
87
                g.send_and_recv(
                    sampled_edges,
                    fn.copy_u("h", "m"),
                    fn.mean("m", "h_%s" % etype[1]),
                    etype=etype,
                )
                hr[etype] = g.ndata["h_%s" % etype[1]]
88
89
90
91
                if self.activation is not None:
                    hr[etype] = self.activation(hr[etype])

            # formula 9 using mean as inter-relation aggregator
92
93
94
            p_tensor = (
                th.Tensor(list(self.p.values())).view(-1, 1, 1).to(g.device)
            )
95
96
97
98
99
100
101
102
103
            h_homo = th.sum(th.stack(list(hr.values())) * p_tensor, dim=0)
            h_homo += feat
            if self.activation is not None:
                h_homo = self.activation(h_homo)

            return self.linear(h_homo)


class CAREGNN(nn.Module):
104
105
106
107
108
109
110
111
112
113
    def __init__(
        self,
        in_dim,
        num_classes,
        hid_dim=64,
        edges=None,
        num_layers=2,
        activation=None,
        step_size=0.02,
    ):
114
115
116
117
118
119
120
121
122
123
124
125
126
        super(CAREGNN, self).__init__()
        self.in_dim = in_dim
        self.hid_dim = hid_dim
        self.num_classes = num_classes
        self.edges = edges
        self.activation = activation
        self.step_size = step_size
        self.num_layers = num_layers

        self.layers = nn.ModuleList()

        if self.num_layers == 1:
            # Single layer
127
128
129
130
131
132
133
134
135
136
            self.layers.append(
                CAREConv(
                    self.in_dim,
                    self.num_classes,
                    self.num_classes,
                    self.edges,
                    activation=self.activation,
                    step_size=self.step_size,
                )
            )
137
138
139

        else:
            # Input layer
140
141
142
143
144
145
146
147
148
149
            self.layers.append(
                CAREConv(
                    self.in_dim,
                    self.hid_dim,
                    self.num_classes,
                    self.edges,
                    activation=self.activation,
                    step_size=self.step_size,
                )
            )
150
151
152

            # Hidden layers with n - 2 layers
            for i in range(self.num_layers - 2):
153
154
155
156
157
158
159
160
161
162
                self.layers.append(
                    CAREConv(
                        self.hid_dim,
                        self.hid_dim,
                        self.num_classes,
                        self.edges,
                        activation=self.activation,
                        step_size=self.step_size,
                    )
                )
163
164

            # Output layer
165
166
167
168
169
170
171
172
173
174
            self.layers.append(
                CAREConv(
                    self.hid_dim,
                    self.num_classes,
                    self.num_classes,
                    self.edges,
                    activation=self.activation,
                    step_size=self.step_size,
                )
            )
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191

    def forward(self, graph, feat):
        # For full graph training, directly use the graph
        # formula 4
        sim = th.tanh(self.layers[0].MLP(feat))

        # Forward of n layers of CARE-GNN
        for layer in self.layers:
            feat = layer(graph, feat)

        return feat, sim

    def RLModule(self, graph, epoch, idx):
        for layer in self.layers:
            for etype in self.edges:
                if not layer.cvg[etype]:
                    # formula 5
192
                    eid = graph.in_edges(idx, form="eid", etype=etype)
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
                    avg_dist = th.mean(layer.dist[etype][eid])

                    # formula 6
                    if layer.last_avg_dist[etype] < avg_dist:
                        if layer.p[etype] - self.step_size > 0:
                            layer.p[etype] -= self.step_size
                        layer.f[etype].append(-1)
                    else:
                        if layer.p[etype] + self.step_size <= 1:
                            layer.p[etype] += self.step_size
                        layer.f[etype].append(+1)
                    layer.last_avg_dist[etype] = avg_dist

                    # formula 7
                    if epoch >= 9 and abs(sum(layer.f[etype][-10:])) <= 2:
                        layer.cvg[etype] = True