Unverified Commit 6c0cc1fb authored by Maybewuss's avatar Maybewuss Committed by GitHub
Browse files

modify code with buildin function (#2394)



* modify code with buildin function

* use enisum
Co-authored-by: default avatarZihao Ye <expye@outlook.com>
parent 013d1456
...@@ -4,6 +4,7 @@ import torch ...@@ -4,6 +4,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl.function as fn import dgl.function as fn
from dgl.ops import edge_softmax
class HGTLayer(nn.Module): class HGTLayer(nn.Module):
def __init__(self, def __init__(self,
...@@ -27,14 +28,14 @@ class HGTLayer(nn.Module): ...@@ -27,14 +28,14 @@ class HGTLayer(nn.Module):
self.d_k = out_dim // n_heads self.d_k = out_dim // n_heads
self.sqrt_dk = math.sqrt(self.d_k) self.sqrt_dk = math.sqrt(self.d_k)
self.att = None self.att = None
self.k_linears = nn.ModuleList() self.k_linears = nn.ModuleList()
self.q_linears = nn.ModuleList() self.q_linears = nn.ModuleList()
self.v_linears = nn.ModuleList() self.v_linears = nn.ModuleList()
self.a_linears = nn.ModuleList() self.a_linears = nn.ModuleList()
self.norms = nn.ModuleList() self.norms = nn.ModuleList()
self.use_norm = use_norm self.use_norm = use_norm
for t in range(self.num_types): for t in range(self.num_types):
self.k_linears.append(nn.Linear(in_dim, out_dim)) self.k_linears.append(nn.Linear(in_dim, out_dim))
self.q_linears.append(nn.Linear(in_dim, out_dim)) self.q_linears.append(nn.Linear(in_dim, out_dim))
...@@ -42,62 +43,52 @@ class HGTLayer(nn.Module): ...@@ -42,62 +43,52 @@ class HGTLayer(nn.Module):
self.a_linears.append(nn.Linear(out_dim, out_dim)) self.a_linears.append(nn.Linear(out_dim, out_dim))
if use_norm: if use_norm:
self.norms.append(nn.LayerNorm(out_dim)) self.norms.append(nn.LayerNorm(out_dim))
self.relation_pri = nn.Parameter(torch.ones(self.num_relations, self.n_heads)) self.relation_pri = nn.Parameter(torch.ones(self.num_relations, self.n_heads))
self.relation_att = nn.Parameter(torch.Tensor(self.num_relations, n_heads, self.d_k, self.d_k)) self.relation_att = nn.Parameter(torch.Tensor(self.num_relations, n_heads, self.d_k, self.d_k))
self.relation_msg = nn.Parameter(torch.Tensor(self.num_relations, n_heads, self.d_k, self.d_k)) self.relation_msg = nn.Parameter(torch.Tensor(self.num_relations, n_heads, self.d_k, self.d_k))
self.skip = nn.Parameter(torch.ones(self.num_types)) self.skip = nn.Parameter(torch.ones(self.num_types))
self.drop = nn.Dropout(dropout) self.drop = nn.Dropout(dropout)
nn.init.xavier_uniform_(self.relation_att) nn.init.xavier_uniform_(self.relation_att)
nn.init.xavier_uniform_(self.relation_msg) nn.init.xavier_uniform_(self.relation_msg)
def edge_attention(self, edges):
etype = edges.data['id'][0]
'''
Step 1: Heterogeneous Mutual Attention
'''
relation_att = self.relation_att[etype]
relation_pri = self.relation_pri[etype]
key = torch.bmm(edges.src['k'].transpose(1,0), relation_att).transpose(1,0)
att = (edges.dst['q'] * key).sum(dim=-1) * relation_pri / self.sqrt_dk
'''
Step 2: Heterogeneous Message Passing
'''
relation_msg = self.relation_msg[etype]
val = torch.bmm(edges.src['v'].transpose(1,0), relation_msg).transpose(1,0)
return {'a': att, 'v': val}
def message_func(self, edges):
return {'v': edges.data['v'], 'a': edges.data['a']}
def reduce_func(self, nodes):
'''
Softmax based on target node's id (edge_index_i).
NOTE: Using DGL's API, there is a minor difference with this softmax with the original one.
This implementation will do softmax only on edges belong to the same relation type, instead of for all of the edges.
'''
att = F.softmax(nodes.mailbox['a'], dim=1)
h = torch.sum(att.unsqueeze(dim = -1) * nodes.mailbox['v'], dim=1)
return {'t': h.view(-1, self.out_dim)}
def forward(self, G, h): def forward(self, G, h):
with G.local_scope(): with G.local_scope():
node_dict, edge_dict = self.node_dict, self.edge_dict node_dict, edge_dict = self.node_dict, self.edge_dict
for srctype, etype, dsttype in G.canonical_etypes: for srctype, etype, dsttype in G.canonical_etypes:
sub_graph = G[srctype, etype, dsttype]
k_linear = self.k_linears[node_dict[srctype]] k_linear = self.k_linears[node_dict[srctype]]
v_linear = self.v_linears[node_dict[srctype]] v_linear = self.v_linears[node_dict[srctype]]
q_linear = self.q_linears[node_dict[dsttype]] q_linear = self.q_linears[node_dict[dsttype]]
G.nodes[srctype].data['k'] = k_linear(h[srctype]).view(-1, self.n_heads, self.d_k) k = k_linear(h[srctype]).view(-1, self.n_heads, self.d_k)
G.nodes[srctype].data['v'] = v_linear(h[srctype]).view(-1, self.n_heads, self.d_k) v = v_linear(h[srctype]).view(-1, self.n_heads, self.d_k)
G.nodes[dsttype].data['q'] = q_linear(h[dsttype]).view(-1, self.n_heads, self.d_k) q = q_linear(h[dsttype]).view(-1, self.n_heads, self.d_k)
G.apply_edges(func=self.edge_attention, etype=etype) e_id = self.edge_dict[etype]
G.multi_update_all({etype : (self.message_func, self.reduce_func) \
relation_att = self.relation_att[e_id]
relation_pri = self.relation_pri[e_id]
relation_msg = self.relation_msg[e_id]
k = torch.einsum("bij,ijk->bik", k, realtion_att)
v = torch.einsum("bij,ijk->bik", k, relation_msg)
sub_graph.srcdata['k'] = k
sub_graph.dstdata['q'] = q
sub_graph.srcdata['v'] = v
sub_graph.apply_edges(fn.v_dot_u('q', 'k', 't'))
attn_score = sub_graph.edata.pop('t').sum(-1) * relation_pri / self.sqrt_dk
attn_score = edge_softmax(sub_graph, attn_score, norm_by='dst')
sub_graph.edata['t'] = attn_score.unsqueeze(-1)
G.multi_update_all({etype : (fn.u_mul_e('v', 't', 'm'), fn.sum('m', 't')) \
for etype in edge_dict}, cross_reducer = 'mean') for etype in edge_dict}, cross_reducer = 'mean')
new_h = {} new_h = {}
for ntype in G.ntypes: for ntype in G.ntypes:
''' '''
...@@ -106,14 +97,15 @@ class HGTLayer(nn.Module): ...@@ -106,14 +97,15 @@ class HGTLayer(nn.Module):
''' '''
n_id = node_dict[ntype] n_id = node_dict[ntype]
alpha = torch.sigmoid(self.skip[n_id]) alpha = torch.sigmoid(self.skip[n_id])
trans_out = self.drop(self.a_linears[n_id](G.nodes[ntype].data['t'])) t = G.nodes[ntype].data['t'].view(-1, self.out_dim)
trans_out = self.drop(self.a_linears[n_id](t))
trans_out = trans_out * alpha + h[ntype] * (1-alpha) trans_out = trans_out * alpha + h[ntype] * (1-alpha)
if self.use_norm: if self.use_norm:
new_h[ntype] = self.norms[n_id](trans_out) new_h[ntype] = self.norms[n_id](trans_out)
else: else:
new_h[ntype] = trans_out new_h[ntype] = trans_out
return new_h return new_h
class HGT(nn.Module): class HGT(nn.Module):
def __init__(self, G, node_dict, edge_dict, n_inp, n_hid, n_out, n_layers, n_heads, use_norm = True): def __init__(self, G, node_dict, edge_dict, n_inp, n_hid, n_out, n_layers, n_heads, use_norm = True):
super(HGT, self).__init__() super(HGT, self).__init__()
...@@ -167,8 +159,8 @@ class HeteroRGCNLayer(nn.Module): ...@@ -167,8 +159,8 @@ class HeteroRGCNLayer(nn.Module):
G.multi_update_all(funcs, 'sum') G.multi_update_all(funcs, 'sum')
# return the updated node feature dictionary # return the updated node feature dictionary
return {ntype : G.nodes[ntype].data['h'] for ntype in G.ntypes} return {ntype : G.nodes[ntype].data['h'] for ntype in G.ntypes}
class HeteroRGCN(nn.Module): class HeteroRGCN(nn.Module):
def __init__(self, G, in_size, hidden_size, out_size): def __init__(self, G, in_size, hidden_size, out_size):
super(HeteroRGCN, self).__init__() super(HeteroRGCN, self).__init__()
......
...@@ -13,7 +13,7 @@ from model import * ...@@ -13,7 +13,7 @@ from model import *
import argparse import argparse
torch.manual_seed(0) torch.manual_seed(0)
data_url = 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/ACM.mat' data_url = 'https://data.dgl.ai/dataset/ACM.mat'
data_file_path = '/tmp/ACM.mat' data_file_path = '/tmp/ACM.mat'
urllib.request.urlretrieve(data_url, data_file_path) urllib.request.urlretrieve(data_url, data_file_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment