Unverified Commit ce27ebbb authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[BUGIFX] fix a bug in edge_ids (#560)

* add test.

* fix compute.

* fix test.

* turn on test.

* fix a bug.

* add test.

* fix.

* disable test.
parent 924efc65
...@@ -606,7 +606,8 @@ class ImmutableGraph: public GraphInterface { ...@@ -606,7 +606,8 @@ class ImmutableGraph: public GraphInterface {
*/ */
EdgeArray EdgeIds(IdArray src, IdArray dst) const override { EdgeArray EdgeIds(IdArray src, IdArray dst) const override {
if (in_csr_) { if (in_csr_) {
return in_csr_->EdgeIds(dst, src); EdgeArray edges = in_csr_->EdgeIds(dst, src);
return EdgeArray{edges.dst, edges.src, edges.id};
} else { } else {
return GetOutCSR()->EdgeIds(src, dst); return GetOutCSR()->EdgeIds(src, dst);
} }
......
...@@ -3,6 +3,7 @@ import backend as F ...@@ -3,6 +3,7 @@ import backend as F
import networkx as nx import networkx as nx
import numpy as np import numpy as np
import scipy as sp import scipy as sp
from scipy import sparse as spsp
import dgl import dgl
from dgl.graph_index import map_to_subgraph_nid, GraphIndex, create_graph_index from dgl.graph_index import map_to_subgraph_nid, GraphIndex, create_graph_index
from dgl import utils from dgl import utils
...@@ -182,8 +183,21 @@ def test_load_csr(): ...@@ -182,8 +183,21 @@ def test_load_csr():
assert np.all(F.asnumpy(src) == coo.row) assert np.all(F.asnumpy(src) == coo.row)
assert np.all(F.asnumpy(dst) == coo.col) assert np.all(F.asnumpy(dst) == coo.col)
def test_edge_ids():
np.random.seed(0)
csr = (spsp.random(20, 20, density=0.1, format='csr') != 0).astype(np.int64)
#csr = csr.transpose()
g = dgl.DGLGraph(csr, readonly=True)
num_nodes = g.number_of_nodes()
in_edges = g._graph.in_edges(v=dgl.utils.toindex([2]))
src, dst, eids = g._graph.edge_ids(dgl.utils.toindex(in_edges[0]),
dgl.utils.toindex(in_edges[1]))
assert np.all(in_edges[0].tonumpy() == src.tonumpy())
assert np.all(in_edges[1].tonumpy() == dst.tonumpy())
if __name__ == '__main__': if __name__ == '__main__':
test_basics() test_basics()
test_edge_ids()
test_graph_gen() test_graph_gen()
test_node_subgraph() test_node_subgraph()
test_create_graph() test_create_graph()
......
...@@ -107,16 +107,13 @@ def check_compute_func(worker_id, graph_name): ...@@ -107,16 +107,13 @@ def check_compute_func(worker_id, graph_name):
g.init_ndata('tmp', (g.number_of_nodes(), 10), 'float32') g.init_ndata('tmp', (g.number_of_nodes(), 10), 'float32')
data = g.nodes[:].data['tmp'] data = g.nodes[:].data['tmp']
# Test pull # Test pull
assert np.all(data[1].asnumpy() != g.nodes[1].data['preprocess'].asnumpy())
g.pull(1, fn.copy_src(src='feat', out='m'), fn.sum(msg='m', out='tmp')) g.pull(1, fn.copy_src(src='feat', out='m'), fn.sum(msg='m', out='tmp'))
assert np.all(data[1].asnumpy() == g.nodes[1].data['preprocess'].asnumpy()) assert np.all(data[1].asnumpy() == g.nodes[1].data['preprocess'].asnumpy())
# Test send_and_recv # Test send_and_recv
# TODO(zhengda) it seems the test fails because send_and_recv has a bug in_edges = g.in_edges(v=2)
#in_edges = g.in_edges(v=2) g.send_and_recv(in_edges, fn.copy_src(src='feat', out='m'), fn.sum(msg='m', out='tmp'))
#assert np.all(data[2].asnumpy() != g.nodes[2].data['preprocess'].asnumpy()) assert np.all(data[2].asnumpy() == g.nodes[2].data['preprocess'].asnumpy())
#g.send_and_recv(in_edges, fn.copy_src(src='feat', out='m'), fn.sum(msg='m', out='tmp'))
#assert np.all(data[2].asnumpy() == g.nodes[2].data['preprocess'].asnumpy())
g.destroy() g.destroy()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment