Unverified Commit bc5cac44 authored by jwyyy's avatar jwyyy Committed by GitHub
Browse files

fix broadcast tensor dim (#3351)


Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent bacc9047
...@@ -413,6 +413,8 @@ def broadcast_nodes(graph, graph_feat, *, ntype=None): ...@@ -413,6 +413,8 @@ def broadcast_nodes(graph, graph_feat, *, ntype=None):
-------- --------
broadcast_edges broadcast_edges
""" """
if len(F.shape(graph_feat)) == 1:
graph_feat = F.unsqueeze(graph_feat, dim=0)
return F.repeat(graph_feat, graph.batch_num_nodes(ntype), dim=0) return F.repeat(graph_feat, graph.batch_num_nodes(ntype), dim=0)
def broadcast_edges(graph, graph_feat, *, etype=None): def broadcast_edges(graph, graph_feat, *, etype=None):
...@@ -478,6 +480,8 @@ def broadcast_edges(graph, graph_feat, *, etype=None): ...@@ -478,6 +480,8 @@ def broadcast_edges(graph, graph_feat, *, etype=None):
-------- --------
broadcast_nodes broadcast_nodes
""" """
if len(F.shape(graph_feat)) == 1:
graph_feat = F.unsqueeze(graph_feat, dim=0)
return F.repeat(graph_feat, graph.batch_num_edges(etype), dim=0) return F.repeat(graph_feat, graph.batch_num_edges(etype), dim=0)
READOUT_ON_ATTRS = { READOUT_ON_ATTRS = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment