"vscode:/vscode.git/clone" did not exist on "786ec3254d54b00fd40c6a517c8f22b1642bbfc7"
Unverified Commit bc5cac44 authored by jwyyy's avatar jwyyy Committed by GitHub
Browse files

fix broadcast tensor dim (#3351)


Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent bacc9047
......@@ -413,6 +413,8 @@ def broadcast_nodes(graph, graph_feat, *, ntype=None):
--------
broadcast_edges
"""
if len(F.shape(graph_feat)) == 1:
graph_feat = F.unsqueeze(graph_feat, dim=0)
return F.repeat(graph_feat, graph.batch_num_nodes(ntype), dim=0)
def broadcast_edges(graph, graph_feat, *, etype=None):
......@@ -478,6 +480,8 @@ def broadcast_edges(graph, graph_feat, *, etype=None):
--------
broadcast_nodes
"""
if len(F.shape(graph_feat)) == 1:
graph_feat = F.unsqueeze(graph_feat, dim=0)
return F.repeat(graph_feat, graph.batch_num_edges(etype), dim=0)
READOUT_ON_ATTRS = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment