Unverified Commit 365d3617 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bug] Fix #1036 (#1037)

* fix

* unit test
parent 287f387b
......@@ -2208,7 +2208,7 @@ class DGLGraph(DGLBaseGraph):
raise DGLError("Group_by should be either src or dst")
if is_all(edges):
u, v, _ = self._graph.edges()
u, v, _ = self._graph.edges('eid')
eid = utils.toindex(slice(0, self.number_of_edges()))
elif isinstance(edges, tuple):
u, v = edges
......@@ -2270,7 +2270,7 @@ class DGLGraph(DGLBaseGraph):
if is_all(edges):
eid = utils.toindex(slice(0, self.number_of_edges()))
u, v, _ = self._graph.edges()
u, v, _ = self._graph.edges('eid')
elif isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
......
......@@ -2383,7 +2383,7 @@ class DGLHeteroGraph(object):
etid = self.get_etype_id(etype)
stid, dtid = self._graph.metagraph.find_edge(etid)
if is_all(edges):
u, v, _ = self._graph.edges(etid)
u, v, _ = self._graph.edges(etid, 'eid')
eid = utils.toindex(slice(0, self.number_of_edges(etype)))
elif isinstance(edges, tuple):
u, v = edges
......@@ -2468,7 +2468,7 @@ class DGLHeteroGraph(object):
if is_all(edges):
eid = utils.toindex(slice(0, self._graph.number_of_edges(etid)))
u, v, _ = self._graph.edges(etid)
u, v, _ = self._graph.edges(etid, 'eid')
elif isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
......
import backend as F
import dgl
import numpy as np
import scipy.sparse as ssp
import networkx as nx
from dgl import DGLGraph
from collections import defaultdict as ddict
......@@ -654,6 +656,28 @@ def test_group_apply_edges():
# test group by destination nodes
_test('dst')
# GitHub issue #1036
def test_group_apply_edges2():
m = ssp.random(10, 10, 0.2)
g = DGLGraph(m, readonly=True)
g.ndata['deg'] = g.in_degrees()
g.ndata['id'] = F.arange(0, g.number_of_nodes())
g.edata['id'] = F.arange(0, g.number_of_edges())
def apply(edges):
w = edges.data['id']
n_nodes, deg = w.shape
dst = edges.dst['id'][:, 0]
eid1 = F.asnumpy(g.in_edges(dst, 'eid')).reshape(n_nodes, deg).sort(1)
eid2 = F.asnumpy(edges.data['id']).sort(1)
assert np.array_equal(eid1, eid2)
return {'id2': w}
g.group_apply_edges('dst', apply, inplace=True)
def test_local_var():
g = DGLGraph(nx.path_graph(5))
g.ndata['h'] = F.zeros((g.number_of_nodes(), 3))
......@@ -803,5 +827,6 @@ if __name__ == '__main__':
test_dynamic_addition()
test_repr()
test_group_apply_edges()
test_group_apply_edges2()
test_local_var()
test_local_scope()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment