Unverified Commit 66f7fe8b authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

Update model.py (#2796)

parent 2952f3c4
...@@ -78,7 +78,7 @@ class HGTLayer(nn.Module): ...@@ -78,7 +78,7 @@ class HGTLayer(nn.Module):
sub_graph.srcdata['k'] = k sub_graph.srcdata['k'] = k
sub_graph.dstdata['q'] = q sub_graph.dstdata['q'] = q
sub_graph.srcdata['v'] = v sub_graph.srcdata['v_%d' % e_id] = v
sub_graph.apply_edges(fn.v_dot_u('q', 'k', 't')) sub_graph.apply_edges(fn.v_dot_u('q', 'k', 't'))
attn_score = sub_graph.edata.pop('t').sum(-1) * relation_pri / self.sqrt_dk attn_score = sub_graph.edata.pop('t').sum(-1) * relation_pri / self.sqrt_dk
...@@ -86,8 +86,8 @@ class HGTLayer(nn.Module): ...@@ -86,8 +86,8 @@ class HGTLayer(nn.Module):
sub_graph.edata['t'] = attn_score.unsqueeze(-1) sub_graph.edata['t'] = attn_score.unsqueeze(-1)
G.multi_update_all({etype : (fn.u_mul_e('v', 't', 'm'), fn.sum('m', 't')) \ G.multi_update_all({etype : (fn.u_mul_e('v_%d' % e_id, 't', 'm'), fn.sum('m', 't')) \
for etype in edge_dict}, cross_reducer = 'mean') for etype, e_id in edge_dict.items()}, cross_reducer = 'mean')
new_h = {} new_h = {}
for ntype in G.ntypes: for ntype in G.ntypes:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment