Commit 40ca5de4 authored by Gan Quan's avatar Gan Quan Committed by Minjie Wang
Browse files

[API] Readout interfaces (#124)

* fixing builtin src*edge shape mismatch

* bundled function refactor (?)

* fixing names

* readout prototype

* oops

* more fixes

* removing readout prototype

* sum_on() with SPMV, fixing batching with 0 edges

* readouts with segmented sum

* typo (??????)

* fixes NLTK dependency (#125)

* misc fixes including #126 (pushing again)

* sanity check for mxnet

* fixes NLTK dependency (#125) and #126

* reverting to sum_nodes/edges
parent 8ea359d1
...@@ -103,8 +103,6 @@ venv.bak/ ...@@ -103,8 +103,6 @@ venv.bak/
# mypy # mypy
.mypy_cache/ .mypy_cache/
*.swp
*.swo
examples/pytorch/data/ind.pubmed.y examples/pytorch/data/ind.pubmed.y
examples/pytorch/data/ind.pubmed.x examples/pytorch/data/ind.pubmed.x
examples/pytorch/data/ind.pubmed.ty examples/pytorch/data/ind.pubmed.ty
...@@ -136,3 +134,13 @@ examples/pytorch/generative_graph/*.p ...@@ -136,3 +134,13 @@ examples/pytorch/generative_graph/*.p
# data directory # data directory
_download _download
# CTags & CScope
tags
cscope.*
# Vim
*.swp
*.swo
*.un~
*~
...@@ -162,6 +162,21 @@ def dtype(input): ...@@ -162,6 +162,21 @@ def dtype(input):
""" """
pass pass
def ndim(input):
"""Return the number of dimensions of the tensor.
Parameters
----------
input : Tensor
The input tensor.
Returns
-------
int
The number of dimensions
"""
pass
def context(input): def context(input):
"""Return the context/device of the input tensor. """Return the context/device of the input tensor.
...@@ -251,6 +266,23 @@ def sum(input, dim): ...@@ -251,6 +266,23 @@ def sum(input, dim):
""" """
pass pass
def mean(input, dim):
"""Reduce average the input tensor along the given dim.
Parameters
----------
input : Tensor
The input tensor.
dim : int
The reduce dim.
Returns
-------
Tensor
A framework-specific tensor.
"""
pass
def max(input, dim): def max(input, dim):
"""Reduce max the input tensor along the given dim. """Reduce max the input tensor along the given dim.
...@@ -285,6 +317,23 @@ def cat(seq, dim): ...@@ -285,6 +317,23 @@ def cat(seq, dim):
""" """
pass pass
def stack(seq, dim):
"""Stack the sequence of tensors along the given dimension.
Parameters
----------
seq : list of Tensor
The tensor sequence.
dim : int
The concat dim.
Returns
-------
Tensor
A framework-specific tensor.
"""
pass
def split(input, sizes_or_sections, dim): def split(input, sizes_or_sections, dim):
"""Split the input tensor into chunks. """Split the input tensor into chunks.
...@@ -484,6 +533,58 @@ def spmm(x, y): ...@@ -484,6 +533,58 @@ def spmm(x, y):
""" """
pass pass
def unsorted_1d_segment_sum(input, seg_id, n_segs, dim):
"""Computes the sum along segments of a tensor.
Equivalent to tf.unsorted_segment_sum, but seg_id is required to be a
1D tensor.
Parameters
----------
input : Tensor
The input tensor
seg_id : 1D Tensor
The segment IDs whose values are between 0 and n_segs - 1. Should
have the same length as input.
n_segs : int
Number of distinct segments
dim : int
Dimension to sum on
Returns
-------
Tensor
The result
"""
pass
def unsorted_1d_segment_mean(input, seg_id, n_segs, dim):
"""Computes the mean along segments of a tensor.
Equivalent to tf.unsorted_segment_mean, but seg_id is required to be a
1D tensor.
Note that segments never appeared in seg_id will have results of 0.
Parameters
----------
input : Tensor
The input tensor
seg_id : 1D Tensor
The segment IDs whose values are between 0 and n_segs - 1. Should
have the same length as input.
n_segs : int
Number of distinct segments
dim : int
Dimension to average on
Returns
-------
Tensor
The result
"""
pass
############################################################################### ###############################################################################
# Tensor functions used *only* on index tensor # Tensor functions used *only* on index tensor
# ---------------- # ----------------
......
...@@ -49,6 +49,9 @@ def dtype(input): ...@@ -49,6 +49,9 @@ def dtype(input):
# NOTE: the input cannot be a symbol # NOTE: the input cannot be a symbol
return input.dtype return input.dtype
def ndim(input):
return input.ndim
def context(input): def context(input):
return input.context return input.context
...@@ -64,12 +67,18 @@ def copy_to(input, ctx): ...@@ -64,12 +67,18 @@ def copy_to(input, ctx):
def sum(input, dim): def sum(input, dim):
return nd.sum(input, axis=dim) return nd.sum(input, axis=dim)
def mean(input, dim):
return nd.mean(input, axis=dim)
def max(input, dim): def max(input, dim):
return nd.max(input, axis=dim).asnumpy()[0] return nd.max(input, axis=dim).asnumpy()[0]
def cat(seq, dim): def cat(seq, dim):
return nd.concat(*seq, dim=dim) return nd.concat(*seq, dim=dim)
def stack(seq, dim):
return nd.stack(*seq, dim=dim)
def split(x, sizes_or_sections, dim): def split(x, sizes_or_sections, dim):
if isinstance(sizes_or_sections, list): if isinstance(sizes_or_sections, list):
# TODO: fallback to numpy is unfortunate # TODO: fallback to numpy is unfortunate
...@@ -114,6 +123,35 @@ def ones(shape, dtype): ...@@ -114,6 +123,35 @@ def ones(shape, dtype):
def spmm(x, y): def spmm(x, y):
return nd.dot(x, y) return nd.dot(x, y)
def unsorted_1d_segment_sum(input, seg_id, n_segs, dim):
# TODO: support other dimensions
assert dim == 0, 'MXNet only supports segment sum on first dimension'
# Use SPMV to simulate segment sum
ctx = input.context
n_inputs = input.shape[0]
input_shape_suffix = input.shape[1:]
input = input.reshape(n_inputs, -1)
n_range = nd.arange(n_inputs, dtype='int64').as_in_context(input.context)
w_nnz = nd.ones(n_inputs).as_in_context(input.context)
w_nid = nd.stack(seg_id, n_range, axis=0)
w = nd.sparse.csr_matrix((w_nnz, (seg_id, n_range)), (n_segs, n_inputs))
w = w.as_in_context(input.context)
y = nd.dot(w, input)
y = nd.reshape(y, (n_segs,) + input_shape_suffix)
return y
def unsorted_1d_segment_mean(input, seg_id, n_segs, dim):
# TODO: support other dimensions
assert dim == 0, 'MXNet only supports segment mean on first dimension'
n_ones = nd.ones_like(seg_id).astype(input.dtype)
w = unsorted_1d_segment_sum(n_ones, seg_id, n_segs, 0)
w = nd.clip(w, a_min=1, a_max=np.inf)
y = unsorted_1d_segment_sum(input, seg_id, n_segs, dim)
y /= w.reshape((-1,) + (1,) * (y.ndim - 1))
return y
def unique(input): def unique(input):
# TODO: fallback to numpy is unfortunate # TODO: fallback to numpy is unfortunate
tmp = input.asnumpy() tmp = input.asnumpy()
......
...@@ -37,6 +37,9 @@ def shape(input): ...@@ -37,6 +37,9 @@ def shape(input):
def dtype(input): def dtype(input):
return input.dtype return input.dtype
def ndim(input):
return input.dim()
def context(input): def context(input):
return input.device return input.device
...@@ -58,6 +61,9 @@ def copy_to(input, ctx): ...@@ -58,6 +61,9 @@ def copy_to(input, ctx):
def sum(input, dim): def sum(input, dim):
return th.sum(input, dim=dim) return th.sum(input, dim=dim)
def mean(input, dim):
return th.mean(input, dim=dim)
def max(input, dim): def max(input, dim):
# NOTE: the second argmax array is not returned # NOTE: the second argmax array is not returned
return th.max(input, dim=dim)[0] return th.max(input, dim=dim)[0]
...@@ -65,6 +71,9 @@ def max(input, dim): ...@@ -65,6 +71,9 @@ def max(input, dim):
def cat(seq, dim): def cat(seq, dim):
return th.cat(seq, dim=dim) return th.cat(seq, dim=dim)
def stack(seq, dim):
return th.stack(seq, dim=dim)
def split(input, sizes_or_sections, dim): def split(input, sizes_or_sections, dim):
return th.split(input, sizes_or_sections, dim) return th.split(input, sizes_or_sections, dim)
...@@ -98,6 +107,19 @@ def ones(shape, dtype): ...@@ -98,6 +107,19 @@ def ones(shape, dtype):
def spmm(x, y): def spmm(x, y):
return th.spmm(x, y) return th.spmm(x, y)
def unsorted_1d_segment_sum(input, seg_id, n_segs, dim):
y = th.zeros(n_segs, *input.shape[1:]).to(input)
seg_id = seg_id.view((-1,) + (1,) * (input.dim() - 1)).expand_as(input)
y = y.scatter_add_(dim, seg_id, input)
return y
def unsorted_1d_segment_mean(input, seg_id, n_segs, dim):
w = unsorted_1d_segment_sum(th.ones_like(seg_id), seg_id, n_segs, 0).to(input)
w = w.clamp(min=1) # remove 0 entries
y = unsorted_1d_segment_sum(input, seg_id, n_segs, dim)
y /= w.view((-1,) + (1,) * (y.dim() - 1))
return y
def unique(input): def unique(input):
return th.unique(input) return th.unique(input)
......
...@@ -10,7 +10,8 @@ from . import graph_index as gi ...@@ -10,7 +10,8 @@ from . import graph_index as gi
from . import backend as F from . import backend as F
from . import utils from . import utils
__all__ = ['BatchedDGLGraph', 'batch', 'unbatch', 'split'] __all__ = ['BatchedDGLGraph', 'batch', 'unbatch', 'split',
'sum_nodes', 'sum_edges', 'mean_nodes', 'mean_edges']
class BatchedDGLGraph(DGLGraph): class BatchedDGLGraph(DGLGraph):
"""Class for batched DGL graphs. """Class for batched DGL graphs.
...@@ -31,11 +32,13 @@ class BatchedDGLGraph(DGLGraph): ...@@ -31,11 +32,13 @@ class BatchedDGLGraph(DGLGraph):
batched_index = gi.disjoint_union([g._graph for g in graph_list]) batched_index = gi.disjoint_union([g._graph for g in graph_list])
# create batched node and edge frames # create batched node and edge frames
# NOTE: following code will materialize the columns of the input graphs. # NOTE: following code will materialize the columns of the input graphs.
cols = {key: F.cat([gr._node_frame[key] for gr in graph_list], dim=0) cols = {key: F.cat([gr._node_frame[key] for gr in graph_list
if gr.number_of_nodes() > 0], dim=0)
for key in node_attrs} for key in node_attrs}
batched_node_frame = FrameRef(Frame(cols)) batched_node_frame = FrameRef(Frame(cols))
cols = {key: F.cat([gr._edge_frame[key] for gr in graph_list], dim=0) cols = {key: F.cat([gr._edge_frame[key] for gr in graph_list
if gr.number_of_edges() > 0], dim=0)
for key in edge_attrs} for key in edge_attrs}
batched_edge_frame = FrameRef(Frame(cols)) batched_edge_frame = FrameRef(Frame(cols))
...@@ -98,25 +101,6 @@ class BatchedDGLGraph(DGLGraph): ...@@ -98,25 +101,6 @@ class BatchedDGLGraph(DGLGraph):
# TODO # TODO
pass pass
def readout(self, reduce_func):
"""Perform readout for each graph in the batch.
The readout value is a tensor of shape (B, D1, D2, ...) where B is the
batch size.
Parameters
----------
reduce_func : callable
The reduce function for readout.
Returns
-------
dict of tensors
The readout values.
"""
# TODO
pass
def split(graph_batch, num_or_size_splits): def split(graph_batch, num_or_size_splits):
"""Split the batch.""" """Split the batch."""
# TODO(minjie): could follow torch.split syntax # TODO(minjie): could follow torch.split syntax
...@@ -136,7 +120,7 @@ def unbatch(graph): ...@@ -136,7 +120,7 @@ def unbatch(graph):
smaller partitions. This is usually wasteful. smaller partitions. This is usually wasteful.
For simpler tasks such as node/edge state aggregation by example, For simpler tasks such as node/edge state aggregation by example,
try to use BatchedDGLGraph.readout(). try to use readout functions.
""" """
assert isinstance(graph, BatchedDGLGraph) assert isinstance(graph, BatchedDGLGraph)
bsize = graph.batch_size bsize = graph.batch_size
...@@ -191,3 +175,177 @@ def batch(graph_list, node_attrs=ALL, edge_attrs=ALL): ...@@ -191,3 +175,177 @@ def batch(graph_list, node_attrs=ALL, edge_attrs=ALL):
elif isinstance(edge_attrs, str): elif isinstance(edge_attrs, str):
edge_attrs = [edge_attrs] edge_attrs = [edge_attrs]
return BatchedDGLGraph(graph_list, node_attrs, edge_attrs) return BatchedDGLGraph(graph_list, node_attrs, edge_attrs)
_readout_on_attrs = {
'nodes': ('ndata', 'batch_num_nodes', 'number_of_nodes'),
'edges': ('edata', 'batch_num_edges', 'number_of_edges'),
}
def _sum_on(graph, on, input, weight):
data_attr, batch_num_objs_attr, num_objs_attr = _readout_on_attrs[on]
data = getattr(graph, data_attr)
input = data[input]
if weight is not None:
weight = data[weight]
weight = F.reshape(weight, (-1,) + (1,) * (F.ndim(input) - 1))
input = weight * input
if isinstance(graph, BatchedDGLGraph):
n_graphs = graph.batch_size
batch_num_objs = getattr(graph, batch_num_objs_attr)
n_objs = getattr(graph, num_objs_attr)()
seg_id = F.zerocopy_from_numpy(
np.arange(n_graphs, dtype='int64').repeat(batch_num_objs))
seg_id = F.copy_to(seg_id, F.context(input))
y = F.unsorted_1d_segment_sum(input, seg_id, n_graphs, 0)
return y
else:
return F.sum(input, 0)
def sum_nodes(graph, input, weight=None):
"""Sums all the values of node field `input` in `graph`, optionally
multiplies the field by a scalar node field `weight`.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph
input : str
The input field
weight : optional, str
The weight field. Default is all 1 (i.e. not weighting)
Returns
-------
tensor
The summed tensor.
Notes
-----
If graph is a BatchedDGLGraph, a stacked tensor is returned instead,
i.e. having an extra first dimension.
Each row of the stacked tensor contains the readout result of
corresponding example in the batch. If an example has no nodes,
a zero tensor with the same shape is returned at the corresponding row.
"""
return _sum_on(graph, 'nodes', input, weight)
def sum_edges(graph, input, weight=None):
"""Sums all the values of edge field `input` in `graph`, optionally
multiplies the field by a scalar edge field `weight`.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph
input : str
The input field
weight : optional, str
The weight field. Default is all 1 (i.e. not weighting)
Returns
-------
tensor
The summed tensor.
Notes
-----
If graph is a BatchedDGLGraph, a stacked tensor is returned instead,
i.e. having an extra first dimension.
Each row of the stacked tensor contains the readout result of
corresponding example in the batch. If an example has no edges,
a zero tensor with the same shape is returned at the corresponding row.
"""
return _sum_on(graph, 'edges', input, weight)
def _mean_on(graph, on, input, weight):
data_attr, batch_num_objs_attr, num_objs_attr = _readout_on_attrs[on]
data = getattr(graph, data_attr)
input = data[input]
if weight is not None:
weight = data[weight]
weight = F.reshape(weight, (-1,) + (1,) * (F.ndim(input) - 1))
input = weight * input
if isinstance(graph, BatchedDGLGraph):
n_graphs = graph.batch_size
batch_num_objs = getattr(graph, batch_num_objs_attr)
n_objs = getattr(graph, num_objs_attr)()
seg_id = F.zerocopy_from_numpy(
np.arange(n_graphs, dtype='int64').repeat(batch_num_objs))
seg_id = F.copy_to(seg_id, F.context(input))
if weight is not None:
w = F.unsorted_1d_segment_sum(weight, seg_id, n_graphs, 0)
y = F.unsorted_1d_segment_sum(input, seg_id, n_graphs, 0)
y = y / w
else:
y = F.unsorted_1d_segment_mean(input, seg_id, n_graphs, 0)
return y
else:
if weight is None:
return F.mean(input, 0)
else:
y = F.sum(input, 0) / F.sum(weight, 0)
return y
def mean_nodes(graph, input, weight=None):
"""Averages all the values of node field `input` in `graph`, optionally
multiplies the field by a scalar node field `weight`.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph
input : str
The input field
weight : optional, str
The weight field. Default is all 1 (i.e. not weighting)
Returns
-------
tensor
The averaged tensor.
Notes
-----
If graph is a BatchedDGLGraph, a stacked tensor is returned instead,
i.e. having an extra first dimension.
Each row of the stacked tensor contains the readout result of
corresponding example in the batch. If an example has no nodes,
a zero tensor with the same shape is returned at the corresponding row.
"""
return _mean_on(graph, 'nodes', input, weight)
def mean_edges(graph, input, weight=None):
"""Averages all the values of edge field `input` in `graph`, optionally
multiplies the field by a scalar edge field `weight`.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph
input : str
The input field
weight : optional, str
The weight field. Default is all 1 (i.e. not weighting)
Returns
-------
tensor
The averaged tensor.
Notes
-----
If graph is a BatchedDGLGraph, a stacked tensor is returned instead,
i.e. having an extra first dimension.
Each row of the stacked tensor contains the readout result of
corresponding example in the batch. If an example has no edges,
a zero tensor with the same shape is returned at the corresponding row.
"""
return _mean_on(graph, 'edges', input, weight)
"""Built-in functions."""
from functools import update_wrapper
__all__ = ['create_bundled_function_class']
def create_bundled_function_class(name, cls):
class Bundled(cls):
def __init__(self, fn_list):
if not isinstance(fn_list, (list, tuple)):
fn_list = [fn_list]
self.fn_list = fn_list
def is_spmv_supported(self, *args, **kwargs):
return all(isinstance(fn, cls) and
fn.is_spmv_supported(*args, **kwargs)
for fn in self.fn_list)
def __call__(self, *args, **kwargs):
ret = {}
for fn in self.fn_list:
result = fn(*args, **kwargs)
ret.update(result)
return ret
def name(self):
return "bundled"
# Fake the names for introspection
Bundled.__module__ = cls.__module__
Bundled.__name__ = name
Bundled.__qualname__ = name
for method_name in ('__init__', '__call__', 'is_spmv_supported', 'name'):
method = getattr(Bundled, method_name)
method.__qualname__ = f'{Bundled.__qualname__}.{method_name}'
for method_name in ('__call__', 'is_spmv_supported', 'name'):
method = getattr(Bundled, method_name)
method = update_wrapper(method,
cls.__dict__[method.__name__],
('__module__', '__doc__', '__annotations__'))
return Bundled
...@@ -3,6 +3,7 @@ from __future__ import absolute_import ...@@ -3,6 +3,7 @@ from __future__ import absolute_import
import operator import operator
import dgl.backend as F import dgl.backend as F
from .base import create_bundled_function_class
__all__ = ["src_mul_edge", "copy_src", "copy_edge"] __all__ = ["src_mul_edge", "copy_src", "copy_edge"]
...@@ -26,27 +27,8 @@ class MessageFunction(object): ...@@ -26,27 +27,8 @@ class MessageFunction(object):
raise NotImplementedError raise NotImplementedError
class BundledMessageFunction(MessageFunction): BundledMessageFunction = create_bundled_function_class(
def __init__(self, fn_list): 'BundledMessageFunction', MessageFunction)
if not isinstance(fn_list, (list, tuple)):
fn_list = [fn_list]
self.fn_list = fn_list
def is_spmv_supported(self, g):
for fn in self.fn_list:
if not isinstance(fn, MessageFunction) or not fn.is_spmv_supported(g):
return False
return True
def __call__(self, edges):
ret = dict()
for fn in self.fn_list:
msg = fn(edges)
ret.update(msg)
return ret
def name(self):
return "bundled"
def _is_spmv_supported_node_feat(g, field): def _is_spmv_supported_node_feat(g, field):
...@@ -80,8 +62,12 @@ class SrcMulEdgeMessageFunction(MessageFunction): ...@@ -80,8 +62,12 @@ class SrcMulEdgeMessageFunction(MessageFunction):
and _is_spmv_supported_edge_feat(g, self.edge_field) and _is_spmv_supported_edge_feat(g, self.edge_field)
def __call__(self, edges): def __call__(self, edges):
src_data = edges.src[self.src_field]
edata = edges.data[self.edge_field]
src_dim = F.ndim(src_data)
eshape = F.shape(edata)[0]
ret = self.mul_op(edges.src[self.src_field], ret = self.mul_op(edges.src[self.src_field],
edges.data[self.edge_field]) F.reshape(edges.data[self.edge_field], (eshape,) + (1,) * (src_dim - 1)))
return {self.out_field : ret} return {self.out_field : ret}
def name(self): def name(self):
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from __future__ import absolute_import from __future__ import absolute_import
from .. import backend as F from .. import backend as F
from .base import create_bundled_function_class
__all__ = ["sum", "max"] __all__ = ["sum", "max"]
...@@ -23,44 +24,29 @@ class ReduceFunction(object): ...@@ -23,44 +24,29 @@ class ReduceFunction(object):
"""Return whether the SPMV optimization is supported.""" """Return whether the SPMV optimization is supported."""
raise NotImplementedError raise NotImplementedError
class BundledReduceFunction(ReduceFunction):
def __init__(self, fn_list):
if not isinstance(fn_list, (list, tuple)):
fn_list = [fn_list]
self.fn_list = fn_list
def is_spmv_supported(self): BundledReduceFunction = create_bundled_function_class(
for fn in self.fn_list: 'BundledReduceFunction', ReduceFunction)
if not isinstance(fn, ReduceFunction) or not fn.is_spmv_supported():
return False
return True
def __call__(self, nodes):
ret = dict()
for fn in self.fn_list:
rpr = fn(nodes)
ret.update(rpr)
return ret
def name(self):
return "bundled"
class ReducerFunctionTemplate(ReduceFunction): class SimpleReduceFunction(ReduceFunction):
"""Builtin reduce function that aggregates a single field into another
single field."""
def __init__(self, name, op, msg_field, out_field): def __init__(self, name, op, msg_field, out_field):
self.name = name self._name = name
self.op = op self.op = op
self.msg_field = msg_field self.msg_field = msg_field
self.out_field = out_field self.out_field = out_field
def is_spmv_supported(self): def is_spmv_supported(self):
# NOTE: only sum is supported right now. # NOTE: only sum is supported right now.
return self.name == "sum" return self._name == "sum"
def __call__(self, nodes): def __call__(self, nodes):
return {self.out_field : self.op(nodes.mailbox[self.msg_field], 1)} return {self.out_field : self.op(nodes.mailbox[self.msg_field], 1)}
def name(self): def name(self):
return self.name return self._name
def sum(msg, out): def sum(msg, out):
"""Builtin reduce function that aggregates messages by sum. """Builtin reduce function that aggregates messages by sum.
...@@ -72,7 +58,7 @@ def sum(msg, out): ...@@ -72,7 +58,7 @@ def sum(msg, out):
out : str out : str
The output node feature name. The output node feature name.
""" """
return ReducerFunctionTemplate("sum", F.sum, msg, out) return SimpleReduceFunction("sum", F.sum, msg, out)
def max(msg, out): def max(msg, out):
"""Builtin reduce function that aggregates messages by max. """Builtin reduce function that aggregates messages by max.
...@@ -84,4 +70,4 @@ def max(msg, out): ...@@ -84,4 +70,4 @@ def max(msg, out):
out : str out : str
The output node feature name. The output node feature name.
""" """
return ReducerFunctionTemplate("max", F.max, msg, out) return SimpleReduceFunction("max", F.max, msg, out)
import torch as th
import dgl
def test_simple_readout():
g1 = dgl.DGLGraph()
g1.add_nodes(3)
g2 = dgl.DGLGraph()
g2.add_nodes(4) # no edges
g1.add_edges([0, 1, 2], [2, 0, 1])
n1 = th.randn(3, 5)
n2 = th.randn(4, 5)
e1 = th.randn(3, 5)
s1 = n1.sum(0) # node sums
s2 = n2.sum(0)
se1 = e1.sum(0) # edge sums
m1 = n1.mean(0) # node means
m2 = n2.mean(0)
me1 = e1.mean(0) # edge means
w1 = th.randn(3)
w2 = th.randn(4)
ws1 = (n1 * w1[:, None]).sum(0) # weighted node sums
ws2 = (n2 * w2[:, None]).sum(0)
wm1 = (n1 * w1[:, None]).sum(0) / w1[:, None].sum(0) # weighted node means
wm2 = (n2 * w2[:, None]).sum(0) / w2[:, None].sum(0)
g1.ndata['x'] = n1
g2.ndata['x'] = n2
g1.ndata['w'] = w1
g2.ndata['w'] = w2
g1.edata['x'] = e1
assert th.allclose(dgl.sum_nodes(g1, 'x'), s1)
assert th.allclose(dgl.sum_nodes(g1, 'x', 'w'), ws1)
assert th.allclose(dgl.sum_edges(g1, 'x'), se1)
assert th.allclose(dgl.mean_nodes(g1, 'x'), m1)
assert th.allclose(dgl.mean_nodes(g1, 'x', 'w'), wm1)
assert th.allclose(dgl.mean_edges(g1, 'x'), me1)
g = dgl.batch([g1, g2])
s = dgl.sum_nodes(g, 'x')
m = dgl.mean_nodes(g, 'x')
assert th.allclose(s, th.stack([s1, s2], 0))
assert th.allclose(m, th.stack([m1, m2], 0))
ws = dgl.sum_nodes(g, 'x', 'w')
wm = dgl.mean_nodes(g, 'x', 'w')
assert th.allclose(ws, th.stack([ws1, ws2], 0))
assert th.allclose(wm, th.stack([wm1, wm2], 0))
s = dgl.sum_edges(g, 'x')
m = dgl.mean_edges(g, 'x')
assert th.allclose(s, th.stack([se1, th.zeros(5)], 0))
assert th.allclose(m, th.stack([me1, th.zeros(5)], 0))
if __name__ == '__main__':
test_simple_readout()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment