Unverified Commit 6c0cc1fb authored by Maybewuss's avatar Maybewuss Committed by GitHub
Browse files

modify code with buildin function (#2394)



* modify code with buildin function

* use enisum
Co-authored-by: default avatarZihao Ye <expye@outlook.com>
parent 013d1456
......@@ -4,6 +4,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
from dgl.ops import edge_softmax
class HGTLayer(nn.Module):
def __init__(self,
......@@ -27,14 +28,14 @@ class HGTLayer(nn.Module):
self.d_k = out_dim // n_heads
self.sqrt_dk = math.sqrt(self.d_k)
self.att = None
self.k_linears = nn.ModuleList()
self.q_linears = nn.ModuleList()
self.v_linears = nn.ModuleList()
self.a_linears = nn.ModuleList()
self.norms = nn.ModuleList()
self.use_norm = use_norm
for t in range(self.num_types):
self.k_linears.append(nn.Linear(in_dim, out_dim))
self.q_linears.append(nn.Linear(in_dim, out_dim))
......@@ -42,62 +43,52 @@ class HGTLayer(nn.Module):
self.a_linears.append(nn.Linear(out_dim, out_dim))
if use_norm:
self.norms.append(nn.LayerNorm(out_dim))
self.relation_pri = nn.Parameter(torch.ones(self.num_relations, self.n_heads))
self.relation_att = nn.Parameter(torch.Tensor(self.num_relations, n_heads, self.d_k, self.d_k))
self.relation_msg = nn.Parameter(torch.Tensor(self.num_relations, n_heads, self.d_k, self.d_k))
self.skip = nn.Parameter(torch.ones(self.num_types))
self.drop = nn.Dropout(dropout)
nn.init.xavier_uniform_(self.relation_att)
nn.init.xavier_uniform_(self.relation_msg)
def edge_attention(self, edges):
etype = edges.data['id'][0]
'''
Step 1: Heterogeneous Mutual Attention
'''
relation_att = self.relation_att[etype]
relation_pri = self.relation_pri[etype]
key = torch.bmm(edges.src['k'].transpose(1,0), relation_att).transpose(1,0)
att = (edges.dst['q'] * key).sum(dim=-1) * relation_pri / self.sqrt_dk
'''
Step 2: Heterogeneous Message Passing
'''
relation_msg = self.relation_msg[etype]
val = torch.bmm(edges.src['v'].transpose(1,0), relation_msg).transpose(1,0)
return {'a': att, 'v': val}
def message_func(self, edges):
return {'v': edges.data['v'], 'a': edges.data['a']}
def reduce_func(self, nodes):
'''
Softmax based on target node's id (edge_index_i).
NOTE: Using DGL's API, there is a minor difference with this softmax with the original one.
This implementation will do softmax only on edges belong to the same relation type, instead of for all of the edges.
'''
att = F.softmax(nodes.mailbox['a'], dim=1)
h = torch.sum(att.unsqueeze(dim = -1) * nodes.mailbox['v'], dim=1)
return {'t': h.view(-1, self.out_dim)}
def forward(self, G, h):
with G.local_scope():
node_dict, edge_dict = self.node_dict, self.edge_dict
for srctype, etype, dsttype in G.canonical_etypes:
sub_graph = G[srctype, etype, dsttype]
k_linear = self.k_linears[node_dict[srctype]]
v_linear = self.v_linears[node_dict[srctype]]
v_linear = self.v_linears[node_dict[srctype]]
q_linear = self.q_linears[node_dict[dsttype]]
G.nodes[srctype].data['k'] = k_linear(h[srctype]).view(-1, self.n_heads, self.d_k)
G.nodes[srctype].data['v'] = v_linear(h[srctype]).view(-1, self.n_heads, self.d_k)
G.nodes[dsttype].data['q'] = q_linear(h[dsttype]).view(-1, self.n_heads, self.d_k)
G.apply_edges(func=self.edge_attention, etype=etype)
G.multi_update_all({etype : (self.message_func, self.reduce_func) \
k = k_linear(h[srctype]).view(-1, self.n_heads, self.d_k)
v = v_linear(h[srctype]).view(-1, self.n_heads, self.d_k)
q = q_linear(h[dsttype]).view(-1, self.n_heads, self.d_k)
e_id = self.edge_dict[etype]
relation_att = self.relation_att[e_id]
relation_pri = self.relation_pri[e_id]
relation_msg = self.relation_msg[e_id]
k = torch.einsum("bij,ijk->bik", k, realtion_att)
v = torch.einsum("bij,ijk->bik", k, relation_msg)
sub_graph.srcdata['k'] = k
sub_graph.dstdata['q'] = q
sub_graph.srcdata['v'] = v
sub_graph.apply_edges(fn.v_dot_u('q', 'k', 't'))
attn_score = sub_graph.edata.pop('t').sum(-1) * relation_pri / self.sqrt_dk
attn_score = edge_softmax(sub_graph, attn_score, norm_by='dst')
sub_graph.edata['t'] = attn_score.unsqueeze(-1)
G.multi_update_all({etype : (fn.u_mul_e('v', 't', 'm'), fn.sum('m', 't')) \
for etype in edge_dict}, cross_reducer = 'mean')
new_h = {}
for ntype in G.ntypes:
'''
......@@ -106,14 +97,15 @@ class HGTLayer(nn.Module):
'''
n_id = node_dict[ntype]
alpha = torch.sigmoid(self.skip[n_id])
trans_out = self.drop(self.a_linears[n_id](G.nodes[ntype].data['t']))
t = G.nodes[ntype].data['t'].view(-1, self.out_dim)
trans_out = self.drop(self.a_linears[n_id](t))
trans_out = trans_out * alpha + h[ntype] * (1-alpha)
if self.use_norm:
new_h[ntype] = self.norms[n_id](trans_out)
else:
new_h[ntype] = trans_out
return new_h
class HGT(nn.Module):
def __init__(self, G, node_dict, edge_dict, n_inp, n_hid, n_out, n_layers, n_heads, use_norm = True):
super(HGT, self).__init__()
......@@ -167,8 +159,8 @@ class HeteroRGCNLayer(nn.Module):
G.multi_update_all(funcs, 'sum')
# return the updated node feature dictionary
return {ntype : G.nodes[ntype].data['h'] for ntype in G.ntypes}
class HeteroRGCN(nn.Module):
def __init__(self, G, in_size, hidden_size, out_size):
super(HeteroRGCN, self).__init__()
......
......@@ -13,7 +13,7 @@ from model import *
import argparse
torch.manual_seed(0)
data_url = 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/ACM.mat'
data_url = 'https://data.dgl.ai/dataset/ACM.mat'
data_file_path = '/tmp/ACM.mat'
urllib.request.urlretrieve(data_url, data_file_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment