Unverified Commit 919b7838 authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

[Example][Bugfix] Fix infograph example (#4298)



* Fix infograph example

* Update

* Revert the changes and update Doc

* Update

* Split lines to pass CI-lint

* Update

* Update
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent 14a77c86
"""Classes and functions for batching multiple graphs together."""
from __future__ import absolute_import
from .base import DGLError
from .base import DGLError, dgl_warning
from . import backend as F
from .ops import segment
......@@ -365,8 +365,9 @@ def broadcast_nodes(graph, graph_feat, *, ntype=None):
graph : DGLGraph
The graph.
graph_feat : tensor
The feature to broadcast. Tensor shape is :math:`(*)` for single graph, and
:math:`(B, *)` for batched graph.
The feature to broadcast. Tensor shape is :math:`(B, *)` for batched graph,
where :math:`B` is the batch size.
ntype : str, optional
Node type. Can be omitted if there is only one node type.
......@@ -403,9 +404,11 @@ def broadcast_nodes(graph, graph_feat, *, ntype=None):
[0.2721, 0.4629, 0.7269, 0.0724, 0.1014],
[0.2721, 0.4629, 0.7269, 0.0724, 0.1014]])
Broadcast feature to all nodes in the single graph.
Broadcast feature to all nodes in the single graph (the feature tensor shape
to broadcast should be :math:`(1, *)`).
>>> dgl.broadcast_nodes(g1, feat[0])
>>> feat0 = th.unsqueeze(feat[0], 0)
>>> dgl.broadcast_nodes(g1, feat0)
tensor([[0.4325, 0.7710, 0.5541, 0.0544, 0.9368],
[0.4325, 0.7710, 0.5541, 0.0544, 0.9368]])
......@@ -413,7 +416,9 @@ def broadcast_nodes(graph, graph_feat, *, ntype=None):
--------
broadcast_edges
"""
if len(F.shape(graph_feat)) == 1:
if (F.shape(graph_feat)[0] != graph.batch_size and graph.batch_size == 1):
dgl_warning('For a single graph, use a tensor of shape (1, *) for graph_feat.'
' The support of shape (*) will be deprecated.')
graph_feat = F.unsqueeze(graph_feat, dim=0)
return F.repeat(graph_feat, graph.batch_num_nodes(ntype), dim=0)
......@@ -434,8 +439,8 @@ def broadcast_edges(graph, graph_feat, *, etype=None):
graph : DGLGraph
The graph.
graph_feat : tensor
The feature to broadcast. Tensor shape is :math:`(*)` for single graph, and
:math:`(B, *)` for batched graph.
The feature to broadcast. Tensor shape is :math:`(B, *)` for batched graph,
where :math:`B` is the batch size.
etype : str, typle of str, optional
Edge type. Can be omitted if there is only one edge type in the graph.
......@@ -470,9 +475,11 @@ def broadcast_edges(graph, graph_feat, *, etype=None):
[0.2721, 0.4629, 0.7269, 0.0724, 0.1014],
[0.2721, 0.4629, 0.7269, 0.0724, 0.1014]])
Broadcast feature to all edges in the single graph.
Broadcast feature to all edges in the single graph (the feature tensor shape
to broadcast should be :math:`(1, *)`).
>>> dgl.broadcast_edges(g2, feat[1])
>>> feat1 = th.unsqueeze(feat[1], 0)
>>> dgl.broadcast_edges(g2, feat1)
tensor([[0.2721, 0.4629, 0.7269, 0.0724, 0.1014],
[0.2721, 0.4629, 0.7269, 0.0724, 0.1014]])
......@@ -480,7 +487,9 @@ def broadcast_edges(graph, graph_feat, *, etype=None):
--------
broadcast_nodes
"""
if len(F.shape(graph_feat)) == 1:
if (F.shape(graph_feat)[0] != graph.batch_size and graph.batch_size == 1):
dgl_warning('For a single graph, use a tensor of shape (1, *) for graph_feat.'
' The support of shape (*) will be deprecated.')
graph_feat = F.unsqueeze(graph_feat, dim=0)
return F.repeat(graph_feat, graph.batch_num_edges(etype), dim=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment