Unverified Commit ce27ebbb authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[BUGIFX] fix a bug in edge_ids (#560)

* add test.

* fix compute.

* fix test.

* turn on test.

* fix a bug.

* add test.

* fix.

* disable test.
parent 924efc65
......@@ -606,7 +606,8 @@ class ImmutableGraph: public GraphInterface {
*/
EdgeArray EdgeIds(IdArray src, IdArray dst) const override {
if (in_csr_) {
return in_csr_->EdgeIds(dst, src);
EdgeArray edges = in_csr_->EdgeIds(dst, src);
return EdgeArray{edges.dst, edges.src, edges.id};
} else {
return GetOutCSR()->EdgeIds(src, dst);
}
......
......@@ -3,6 +3,7 @@ import backend as F
import networkx as nx
import numpy as np
import scipy as sp
from scipy import sparse as spsp
import dgl
from dgl.graph_index import map_to_subgraph_nid, GraphIndex, create_graph_index
from dgl import utils
......@@ -182,8 +183,21 @@ def test_load_csr():
assert np.all(F.asnumpy(src) == coo.row)
assert np.all(F.asnumpy(dst) == coo.col)
def test_edge_ids():
np.random.seed(0)
csr = (spsp.random(20, 20, density=0.1, format='csr') != 0).astype(np.int64)
#csr = csr.transpose()
g = dgl.DGLGraph(csr, readonly=True)
num_nodes = g.number_of_nodes()
in_edges = g._graph.in_edges(v=dgl.utils.toindex([2]))
src, dst, eids = g._graph.edge_ids(dgl.utils.toindex(in_edges[0]),
dgl.utils.toindex(in_edges[1]))
assert np.all(in_edges[0].tonumpy() == src.tonumpy())
assert np.all(in_edges[1].tonumpy() == dst.tonumpy())
if __name__ == '__main__':
test_basics()
test_edge_ids()
test_graph_gen()
test_node_subgraph()
test_create_graph()
......
......@@ -107,16 +107,13 @@ def check_compute_func(worker_id, graph_name):
g.init_ndata('tmp', (g.number_of_nodes(), 10), 'float32')
data = g.nodes[:].data['tmp']
# Test pull
assert np.all(data[1].asnumpy() != g.nodes[1].data['preprocess'].asnumpy())
g.pull(1, fn.copy_src(src='feat', out='m'), fn.sum(msg='m', out='tmp'))
assert np.all(data[1].asnumpy() == g.nodes[1].data['preprocess'].asnumpy())
# Test send_and_recv
# TODO(zhengda) it seems the test fails because send_and_recv has a bug
#in_edges = g.in_edges(v=2)
#assert np.all(data[2].asnumpy() != g.nodes[2].data['preprocess'].asnumpy())
#g.send_and_recv(in_edges, fn.copy_src(src='feat', out='m'), fn.sum(msg='m', out='tmp'))
#assert np.all(data[2].asnumpy() == g.nodes[2].data['preprocess'].asnumpy())
in_edges = g.in_edges(v=2)
g.send_and_recv(in_edges, fn.copy_src(src='feat', out='m'), fn.sum(msg='m', out='tmp'))
assert np.all(data[2].asnumpy() == g.nodes[2].data['preprocess'].asnumpy())
g.destroy()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment