import dgl import math import torch import torch.nn as nn import torch.nn.functional as F import dgl.function as fn class HGTLayer(nn.Module): def __init__(self, in_dim, out_dim, num_types, num_relations, n_heads, dropout = 0.2, use_norm = False): super(HGTLayer, self).__init__() self.in_dim = in_dim self.out_dim = out_dim self.num_types = num_types self.num_relations = num_relations self.total_rel = num_types * num_relations * num_types self.n_heads = n_heads self.d_k = out_dim // n_heads self.sqrt_dk = math.sqrt(self.d_k) self.att = None self.k_linears = nn.ModuleList() self.q_linears = nn.ModuleList() self.v_linears = nn.ModuleList() self.a_linears = nn.ModuleList() self.norms = nn.ModuleList() self.use_norm = use_norm for t in range(num_types): self.k_linears.append(nn.Linear(in_dim, out_dim)) self.q_linears.append(nn.Linear(in_dim, out_dim)) self.v_linears.append(nn.Linear(in_dim, out_dim)) self.a_linears.append(nn.Linear(out_dim, out_dim)) if use_norm: self.norms.append(nn.LayerNorm(out_dim)) self.relation_pri = nn.Parameter(torch.ones(num_relations, self.n_heads)) self.relation_att = nn.Parameter(torch.Tensor(num_relations, n_heads, self.d_k, self.d_k)) self.relation_msg = nn.Parameter(torch.Tensor(num_relations, n_heads, self.d_k, self.d_k)) self.skip = nn.Parameter(torch.ones(num_types)) self.drop = nn.Dropout(dropout) nn.init.xavier_uniform_(self.relation_att) nn.init.xavier_uniform_(self.relation_msg) def edge_attention(self, edges): etype = edges.data['id'][0] ''' Step 1: Heterogeneous Mutual Attention ''' relation_att = self.relation_att[etype] relation_pri = self.relation_pri[etype] key = torch.bmm(edges.src['k'].transpose(1,0), relation_att).transpose(1,0) att = (edges.dst['q'] * key).sum(dim=-1) * relation_pri / self.sqrt_dk ''' Step 2: Heterogeneous Message Passing ''' relation_msg = self.relation_msg[etype] val = torch.bmm(edges.src['v'].transpose(1,0), relation_msg).transpose(1,0) return {'a': att, 'v': val} def message_func(self, edges): return {'v': edges.data['v'], 'a': edges.data['a']} def reduce_func(self, nodes): ''' Softmax based on target node's id (edge_index_i). NOTE: Using DGL's API, there is a minor difference with this softmax with the original one. This implementation will do softmax only on edges belong to the same relation type, instead of for all of the edges. ''' att = F.softmax(nodes.mailbox['a'], dim=1) h = torch.sum(att.unsqueeze(dim = -1) * nodes.mailbox['v'], dim=1) return {'t': h.view(-1, self.out_dim)} def forward(self, G, inp_key, out_key): node_dict, edge_dict = G.node_dict, G.edge_dict for srctype, etype, dsttype in G.canonical_etypes: k_linear = self.k_linears[node_dict[srctype]] v_linear = self.v_linears[node_dict[srctype]] q_linear = self.q_linears[node_dict[dsttype]] G.nodes[srctype].data['k'] = k_linear(G.nodes[srctype].data[inp_key]).view(-1, self.n_heads, self.d_k) G.nodes[srctype].data['v'] = v_linear(G.nodes[srctype].data[inp_key]).view(-1, self.n_heads, self.d_k) G.nodes[dsttype].data['q'] = q_linear(G.nodes[dsttype].data[inp_key]).view(-1, self.n_heads, self.d_k) G.apply_edges(func=self.edge_attention, etype=etype) G.multi_update_all({etype : (self.message_func, self.reduce_func) \ for etype in edge_dict}, cross_reducer = 'mean') for ntype in G.ntypes: ''' Step 3: Target-specific Aggregation x = norm( W[node_type] * gelu( Agg(x) ) + x ) ''' n_id = node_dict[ntype] alpha = torch.sigmoid(self.skip[n_id]) trans_out = self.drop(self.a_linears[n_id](G.nodes[ntype].data['t'])) trans_out = trans_out * alpha + G.nodes[ntype].data[inp_key] * (1-alpha) if self.use_norm: G.nodes[ntype].data[out_key] = self.norms[n_id](trans_out) class HGT(nn.Module): def __init__(self, G, n_inp, n_hid, n_out, n_layers, n_heads, use_norm = True): super(HGT, self).__init__() self.gcs = nn.ModuleList() self.n_inp = n_inp self.n_hid = n_hid self.n_out = n_out self.n_layers = n_layers self.adapt_ws = nn.ModuleList() for t in range(len(G.node_dict)): self.adapt_ws.append(nn.Linear(n_inp, n_hid)) for _ in range(n_layers): self.gcs.append(HGTLayer(n_hid, n_hid, len(G.node_dict), len(G.edge_dict), n_heads, use_norm = use_norm)) self.out = nn.Linear(n_hid, n_out) def forward(self, G, out_key): for ntype in G.ntypes: n_id = G.node_dict[ntype] G.nodes[ntype].data['h'] = F.gelu(self.adapt_ws[n_id](G.nodes[ntype].data['inp'])) for i in range(self.n_layers): self.gcs[i](G, 'h', 'h') return self.out(G.nodes[out_key].data['h']) class HeteroRGCNLayer(nn.Module): def __init__(self, in_size, out_size, etypes): super(HeteroRGCNLayer, self).__init__() # W_r for each relation self.weight = nn.ModuleDict({ name : nn.Linear(in_size, out_size) for name in etypes }) def forward(self, G, feat_dict): # The input is a dictionary of node features for each type funcs = {} for srctype, etype, dsttype in G.canonical_etypes: # Compute W_r * h Wh = self.weight[etype](feat_dict[srctype]) # Save it in graph for message passing G.nodes[srctype].data['Wh_%s' % etype] = Wh # Specify per-relation message passing functions: (message_func, reduce_func). # Note that the results are saved to the same destination feature 'h', which # hints the type wise reducer for aggregation. funcs[etype] = (fn.copy_u('Wh_%s' % etype, 'm'), fn.mean('m', 'h')) # Trigger message passing of multiple types. # The first argument is the message passing functions for each relation. # The second one is the type wise reducer, could be "sum", "max", # "min", "mean", "stack" G.multi_update_all(funcs, 'sum') # return the updated node feature dictionary return {ntype : G.nodes[ntype].data['h'] for ntype in G.ntypes} class HeteroRGCN(nn.Module): def __init__(self, G, in_size, hidden_size, out_size): super(HeteroRGCN, self).__init__() # create layers self.layer1 = HeteroRGCNLayer(in_size, hidden_size, G.etypes) self.layer2 = HeteroRGCNLayer(hidden_size, out_size, G.etypes) def forward(self, G, out_key): input_dict = {ntype : G.nodes[ntype].data['inp'] for ntype in G.ntypes} h_dict = self.layer1(G, input_dict) h_dict = {k : F.leaky_relu(h) for k, h in h_dict.items()} h_dict = self.layer2(G, h_dict) # get paper logits return h_dict[out_key]