"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "2261510bbcf43525d556238d6b0c24112f348a83"
Commit 2c234118 authored by Zihao Ye's avatar Zihao Ye Committed by Minjie Wang
Browse files

[Feature] Add builtin mean reducer (#787)

* upd

* upd

* upd

* upd

* upd

* passed test

* add note

* upd

* trigger

* slight change

* upd

* upd

* trigger

* fix

* simplify

* upd

* upd

* fudge

* upd

* trigger

* test partial

* upd

* trigger
parent 73b2668f
...@@ -107,6 +107,8 @@ Here is a cheatsheet of all the DGL builtins. ...@@ -107,6 +107,8 @@ Here is a cheatsheet of all the DGL builtins.
| | ``sum`` | | | | ``sum`` | |
| +----------------------------------------------------+-----------------------+ | +----------------------------------------------------+-----------------------+
| | ``prod`` | | | | ``prod`` | |
| +----------------------------------------------------+-----------------------+
| | ``mean`` | |
+-------------------------+----------------------------------------------------+-----------------------+ +-------------------------+----------------------------------------------------+-----------------------+
Next Step Next Step
......
...@@ -9,6 +9,7 @@ import numbers ...@@ -9,6 +9,7 @@ import numbers
import builtins import builtins
from ... import ndarray as dglnd from ... import ndarray as dglnd
from ... import kernel as K from ... import kernel as K
from ...function.base import TargetCode
MX_VERSION = LooseVersion(mx.__version__) MX_VERSION = LooseVersion(mx.__version__)
# After MXNet 1.5, empty tensors aren't supprted by default. # After MXNet 1.5, empty tensors aren't supprted by default.
...@@ -368,20 +369,48 @@ class BinaryReduce(mx.autograd.Function): ...@@ -368,20 +369,48 @@ class BinaryReduce(mx.autograd.Function):
ctx=lhs_data.context, dtype=lhs_data.dtype) ctx=lhs_data.context, dtype=lhs_data.dtype)
out_data_nd = zerocopy_to_dgl_ndarray_for_write(out_data) out_data_nd = zerocopy_to_dgl_ndarray_for_write(out_data)
K.binary_op_reduce( K.binary_op_reduce(
self.reducer, self.binary_op, self.graph, self.lhs, self.rhs, self.reducer if self.reducer != 'mean' else 'sum',
self.binary_op, self.graph, self.lhs, self.rhs,
lhs_data_nd, rhs_data_nd, out_data_nd, self.lhs_map[0], lhs_data_nd, rhs_data_nd, out_data_nd, self.lhs_map[0],
self.rhs_map[0], self.out_map[0]) self.rhs_map[0], self.out_map[0])
# normalize if mean reducer
# NOTE(zihao): this is a temporary hack and we should have better solution in the future.
if self.reducer == 'mean':
degs = nd.empty((out_data.shape[0],),
ctx=out_data.context, dtype=out_data.dtype)
degs_nd = zerocopy_to_dgl_ndarray(degs)
if self.lhs != TargetCode.DST:
target = self.lhs
n = lhs_data.shape[0]
in_map = self.lhs_map[0]
else:
target = self.rhs
n = rhs_data.shape[0]
in_map = self.rhs_map[0]
in_ones = nd.ones((n,), ctx=lhs_data.context, dtype=lhs_data.dtype)
in_ones_nd = zerocopy_to_dgl_ndarray(in_ones)
K.copy_reduce(
'sum', self.graph, target, in_ones_nd, degs_nd,
in_map, self.out_map[0])
# reshape
degs = degs.reshape((out_data.shape[0],) + (1,) * (out_data.ndim - 1)).clip(1, float('inf'))
out_data = out_data / degs
else:
degs = None
self.save_for_backward(lhs_data_nd, rhs_data_nd, out_data_nd, self.save_for_backward(lhs_data_nd, rhs_data_nd, out_data_nd,
feat_shape) feat_shape, degs)
return out_data return out_data
def backward(self, grad_out): def backward(self, grad_out):
lhs_data_nd, rhs_data_nd, out_data_nd, feat_shape = self.saved_tensors lhs_data_nd, rhs_data_nd, out_data_nd, feat_shape, degs = self.saved_tensors
if self.reducer == 'mean':
grad_out = grad_out / degs
grad_out_nd = zerocopy_to_dgl_ndarray(grad_out) grad_out_nd = zerocopy_to_dgl_ndarray(grad_out)
grad_lhs = nd.empty((lhs_data_nd.shape[0],) + feat_shape, grad_lhs = nd.empty((lhs_data_nd.shape[0],) + feat_shape,
ctx=grad_out.context, dtype=grad_out.dtype) ctx=grad_out.context, dtype=grad_out.dtype)
K.backward_lhs_binary_op_reduce( K.backward_lhs_binary_op_reduce(
self.reducer, self.binary_op, self.graph, self.lhs, self.rhs, self.reducer if self.reducer != 'mean' else 'sum',
self.binary_op, self.graph, self.lhs, self.rhs,
lhs_data_nd, rhs_data_nd, out_data_nd, grad_out_nd, lhs_data_nd, rhs_data_nd, out_data_nd, grad_out_nd,
zerocopy_to_dgl_ndarray_for_write(grad_lhs), self.lhs_map[1], zerocopy_to_dgl_ndarray_for_write(grad_lhs), self.lhs_map[1],
self.rhs_map[1], self.out_map[1]) self.rhs_map[1], self.out_map[1])
...@@ -389,7 +418,8 @@ class BinaryReduce(mx.autograd.Function): ...@@ -389,7 +418,8 @@ class BinaryReduce(mx.autograd.Function):
grad_rhs = nd.empty((rhs_data_nd.shape[0],) + feat_shape, grad_rhs = nd.empty((rhs_data_nd.shape[0],) + feat_shape,
ctx=grad_out.context, dtype=grad_out.dtype) ctx=grad_out.context, dtype=grad_out.dtype)
K.backward_rhs_binary_op_reduce( K.backward_rhs_binary_op_reduce(
self.reducer, self.binary_op, self.graph, self.lhs, self.rhs, self.reducer if self.reducer != 'mean' else 'sum',
self.binary_op, self.graph, self.lhs, self.rhs,
lhs_data_nd, rhs_data_nd, out_data_nd, grad_out_nd, lhs_data_nd, rhs_data_nd, out_data_nd, grad_out_nd,
zerocopy_to_dgl_ndarray_for_write(grad_rhs), self.lhs_map[1], zerocopy_to_dgl_ndarray_for_write(grad_rhs), self.lhs_map[1],
self.rhs_map[1], self.out_map[1]) self.rhs_map[1], self.out_map[1])
...@@ -423,18 +453,39 @@ class CopyReduce(mx.autograd.Function): ...@@ -423,18 +453,39 @@ class CopyReduce(mx.autograd.Function):
in_data_nd = zerocopy_to_dgl_ndarray(in_data) in_data_nd = zerocopy_to_dgl_ndarray(in_data)
out_data_nd = zerocopy_to_dgl_ndarray_for_write(out_data) out_data_nd = zerocopy_to_dgl_ndarray_for_write(out_data)
K.copy_reduce( K.copy_reduce(
self.reducer, self.graph, self.target, in_data_nd, out_data_nd, self.reducer if self.reducer != 'mean' else 'sum',
self.graph, self.target, in_data_nd, out_data_nd,
self.in_map[0], self.out_map[0]) self.in_map[0], self.out_map[0])
self.save_for_backward(in_data_nd, out_data_nd) # normalize if mean reducer
# NOTE(zihao): this is a temporary hack and we should have better solution in the future.
if self.reducer == 'mean':
in_ones = nd.ones((in_data.shape[0],),
ctx=in_data.context, dtype=in_data.dtype)
degs = nd.empty((out_data.shape[0],),
ctx=out_data.context, dtype=out_data.dtype)
in_ones_nd = zerocopy_to_dgl_ndarray(in_ones)
degs_nd = zerocopy_to_dgl_ndarray(degs)
K.copy_reduce(
'sum', self.graph, self.target, in_ones_nd, degs_nd,
self.in_map[0], self.out_map[0])
# reshape
degs = degs.reshape((out_data.shape[0],) + (1,) * (out_data.ndim - 1)).clip(1, float('inf'))
out_data = out_data / degs
else:
degs = None
self.save_for_backward(in_data_nd, out_data_nd, degs)
return out_data return out_data
def backward(self, grad_out): def backward(self, grad_out):
in_data_nd, out_data_nd = self.saved_tensors in_data_nd, out_data_nd, degs = self.saved_tensors
grad_out_nd = zerocopy_to_dgl_ndarray(grad_out)
grad_in = nd.empty(in_data_nd.shape, ctx=grad_out.context, grad_in = nd.empty(in_data_nd.shape, ctx=grad_out.context,
dtype=grad_out.dtype) dtype=grad_out.dtype)
if self.reducer == 'mean':
grad_out = grad_out / degs
grad_out_nd = zerocopy_to_dgl_ndarray(grad_out)
K.backward_copy_reduce( K.backward_copy_reduce(
self.reducer, self.graph, self.target, in_data_nd, out_data_nd, self.reducer if self.reducer != 'mean' else 'sum',
self.graph, self.target, in_data_nd, out_data_nd,
grad_out_nd, zerocopy_to_dgl_ndarray_for_write(grad_in), grad_out_nd, zerocopy_to_dgl_ndarray_for_write(grad_in),
self.in_map[1], self.out_map[1]) self.in_map[1], self.out_map[1])
# clear saved tensors explicitly # clear saved tensors explicitly
......
...@@ -8,6 +8,7 @@ from torch.utils import dlpack ...@@ -8,6 +8,7 @@ from torch.utils import dlpack
from ... import ndarray as nd from ... import ndarray as nd
from ... import kernel as K from ... import kernel as K
from ...function.base import TargetCode
TH_VERSION = LooseVersion(th.__version__) TH_VERSION = LooseVersion(th.__version__)
...@@ -289,34 +290,61 @@ class BinaryReduce(th.autograd.Function): ...@@ -289,34 +290,61 @@ class BinaryReduce(th.autograd.Function):
out_data = lhs_data.new_empty((out_size,) + feat_shape) out_data = lhs_data.new_empty((out_size,) + feat_shape)
out_data_nd = zerocopy_to_dgl_ndarray(out_data) out_data_nd = zerocopy_to_dgl_ndarray(out_data)
K.binary_op_reduce( K.binary_op_reduce(
reducer, binary_op, graph, lhs, rhs, lhs_data_nd, rhs_data_nd, reducer if reducer != 'mean' else 'sum',
binary_op, graph, lhs, rhs, lhs_data_nd, rhs_data_nd,
out_data_nd, lhs_map[0], rhs_map[0], out_map[0]) out_data_nd, lhs_map[0], rhs_map[0], out_map[0])
# normalize if mean reducer
# NOTE(zihao): this is a temporary hack and we should have better solution in the future.
if reducer == 'mean':
degs = lhs_data.new_empty((out_data.shape[0],))
degs_nd = zerocopy_to_dgl_ndarray(degs)
if lhs != TargetCode.DST: # src or edge
target = lhs
n = lhs_data.shape[0]
in_map = lhs_map[0]
else: # rhs != TargetCode.DST
target = rhs
n = rhs_data.shape[0]
in_map = rhs_map[0]
in_ones = lhs_data.new_ones((n,))
in_ones_nd = zerocopy_to_dgl_ndarray(in_ones)
K.copy_reduce(
'sum', graph, target, in_ones_nd, degs_nd, in_map, out_map[0])
# reshape
degs = degs.reshape((out_data.shape[0],) + (1,) * (out_data.dim() - 1)).clamp(min=1)
out_data = out_data / degs
else:
degs = None
# save_for_backward can only save variables # save_for_backward can only save variables
ctx.backward_cache = (reducer, binary_op, graph, lhs, rhs, lhs_map, ctx.backward_cache = (reducer, binary_op, graph, lhs, rhs, lhs_map,
rhs_map, out_map, lhs_data_nd, rhs_data_nd, rhs_map, out_map, lhs_data_nd, rhs_data_nd,
out_data_nd, feat_shape) out_data_nd, feat_shape, degs)
return out_data return out_data
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
reducer, binary_op, graph, lhs, rhs, lhs_map, rhs_map, out_map, \ reducer, binary_op, graph, lhs, rhs, lhs_map, rhs_map, out_map, \
lhs_data_nd, rhs_data_nd, out_data_nd, feat_shape \ lhs_data_nd, rhs_data_nd, out_data_nd, feat_shape, degs \
= ctx.backward_cache = ctx.backward_cache
ctx.backward_cache = None ctx.backward_cache = None
grad_lhs = None grad_lhs = None
grad_rhs = None grad_rhs = None
if reducer == 'mean':
grad_out = grad_out / degs
grad_out_nd = zerocopy_to_dgl_ndarray(grad_out) grad_out_nd = zerocopy_to_dgl_ndarray(grad_out)
if ctx.needs_input_grad[5]: if ctx.needs_input_grad[5]:
grad_lhs = grad_out.new_empty((lhs_data_nd.shape[0],) + feat_shape) grad_lhs = grad_out.new_empty((lhs_data_nd.shape[0],) + feat_shape)
K.backward_lhs_binary_op_reduce( K.backward_lhs_binary_op_reduce(
reducer, binary_op, graph, lhs, rhs, lhs_data_nd, rhs_data_nd, reducer if reducer != 'mean' else 'sum',
binary_op, graph, lhs, rhs, lhs_data_nd, rhs_data_nd,
out_data_nd, grad_out_nd, zerocopy_to_dgl_ndarray(grad_lhs), out_data_nd, grad_out_nd, zerocopy_to_dgl_ndarray(grad_lhs),
lhs_map[1], rhs_map[1], out_map[1]) lhs_map[1], rhs_map[1], out_map[1])
grad_lhs = _reduce_grad(grad_lhs, lhs_data_nd.shape) grad_lhs = _reduce_grad(grad_lhs, lhs_data_nd.shape)
if ctx.needs_input_grad[6]: if ctx.needs_input_grad[6]:
grad_rhs = grad_out.new_empty((rhs_data_nd.shape[0],) + feat_shape) grad_rhs = grad_out.new_empty((rhs_data_nd.shape[0],) + feat_shape)
K.backward_rhs_binary_op_reduce( K.backward_rhs_binary_op_reduce(
reducer, binary_op, graph, lhs, rhs, lhs_data_nd, rhs_data_nd, reducer if reducer != 'mean' else 'sum',
binary_op, graph, lhs, rhs, lhs_data_nd, rhs_data_nd,
out_data_nd, grad_out_nd, zerocopy_to_dgl_ndarray(grad_rhs), out_data_nd, grad_out_nd, zerocopy_to_dgl_ndarray(grad_rhs),
lhs_map[1], rhs_map[1], out_map[1]) lhs_map[1], rhs_map[1], out_map[1])
grad_rhs = _reduce_grad(grad_rhs, rhs_data_nd.shape) grad_rhs = _reduce_grad(grad_rhs, rhs_data_nd.shape)
...@@ -333,24 +361,41 @@ class CopyReduce(th.autograd.Function): ...@@ -333,24 +361,41 @@ class CopyReduce(th.autograd.Function):
in_data_nd = zerocopy_to_dgl_ndarray(in_data) in_data_nd = zerocopy_to_dgl_ndarray(in_data)
out_data_nd = zerocopy_to_dgl_ndarray(out_data) out_data_nd = zerocopy_to_dgl_ndarray(out_data)
K.copy_reduce( K.copy_reduce(
reducer, graph, target, in_data_nd, out_data_nd, in_map[0], reducer if reducer != 'mean' else 'sum',
out_map[0]) graph, target, in_data_nd, out_data_nd, in_map[0], out_map[0])
# normalize if mean reducer
# NOTE(zihao): this is a temporary hack and we should have better solution in the future.
if reducer == 'mean':
in_ones = in_data.new_ones((in_data.shape[0],))
degs = in_data.new_empty((out_data.shape[0],))
in_ones_nd = zerocopy_to_dgl_ndarray(in_ones)
degs_nd = zerocopy_to_dgl_ndarray(degs)
K.copy_reduce(
'sum', graph, target, in_ones_nd, degs_nd, in_map[0], out_map[0])
# reshape
degs = degs.reshape((out_data.shape[0],) + (1,) * (out_data.dim() - 1)).clamp(min=1)
out_data = out_data / degs
else:
degs = None
# save_for_backward can only save variables # save_for_backward can only save variables
ctx.backward_cache = (reducer, graph, target, in_map, out_map, ctx.backward_cache = (reducer, graph, target, in_map, out_map,
in_data_nd, out_data_nd) in_data_nd, out_data_nd, degs)
return out_data return out_data
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
reducer, graph, target, in_map, out_map, in_data_nd, out_data_nd \ reducer, graph, target, in_map, out_map, in_data_nd, out_data_nd, degs \
= ctx.backward_cache = ctx.backward_cache
ctx.backward_cache = None ctx.backward_cache = None
grad_in = None grad_in = None
if reducer == 'mean':
grad_out = grad_out / degs
grad_out_nd = zerocopy_to_dgl_ndarray(grad_out) grad_out_nd = zerocopy_to_dgl_ndarray(grad_out)
if ctx.needs_input_grad[3]: if ctx.needs_input_grad[3]:
grad_in = grad_out.new_empty(in_data_nd.shape) grad_in = grad_out.new_empty(in_data_nd.shape)
K.backward_copy_reduce( K.backward_copy_reduce(
reducer, graph, target, in_data_nd, out_data_nd, grad_out_nd, reducer if reducer != 'mean' else 'sum',
graph, target, in_data_nd, out_data_nd, grad_out_nd,
zerocopy_to_dgl_ndarray(grad_in), in_map[1], out_map[1]) zerocopy_to_dgl_ndarray(grad_in), in_map[1], out_map[1])
return None, None, None, grad_in, None, None, None return None, None, None, grad_in, None, None, None
......
...@@ -51,7 +51,7 @@ class SimpleReduceFunction(ReduceFunction): ...@@ -51,7 +51,7 @@ class SimpleReduceFunction(ReduceFunction):
############################################################################### ###############################################################################
# Generate all following reducer functions: # Generate all following reducer functions:
# sum, max, min, prod # sum, max, min, mean, prod
def _gen_reduce_builtin(reducer): def _gen_reduce_builtin(reducer):
docstring = """Builtin reduce function that aggregates messages by {0}. docstring = """Builtin reduce function that aggregates messages by {0}.
...@@ -87,7 +87,7 @@ __all__ = [] ...@@ -87,7 +87,7 @@ __all__ = []
def _register_builtin_reduce_func(): def _register_builtin_reduce_func():
"""Register builtin reduce functions""" """Register builtin reduce functions"""
for reduce_op in ["max", "min", "sum", "prod"]: for reduce_op in ["max", "min", "sum", "mean", "prod"]:
builtin = _gen_reduce_builtin(reduce_op) builtin = _gen_reduce_builtin(reduce_op)
setattr(sys.modules[__name__], reduce_op, builtin) setattr(sys.modules[__name__], reduce_op, builtin)
__all__.append(reduce_op) __all__.append(reduce_op)
......
...@@ -14,6 +14,8 @@ def udf_copy_src(edges): ...@@ -14,6 +14,8 @@ def udf_copy_src(edges):
def udf_copy_edge(edges): def udf_copy_edge(edges):
return {'m': edges.data['e']} return {'m': edges.data['e']}
def udf_mean(nodes):
return {'r2': nodes.mailbox['m'].mean(1)}
def udf_sum(nodes): def udf_sum(nodes):
return {'r2': nodes.mailbox['m'].sum(1)} return {'r2': nodes.mailbox['m'].sum(1)}
...@@ -26,8 +28,8 @@ def udf_max(nodes): ...@@ -26,8 +28,8 @@ def udf_max(nodes):
D1 = 5 D1 = 5
D2 = 3 D2 = 3
D3 = 4 D3 = 4
builtin = {'sum': fn.sum, 'max': fn.max} builtin = {'sum': fn.sum, 'max': fn.max, 'mean': fn.mean}
udf_reduce = {'sum': udf_sum, 'max': udf_max} udf_reduce = {'sum': udf_sum, 'max': udf_max, 'mean': udf_mean}
fill_value = {'sum': 0, 'max': float("-inf")} fill_value = {'sum': 0, 'max': float("-inf")}
...@@ -57,17 +59,23 @@ def generate_feature(g, broadcast='none'): ...@@ -57,17 +59,23 @@ def generate_feature(g, broadcast='none'):
def test_copy_src_reduce(): def test_copy_src_reduce():
def _test(red): def _test(red, partial):
g = dgl.DGLGraph(nx.erdos_renyi_graph(100, 0.1)) g = dgl.DGLGraph(nx.erdos_renyi_graph(100, 0.1))
hu, hv, he = generate_feature(g, 'none') hu, hv, he = generate_feature(g, 'none')
if partial:
nid = F.tensor(list(range(0, 100, 2)))
g.ndata['u'] = F.attach_grad(F.clone(hu)) g.ndata['u'] = F.attach_grad(F.clone(hu))
g.ndata['v'] = F.attach_grad(F.clone(hv)) g.ndata['v'] = F.attach_grad(F.clone(hv))
g.edata['e'] = F.attach_grad(F.clone(he)) g.edata['e'] = F.attach_grad(F.clone(he))
with F.record_grad(): with F.record_grad():
g.update_all(fn.copy_src(src='u', out='m'), if partial:
builtin[red](msg='m', out='r1')) g.pull(nid, fn.copy_src(src='u', out='m'),
builtin[red](msg='m', out='r1'))
else:
g.update_all(fn.copy_src(src='u', out='m'),
builtin[red](msg='m', out='r1'))
r1 = g.ndata['r1'] r1 = g.ndata['r1']
F.backward(r1.sum()) F.backward(r1.sum())
n_grad1 = F.grad(g.ndata['u']) n_grad1 = F.grad(g.ndata['u'])
...@@ -78,7 +86,10 @@ def test_copy_src_reduce(): ...@@ -78,7 +86,10 @@ def test_copy_src_reduce():
g.edata['e'] = F.attach_grad(F.clone(he)) g.edata['e'] = F.attach_grad(F.clone(he))
with F.record_grad(): with F.record_grad():
g.update_all(udf_copy_src, udf_reduce[red]) if partial:
g.pull(nid, udf_copy_src, udf_reduce[red])
else:
g.update_all(udf_copy_src, udf_reduce[red])
r2 = g.ndata['r2'] r2 = g.ndata['r2']
F.backward(r2.sum()) F.backward(r2.sum())
n_grad2 = F.grad(g.ndata['u']) n_grad2 = F.grad(g.ndata['u'])
...@@ -86,21 +97,34 @@ def test_copy_src_reduce(): ...@@ -86,21 +97,34 @@ def test_copy_src_reduce():
assert F.allclose(r1, r2) assert F.allclose(r1, r2)
assert(F.allclose(n_grad1, n_grad2)) assert(F.allclose(n_grad1, n_grad2))
_test('sum') _test('sum', False)
_test('max') _test('max', False)
_test('mean', False)
_test('sum', True)
_test('max', True)
_test('mean', True)
def test_copy_edge_reduce(): def test_copy_edge_reduce():
def _test(red): def _test(red, partial):
g = dgl.DGLGraph(nx.erdos_renyi_graph(100, 0.1)) g = dgl.DGLGraph(nx.erdos_renyi_graph(100, 0.1))
hu, hv, he = generate_feature(g, 'none') hu, hv, he = generate_feature(g, 'none')
if partial:
nid = F.tensor(list(range(0, 100, 2)))
g.ndata['u'] = F.attach_grad(F.clone(hu)) g.ndata['u'] = F.attach_grad(F.clone(hu))
g.ndata['v'] = F.attach_grad(F.clone(hv)) g.ndata['v'] = F.attach_grad(F.clone(hv))
g.edata['e'] = F.attach_grad(F.clone(he)) g.edata['e'] = F.attach_grad(F.clone(he))
with F.record_grad(): with F.record_grad():
g.update_all(fn.copy_edge(edge='e', out='m'), if partial:
builtin[red](msg='m', out='r1')) g.pull(nid, fn.copy_edge(edge='e', out='m'),
builtin[red](msg='m', out='r1'))
else:
g.update_all(fn.copy_edge(edge='e', out='m'),
builtin[red](msg='m', out='r1'))
r1 = g.ndata['r1'] r1 = g.ndata['r1']
F.backward(r1.sum()) F.backward(r1.sum())
e_grad1 = F.grad(g.edata['e']) e_grad1 = F.grad(g.edata['e'])
...@@ -111,7 +135,10 @@ def test_copy_edge_reduce(): ...@@ -111,7 +135,10 @@ def test_copy_edge_reduce():
g.edata['e'] = F.attach_grad(F.clone(he)) g.edata['e'] = F.attach_grad(F.clone(he))
with F.record_grad(): with F.record_grad():
g.update_all(udf_copy_edge, udf_reduce[red]) if partial:
g.pull(nid, udf_copy_edge, udf_reduce[red])
else:
g.update_all(udf_copy_edge, udf_reduce[red])
r2 = g.ndata['r2'] r2 = g.ndata['r2']
F.backward(r2.sum()) F.backward(r2.sum())
e_grad2 = F.grad(g.edata['e']) e_grad2 = F.grad(g.edata['e'])
...@@ -119,12 +146,16 @@ def test_copy_edge_reduce(): ...@@ -119,12 +146,16 @@ def test_copy_edge_reduce():
assert F.allclose(r1, r2) assert F.allclose(r1, r2)
assert(F.allclose(e_grad1, e_grad2)) assert(F.allclose(e_grad1, e_grad2))
_test('sum') _test('sum', False)
_test('max') _test('max', False)
_test('mean', False)
_test('sum', True)
_test('max', True)
_test('mean', True)
def test_all_binary_builtins(): def test_all_binary_builtins():
def _test(g, lhs, rhs, binary_op, reducer, broadcast='none'): def _test(g, lhs, rhs, binary_op, reducer, paritial, nid, broadcast='none'):
hu, hv, he = generate_feature(g, broadcast) hu, hv, he = generate_feature(g, broadcast)
g.ndata['u'] = F.attach_grad(F.clone(hu)) g.ndata['u'] = F.attach_grad(F.clone(hu))
g.ndata['v'] = F.attach_grad(F.clone(hv)) g.ndata['v'] = F.attach_grad(F.clone(hv))
...@@ -143,8 +174,11 @@ def test_all_binary_builtins(): ...@@ -143,8 +174,11 @@ def test_all_binary_builtins():
return g.edata["e"] return g.edata["e"]
with F.record_grad(): with F.record_grad():
g.update_all(builtin_msg(lhs, rhs, 'm'), builtin_red('m', 'r1')) if partial:
r1 = g.ndata['r1'] g.pull(nid, builtin_msg(lhs, rhs, 'm'), builtin_red('m', 'r1'))
else:
g.update_all(builtin_msg(lhs, rhs, 'm'), builtin_red('m', 'r1'))
r1 = g.ndata.pop('r1')
F.backward(r1.sum()) F.backward(r1.sum())
lhs_grad_1 = F.grad(target_feature_switch(g, lhs)) lhs_grad_1 = F.grad(target_feature_switch(g, lhs))
rhs_grad_1 = F.grad(target_feature_switch(g, rhs)) rhs_grad_1 = F.grad(target_feature_switch(g, rhs))
...@@ -175,8 +209,11 @@ def test_all_binary_builtins(): ...@@ -175,8 +209,11 @@ def test_all_binary_builtins():
return {"r2": op(nodes.mailbox['m'], 1)} return {"r2": op(nodes.mailbox['m'], 1)}
with F.record_grad(): with F.record_grad():
g.update_all(mfunc, rfunc) if partial:
r2 = g.ndata['r2'] g.pull(nid, mfunc, rfunc)
else:
g.update_all(mfunc, rfunc)
r2 = g.ndata.pop('r2')
F.backward(r2.sum(), F.tensor([1.])) F.backward(r2.sum(), F.tensor([1.]))
lhs_grad_2 = F.grad(target_feature_switch(g, lhs)) lhs_grad_2 = F.grad(target_feature_switch(g, lhs))
rhs_grad_2 = F.grad(target_feature_switch(g, rhs)) rhs_grad_2 = F.grad(target_feature_switch(g, rhs))
...@@ -192,7 +229,7 @@ def test_all_binary_builtins(): ...@@ -192,7 +229,7 @@ def test_all_binary_builtins():
print("ERROR: Test {}_{}_{}_{} {}". print("ERROR: Test {}_{}_{}_{} {}".
format(lhs, binary_op, rhs, reducer, broadcast)) format(lhs, binary_op, rhs, reducer, broadcast))
print(a, b) print(a, b)
for i, (x, y) in enumerate(zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())): for i, (x, y) in enumerate(zip(F.asnumpy(F.cpu(a)).flatten(), F.asnumpy(F.cpu(b)).flatten())):
if not np.allclose(x, y, rtol, atol): if not np.allclose(x, y, rtol, atol):
print('@{} {} v.s. {}'.format(i, x, y)) print('@{} {} v.s. {}'.format(i, x, y))
...@@ -221,14 +258,16 @@ def test_all_binary_builtins(): ...@@ -221,14 +258,16 @@ def test_all_binary_builtins():
g.add_edge(18, 1) g.add_edge(18, 1)
g.add_edge(19, 0) g.add_edge(19, 0)
g.add_edge(19, 1) g.add_edge(19, 1)
nid = F.tensor([1, 3, 4, 5, 7, 10, 13, 17, 19])
target = ["u", "v", "e"] target = ["u", "v", "e"]
for lhs, rhs in product(target, target): for lhs, rhs in product(target, target):
if lhs == rhs: if lhs == rhs:
continue continue
for binary_op in ["add", "sub", "mul", "div"]: for binary_op in ["add", "sub", "mul", "div"]:
for reducer in ["sum", "max", "min", "prod"]: for reducer in ["sum", "max", "min", "prod", "mean"]:
for broadcast in ["none", lhs, rhs]: for broadcast in ["none", lhs, rhs]:
_test(g, lhs, rhs, binary_op, reducer) for partial in [False, True]:
_test(g, lhs, rhs, binary_op, reducer, partial, nid)
if __name__ == '__main__': if __name__ == '__main__':
test_copy_src_reduce() test_copy_src_reduce()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment