Unverified Commit 365d3617 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bug] Fix #1036 (#1037)

* fix

* unit test
parent 287f387b
...@@ -2208,7 +2208,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -2208,7 +2208,7 @@ class DGLGraph(DGLBaseGraph):
raise DGLError("Group_by should be either src or dst") raise DGLError("Group_by should be either src or dst")
if is_all(edges): if is_all(edges):
u, v, _ = self._graph.edges() u, v, _ = self._graph.edges('eid')
eid = utils.toindex(slice(0, self.number_of_edges())) eid = utils.toindex(slice(0, self.number_of_edges()))
elif isinstance(edges, tuple): elif isinstance(edges, tuple):
u, v = edges u, v = edges
...@@ -2270,7 +2270,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -2270,7 +2270,7 @@ class DGLGraph(DGLBaseGraph):
if is_all(edges): if is_all(edges):
eid = utils.toindex(slice(0, self.number_of_edges())) eid = utils.toindex(slice(0, self.number_of_edges()))
u, v, _ = self._graph.edges() u, v, _ = self._graph.edges('eid')
elif isinstance(edges, tuple): elif isinstance(edges, tuple):
u, v = edges u, v = edges
u = utils.toindex(u) u = utils.toindex(u)
......
...@@ -2383,7 +2383,7 @@ class DGLHeteroGraph(object): ...@@ -2383,7 +2383,7 @@ class DGLHeteroGraph(object):
etid = self.get_etype_id(etype) etid = self.get_etype_id(etype)
stid, dtid = self._graph.metagraph.find_edge(etid) stid, dtid = self._graph.metagraph.find_edge(etid)
if is_all(edges): if is_all(edges):
u, v, _ = self._graph.edges(etid) u, v, _ = self._graph.edges(etid, 'eid')
eid = utils.toindex(slice(0, self.number_of_edges(etype))) eid = utils.toindex(slice(0, self.number_of_edges(etype)))
elif isinstance(edges, tuple): elif isinstance(edges, tuple):
u, v = edges u, v = edges
...@@ -2468,7 +2468,7 @@ class DGLHeteroGraph(object): ...@@ -2468,7 +2468,7 @@ class DGLHeteroGraph(object):
if is_all(edges): if is_all(edges):
eid = utils.toindex(slice(0, self._graph.number_of_edges(etid))) eid = utils.toindex(slice(0, self._graph.number_of_edges(etid)))
u, v, _ = self._graph.edges(etid) u, v, _ = self._graph.edges(etid, 'eid')
elif isinstance(edges, tuple): elif isinstance(edges, tuple):
u, v = edges u, v = edges
u = utils.toindex(u) u = utils.toindex(u)
......
import backend as F import backend as F
import dgl import dgl
import numpy as np
import scipy.sparse as ssp
import networkx as nx import networkx as nx
from dgl import DGLGraph from dgl import DGLGraph
from collections import defaultdict as ddict from collections import defaultdict as ddict
...@@ -654,6 +656,28 @@ def test_group_apply_edges(): ...@@ -654,6 +656,28 @@ def test_group_apply_edges():
# test group by destination nodes # test group by destination nodes
_test('dst') _test('dst')
# GitHub issue #1036
def test_group_apply_edges2():
m = ssp.random(10, 10, 0.2)
g = DGLGraph(m, readonly=True)
g.ndata['deg'] = g.in_degrees()
g.ndata['id'] = F.arange(0, g.number_of_nodes())
g.edata['id'] = F.arange(0, g.number_of_edges())
def apply(edges):
w = edges.data['id']
n_nodes, deg = w.shape
dst = edges.dst['id'][:, 0]
eid1 = F.asnumpy(g.in_edges(dst, 'eid')).reshape(n_nodes, deg).sort(1)
eid2 = F.asnumpy(edges.data['id']).sort(1)
assert np.array_equal(eid1, eid2)
return {'id2': w}
g.group_apply_edges('dst', apply, inplace=True)
def test_local_var(): def test_local_var():
g = DGLGraph(nx.path_graph(5)) g = DGLGraph(nx.path_graph(5))
g.ndata['h'] = F.zeros((g.number_of_nodes(), 3)) g.ndata['h'] = F.zeros((g.number_of_nodes(), 3))
...@@ -803,5 +827,6 @@ if __name__ == '__main__': ...@@ -803,5 +827,6 @@ if __name__ == '__main__':
test_dynamic_addition() test_dynamic_addition()
test_repr() test_repr()
test_group_apply_edges() test_group_apply_edges()
test_group_apply_edges2()
test_local_var() test_local_var()
test_local_scope() test_local_scope()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment