Commit 40ca5de4 authored by Gan Quan's avatar Gan Quan Committed by Minjie Wang
Browse files

[API] Readout interfaces (#124)

* fixing builtin src*edge shape mismatch

* bundled function refactor (?)

* fixing names

* readout prototype

* oops

* more fixes

* removing readout prototype

* sum_on() with SPMV, fixing batching with 0 edges

* readouts with segmented sum

* typo (??????)

* fixes NLTK dependency (#125)

* misc fixes including #126 (pushing again)

* sanity check for mxnet

* fixes NLTK dependency (#125) and #126

* reverting to sum_nodes/edges
parent 8ea359d1
......@@ -103,8 +103,6 @@ venv.bak/
# mypy
.mypy_cache/
*.swp
*.swo
examples/pytorch/data/ind.pubmed.y
examples/pytorch/data/ind.pubmed.x
examples/pytorch/data/ind.pubmed.ty
......@@ -136,3 +134,13 @@ examples/pytorch/generative_graph/*.p
# data directory
_download
# CTags & CScope
tags
cscope.*
# Vim
*.swp
*.swo
*.un~
*~
......@@ -162,6 +162,21 @@ def dtype(input):
"""
pass
def ndim(input):
"""Return the number of dimensions of the tensor.
Parameters
----------
input : Tensor
The input tensor.
Returns
-------
int
The number of dimensions
"""
pass
def context(input):
"""Return the context/device of the input tensor.
......@@ -251,6 +266,23 @@ def sum(input, dim):
"""
pass
def mean(input, dim):
"""Reduce average the input tensor along the given dim.
Parameters
----------
input : Tensor
The input tensor.
dim : int
The reduce dim.
Returns
-------
Tensor
A framework-specific tensor.
"""
pass
def max(input, dim):
"""Reduce max the input tensor along the given dim.
......@@ -285,6 +317,23 @@ def cat(seq, dim):
"""
pass
def stack(seq, dim):
"""Stack the sequence of tensors along the given dimension.
Parameters
----------
seq : list of Tensor
The tensor sequence.
dim : int
The concat dim.
Returns
-------
Tensor
A framework-specific tensor.
"""
pass
def split(input, sizes_or_sections, dim):
"""Split the input tensor into chunks.
......@@ -484,6 +533,58 @@ def spmm(x, y):
"""
pass
def unsorted_1d_segment_sum(input, seg_id, n_segs, dim):
"""Computes the sum along segments of a tensor.
Equivalent to tf.unsorted_segment_sum, but seg_id is required to be a
1D tensor.
Parameters
----------
input : Tensor
The input tensor
seg_id : 1D Tensor
The segment IDs whose values are between 0 and n_segs - 1. Should
have the same length as input.
n_segs : int
Number of distinct segments
dim : int
Dimension to sum on
Returns
-------
Tensor
The result
"""
pass
def unsorted_1d_segment_mean(input, seg_id, n_segs, dim):
"""Computes the mean along segments of a tensor.
Equivalent to tf.unsorted_segment_mean, but seg_id is required to be a
1D tensor.
Note that segments never appeared in seg_id will have results of 0.
Parameters
----------
input : Tensor
The input tensor
seg_id : 1D Tensor
The segment IDs whose values are between 0 and n_segs - 1. Should
have the same length as input.
n_segs : int
Number of distinct segments
dim : int
Dimension to average on
Returns
-------
Tensor
The result
"""
pass
###############################################################################
# Tensor functions used *only* on index tensor
# ----------------
......
......@@ -49,6 +49,9 @@ def dtype(input):
# NOTE: the input cannot be a symbol
return input.dtype
def ndim(input):
return input.ndim
def context(input):
return input.context
......@@ -64,12 +67,18 @@ def copy_to(input, ctx):
def sum(input, dim):
return nd.sum(input, axis=dim)
def mean(input, dim):
return nd.mean(input, axis=dim)
def max(input, dim):
return nd.max(input, axis=dim).asnumpy()[0]
def cat(seq, dim):
return nd.concat(*seq, dim=dim)
def stack(seq, dim):
return nd.stack(*seq, dim=dim)
def split(x, sizes_or_sections, dim):
if isinstance(sizes_or_sections, list):
# TODO: fallback to numpy is unfortunate
......@@ -114,6 +123,35 @@ def ones(shape, dtype):
def spmm(x, y):
return nd.dot(x, y)
def unsorted_1d_segment_sum(input, seg_id, n_segs, dim):
# TODO: support other dimensions
assert dim == 0, 'MXNet only supports segment sum on first dimension'
# Use SPMV to simulate segment sum
ctx = input.context
n_inputs = input.shape[0]
input_shape_suffix = input.shape[1:]
input = input.reshape(n_inputs, -1)
n_range = nd.arange(n_inputs, dtype='int64').as_in_context(input.context)
w_nnz = nd.ones(n_inputs).as_in_context(input.context)
w_nid = nd.stack(seg_id, n_range, axis=0)
w = nd.sparse.csr_matrix((w_nnz, (seg_id, n_range)), (n_segs, n_inputs))
w = w.as_in_context(input.context)
y = nd.dot(w, input)
y = nd.reshape(y, (n_segs,) + input_shape_suffix)
return y
def unsorted_1d_segment_mean(input, seg_id, n_segs, dim):
# TODO: support other dimensions
assert dim == 0, 'MXNet only supports segment mean on first dimension'
n_ones = nd.ones_like(seg_id).astype(input.dtype)
w = unsorted_1d_segment_sum(n_ones, seg_id, n_segs, 0)
w = nd.clip(w, a_min=1, a_max=np.inf)
y = unsorted_1d_segment_sum(input, seg_id, n_segs, dim)
y /= w.reshape((-1,) + (1,) * (y.ndim - 1))
return y
def unique(input):
# TODO: fallback to numpy is unfortunate
tmp = input.asnumpy()
......
......@@ -37,6 +37,9 @@ def shape(input):
def dtype(input):
return input.dtype
def ndim(input):
return input.dim()
def context(input):
return input.device
......@@ -58,6 +61,9 @@ def copy_to(input, ctx):
def sum(input, dim):
return th.sum(input, dim=dim)
def mean(input, dim):
return th.mean(input, dim=dim)
def max(input, dim):
# NOTE: the second argmax array is not returned
return th.max(input, dim=dim)[0]
......@@ -65,6 +71,9 @@ def max(input, dim):
def cat(seq, dim):
return th.cat(seq, dim=dim)
def stack(seq, dim):
return th.stack(seq, dim=dim)
def split(input, sizes_or_sections, dim):
return th.split(input, sizes_or_sections, dim)
......@@ -98,6 +107,19 @@ def ones(shape, dtype):
def spmm(x, y):
return th.spmm(x, y)
def unsorted_1d_segment_sum(input, seg_id, n_segs, dim):
y = th.zeros(n_segs, *input.shape[1:]).to(input)
seg_id = seg_id.view((-1,) + (1,) * (input.dim() - 1)).expand_as(input)
y = y.scatter_add_(dim, seg_id, input)
return y
def unsorted_1d_segment_mean(input, seg_id, n_segs, dim):
w = unsorted_1d_segment_sum(th.ones_like(seg_id), seg_id, n_segs, 0).to(input)
w = w.clamp(min=1) # remove 0 entries
y = unsorted_1d_segment_sum(input, seg_id, n_segs, dim)
y /= w.view((-1,) + (1,) * (y.dim() - 1))
return y
def unique(input):
return th.unique(input)
......
......@@ -10,7 +10,8 @@ from . import graph_index as gi
from . import backend as F
from . import utils
__all__ = ['BatchedDGLGraph', 'batch', 'unbatch', 'split']
__all__ = ['BatchedDGLGraph', 'batch', 'unbatch', 'split',
'sum_nodes', 'sum_edges', 'mean_nodes', 'mean_edges']
class BatchedDGLGraph(DGLGraph):
"""Class for batched DGL graphs.
......@@ -31,11 +32,13 @@ class BatchedDGLGraph(DGLGraph):
batched_index = gi.disjoint_union([g._graph for g in graph_list])
# create batched node and edge frames
# NOTE: following code will materialize the columns of the input graphs.
cols = {key: F.cat([gr._node_frame[key] for gr in graph_list], dim=0)
cols = {key: F.cat([gr._node_frame[key] for gr in graph_list
if gr.number_of_nodes() > 0], dim=0)
for key in node_attrs}
batched_node_frame = FrameRef(Frame(cols))
cols = {key: F.cat([gr._edge_frame[key] for gr in graph_list], dim=0)
cols = {key: F.cat([gr._edge_frame[key] for gr in graph_list
if gr.number_of_edges() > 0], dim=0)
for key in edge_attrs}
batched_edge_frame = FrameRef(Frame(cols))
......@@ -98,25 +101,6 @@ class BatchedDGLGraph(DGLGraph):
# TODO
pass
def readout(self, reduce_func):
"""Perform readout for each graph in the batch.
The readout value is a tensor of shape (B, D1, D2, ...) where B is the
batch size.
Parameters
----------
reduce_func : callable
The reduce function for readout.
Returns
-------
dict of tensors
The readout values.
"""
# TODO
pass
def split(graph_batch, num_or_size_splits):
"""Split the batch."""
# TODO(minjie): could follow torch.split syntax
......@@ -136,7 +120,7 @@ def unbatch(graph):
smaller partitions. This is usually wasteful.
For simpler tasks such as node/edge state aggregation by example,
try to use BatchedDGLGraph.readout().
try to use readout functions.
"""
assert isinstance(graph, BatchedDGLGraph)
bsize = graph.batch_size
......@@ -191,3 +175,177 @@ def batch(graph_list, node_attrs=ALL, edge_attrs=ALL):
elif isinstance(edge_attrs, str):
edge_attrs = [edge_attrs]
return BatchedDGLGraph(graph_list, node_attrs, edge_attrs)
_readout_on_attrs = {
'nodes': ('ndata', 'batch_num_nodes', 'number_of_nodes'),
'edges': ('edata', 'batch_num_edges', 'number_of_edges'),
}
def _sum_on(graph, on, input, weight):
data_attr, batch_num_objs_attr, num_objs_attr = _readout_on_attrs[on]
data = getattr(graph, data_attr)
input = data[input]
if weight is not None:
weight = data[weight]
weight = F.reshape(weight, (-1,) + (1,) * (F.ndim(input) - 1))
input = weight * input
if isinstance(graph, BatchedDGLGraph):
n_graphs = graph.batch_size
batch_num_objs = getattr(graph, batch_num_objs_attr)
n_objs = getattr(graph, num_objs_attr)()
seg_id = F.zerocopy_from_numpy(
np.arange(n_graphs, dtype='int64').repeat(batch_num_objs))
seg_id = F.copy_to(seg_id, F.context(input))
y = F.unsorted_1d_segment_sum(input, seg_id, n_graphs, 0)
return y
else:
return F.sum(input, 0)
def sum_nodes(graph, input, weight=None):
"""Sums all the values of node field `input` in `graph`, optionally
multiplies the field by a scalar node field `weight`.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph
input : str
The input field
weight : optional, str
The weight field. Default is all 1 (i.e. not weighting)
Returns
-------
tensor
The summed tensor.
Notes
-----
If graph is a BatchedDGLGraph, a stacked tensor is returned instead,
i.e. having an extra first dimension.
Each row of the stacked tensor contains the readout result of
corresponding example in the batch. If an example has no nodes,
a zero tensor with the same shape is returned at the corresponding row.
"""
return _sum_on(graph, 'nodes', input, weight)
def sum_edges(graph, input, weight=None):
"""Sums all the values of edge field `input` in `graph`, optionally
multiplies the field by a scalar edge field `weight`.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph
input : str
The input field
weight : optional, str
The weight field. Default is all 1 (i.e. not weighting)
Returns
-------
tensor
The summed tensor.
Notes
-----
If graph is a BatchedDGLGraph, a stacked tensor is returned instead,
i.e. having an extra first dimension.
Each row of the stacked tensor contains the readout result of
corresponding example in the batch. If an example has no edges,
a zero tensor with the same shape is returned at the corresponding row.
"""
return _sum_on(graph, 'edges', input, weight)
def _mean_on(graph, on, input, weight):
data_attr, batch_num_objs_attr, num_objs_attr = _readout_on_attrs[on]
data = getattr(graph, data_attr)
input = data[input]
if weight is not None:
weight = data[weight]
weight = F.reshape(weight, (-1,) + (1,) * (F.ndim(input) - 1))
input = weight * input
if isinstance(graph, BatchedDGLGraph):
n_graphs = graph.batch_size
batch_num_objs = getattr(graph, batch_num_objs_attr)
n_objs = getattr(graph, num_objs_attr)()
seg_id = F.zerocopy_from_numpy(
np.arange(n_graphs, dtype='int64').repeat(batch_num_objs))
seg_id = F.copy_to(seg_id, F.context(input))
if weight is not None:
w = F.unsorted_1d_segment_sum(weight, seg_id, n_graphs, 0)
y = F.unsorted_1d_segment_sum(input, seg_id, n_graphs, 0)
y = y / w
else:
y = F.unsorted_1d_segment_mean(input, seg_id, n_graphs, 0)
return y
else:
if weight is None:
return F.mean(input, 0)
else:
y = F.sum(input, 0) / F.sum(weight, 0)
return y
def mean_nodes(graph, input, weight=None):
"""Averages all the values of node field `input` in `graph`, optionally
multiplies the field by a scalar node field `weight`.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph
input : str
The input field
weight : optional, str
The weight field. Default is all 1 (i.e. not weighting)
Returns
-------
tensor
The averaged tensor.
Notes
-----
If graph is a BatchedDGLGraph, a stacked tensor is returned instead,
i.e. having an extra first dimension.
Each row of the stacked tensor contains the readout result of
corresponding example in the batch. If an example has no nodes,
a zero tensor with the same shape is returned at the corresponding row.
"""
return _mean_on(graph, 'nodes', input, weight)
def mean_edges(graph, input, weight=None):
"""Averages all the values of edge field `input` in `graph`, optionally
multiplies the field by a scalar edge field `weight`.
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph
input : str
The input field
weight : optional, str
The weight field. Default is all 1 (i.e. not weighting)
Returns
-------
tensor
The averaged tensor.
Notes
-----
If graph is a BatchedDGLGraph, a stacked tensor is returned instead,
i.e. having an extra first dimension.
Each row of the stacked tensor contains the readout result of
corresponding example in the batch. If an example has no edges,
a zero tensor with the same shape is returned at the corresponding row.
"""
return _mean_on(graph, 'edges', input, weight)
"""Built-in functions."""
from functools import update_wrapper
__all__ = ['create_bundled_function_class']
def create_bundled_function_class(name, cls):
class Bundled(cls):
def __init__(self, fn_list):
if not isinstance(fn_list, (list, tuple)):
fn_list = [fn_list]
self.fn_list = fn_list
def is_spmv_supported(self, *args, **kwargs):
return all(isinstance(fn, cls) and
fn.is_spmv_supported(*args, **kwargs)
for fn in self.fn_list)
def __call__(self, *args, **kwargs):
ret = {}
for fn in self.fn_list:
result = fn(*args, **kwargs)
ret.update(result)
return ret
def name(self):
return "bundled"
# Fake the names for introspection
Bundled.__module__ = cls.__module__
Bundled.__name__ = name
Bundled.__qualname__ = name
for method_name in ('__init__', '__call__', 'is_spmv_supported', 'name'):
method = getattr(Bundled, method_name)
method.__qualname__ = f'{Bundled.__qualname__}.{method_name}'
for method_name in ('__call__', 'is_spmv_supported', 'name'):
method = getattr(Bundled, method_name)
method = update_wrapper(method,
cls.__dict__[method.__name__],
('__module__', '__doc__', '__annotations__'))
return Bundled
......@@ -3,6 +3,7 @@ from __future__ import absolute_import
import operator
import dgl.backend as F
from .base import create_bundled_function_class
__all__ = ["src_mul_edge", "copy_src", "copy_edge"]
......@@ -26,27 +27,8 @@ class MessageFunction(object):
raise NotImplementedError
class BundledMessageFunction(MessageFunction):
def __init__(self, fn_list):
if not isinstance(fn_list, (list, tuple)):
fn_list = [fn_list]
self.fn_list = fn_list
def is_spmv_supported(self, g):
for fn in self.fn_list:
if not isinstance(fn, MessageFunction) or not fn.is_spmv_supported(g):
return False
return True
def __call__(self, edges):
ret = dict()
for fn in self.fn_list:
msg = fn(edges)
ret.update(msg)
return ret
def name(self):
return "bundled"
BundledMessageFunction = create_bundled_function_class(
'BundledMessageFunction', MessageFunction)
def _is_spmv_supported_node_feat(g, field):
......@@ -80,8 +62,12 @@ class SrcMulEdgeMessageFunction(MessageFunction):
and _is_spmv_supported_edge_feat(g, self.edge_field)
def __call__(self, edges):
src_data = edges.src[self.src_field]
edata = edges.data[self.edge_field]
src_dim = F.ndim(src_data)
eshape = F.shape(edata)[0]
ret = self.mul_op(edges.src[self.src_field],
edges.data[self.edge_field])
F.reshape(edges.data[self.edge_field], (eshape,) + (1,) * (src_dim - 1)))
return {self.out_field : ret}
def name(self):
......
......@@ -2,6 +2,7 @@
from __future__ import absolute_import
from .. import backend as F
from .base import create_bundled_function_class
__all__ = ["sum", "max"]
......@@ -23,44 +24,29 @@ class ReduceFunction(object):
"""Return whether the SPMV optimization is supported."""
raise NotImplementedError
class BundledReduceFunction(ReduceFunction):
def __init__(self, fn_list):
if not isinstance(fn_list, (list, tuple)):
fn_list = [fn_list]
self.fn_list = fn_list
def is_spmv_supported(self):
for fn in self.fn_list:
if not isinstance(fn, ReduceFunction) or not fn.is_spmv_supported():
return False
return True
def __call__(self, nodes):
ret = dict()
for fn in self.fn_list:
rpr = fn(nodes)
ret.update(rpr)
return ret
BundledReduceFunction = create_bundled_function_class(
'BundledReduceFunction', ReduceFunction)
def name(self):
return "bundled"
class ReducerFunctionTemplate(ReduceFunction):
class SimpleReduceFunction(ReduceFunction):
"""Builtin reduce function that aggregates a single field into another
single field."""
def __init__(self, name, op, msg_field, out_field):
self.name = name
self._name = name
self.op = op
self.msg_field = msg_field
self.out_field = out_field
def is_spmv_supported(self):
# NOTE: only sum is supported right now.
return self.name == "sum"
return self._name == "sum"
def __call__(self, nodes):
return {self.out_field : self.op(nodes.mailbox[self.msg_field], 1)}
def name(self):
return self.name
return self._name
def sum(msg, out):
"""Builtin reduce function that aggregates messages by sum.
......@@ -72,7 +58,7 @@ def sum(msg, out):
out : str
The output node feature name.
"""
return ReducerFunctionTemplate("sum", F.sum, msg, out)
return SimpleReduceFunction("sum", F.sum, msg, out)
def max(msg, out):
"""Builtin reduce function that aggregates messages by max.
......@@ -84,4 +70,4 @@ def max(msg, out):
out : str
The output node feature name.
"""
return ReducerFunctionTemplate("max", F.max, msg, out)
return SimpleReduceFunction("max", F.max, msg, out)
import torch as th
import dgl
def test_simple_readout():
g1 = dgl.DGLGraph()
g1.add_nodes(3)
g2 = dgl.DGLGraph()
g2.add_nodes(4) # no edges
g1.add_edges([0, 1, 2], [2, 0, 1])
n1 = th.randn(3, 5)
n2 = th.randn(4, 5)
e1 = th.randn(3, 5)
s1 = n1.sum(0) # node sums
s2 = n2.sum(0)
se1 = e1.sum(0) # edge sums
m1 = n1.mean(0) # node means
m2 = n2.mean(0)
me1 = e1.mean(0) # edge means
w1 = th.randn(3)
w2 = th.randn(4)
ws1 = (n1 * w1[:, None]).sum(0) # weighted node sums
ws2 = (n2 * w2[:, None]).sum(0)
wm1 = (n1 * w1[:, None]).sum(0) / w1[:, None].sum(0) # weighted node means
wm2 = (n2 * w2[:, None]).sum(0) / w2[:, None].sum(0)
g1.ndata['x'] = n1
g2.ndata['x'] = n2
g1.ndata['w'] = w1
g2.ndata['w'] = w2
g1.edata['x'] = e1
assert th.allclose(dgl.sum_nodes(g1, 'x'), s1)
assert th.allclose(dgl.sum_nodes(g1, 'x', 'w'), ws1)
assert th.allclose(dgl.sum_edges(g1, 'x'), se1)
assert th.allclose(dgl.mean_nodes(g1, 'x'), m1)
assert th.allclose(dgl.mean_nodes(g1, 'x', 'w'), wm1)
assert th.allclose(dgl.mean_edges(g1, 'x'), me1)
g = dgl.batch([g1, g2])
s = dgl.sum_nodes(g, 'x')
m = dgl.mean_nodes(g, 'x')
assert th.allclose(s, th.stack([s1, s2], 0))
assert th.allclose(m, th.stack([m1, m2], 0))
ws = dgl.sum_nodes(g, 'x', 'w')
wm = dgl.mean_nodes(g, 'x', 'w')
assert th.allclose(ws, th.stack([ws1, ws2], 0))
assert th.allclose(wm, th.stack([wm1, wm2], 0))
s = dgl.sum_edges(g, 'x')
m = dgl.mean_edges(g, 'x')
assert th.allclose(s, th.stack([se1, th.zeros(5)], 0))
assert th.allclose(m, th.stack([me1, th.zeros(5)], 0))
if __name__ == '__main__':
test_simple_readout()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment