Unverified Commit 40caf1ab authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[HeteroGraph] Fix the failure of apply_nodes when the input function changes...

[HeteroGraph] Fix the failure of apply_nodes when the input function changes feature size for all nodes (#2223)

* Fix

* Fix
parent 7e0107c3
...@@ -4016,10 +4016,10 @@ class DGLHeteroGraph(object): ...@@ -4016,10 +4016,10 @@ class DGLHeteroGraph(object):
ntid = self.get_ntype_id(ntype) ntid = self.get_ntype_id(ntype)
ntype = self.ntypes[ntid] ntype = self.ntypes[ntid]
if is_all(v): if is_all(v):
v = self.nodes(ntype) v_id = self.nodes(ntype)
else: else:
v = utils.prepare_tensor(self, v, 'v') v_id = utils.prepare_tensor(self, v, 'v')
ndata = core.invoke_node_udf(self, v, ntype, func, orig_nid=v) ndata = core.invoke_node_udf(self, v_id, ntype, func, orig_nid=v_id)
self._set_n_repr(ntid, v, ndata) self._set_n_repr(ntid, v, ndata)
def apply_edges(self, func, edges=ALL, etype=None, inplace=False): def apply_edges(self, func, edges=ALL, etype=None, inplace=False):
......
...@@ -1295,6 +1295,8 @@ def test_subgraph(idtype): ...@@ -1295,6 +1295,8 @@ def test_subgraph(idtype):
def test_apply(idtype): def test_apply(idtype):
def node_udf(nodes): def node_udf(nodes):
return {'h': nodes.data['h'] * 2} return {'h': nodes.data['h'] * 2}
def node_udf2(nodes):
return {'h': F.sum(nodes.data['h'], dim=1, keepdims=True)}
def edge_udf(edges): def edge_udf(edges):
return {'h': edges.data['h'] * 2 + edges.src['h']} return {'h': edges.data['h'] * 2 + edges.src['h']}
...@@ -1314,6 +1316,11 @@ def test_apply(idtype): ...@@ -1314,6 +1316,11 @@ def test_apply(idtype):
g['plays'].apply_edges(edge_udf) g['plays'].apply_edges(edge_udf)
assert F.array_equal(g['plays'].edata['h'], F.ones((4, 5)) * 12) assert F.array_equal(g['plays'].edata['h'], F.ones((4, 5)) * 12)
# Test the case that feature size changes
g.nodes['user'].data['h'] = F.ones((3, 5))
g.apply_nodes(node_udf2, ntype='user')
assert F.array_equal(g.nodes['user'].data['h'], F.ones((3, 1)) * 5)
# test fail case # test fail case
# fail due to multiple types # fail due to multiple types
with pytest.raises(DGLError): with pytest.raises(DGLError):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment