model.py 7.04 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
import dgl.function as fn
2
import torch as th
3
4
5
6
7
8
9
10
11
12
import torch.nn as nn
from dgl.nn.functional import edge_softmax


class MLP(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.W = nn.Linear(in_dim, out_dim)

    def apply_edges(self, edges):
13
14
15
        h_e = edges.data["h"]
        h_u = edges.src["h"]
        h_v = edges.dst["h"]
16
        score = self.W(th.cat([h_e, h_u, h_v], -1))
17
        return {"score": score}
18
19
20

    def forward(self, g, e_feat, u_feat, v_feat):
        with g.local_scope():
21
22
23
            g.edges["forward"].data["h"] = e_feat
            g.nodes["u"].data["h"] = u_feat
            g.nodes["v"].data["h"] = v_feat
24
            g.apply_edges(self.apply_edges, etype="forward")
25
            return g.edges["forward"].data["score"]
26
27
28
29
30


class GASConv(nn.Module):
    """One layer of GAS."""

31
32
33
34
35
36
37
38
39
40
41
    def __init__(
        self,
        e_in_dim,
        u_in_dim,
        v_in_dim,
        e_out_dim,
        u_out_dim,
        v_out_dim,
        activation=None,
        dropout=0,
    ):
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
        super(GASConv, self).__init__()

        self.activation = activation
        self.dropout = nn.Dropout(dropout)

        self.e_linear = nn.Linear(e_in_dim, e_out_dim)
        self.u_linear = nn.Linear(u_in_dim, e_out_dim)
        self.v_linear = nn.Linear(v_in_dim, e_out_dim)

        self.W_ATTN_u = nn.Linear(u_in_dim, v_in_dim + e_in_dim)
        self.W_ATTN_v = nn.Linear(v_in_dim, u_in_dim + e_in_dim)

        # the proportion of h_u and h_Nu are specified as 1/2 in formula 8
        nu_dim = int(u_out_dim / 2)
        nv_dim = int(v_out_dim / 2)

        self.W_u = nn.Linear(v_in_dim + e_in_dim, nu_dim)
        self.W_v = nn.Linear(u_in_dim + e_in_dim, nv_dim)

        self.Vu = nn.Linear(u_in_dim, u_out_dim - nu_dim)
        self.Vv = nn.Linear(v_in_dim, v_out_dim - nv_dim)

    def forward(self, g, e_feat, u_feat, v_feat):
        with g.local_scope():
66
67
68
69
            g.nodes["u"].data["h"] = u_feat
            g.nodes["v"].data["h"] = v_feat
            g.edges["forward"].data["h"] = e_feat
            g.edges["backward"].data["h"] = e_feat
70
71

            # formula 3 and 4 (optimized implementation to save memory)
72
73
74
75
76
77
78
79
80
81
82
83
            g.nodes["u"].data.update({"he_u": self.u_linear(u_feat)})
            g.nodes["v"].data.update({"he_v": self.v_linear(v_feat)})
            g.edges["forward"].data.update({"he_e": self.e_linear(e_feat)})
            g.apply_edges(
                lambda edges: {
                    "he": edges.data["he_e"]
                    + edges.src["he_u"]
                    + edges.dst["he_v"]
                },
                etype="forward",
            )
            he = g.edges["forward"].data["he"]
84
85
86
87
            if self.activation is not None:
                he = self.activation(he)

            # formula 6
88
89
90
91
92
93
94
95
96
97
98
99
            g.apply_edges(
                lambda edges: {
                    "h_ve": th.cat([edges.src["h"], edges.data["h"]], -1)
                },
                etype="backward",
            )
            g.apply_edges(
                lambda edges: {
                    "h_ue": th.cat([edges.src["h"], edges.data["h"]], -1)
                },
                etype="forward",
            )
100
101

            # formula 7, self-attention
102
103
            g.nodes["u"].data["h_att_u"] = self.W_ATTN_u(u_feat)
            g.nodes["v"].data["h_att_v"] = self.W_ATTN_v(v_feat)
104
105

            # Step 1: dot product
106
107
108
109
110
111
            g.apply_edges(
                fn.e_dot_v("h_ve", "h_att_u", "edotv"), etype="backward"
            )
            g.apply_edges(
                fn.e_dot_v("h_ue", "h_att_v", "edotv"), etype="forward"
            )
112
113

            # Step 2. softmax
114
115
116
117
118
119
            g.edges["backward"].data["sfm"] = edge_softmax(
                g["backward"], g.edges["backward"].data["edotv"]
            )
            g.edges["forward"].data["sfm"] = edge_softmax(
                g["forward"], g.edges["forward"].data["edotv"]
            )
120
121

            # Step 3. Broadcast softmax value to each edge, and then attention is done
122
123
124
125
126
127
128
129
            g.apply_edges(
                lambda edges: {"attn": edges.data["h_ve"] * edges.data["sfm"]},
                etype="backward",
            )
            g.apply_edges(
                lambda edges: {"attn": edges.data["h_ue"] * edges.data["sfm"]},
                etype="forward",
            )
130
131

            # Step 4. Aggregate attention to dst,user nodes, so formula 7 is done
132
133
134
135
136
137
            g.update_all(
                fn.copy_e("attn", "m"), fn.sum("m", "agg_u"), etype="backward"
            )
            g.update_all(
                fn.copy_e("attn", "m"), fn.sum("m", "agg_v"), etype="forward"
            )
138
139

            # formula 5
140
141
            h_nu = self.W_u(g.nodes["u"].data["agg_u"])
            h_nv = self.W_v(g.nodes["v"].data["agg_v"])
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
            if self.activation is not None:
                h_nu = self.activation(h_nu)
                h_nv = self.activation(h_nv)

            # Dropout
            he = self.dropout(he)
            h_nu = self.dropout(h_nu)
            h_nv = self.dropout(h_nv)

            # formula 8
            hu = th.cat([self.Vu(u_feat), h_nu], -1)
            hv = th.cat([self.Vv(v_feat), h_nv], -1)

            return he, hu, hv


class GAS(nn.Module):
159
160
161
162
163
164
165
166
167
168
169
170
171
    def __init__(
        self,
        e_in_dim,
        u_in_dim,
        v_in_dim,
        e_hid_dim,
        u_hid_dim,
        v_hid_dim,
        out_dim,
        num_layers=2,
        dropout=0.0,
        activation=None,
    ):
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
        super(GAS, self).__init__()
        self.e_in_dim = e_in_dim
        self.u_in_dim = u_in_dim
        self.v_in_dim = v_in_dim
        self.e_hid_dim = e_hid_dim
        self.u_hid_dim = u_hid_dim
        self.v_hid_dim = v_hid_dim
        self.out_dim = out_dim
        self.num_layer = num_layers
        self.dropout = dropout
        self.activation = activation
        self.predictor = MLP(e_hid_dim + u_hid_dim + v_hid_dim, out_dim)
        self.layers = nn.ModuleList()

        # Input layer
187
188
189
190
191
192
193
194
195
196
197
198
        self.layers.append(
            GASConv(
                self.e_in_dim,
                self.u_in_dim,
                self.v_in_dim,
                self.e_hid_dim,
                self.u_hid_dim,
                self.v_hid_dim,
                activation=self.activation,
                dropout=self.dropout,
            )
        )
199
200
201

        # Hidden layers with n - 1 CompGraphConv layers
        for i in range(self.num_layer - 1):
202
203
204
205
206
207
208
209
210
211
212
213
            self.layers.append(
                GASConv(
                    self.e_hid_dim,
                    self.u_hid_dim,
                    self.v_hid_dim,
                    self.e_hid_dim,
                    self.u_hid_dim,
                    self.v_hid_dim,
                    activation=self.activation,
                    dropout=self.dropout,
                )
            )
214
215
216
217
218
219
220
221
222

    def forward(self, graph, e_feat, u_feat, v_feat):
        # For full graph training, directly use the graph
        # Forward of n layers of GAS
        for layer in self.layers:
            e_feat, u_feat, v_feat = layer(graph, e_feat, u_feat, v_feat)

        # return the result of final prediction layer
        return self.predictor(graph, e_feat, u_feat, v_feat)