import torch as th import torch.nn as nn import dgl.function as fn from dgl.nn.functional import edge_softmax class MLP(nn.Module): def __init__(self, in_dim, out_dim): super().__init__() self.W = nn.Linear(in_dim, out_dim) def apply_edges(self, edges): h_e = edges.data["h"] h_u = edges.src["h"] h_v = edges.dst["h"] score = self.W(th.cat([h_e, h_u, h_v], -1)) return {"score": score} def forward(self, g, e_feat, u_feat, v_feat): with g.local_scope(): g.edges["forward"].data["h"] = e_feat g.nodes["u"].data["h"] = u_feat g.nodes["v"].data["h"] = v_feat g.apply_edges(self.apply_edges, etype="forward") return g.edges["forward"].data["score"] class GASConv(nn.Module): """One layer of GAS.""" def __init__( self, e_in_dim, u_in_dim, v_in_dim, e_out_dim, u_out_dim, v_out_dim, activation=None, dropout=0, ): super(GASConv, self).__init__() self.activation = activation self.dropout = nn.Dropout(dropout) self.e_linear = nn.Linear(e_in_dim, e_out_dim) self.u_linear = nn.Linear(u_in_dim, e_out_dim) self.v_linear = nn.Linear(v_in_dim, e_out_dim) self.W_ATTN_u = nn.Linear(u_in_dim, v_in_dim + e_in_dim) self.W_ATTN_v = nn.Linear(v_in_dim, u_in_dim + e_in_dim) # the proportion of h_u and h_Nu are specified as 1/2 in formula 8 nu_dim = int(u_out_dim / 2) nv_dim = int(v_out_dim / 2) self.W_u = nn.Linear(v_in_dim + e_in_dim, nu_dim) self.W_v = nn.Linear(u_in_dim + e_in_dim, nv_dim) self.Vu = nn.Linear(u_in_dim, u_out_dim - nu_dim) self.Vv = nn.Linear(v_in_dim, v_out_dim - nv_dim) def forward(self, g, f_feat, b_feat, u_feat, v_feat): g.srcnodes["u"].data["h"] = u_feat g.srcnodes["v"].data["h"] = v_feat g.dstnodes["u"].data["h"] = u_feat[: g.number_of_dst_nodes(ntype="u")] g.dstnodes["v"].data["h"] = v_feat[: g.number_of_dst_nodes(ntype="v")] g.edges["forward"].data["h"] = f_feat g.edges["backward"].data["h"] = b_feat # formula 3 and 4 (optimized implementation to save memory) g.srcnodes["u"].data.update( {"he_u": self.u_linear(g.srcnodes["u"].data["h"])} ) g.srcnodes["v"].data.update( {"he_v": self.v_linear(g.srcnodes["v"].data["h"])} ) g.dstnodes["u"].data.update( {"he_u": self.u_linear(g.dstnodes["u"].data["h"])} ) g.dstnodes["v"].data.update( {"he_v": self.v_linear(g.dstnodes["v"].data["h"])} ) g.edges["forward"].data.update({"he_e": self.e_linear(f_feat)}) g.edges["backward"].data.update({"he_e": self.e_linear(b_feat)}) g.apply_edges( lambda edges: { "he": edges.data["he_e"] + edges.dst["he_u"] + edges.src["he_v"] }, etype="backward", ) g.apply_edges( lambda edges: { "he": edges.data["he_e"] + edges.src["he_u"] + edges.dst["he_v"] }, etype="forward", ) hf = g.edges["forward"].data["he"] hb = g.edges["backward"].data["he"] if self.activation is not None: hf = self.activation(hf) hb = self.activation(hb) # formula 6 g.apply_edges( lambda edges: { "h_ve": th.cat([edges.src["h"], edges.data["h"]], -1) }, etype="backward", ) g.apply_edges( lambda edges: { "h_ue": th.cat([edges.src["h"], edges.data["h"]], -1) }, etype="forward", ) # formula 7, self-attention g.srcnodes["u"].data["h_att_u"] = self.W_ATTN_u( g.srcnodes["u"].data["h"] ) g.srcnodes["v"].data["h_att_v"] = self.W_ATTN_v( g.srcnodes["v"].data["h"] ) g.dstnodes["u"].data["h_att_u"] = self.W_ATTN_u( g.dstnodes["u"].data["h"] ) g.dstnodes["v"].data["h_att_v"] = self.W_ATTN_v( g.dstnodes["v"].data["h"] ) # Step 1: dot product g.apply_edges(fn.e_dot_v("h_ve", "h_att_u", "edotv"), etype="backward") g.apply_edges(fn.e_dot_v("h_ue", "h_att_v", "edotv"), etype="forward") # Step 2. softmax g.edges["backward"].data["sfm"] = edge_softmax( g["backward"], g.edges["backward"].data["edotv"] ) g.edges["forward"].data["sfm"] = edge_softmax( g["forward"], g.edges["forward"].data["edotv"] ) # Step 3. Broadcast softmax value to each edge, and then attention is done g.apply_edges( lambda edges: {"attn": edges.data["h_ve"] * edges.data["sfm"]}, etype="backward", ) g.apply_edges( lambda edges: {"attn": edges.data["h_ue"] * edges.data["sfm"]}, etype="forward", ) # Step 4. Aggregate attention to dst,user nodes, so formula 7 is done g.update_all( fn.copy_e("attn", "m"), fn.sum("m", "agg_u"), etype="backward" ) g.update_all( fn.copy_e("attn", "m"), fn.sum("m", "agg_v"), etype="forward" ) # formula 5 h_nu = self.W_u(g.dstnodes["u"].data["agg_u"]) h_nv = self.W_v(g.dstnodes["v"].data["agg_v"]) if self.activation is not None: h_nu = self.activation(h_nu) h_nv = self.activation(h_nv) # Dropout hf = self.dropout(hf) hb = self.dropout(hb) h_nu = self.dropout(h_nu) h_nv = self.dropout(h_nv) # formula 8 hu = th.cat([self.Vu(g.dstnodes["u"].data["h"]), h_nu], -1) hv = th.cat([self.Vv(g.dstnodes["v"].data["h"]), h_nv], -1) return hf, hb, hu, hv class GAS(nn.Module): def __init__( self, e_in_dim, u_in_dim, v_in_dim, e_hid_dim, u_hid_dim, v_hid_dim, out_dim, num_layers=2, dropout=0.0, activation=None, ): super(GAS, self).__init__() self.e_in_dim = e_in_dim self.u_in_dim = u_in_dim self.v_in_dim = v_in_dim self.e_hid_dim = e_hid_dim self.u_hid_dim = u_hid_dim self.v_hid_dim = v_hid_dim self.out_dim = out_dim self.num_layer = num_layers self.dropout = dropout self.activation = activation self.predictor = MLP(e_hid_dim + u_hid_dim + v_hid_dim, out_dim) self.layers = nn.ModuleList() # Input layer self.layers.append( GASConv( self.e_in_dim, self.u_in_dim, self.v_in_dim, self.e_hid_dim, self.u_hid_dim, self.v_hid_dim, activation=self.activation, dropout=self.dropout, ) ) # Hidden layers with n - 1 CompGraphConv layers for i in range(self.num_layer - 1): self.layers.append( GASConv( self.e_hid_dim, self.u_hid_dim, self.v_hid_dim, self.e_hid_dim, self.u_hid_dim, self.v_hid_dim, activation=self.activation, dropout=self.dropout, ) ) def forward(self, subgraph, blocks, f_feat, b_feat, u_feat, v_feat): # Forward of n layers of GAS for layer, block in zip(self.layers, blocks): f_feat, b_feat, u_feat, v_feat = layer( block, f_feat[: block.num_edges(etype="forward")], b_feat[: block.num_edges(etype="backward")], u_feat, v_feat, ) # return the result of final prediction layer return self.predictor( subgraph, f_feat[: subgraph.num_edges(etype="forward")], u_feat, v_feat, )