Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
8dc6784b
Commit
8dc6784b
authored
Nov 29, 2018
by
Mufei Li
Committed by
Minjie Wang
Nov 29, 2018
Browse files
[Doc] docstring for BatchedDGLGraph + some fix (#184)
* Update doc * Update batched_graph.py * Fix * Fix * Fix * Fix
parent
e17c41c0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
416 additions
and
66 deletions
+416
-66
docs/source/api/python/batch.rst
docs/source/api/python/batch.rst
+25
-3
python/dgl/batched_graph.py
python/dgl/batched_graph.py
+391
-63
No files found.
docs/source/api/python/batch.rst
View file @
8dc6784b
BatchedDGLGraph
.. _apibatch:
===============
.. automodule:: dgl.batched_graph
BatchedDGLGraph -- Enable batched graph operations
==================================================
.. currentmodule:: dgl
.. autoclass:: BatchedDGLGraph
.. autoclass:: BatchedDGLGraph
Merge and decompose
-------------------
.. autosummary::
.. autosummary::
:toctree: ../../generated/
:toctree: ../../generated/
batch
batch
unbatch
unbatch
Query batch summary
----------------------
.. autosummary::
:toctree: ../../generated/
BatchedDGLGraph.batch_size
BatchedDGLGraph.batch_num_nodes
BatchedDGLGraph.batch_num_edges
Graph Readout
-------------
.. autosummary::
:toctree: ../../generated/
sum_nodes
sum_nodes
sum_edges
sum_edges
mean_nodes
mean_nodes
...
...
python/dgl/batched_graph.py
View file @
8dc6784b
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
from
__future__
import
absolute_import
from
__future__
import
absolute_import
import
numpy
as
np
import
numpy
as
np
from
collections
import
Iterable
from
.base
import
ALL
,
is_all
from
.base
import
ALL
,
is_all
from
.frame
import
FrameRef
,
Frame
from
.frame
import
FrameRef
,
Frame
...
@@ -16,18 +17,165 @@ __all__ = ['BatchedDGLGraph', 'batch', 'unbatch', 'split',
...
@@ -16,18 +17,165 @@ __all__ = ['BatchedDGLGraph', 'batch', 'unbatch', 'split',
class
BatchedDGLGraph
(
DGLGraph
):
class
BatchedDGLGraph
(
DGLGraph
):
"""Class for batched DGL graphs.
"""Class for batched DGL graphs.
The batched graph is read-only.
A :class:`BatchedDGLGraph` basically merges a list of small graphs into a giant
graph so that one can perform message passing and readout over a batch of graphs
simultaneously.
The nodes and edges are re-indexed with a new id in the batched graph with the
rule below:
| Graph 1 | Graph 2 |...| Graph k
--------------------------------------------------------------------------------
raw id | 0, ..., N1 | 0 , ..., N2 |...| ..., Nk
new id | 0, ..., N1 | N1 + 1, ..., N1 + N2 + 1 |...| ..., N1 + ... + Nk + k - 1
The batched graph is read-only, i.e. one cannot further add nodes and edges.
A RuntimeError will be raised if one attempts.
To modify the features in :class:`BatchedDGLGraph` has no effect on the original
graphs. See the examples below about how to work around.
Parameters
Parameters
----------
----------
graph_list : iterable
graph_list : iterable
A list of DGLGraphs to be batched.
A collection of :class:`~dgl.DGLGraphs` to be batched.
node_attrs : str or iterable
node_attrs : None, str or iterable, optional
The node attributes to also be batched.
The node attributes to be batched. If ``None``, the :class:`BatchedDGLGraph` object
edge_attrs : str or iterable, optional
will not have any node attributes. By default, all node attributes will be batched.
The edge attributes to also be batched.
An error will be raised if graphs having nodes have different attributes. If ``str``
or ``iterable``, this should specify exactly what node attributes to be batched.
edge_attrs : None, str or iterable, optional
Same as for the case of :attr:`node_attrs`
Examples
--------
Create two :class:`~dgl.DGLGraphs` objects.
**Instantiation:**
>>> import dgl
>>> import torch as th
>>> g1 = dgl.DGLGraph()
>>> g1.add_nodes(2) # Add 2 nodes
>>> g1.add_edge(0, 1) # Add edge 0 -> 1
>>> g1.ndata['hv'] = th.tensor([[0.], [1.]]) # Initialize node features
>>> g1.edata['he'] = th.tensor([[0.]]) # Initialize edge features
>>> g2 = dgl.DGLGraph()
>>> g2.add_nodes(3) # Add 3 nodes
>>> g2.add_edges([0, 2], [1, 1]) # Add edges 0 -> 1, 2 -> 1
>>> g2.ndata['hv'] = th.tensor([[2.], [3.], [4.]]) # Initialize node features
>>> g2.edata['he'] = th.tensor([[1.], [2.]]) # Initialize edge features
Merge two :class:`~dgl.DGLGraphs` objects into one :class:`BatchedDGLGraph` object.
When merging a list of graphs, we can choose to include only a subset of the attributes.
>>> bg = dgl.batch([g1, g2], edge_attrs=None)
>>> bg.edata
{}
Below one can see that the nodes are re-indexed. The edges are re-indexed in
the same way.
>>> bg.nodes()
tensor([0, 1, 2, 3, 4])
>>> bg.ndata['hv']
tensor([[0.],
[1.],
[2.],
[3.],
[4.]])
**Property:**
We can still get a brief summary of the graphs that constitute the batched graph.
>>> bg.batch_size
2
>>> bg.batch_num_nodes
[2, 3]
>>> bg.batch_num_edges
[1, 2]
**Readout:**
Another common demand for graph neural networks is graph readout, which is a
function that takes in the node attributes and/or edge attributes for a graph
and outputs a vector summarizing the information in the graph. `BatchedDGLGraph`
also supports performing readout for a batch of graphs at once.
Below we take the built-in readout function :func:`sum_nodes` as an example, which
sums a particular node attribute for each graph.
>>> dgl.sum_nodes(bg, 'hv') # Sum the node attribute 'hv' for each graph.
tensor([[1.], # 0 + 1
[9.]]) # 2 + 3 + 4
**Message passing:**
For message passing and related operations, :class:`BatchedDGLGraph` acts exactly
the same as :class:`~dgl.DGLGraphs`.
**Update Attributes:**
Updating the attributes of the batched graph has no effect on the original graphs.
>>> bg.edata['he'] = th.zeros(3, 2)
>>> g2.edata['he']
tensor([[1.],
[2.]])}
Instead, we can decompose the batched graph back into a list of graphs and use them
to replace the original graphs.
>>> g1, g2 = dgl.unbatch(bg) # returns a list of DGLGraphs
>>> g2.edata['he']
tensor([[0., 0.],
[0., 0.]])}
"""
"""
def
__init__
(
self
,
graph_list
,
node_attrs
,
edge_attrs
):
def
__init__
(
self
,
graph_list
,
node_attrs
,
edge_attrs
):
def
_get_num_item_and_attr_types
(
g
,
mode
):
if
mode
==
'node'
:
num_items
=
g
.
number_of_nodes
()
attr_types
=
set
(
g
.
node_attr_schemes
().
keys
())
elif
mode
==
'edge'
:
num_items
=
g
.
number_of_edges
()
attr_types
=
set
(
g
.
edge_attr_schemes
().
keys
())
return
num_items
,
attr_types
def
_init_attrs
(
attrs
,
mode
):
if
attrs
is
None
:
return
[]
elif
is_all
(
attrs
):
attrs
=
set
()
# Check if at least a graph has mode items and associated features.
for
i
in
range
(
len
(
graph_list
)):
g
=
graph_list
[
i
]
g_num_items
,
g_attrs
=
_get_num_item_and_attr_types
(
g
,
mode
)
if
g_num_items
>
0
and
len
(
g_attrs
)
>
0
:
attrs
=
g_attrs
ref_g_index
=
i
break
# Check if all the graphs with mode items have the same associated features.
if
len
(
attrs
)
>
0
:
for
i
in
range
(
len
(
graph_list
)):
g
=
graph_list
[
i
]
g_num_items
,
g_attrs
=
_get_num_item_and_attr_types
(
g
,
mode
)
if
g_attrs
!=
attrs
and
g_num_items
>
0
:
raise
ValueError
(
'Expect graph {} and {} to have the same {} '
'attributes when {}_attrs=ALL, got {} and '
'{}'
.
format
(
ref_g_index
,
i
,
mode
,
mode
,
attrs
,
g_attrs
))
return
attrs
elif
isinstance
(
attrs
,
str
):
return
[
attrs
]
elif
isinstance
(
attrs
,
Iterable
):
return
attrs
else
:
raise
ValueError
(
'Expected {} attrs to be of type None str or Iterable, '
'got type {}'
.
format
(
mode
,
type
(
attrs
)))
node_attrs
=
_init_attrs
(
node_attrs
,
'node'
)
edge_attrs
=
_init_attrs
(
edge_attrs
,
'edge'
)
# create batched graph index
# create batched graph index
batched_index
=
gi
.
disjoint_union
([
g
.
_graph
for
g
in
graph_list
])
batched_index
=
gi
.
disjoint_union
([
g
.
_graph
for
g
in
graph_list
])
# create batched node and edge frames
# create batched node and edge frames
...
@@ -70,17 +218,32 @@ class BatchedDGLGraph(DGLGraph):
...
@@ -70,17 +218,32 @@ class BatchedDGLGraph(DGLGraph):
@
property
@
property
def
batch_size
(
self
):
def
batch_size
(
self
):
"""Number of graphs in this batch."""
"""Number of graphs in this batch.
Returns
-------
int
Number of graphs in this batch."""
return
self
.
_batch_size
return
self
.
_batch_size
@
property
@
property
def
batch_num_nodes
(
self
):
def
batch_num_nodes
(
self
):
"""Number of nodes of each graph in this batch."""
"""Number of nodes of each graph in this batch.
Returns
-------
list
Number of nodes of each graph in this batch."""
return
self
.
_batch_num_nodes
return
self
.
_batch_num_nodes
@
property
@
property
def
batch_num_edges
(
self
):
def
batch_num_edges
(
self
):
"""Number of edges of each graph in this batch."""
"""Number of edges of each graph in this batch.
Returns
-------
list
Number of edges of each graph in this batch."""
return
self
.
_batch_num_edges
return
self
.
_batch_num_edges
# override APIs
# override APIs
...
@@ -113,20 +276,31 @@ def split(graph_batch, num_or_size_splits):
...
@@ -113,20 +276,31 @@ def split(graph_batch, num_or_size_splits):
pass
pass
def
unbatch
(
graph
):
def
unbatch
(
graph
):
"""
Unbatch and r
eturn the list of graphs in this batch.
"""
R
eturn the list of graphs in this batch.
Parameters
Parameters
----------
----------
graph : BatchedDGLGraph
graph : BatchedDGLGraph
The batched graph.
The batched graph.
Returns
-------
list
A list of :class:`~dgl.DGLGraphs` objects whose attributes are obtained
by partitioning the attributes of the :attr:`graph`. The length of the
list is the same as the batch size of :attr:`graph`.
Notes
Notes
-----
-----
Unbatching will partition each field tensor of the batched graph into
Unbatching will break each field tensor of the batched graph into smaller
smaller partitions. This is usually wasteful.
partitions.
For simpler tasks such as node/edge state aggregation, try to use
readout functions.
For simpler tasks such as node/edge state aggregation by example,
See Also
try to use readout functions.
--------
batch
"""
"""
assert
isinstance
(
graph
,
BatchedDGLGraph
)
assert
isinstance
(
graph
,
BatchedDGLGraph
)
bsize
=
graph
.
batch_size
bsize
=
graph
.
batch_size
...
@@ -149,37 +323,31 @@ def unbatch(graph):
...
@@ -149,37 +323,31 @@ def unbatch(graph):
edge_frame
=
edge_frames
[
i
])
for
i
in
range
(
bsize
)]
edge_frame
=
edge_frames
[
i
])
for
i
in
range
(
bsize
)]
def
batch
(
graph_list
,
node_attrs
=
ALL
,
edge_attrs
=
ALL
):
def
batch
(
graph_list
,
node_attrs
=
ALL
,
edge_attrs
=
ALL
):
"""Batch a list of DGLGraphs into one single graph.
"""Batch a collection of :class:`~dgl.DGLGraphs` and return a
:class:`BatchedDGLGraph` object that is independent of the :attr:`graph_list`.
Once batch is called, the structure of both merged graph and graphs in graph_list
must not be mutated, or unbatch's behavior will be undefined.
Parameters
Parameters
----------
----------
graph_list : iterable
graph_list : iterable
A list of DGLGraphs to be batched.
A collection of :class:`~dgl.DGLGraphs` to be batched.
node_attrs : str or iterable, optional
node_attrs : None, str or iterable
The node attributes to also be batched. Specify None to not batch any attributes.
The node attributes to be batched. If ``None``, the :class:`BatchedDGLGraph`
edge_attrs : str or iterable, optional
object will not have any node attributes. By default, all node attributes will
The edge attributes to also be batched. Specify None to not batch any attributes.
be batched. If ``str`` or iterable, this should specify exactly what node
attributes to be batched.
edge_attrs : None, str or iterable, optional
Same as for the case of :attr:`node_attrs`
Returns
Returns
-------
-------
newgrh:
BatchedDGLGraph
BatchedDGLGraph
one single batched graph
one single batched graph
See Also
--------
BatchedDGLGraph
unbatch
"""
"""
if
node_attrs
is
None
:
node_attrs
=
[]
elif
is_all
(
node_attrs
):
node_attrs
=
graph_list
[
0
].
node_attr_schemes
()
elif
isinstance
(
node_attrs
,
str
):
node_attrs
=
[
node_attrs
]
if
edge_attrs
is
None
:
edge_attrs
=
[]
elif
is_all
(
edge_attrs
):
edge_attrs
=
graph_list
[
0
].
edge_attr_schemes
()
elif
isinstance
(
edge_attrs
,
str
):
edge_attrs
=
[
edge_attrs
]
return
BatchedDGLGraph
(
graph_list
,
node_attrs
,
edge_attrs
)
return
BatchedDGLGraph
(
graph_list
,
node_attrs
,
edge_attrs
)
...
@@ -212,8 +380,8 @@ def _sum_on(graph, on, input, weight):
...
@@ -212,8 +380,8 @@ def _sum_on(graph, on, input, weight):
return
F
.
sum
(
input
,
0
)
return
F
.
sum
(
input
,
0
)
def
sum_nodes
(
graph
,
input
,
weight
=
None
):
def
sum_nodes
(
graph
,
input
,
weight
=
None
):
"""Sums all the values of node field `input` in `graph`, optionally
"""Sums all the values of node field
:attr:
`input` in
:attr:
`graph`, optionally
multiplies the field by a scalar node field `weight`.
multiplies the field by a scalar node field
:attr
`weight`.
Parameters
Parameters
----------
----------
...
@@ -221,8 +389,11 @@ def sum_nodes(graph, input, weight=None):
...
@@ -221,8 +389,11 @@ def sum_nodes(graph, input, weight=None):
The graph
The graph
input : str
input : str
The input field
The input field
weight : optional, str
weight : str, optional
The weight field. Default is all 1 (i.e. not weighting)
The weight field. If None, no weighting will be performed,
otherwise, weight each node feature with field :attr:`input`.
for summation. The weight feature associated in the :attr:`graph`
should be a tensor of shape [graph.number_of_nodes(), 1].
Returns
Returns
-------
-------
...
@@ -231,17 +402,53 @@ def sum_nodes(graph, input, weight=None):
...
@@ -231,17 +402,53 @@ def sum_nodes(graph, input, weight=None):
Notes
Notes
-----
-----
If graph is a BatchedDGLGraph, a stacked tensor is
returned instead,
If graph is a
:class:`
BatchedDGLGraph
` object
, a stacked tensor is
i.e. having an extra first dimension.
returned instead,
i.e. having an extra first dimension.
Each row of the stacked tensor contains the readout result of
Each row of the stacked tensor contains the readout result of
the
corresponding example in the batch.
If an example has no nodes,
corresponding example in the batch. If an example has no nodes,
a zero tensor with the same shape is returned at the corresponding row.
a zero tensor with the same shape is returned at the corresponding row.
Examples
--------
>>> import dgl
>>> import torch as th
Create two :class:`~dgl.DGLGraphs` objects and initialize their
node features.
>>> g1 = dgl.DGLGraph() # Graph 1
>>> g1.add_nodes(2)
>>> g1.ndata['h'] = th.tensor([[1.], [2.]])
>>> g1.ndata['w'] = th.tensor([[3.], [6.]])
>>> g2 = dgl.DGLGraph() # Graph 2
>>> g2.add_nodes(3)
>>> g2.ndata['h'] = th.tensor([[1.], [2.], [3.]])
Sum over node attribute 'h' without weighting for each graph in a
batched graph.
>>> bg = dgl.batch([g1, g2], node_attrs='h')
>>> dgl.sum_nodes(bg, 'h')
tensor([[3.], # 1 + 2
[6.]]) # 1 + 2 + 3
Sum node attribute 'h' with weight from node attribute 'w' for a single
graph.
>>> dgl.sum_nodes(g1, 'h', 'w')
tensor([15.]) # 1 * 3 + 2 * 6
See Also
--------
mean_nodes
sum_edges
mean_edges
"""
"""
return
_sum_on
(
graph
,
'nodes'
,
input
,
weight
)
return
_sum_on
(
graph
,
'nodes'
,
input
,
weight
)
def
sum_edges
(
graph
,
input
,
weight
=
None
):
def
sum_edges
(
graph
,
input
,
weight
=
None
):
"""Sums all the values of edge field `input` in `graph`,
optionally
"""Sums all the values of edge field
:attr:
`input` in
:attr:
`graph`,
multiplies the field by a scalar edge field `weight`.
optionally
multiplies the field by a scalar edge field
:attr:
`weight`.
Parameters
Parameters
----------
----------
...
@@ -249,8 +456,11 @@ def sum_edges(graph, input, weight=None):
...
@@ -249,8 +456,11 @@ def sum_edges(graph, input, weight=None):
The graph
The graph
input : str
input : str
The input field
The input field
weight : optional, str
weight : str, optional
The weight field. Default is all 1 (i.e. not weighting)
The weight field. If None, no weighting will be performed,
otherwise, weight each edge feature with field :attr:`input`.
for summation. The weight feature associated in the :attr:`graph`
should be a tensor of shape [graph.number_of_edges(), 1].
Returns
Returns
-------
-------
...
@@ -259,11 +469,49 @@ def sum_edges(graph, input, weight=None):
...
@@ -259,11 +469,49 @@ def sum_edges(graph, input, weight=None):
Notes
Notes
-----
-----
If graph is a BatchedDGLGraph, a stacked tensor is
returned instead,
If graph is a
:class:`
BatchedDGLGraph
` object
, a stacked tensor is
i.e. having an extra first dimension.
returned instead,
i.e. having an extra first dimension.
Each row of the stacked tensor contains the readout result of
Each row of the stacked tensor contains the readout result of
the
corresponding example in the batch. If an example has no edges,
corresponding example in the batch. If an example has no edges,
a zero tensor with the same shape is returned at the corresponding row.
a zero tensor with the same shape is returned at the corresponding row.
Examples
--------
>>> import dgl
>>> import torch as th
Create two :class:`~dgl.DGLGraphs` objects and initialize their
edge features.
>>> g1 = dgl.DGLGraph() # Graph 1
>>> g1.add_nodes(2)
>>> g1.add_edges([0, 1], [1, 0])
>>> g1.edata['h'] = th.tensor([[1.], [2.]])
>>> g1.edata['w'] = th.tensor([[3.], [6.]])
>>> g2 = dgl.DGLGraph() # Graph 2
>>> g2.add_nodes(3)
>>> g2.add_edges([0, 1, 2], [1, 2, 0])
>>> g2.edata['h'] = th.tensor([[1.], [2.], [3.]])
Sum over edge attribute 'h' without weighting for each graph in a
batched graph.
>>> bg = dgl.batch([g1, g2], edge_attrs='h')
>>> dgl.sum_edges(bg, 'h')
tensor([[3.], # 1 + 2
[6.]]) # 1 + 2 + 3
Sum edge attribute 'h' with weight from edge attribute 'w' for a single
graph.
>>> dgl.sum_edges(g1, 'h', 'w')
tensor([15.]) # 1 * 3 + 2 * 6
See Also
--------
sum_nodes
mean_nodes
mean_edges
"""
"""
return
_sum_on
(
graph
,
'edges'
,
input
,
weight
)
return
_sum_on
(
graph
,
'edges'
,
input
,
weight
)
...
@@ -301,8 +549,8 @@ def _mean_on(graph, on, input, weight):
...
@@ -301,8 +549,8 @@ def _mean_on(graph, on, input, weight):
return
y
return
y
def
mean_nodes
(
graph
,
input
,
weight
=
None
):
def
mean_nodes
(
graph
,
input
,
weight
=
None
):
"""Averages all the values of node field `input` in `graph`,
optionally
"""Averages all the values of node field
:attr:
`input` in
:attr:
`graph`,
multiplies the field by a scalar node field `weight`.
optionally
multiplies the field by a scalar node field
:attr:
`weight`.
Parameters
Parameters
----------
----------
...
@@ -310,8 +558,11 @@ def mean_nodes(graph, input, weight=None):
...
@@ -310,8 +558,11 @@ def mean_nodes(graph, input, weight=None):
The graph
The graph
input : str
input : str
The input field
The input field
weight : optional, str
weight : str, optional
The weight field. Default is all 1 (i.e. not weighting)
The weight field. If None, no weighting will be performed,
otherwise, weight each node feature with field :attr:`input`.
for calculating mean. The weight feature associated in the :attr:`graph`
should be a tensor of shape [graph.number_of_nodes(), 1].
Returns
Returns
-------
-------
...
@@ -320,17 +571,53 @@ def mean_nodes(graph, input, weight=None):
...
@@ -320,17 +571,53 @@ def mean_nodes(graph, input, weight=None):
Notes
Notes
-----
-----
If graph is a BatchedDGLGraph, a stacked tensor is
returned instead,
If graph is a
:class:`
BatchedDGLGraph
` object
, a stacked tensor is
i.e. having an extra first dimension.
returned instead,
i.e. having an extra first dimension.
Each row of the stacked tensor contains the readout result of
Each row of the stacked tensor contains the readout result of
corresponding example in the batch.
If an example has no nodes,
corresponding example in the batch. If an example has no nodes,
a zero tensor with the same shape is returned at the corresponding row.
a zero tensor with the same shape is returned at the corresponding row.
Examples
--------
>>> import dgl
>>> import torch as th
Create two :class:`~dgl.DGLGraphs` objects and initialize their
node features.
>>> g1 = dgl.DGLGraph() # Graph 1
>>> g1.add_nodes(2)
>>> g1.ndata['h'] = th.tensor([[1.], [2.]])
>>> g1.ndata['w'] = th.tensor([[3.], [6.]])
>>> g2 = dgl.DGLGraph() # Graph 2
>>> g2.add_nodes(3)
>>> g2.ndata['h'] = th.tensor([[1.], [2.], [3.]])
Average over node attribute 'h' without weighting for each graph in a
batched graph.
>>> bg = dgl.batch([g1, g2], node_attrs='h')
>>> dgl.mean_nodes(bg, 'h')
tensor([[1.5000], # (1 + 2) / 2
[2.0000]]) # (1 + 2 + 3) / 3
Sum node attribute 'h' with normalized weight from node attribute 'w'
for a single graph.
>>> dgl.mean_nodes(g1, 'h', 'w') # h1 * (w1 / (w1 + w2)) + h2 * (w2 / (w1 + w2))
tensor([1.6667]) # 1 * (3 / (3 + 6)) + 2 * (6 / (3 + 6))
See Also
--------
sum_nodes
sum_edges
mean_edges
"""
"""
return
_mean_on
(
graph
,
'nodes'
,
input
,
weight
)
return
_mean_on
(
graph
,
'nodes'
,
input
,
weight
)
def
mean_edges
(
graph
,
input
,
weight
=
None
):
def
mean_edges
(
graph
,
input
,
weight
=
None
):
"""Averages all the values of edge field `input` in `graph`,
optionally
"""Averages all the values of edge field
:attr:
`input` in
:attr:
`graph`,
multiplies the field by a scalar edge field `weight`.
optionally
multiplies the field by a scalar edge field
:attr:
`weight`.
Parameters
Parameters
----------
----------
...
@@ -339,7 +626,10 @@ def mean_edges(graph, input, weight=None):
...
@@ -339,7 +626,10 @@ def mean_edges(graph, input, weight=None):
input : str
input : str
The input field
The input field
weight : optional, str
weight : optional, str
The weight field. Default is all 1 (i.e. not weighting)
The weight field. If None, no weighting will be performed,
otherwise, weight each edge feature with field :attr:`input`.
for calculating mean. The weight feature associated in the :attr:`graph`
should be a tensor of shape [graph.number_of_edges(), 1].
Returns
Returns
-------
-------
...
@@ -348,10 +638,48 @@ def mean_edges(graph, input, weight=None):
...
@@ -348,10 +638,48 @@ def mean_edges(graph, input, weight=None):
Notes
Notes
-----
-----
If graph is a BatchedDGLGraph, a stacked tensor is
returned instead,
If graph is a
:class:`
BatchedDGLGraph
` object
, a stacked tensor is
i.e. having an extra first dimension.
returned instead,
i.e. having an extra first dimension.
Each row of the stacked tensor contains the readout result of
Each row of the stacked tensor contains the readout result of
corresponding example in the batch. If an example has no edges,
corresponding example in the batch. If an example has no edges,
a zero tensor with the same shape is returned at the corresponding row.
a zero tensor with the same shape is returned at the corresponding row.
Examples
--------
>>> import dgl
>>> import torch as th
Create two :class:`~dgl.DGLGraphs` objects and initialize their
edge features.
>>> g1 = dgl.DGLGraph() # Graph 1
>>> g1.add_nodes(2)
>>> g1.add_edges([0, 1], [1, 0])
>>> g1.edata['h'] = th.tensor([[1.], [2.]])
>>> g1.edata['w'] = th.tensor([[3.], [6.]])
>>> g2 = dgl.DGLGraph() # Graph 2
>>> g2.add_nodes(3)
>>> g2.add_edges([0, 1, 2], [1, 2, 0])
>>> g2.edata['h'] = th.tensor([[1.], [2.], [3.]])
Average over edge attribute 'h' without weighting for each graph in a
batched graph.
>>> bg = dgl.batch([g1, g2], edge_attrs='h')
>>> dgl.mean_edges(bg, 'h')
tensor([[1.5000], # (1 + 2) / 2
[2.0000]]) # (1 + 2 + 3) / 3
Sum edge attribute 'h' with normalized weight from edge attribute 'w'
for a single graph.
>>> dgl.mean_edges(g1, 'h', 'w') # h1 * (w1 / (w1 + w2)) + h2 * (w2 / (w1 + w2))
tensor([1.6667]) # 1 * (3 / (3 + 6)) + 2 * (6 / (3 + 6))
See Also
--------
sum_nodes
mean_nodes
sum_edges
"""
"""
return
_mean_on
(
graph
,
'edges'
,
input
,
weight
)
return
_mean_on
(
graph
,
'edges'
,
input
,
weight
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment