Unverified Commit 6c0cc1fb authored by Maybewuss's avatar Maybewuss Committed by GitHub
Browse files

modify code with buildin function (#2394)



* modify code with buildin function

* use enisum
Co-authored-by: default avatarZihao Ye <expye@outlook.com>
parent 013d1456
......@@ -4,6 +4,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
from dgl.ops import edge_softmax
class HGTLayer(nn.Module):
def __init__(self,
......@@ -52,52 +53,42 @@ class HGTLayer(nn.Module):
nn.init.xavier_uniform_(self.relation_att)
nn.init.xavier_uniform_(self.relation_msg)
def edge_attention(self, edges):
etype = edges.data['id'][0]
'''
Step 1: Heterogeneous Mutual Attention
'''
relation_att = self.relation_att[etype]
relation_pri = self.relation_pri[etype]
key = torch.bmm(edges.src['k'].transpose(1,0), relation_att).transpose(1,0)
att = (edges.dst['q'] * key).sum(dim=-1) * relation_pri / self.sqrt_dk
'''
Step 2: Heterogeneous Message Passing
'''
relation_msg = self.relation_msg[etype]
val = torch.bmm(edges.src['v'].transpose(1,0), relation_msg).transpose(1,0)
return {'a': att, 'v': val}
def message_func(self, edges):
return {'v': edges.data['v'], 'a': edges.data['a']}
def reduce_func(self, nodes):
'''
Softmax based on target node's id (edge_index_i).
NOTE: Using DGL's API, there is a minor difference with this softmax with the original one.
This implementation will do softmax only on edges belong to the same relation type, instead of for all of the edges.
'''
att = F.softmax(nodes.mailbox['a'], dim=1)
h = torch.sum(att.unsqueeze(dim = -1) * nodes.mailbox['v'], dim=1)
return {'t': h.view(-1, self.out_dim)}
def forward(self, G, h):
with G.local_scope():
node_dict, edge_dict = self.node_dict, self.edge_dict
for srctype, etype, dsttype in G.canonical_etypes:
sub_graph = G[srctype, etype, dsttype]
k_linear = self.k_linears[node_dict[srctype]]
v_linear = self.v_linears[node_dict[srctype]]
q_linear = self.q_linears[node_dict[dsttype]]
G.nodes[srctype].data['k'] = k_linear(h[srctype]).view(-1, self.n_heads, self.d_k)
G.nodes[srctype].data['v'] = v_linear(h[srctype]).view(-1, self.n_heads, self.d_k)
G.nodes[dsttype].data['q'] = q_linear(h[dsttype]).view(-1, self.n_heads, self.d_k)
k = k_linear(h[srctype]).view(-1, self.n_heads, self.d_k)
v = v_linear(h[srctype]).view(-1, self.n_heads, self.d_k)
q = q_linear(h[dsttype]).view(-1, self.n_heads, self.d_k)
e_id = self.edge_dict[etype]
relation_att = self.relation_att[e_id]
relation_pri = self.relation_pri[e_id]
relation_msg = self.relation_msg[e_id]
k = torch.einsum("bij,ijk->bik", k, realtion_att)
v = torch.einsum("bij,ijk->bik", k, relation_msg)
G.apply_edges(func=self.edge_attention, etype=etype)
G.multi_update_all({etype : (self.message_func, self.reduce_func) \
sub_graph.srcdata['k'] = k
sub_graph.dstdata['q'] = q
sub_graph.srcdata['v'] = v
sub_graph.apply_edges(fn.v_dot_u('q', 'k', 't'))
attn_score = sub_graph.edata.pop('t').sum(-1) * relation_pri / self.sqrt_dk
attn_score = edge_softmax(sub_graph, attn_score, norm_by='dst')
sub_graph.edata['t'] = attn_score.unsqueeze(-1)
G.multi_update_all({etype : (fn.u_mul_e('v', 't', 'm'), fn.sum('m', 't')) \
for etype in edge_dict}, cross_reducer = 'mean')
new_h = {}
for ntype in G.ntypes:
'''
......@@ -106,7 +97,8 @@ class HGTLayer(nn.Module):
'''
n_id = node_dict[ntype]
alpha = torch.sigmoid(self.skip[n_id])
trans_out = self.drop(self.a_linears[n_id](G.nodes[ntype].data['t']))
t = G.nodes[ntype].data['t'].view(-1, self.out_dim)
trans_out = self.drop(self.a_linears[n_id](t))
trans_out = trans_out * alpha + h[ntype] * (1-alpha)
if self.use_norm:
new_h[ntype] = self.norms[n_id](trans_out)
......
......@@ -13,7 +13,7 @@ from model import *
import argparse
torch.manual_seed(0)
data_url = 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/ACM.mat'
data_url = 'https://data.dgl.ai/dataset/ACM.mat'
data_file_path = '/tmp/ACM.mat'
urllib.request.urlretrieve(data_url, data_file_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment