"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "39e1f7eaa4bf414f822ed13922a89df89aabe0bf"
Unverified Commit 8dbf3a47 authored by Krzysztof Sadowski's avatar Krzysztof Sadowski Committed by GitHub
Browse files

[NN] Fix GATv2Conv residual for mini-batch (#3535)

parent ea8b5d79
...@@ -287,6 +287,7 @@ class GATv2Conv(nn.Module): ...@@ -287,6 +287,7 @@ class GATv2Conv(nn.Module):
-1, self._num_heads, self._out_feats) -1, self._num_heads, self._out_feats)
if graph.is_block: if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()] feat_dst = feat_src[:graph.number_of_dst_nodes()]
h_dst = h_dst[:graph.number_of_dst_nodes()]
graph.srcdata.update({'el': feat_src})# (num_src_edge, num_heads, out_dim) graph.srcdata.update({'el': feat_src})# (num_src_edge, num_heads, out_dim)
graph.dstdata.update({'er': feat_dst}) graph.dstdata.update({'er': feat_dst})
graph.apply_edges(fn.u_add_v('el', 'er', 'e')) graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment