Commit d533e9ba authored by zengxy's avatar zengxy Committed by Da Zheng
Browse files

[BugFix] Fix getting src and dst id of ALL edges in NodeFlow.apply_block (#515)

parent d8c69d53
......@@ -723,8 +723,8 @@ class NodeFlow(DGLBaseGraph):
block_id : int
The specified block to update edge embeddings.
func : callable or None, optional
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
Apply function on the edges. The function should be
an :mod:`Edge UDF <dgl.udf>`.
edges : a list of edge Ids or ALL.
The edges to run the edge update function.
inplace : bool, optional
......@@ -734,12 +734,10 @@ class NodeFlow(DGLBaseGraph):
func = self._apply_edge_funcs[block_id]
assert func is not None
def _layer_local_nid(layer_id):
return F.arange(0, self.layer_size(layer_id))
if is_all(edges):
u = utils.toindex(_layer_local_nid(block_id))
v = utils.toindex(_layer_local_nid(block_id + 1))
u, v, _ = self.block_edges(block_id)
u = utils.toindex(u)
v = utils.toindex(v)
eid = utils.toindex(slice(0, self.block_size(block_id)))
elif isinstance(edges, tuple):
u, v = edges
......
......@@ -115,14 +115,22 @@ def check_apply_edges(create_node_flow):
num_layers = 2
for i in range(num_layers):
g = generate_rand_graph(100)
g.ndata["f"] = F.randn((100, 10))
nf = create_node_flow(g, num_layers)
nf.copy_from_parent()
new_feats = F.randn((nf.block_size(i), 5))
def update_func(nodes):
return {'h2' : new_feats}
def update_func(edges):
return {'h2': new_feats, "f2": edges.src["f"] + edges.dst["f"]}
nf.apply_block(i, update_func)
assert F.array_equal(nf.blocks[i].data['h2'], new_feats)
eids = nf.block_parent_eid(i)
srcs, dsts = g.find_edges(eids)
expected_f_sum = g.ndata["f"][srcs] + g.ndata["f"][dsts]
assert F.array_equal(nf.blocks[i].data['f2'], expected_f_sum)
def test_apply_edges():
check_apply_edges(create_full_nodeflow)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment