Commit 5e75f5db authored by Gan Quan's avatar Gan Quan Committed by Minjie Wang
Browse files

Fix reducing with 0 incoming messages (#48)

* fixing 0 degree reduces

* adopting accum=None semantics

* minor fix as per reviews

* TODO in test

* fix test
parent c42eac71
......@@ -541,7 +541,7 @@ class DGLGraph(DiGraph):
def _batch_recv(self, v, reduce_func, update_func):
f_update = update_func
reordered_v, all_reduced_msgs = self._batch_reduce(v, reduce_func)
null_v, reordered_v, all_reduced_msgs = self._batch_reduce(v, reduce_func)
if all_reduced_msgs is None:
# no message; only do recv.
if is_all(v):
......@@ -550,11 +550,23 @@ class DGLGraph(DiGraph):
self.set_n_repr(f_update(self.get_n_repr(v), None), v)
else:
# Read the node states in the degree-bucketing order.
if len(null_v) == 0:
null_ns = new_null_ns = None
else:
null_ns = self.get_n_repr(null_v)
new_null_ns = f_update(null_ns, None)
if len(reordered_v) == 0:
reordered_ns = new_reordered_ns = None
else:
reordered_ns = self.get_n_repr(reordered_v)
new_ns = f_update(reordered_ns, all_reduced_msgs)
new_reordered_ns = f_update(reordered_ns, all_reduced_msgs)
v_tensor = utils.pack2(null_v.totensor(), reordered_v.totensor())
new_ns = utils.pack2(new_null_ns, new_reordered_ns)
if is_all(v):
# First do reorder and then replace the whole column.
_, indices = F.sort(reordered_v.totensor())
_, indices = F.sort(v_tensor)
indices = utils.toindex(indices)
# TODO(minjie): following code should be included in Frame somehow.
if isinstance(new_ns, dict):
......@@ -566,12 +578,12 @@ class DGLGraph(DiGraph):
self._node_frame[__REPR__] = F.gather_row(new_ns, idx)
else:
# Use setter to do reorder.
self.set_n_repr(new_ns, reordered_v)
self.set_n_repr(new_ns, v_tensor)
def _batch_reduce(self, v, reduce_func):
if is_all(v) and len(self._msg_frame) == 0:
# no message has been sent
return None, None
return None, None, None
if is_all(v):
v = list(range(self.number_of_nodes()))
......@@ -585,11 +597,18 @@ class DGLGraph(DiGraph):
# degree bucketing
degrees, v_buckets = scheduler.degree_bucketing(self.msg_graph, v)
null_v_bucket = None
non_null_v_buckets = []
reduced_msgs = []
for deg, v_bkt in zip(degrees, v_buckets):
bkt_len = len(v_bkt)
dst_reprs = self.get_n_repr(v_bkt)
if deg == 0:
assert null_v_bucket is None
null_v_bucket = v_bkt
continue
bkt_len = len(v_bkt)
uu, vv = self.msg_graph.in_edges(v_bkt)
in_msg_ids = self.msg_graph.get_edge_id(uu, vv)
in_msgs = self._msg_frame.select_rows(in_msg_ids)
......@@ -603,19 +622,24 @@ class DGLGraph(DiGraph):
else:
reshaped_in_msgs = utils.LazyDict(
lambda key: _reshape_fn(in_msgs[key]), self._msg_frame.schemes)
dst_reprs = self.get_n_repr(v_bkt)
non_null_v_buckets.append(v_bkt)
reduced_msgs.append(f_reduce(dst_reprs, reshaped_in_msgs))
# FIXME: this will only trigger if reduced_msgs is empty. Remove?
if len(reduced_msgs) == 0:
# no message has been sent to the specified node
return None, None
return None, None, None
# TODO: clear partial messages
self.clear_messages()
# Read the node states in the degree-bucketing order.
reordered_v = utils.toindex(F.pack(
[v_bkt.totensor() for v_bkt in v_buckets]))
null_v = utils.toindex(null_v_bucket or [])
reordered_v = utils.toindex(
F.pack([v_bkt.totensor() for v_bkt in non_null_v_buckets])
if len(non_null_v_buckets) > 0 else []
)
# Pack all reduced msgs together
if isinstance(reduced_msgs[0], dict):
keys = reduced_msgs[0].keys()
......@@ -625,7 +649,7 @@ class DGLGraph(DiGraph):
else:
all_reduced_msgs = F.pack(reduced_msgs)
return reordered_v, all_reduced_msgs
return null_v, reordered_v, all_reduced_msgs
def update_by_edge(self,
u, v,
......
......@@ -289,3 +289,14 @@ def cached_member(func):
else:
return func(self)
return wrapper
def pack2(a, b):
if a is None:
return b
elif b is None:
return a
else:
if isinstance(a, dict):
return {k: F.pack([a[k], b[k]]) for k in a}
else:
return F.pack([a, b])
......@@ -222,6 +222,34 @@ def test_update_routines():
assert(reduce_msg_shapes == {(1, 8, D), (9, 1, D)})
reduce_msg_shapes.clear()
def test_reduce_0deg():
g = DGLGraph()
g.add_nodes_from([0, 1, 2, 3, 4])
g.add_edge(1, 0)
g.add_edge(2, 0)
g.add_edge(3, 0)
g.add_edge(4, 0)
def _message(src, edge):
return src
def _reduce(node, msgs):
assert msgs is not None
return msgs.sum(1)
def _update(node, accum):
if node.shape[0] == 4:
assert accum is None
return node
else:
assert accum is not None
return node + accum
old_repr = th.randn(5, 5)
g.set_n_repr(old_repr)
g.update_all(_message, _reduce, _update, True)
new_repr = g.get_n_repr()
assert th.allclose(new_repr[1:], old_repr[1:])
assert th.allclose(new_repr[0], old_repr.sum(0))
def _test_delete():
g = generate_graph()
ecol = Variable(th.randn(17, D), requires_grad=grad)
......@@ -239,4 +267,5 @@ if __name__ == '__main__':
test_batch_recv1()
test_batch_recv2()
test_update_routines()
test_reduce_0deg()
#test_delete()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment