Unverified Commit dba8c825 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files
parent 8f3e423c
...@@ -79,9 +79,8 @@ class Transformer(nn.Module): ...@@ -79,9 +79,8 @@ class Transformer(nn.Module):
g.apply_edges(src_dot_dst('k', 'q', 'score'), eids) g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)
g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)), eids) g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)), eids)
# Send weighted values to target nodes # Send weighted values to target nodes
g.send_and_recv(eids, g.send_and_recv(eids, fn.src_mul_edge('v', 'score', 'v'), fn.sum('v', 'wv'))
[fn.src_mul_edge('v', 'score', 'v'), fn.copy_edge('score', 'score')], g.send_and_recv(eids, fn.copy_edge('score', 'score'), fn.sum('score', 'z'))
[fn.sum('v', 'wv'), fn.sum('score', 'z')])
def update_graph(self, g, eids, pre_pairs, post_pairs): def update_graph(self, g, eids, pre_pairs, post_pairs):
"Update the node states and edge states of the graph." "Update the node states and edge states of the graph."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment