Commit 2c234118 authored by Zihao Ye's avatar Zihao Ye Committed by Minjie Wang
Browse files

[Feature] Add builtin mean reducer (#787)

* upd

* upd

* upd

* upd

* upd

* passed test

* add note

* upd

* trigger

* slight change

* upd

* upd

* trigger

* fix

* simplify

* upd

* upd

* fudge

* upd

* trigger

* test partial

* upd

* trigger
parent 73b2668f
......@@ -107,6 +107,8 @@ Here is a cheatsheet of all the DGL builtins.
| | ``sum`` | |
| +----------------------------------------------------+-----------------------+
| | ``prod`` | |
| +----------------------------------------------------+-----------------------+
| | ``mean`` | |
+-------------------------+----------------------------------------------------+-----------------------+
Next Step
......
......@@ -9,6 +9,7 @@ import numbers
import builtins
from ... import ndarray as dglnd
from ... import kernel as K
from ...function.base import TargetCode
MX_VERSION = LooseVersion(mx.__version__)
# After MXNet 1.5, empty tensors aren't supprted by default.
......@@ -368,20 +369,48 @@ class BinaryReduce(mx.autograd.Function):
ctx=lhs_data.context, dtype=lhs_data.dtype)
out_data_nd = zerocopy_to_dgl_ndarray_for_write(out_data)
K.binary_op_reduce(
self.reducer, self.binary_op, self.graph, self.lhs, self.rhs,
self.reducer if self.reducer != 'mean' else 'sum',
self.binary_op, self.graph, self.lhs, self.rhs,
lhs_data_nd, rhs_data_nd, out_data_nd, self.lhs_map[0],
self.rhs_map[0], self.out_map[0])
# normalize if mean reducer
# NOTE(zihao): this is a temporary hack and we should have better solution in the future.
if self.reducer == 'mean':
degs = nd.empty((out_data.shape[0],),
ctx=out_data.context, dtype=out_data.dtype)
degs_nd = zerocopy_to_dgl_ndarray(degs)
if self.lhs != TargetCode.DST:
target = self.lhs
n = lhs_data.shape[0]
in_map = self.lhs_map[0]
else:
target = self.rhs
n = rhs_data.shape[0]
in_map = self.rhs_map[0]
in_ones = nd.ones((n,), ctx=lhs_data.context, dtype=lhs_data.dtype)
in_ones_nd = zerocopy_to_dgl_ndarray(in_ones)
K.copy_reduce(
'sum', self.graph, target, in_ones_nd, degs_nd,
in_map, self.out_map[0])
# reshape
degs = degs.reshape((out_data.shape[0],) + (1,) * (out_data.ndim - 1)).clip(1, float('inf'))
out_data = out_data / degs
else:
degs = None
self.save_for_backward(lhs_data_nd, rhs_data_nd, out_data_nd,
feat_shape)
feat_shape, degs)
return out_data
def backward(self, grad_out):
lhs_data_nd, rhs_data_nd, out_data_nd, feat_shape = self.saved_tensors
lhs_data_nd, rhs_data_nd, out_data_nd, feat_shape, degs = self.saved_tensors
if self.reducer == 'mean':
grad_out = grad_out / degs
grad_out_nd = zerocopy_to_dgl_ndarray(grad_out)
grad_lhs = nd.empty((lhs_data_nd.shape[0],) + feat_shape,
ctx=grad_out.context, dtype=grad_out.dtype)
K.backward_lhs_binary_op_reduce(
self.reducer, self.binary_op, self.graph, self.lhs, self.rhs,
self.reducer if self.reducer != 'mean' else 'sum',
self.binary_op, self.graph, self.lhs, self.rhs,
lhs_data_nd, rhs_data_nd, out_data_nd, grad_out_nd,
zerocopy_to_dgl_ndarray_for_write(grad_lhs), self.lhs_map[1],
self.rhs_map[1], self.out_map[1])
......@@ -389,7 +418,8 @@ class BinaryReduce(mx.autograd.Function):
grad_rhs = nd.empty((rhs_data_nd.shape[0],) + feat_shape,
ctx=grad_out.context, dtype=grad_out.dtype)
K.backward_rhs_binary_op_reduce(
self.reducer, self.binary_op, self.graph, self.lhs, self.rhs,
self.reducer if self.reducer != 'mean' else 'sum',
self.binary_op, self.graph, self.lhs, self.rhs,
lhs_data_nd, rhs_data_nd, out_data_nd, grad_out_nd,
zerocopy_to_dgl_ndarray_for_write(grad_rhs), self.lhs_map[1],
self.rhs_map[1], self.out_map[1])
......@@ -423,18 +453,39 @@ class CopyReduce(mx.autograd.Function):
in_data_nd = zerocopy_to_dgl_ndarray(in_data)
out_data_nd = zerocopy_to_dgl_ndarray_for_write(out_data)
K.copy_reduce(
self.reducer, self.graph, self.target, in_data_nd, out_data_nd,
self.reducer if self.reducer != 'mean' else 'sum',
self.graph, self.target, in_data_nd, out_data_nd,
self.in_map[0], self.out_map[0])
self.save_for_backward(in_data_nd, out_data_nd)
# normalize if mean reducer
# NOTE(zihao): this is a temporary hack and we should have better solution in the future.
if self.reducer == 'mean':
in_ones = nd.ones((in_data.shape[0],),
ctx=in_data.context, dtype=in_data.dtype)
degs = nd.empty((out_data.shape[0],),
ctx=out_data.context, dtype=out_data.dtype)
in_ones_nd = zerocopy_to_dgl_ndarray(in_ones)
degs_nd = zerocopy_to_dgl_ndarray(degs)
K.copy_reduce(
'sum', self.graph, self.target, in_ones_nd, degs_nd,
self.in_map[0], self.out_map[0])
# reshape
degs = degs.reshape((out_data.shape[0],) + (1,) * (out_data.ndim - 1)).clip(1, float('inf'))
out_data = out_data / degs
else:
degs = None
self.save_for_backward(in_data_nd, out_data_nd, degs)
return out_data
def backward(self, grad_out):
in_data_nd, out_data_nd = self.saved_tensors
grad_out_nd = zerocopy_to_dgl_ndarray(grad_out)
in_data_nd, out_data_nd, degs = self.saved_tensors
grad_in = nd.empty(in_data_nd.shape, ctx=grad_out.context,
dtype=grad_out.dtype)
if self.reducer == 'mean':
grad_out = grad_out / degs
grad_out_nd = zerocopy_to_dgl_ndarray(grad_out)
K.backward_copy_reduce(
self.reducer, self.graph, self.target, in_data_nd, out_data_nd,
self.reducer if self.reducer != 'mean' else 'sum',
self.graph, self.target, in_data_nd, out_data_nd,
grad_out_nd, zerocopy_to_dgl_ndarray_for_write(grad_in),
self.in_map[1], self.out_map[1])
# clear saved tensors explicitly
......
......@@ -8,6 +8,7 @@ from torch.utils import dlpack
from ... import ndarray as nd
from ... import kernel as K
from ...function.base import TargetCode
TH_VERSION = LooseVersion(th.__version__)
......@@ -289,34 +290,61 @@ class BinaryReduce(th.autograd.Function):
out_data = lhs_data.new_empty((out_size,) + feat_shape)
out_data_nd = zerocopy_to_dgl_ndarray(out_data)
K.binary_op_reduce(
reducer, binary_op, graph, lhs, rhs, lhs_data_nd, rhs_data_nd,
reducer if reducer != 'mean' else 'sum',
binary_op, graph, lhs, rhs, lhs_data_nd, rhs_data_nd,
out_data_nd, lhs_map[0], rhs_map[0], out_map[0])
# normalize if mean reducer
# NOTE(zihao): this is a temporary hack and we should have better solution in the future.
if reducer == 'mean':
degs = lhs_data.new_empty((out_data.shape[0],))
degs_nd = zerocopy_to_dgl_ndarray(degs)
if lhs != TargetCode.DST: # src or edge
target = lhs
n = lhs_data.shape[0]
in_map = lhs_map[0]
else: # rhs != TargetCode.DST
target = rhs
n = rhs_data.shape[0]
in_map = rhs_map[0]
in_ones = lhs_data.new_ones((n,))
in_ones_nd = zerocopy_to_dgl_ndarray(in_ones)
K.copy_reduce(
'sum', graph, target, in_ones_nd, degs_nd, in_map, out_map[0])
# reshape
degs = degs.reshape((out_data.shape[0],) + (1,) * (out_data.dim() - 1)).clamp(min=1)
out_data = out_data / degs
else:
degs = None
# save_for_backward can only save variables
ctx.backward_cache = (reducer, binary_op, graph, lhs, rhs, lhs_map,
rhs_map, out_map, lhs_data_nd, rhs_data_nd,
out_data_nd, feat_shape)
out_data_nd, feat_shape, degs)
return out_data
@staticmethod
def backward(ctx, grad_out):
reducer, binary_op, graph, lhs, rhs, lhs_map, rhs_map, out_map, \
lhs_data_nd, rhs_data_nd, out_data_nd, feat_shape \
lhs_data_nd, rhs_data_nd, out_data_nd, feat_shape, degs \
= ctx.backward_cache
ctx.backward_cache = None
grad_lhs = None
grad_rhs = None
if reducer == 'mean':
grad_out = grad_out / degs
grad_out_nd = zerocopy_to_dgl_ndarray(grad_out)
if ctx.needs_input_grad[5]:
grad_lhs = grad_out.new_empty((lhs_data_nd.shape[0],) + feat_shape)
K.backward_lhs_binary_op_reduce(
reducer, binary_op, graph, lhs, rhs, lhs_data_nd, rhs_data_nd,
reducer if reducer != 'mean' else 'sum',
binary_op, graph, lhs, rhs, lhs_data_nd, rhs_data_nd,
out_data_nd, grad_out_nd, zerocopy_to_dgl_ndarray(grad_lhs),
lhs_map[1], rhs_map[1], out_map[1])
grad_lhs = _reduce_grad(grad_lhs, lhs_data_nd.shape)
if ctx.needs_input_grad[6]:
grad_rhs = grad_out.new_empty((rhs_data_nd.shape[0],) + feat_shape)
K.backward_rhs_binary_op_reduce(
reducer, binary_op, graph, lhs, rhs, lhs_data_nd, rhs_data_nd,
reducer if reducer != 'mean' else 'sum',
binary_op, graph, lhs, rhs, lhs_data_nd, rhs_data_nd,
out_data_nd, grad_out_nd, zerocopy_to_dgl_ndarray(grad_rhs),
lhs_map[1], rhs_map[1], out_map[1])
grad_rhs = _reduce_grad(grad_rhs, rhs_data_nd.shape)
......@@ -333,24 +361,41 @@ class CopyReduce(th.autograd.Function):
in_data_nd = zerocopy_to_dgl_ndarray(in_data)
out_data_nd = zerocopy_to_dgl_ndarray(out_data)
K.copy_reduce(
reducer, graph, target, in_data_nd, out_data_nd, in_map[0],
out_map[0])
reducer if reducer != 'mean' else 'sum',
graph, target, in_data_nd, out_data_nd, in_map[0], out_map[0])
# normalize if mean reducer
# NOTE(zihao): this is a temporary hack and we should have better solution in the future.
if reducer == 'mean':
in_ones = in_data.new_ones((in_data.shape[0],))
degs = in_data.new_empty((out_data.shape[0],))
in_ones_nd = zerocopy_to_dgl_ndarray(in_ones)
degs_nd = zerocopy_to_dgl_ndarray(degs)
K.copy_reduce(
'sum', graph, target, in_ones_nd, degs_nd, in_map[0], out_map[0])
# reshape
degs = degs.reshape((out_data.shape[0],) + (1,) * (out_data.dim() - 1)).clamp(min=1)
out_data = out_data / degs
else:
degs = None
# save_for_backward can only save variables
ctx.backward_cache = (reducer, graph, target, in_map, out_map,
in_data_nd, out_data_nd)
in_data_nd, out_data_nd, degs)
return out_data
@staticmethod
def backward(ctx, grad_out):
reducer, graph, target, in_map, out_map, in_data_nd, out_data_nd \
reducer, graph, target, in_map, out_map, in_data_nd, out_data_nd, degs \
= ctx.backward_cache
ctx.backward_cache = None
grad_in = None
if reducer == 'mean':
grad_out = grad_out / degs
grad_out_nd = zerocopy_to_dgl_ndarray(grad_out)
if ctx.needs_input_grad[3]:
grad_in = grad_out.new_empty(in_data_nd.shape)
K.backward_copy_reduce(
reducer, graph, target, in_data_nd, out_data_nd, grad_out_nd,
reducer if reducer != 'mean' else 'sum',
graph, target, in_data_nd, out_data_nd, grad_out_nd,
zerocopy_to_dgl_ndarray(grad_in), in_map[1], out_map[1])
return None, None, None, grad_in, None, None, None
......
......@@ -51,7 +51,7 @@ class SimpleReduceFunction(ReduceFunction):
###############################################################################
# Generate all following reducer functions:
# sum, max, min, prod
# sum, max, min, mean, prod
def _gen_reduce_builtin(reducer):
docstring = """Builtin reduce function that aggregates messages by {0}.
......@@ -87,7 +87,7 @@ __all__ = []
def _register_builtin_reduce_func():
"""Register builtin reduce functions"""
for reduce_op in ["max", "min", "sum", "prod"]:
for reduce_op in ["max", "min", "sum", "mean", "prod"]:
builtin = _gen_reduce_builtin(reduce_op)
setattr(sys.modules[__name__], reduce_op, builtin)
__all__.append(reduce_op)
......
......@@ -14,6 +14,8 @@ def udf_copy_src(edges):
def udf_copy_edge(edges):
return {'m': edges.data['e']}
def udf_mean(nodes):
return {'r2': nodes.mailbox['m'].mean(1)}
def udf_sum(nodes):
return {'r2': nodes.mailbox['m'].sum(1)}
......@@ -26,8 +28,8 @@ def udf_max(nodes):
D1 = 5
D2 = 3
D3 = 4
builtin = {'sum': fn.sum, 'max': fn.max}
udf_reduce = {'sum': udf_sum, 'max': udf_max}
builtin = {'sum': fn.sum, 'max': fn.max, 'mean': fn.mean}
udf_reduce = {'sum': udf_sum, 'max': udf_max, 'mean': udf_mean}
fill_value = {'sum': 0, 'max': float("-inf")}
......@@ -57,15 +59,21 @@ def generate_feature(g, broadcast='none'):
def test_copy_src_reduce():
def _test(red):
def _test(red, partial):
g = dgl.DGLGraph(nx.erdos_renyi_graph(100, 0.1))
hu, hv, he = generate_feature(g, 'none')
if partial:
nid = F.tensor(list(range(0, 100, 2)))
g.ndata['u'] = F.attach_grad(F.clone(hu))
g.ndata['v'] = F.attach_grad(F.clone(hv))
g.edata['e'] = F.attach_grad(F.clone(he))
with F.record_grad():
if partial:
g.pull(nid, fn.copy_src(src='u', out='m'),
builtin[red](msg='m', out='r1'))
else:
g.update_all(fn.copy_src(src='u', out='m'),
builtin[red](msg='m', out='r1'))
r1 = g.ndata['r1']
......@@ -78,6 +86,9 @@ def test_copy_src_reduce():
g.edata['e'] = F.attach_grad(F.clone(he))
with F.record_grad():
if partial:
g.pull(nid, udf_copy_src, udf_reduce[red])
else:
g.update_all(udf_copy_src, udf_reduce[red])
r2 = g.ndata['r2']
F.backward(r2.sum())
......@@ -86,19 +97,32 @@ def test_copy_src_reduce():
assert F.allclose(r1, r2)
assert(F.allclose(n_grad1, n_grad2))
_test('sum')
_test('max')
_test('sum', False)
_test('max', False)
_test('mean', False)
_test('sum', True)
_test('max', True)
_test('mean', True)
def test_copy_edge_reduce():
def _test(red):
def _test(red, partial):
g = dgl.DGLGraph(nx.erdos_renyi_graph(100, 0.1))
hu, hv, he = generate_feature(g, 'none')
if partial:
nid = F.tensor(list(range(0, 100, 2)))
g.ndata['u'] = F.attach_grad(F.clone(hu))
g.ndata['v'] = F.attach_grad(F.clone(hv))
g.edata['e'] = F.attach_grad(F.clone(he))
with F.record_grad():
if partial:
g.pull(nid, fn.copy_edge(edge='e', out='m'),
builtin[red](msg='m', out='r1'))
else:
g.update_all(fn.copy_edge(edge='e', out='m'),
builtin[red](msg='m', out='r1'))
r1 = g.ndata['r1']
......@@ -111,6 +135,9 @@ def test_copy_edge_reduce():
g.edata['e'] = F.attach_grad(F.clone(he))
with F.record_grad():
if partial:
g.pull(nid, udf_copy_edge, udf_reduce[red])
else:
g.update_all(udf_copy_edge, udf_reduce[red])
r2 = g.ndata['r2']
F.backward(r2.sum())
......@@ -119,12 +146,16 @@ def test_copy_edge_reduce():
assert F.allclose(r1, r2)
assert(F.allclose(e_grad1, e_grad2))
_test('sum')
_test('max')
_test('sum', False)
_test('max', False)
_test('mean', False)
_test('sum', True)
_test('max', True)
_test('mean', True)
def test_all_binary_builtins():
def _test(g, lhs, rhs, binary_op, reducer, broadcast='none'):
def _test(g, lhs, rhs, binary_op, reducer, paritial, nid, broadcast='none'):
hu, hv, he = generate_feature(g, broadcast)
g.ndata['u'] = F.attach_grad(F.clone(hu))
g.ndata['v'] = F.attach_grad(F.clone(hv))
......@@ -143,8 +174,11 @@ def test_all_binary_builtins():
return g.edata["e"]
with F.record_grad():
if partial:
g.pull(nid, builtin_msg(lhs, rhs, 'm'), builtin_red('m', 'r1'))
else:
g.update_all(builtin_msg(lhs, rhs, 'm'), builtin_red('m', 'r1'))
r1 = g.ndata['r1']
r1 = g.ndata.pop('r1')
F.backward(r1.sum())
lhs_grad_1 = F.grad(target_feature_switch(g, lhs))
rhs_grad_1 = F.grad(target_feature_switch(g, rhs))
......@@ -175,8 +209,11 @@ def test_all_binary_builtins():
return {"r2": op(nodes.mailbox['m'], 1)}
with F.record_grad():
if partial:
g.pull(nid, mfunc, rfunc)
else:
g.update_all(mfunc, rfunc)
r2 = g.ndata['r2']
r2 = g.ndata.pop('r2')
F.backward(r2.sum(), F.tensor([1.]))
lhs_grad_2 = F.grad(target_feature_switch(g, lhs))
rhs_grad_2 = F.grad(target_feature_switch(g, rhs))
......@@ -192,7 +229,7 @@ def test_all_binary_builtins():
print("ERROR: Test {}_{}_{}_{} {}".
format(lhs, binary_op, rhs, reducer, broadcast))
print(a, b)
for i, (x, y) in enumerate(zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())):
for i, (x, y) in enumerate(zip(F.asnumpy(F.cpu(a)).flatten(), F.asnumpy(F.cpu(b)).flatten())):
if not np.allclose(x, y, rtol, atol):
print('@{} {} v.s. {}'.format(i, x, y))
......@@ -221,14 +258,16 @@ def test_all_binary_builtins():
g.add_edge(18, 1)
g.add_edge(19, 0)
g.add_edge(19, 1)
nid = F.tensor([1, 3, 4, 5, 7, 10, 13, 17, 19])
target = ["u", "v", "e"]
for lhs, rhs in product(target, target):
if lhs == rhs:
continue
for binary_op in ["add", "sub", "mul", "div"]:
for reducer in ["sum", "max", "min", "prod"]:
for reducer in ["sum", "max", "min", "prod", "mean"]:
for broadcast in ["none", lhs, rhs]:
_test(g, lhs, rhs, binary_op, reducer)
for partial in [False, True]:
_test(g, lhs, rhs, binary_op, reducer, partial, nid)
if __name__ == '__main__':
test_copy_src_reduce()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment