"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "5e6417e9887be8f02ab5b4f5c548dff7f3a4c8f6"
Unverified Commit fd658745 authored by Andrei Ivanov's avatar Andrei Ivanov Committed by GitHub
Browse files

Improving sparse_ops tests. (#6147)

parent 97697055
......@@ -43,7 +43,8 @@ def _random_simple_graph(
@parametrize_idtype
@pytest.mark.parametrize("dtype", [F.float32, F.float64])
def test_csrmm(idtype, dtype):
@pytest.mark.parametrize("return_edge_ids", [True, False])
def test_csrmm(idtype, dtype, return_edge_ids):
a, A = _random_simple_graph(
idtype, dtype, F.ctx(), 500, 600, 9000, "A", "B", "AB"
)
......@@ -53,7 +54,7 @@ def test_csrmm(idtype, dtype):
C, C_weights = dgl._sparse_ops._csrmm(
A._graph, A.edata["w"], B._graph, B.edata["w"], 2
)
C_adj = C.adjacency_matrix_scipy(0, False, "csr")
C_adj = C.adjacency_matrix_scipy(0, False, "csr", return_edge_ids)
C_adj.data = F.asnumpy(C_weights)
C_adj = F.tensor(C_adj.todense(), dtype=dtype)
c = F.tensor((a * b).todense(), dtype=dtype)
......@@ -111,7 +112,8 @@ def test_csrmm_backward(idtype, dtype, num_vtypes):
@parametrize_idtype
@pytest.mark.parametrize("dtype", [F.float32, F.float64])
def test_csrsum(idtype, dtype):
@pytest.mark.parametrize("return_edge_ids", [True, False])
def test_csrsum(idtype, dtype, return_edge_ids):
a, A = _random_simple_graph(
idtype, dtype, F.ctx(), 500, 600, 9000, "A", "B", "AB"
)
......@@ -121,7 +123,7 @@ def test_csrsum(idtype, dtype):
C, C_weights = dgl._sparse_ops._csrsum(
[A._graph, B._graph], [A.edata["w"], B.edata["w"]]
)
C_adj = C.adjacency_matrix_scipy(0, False, "csr")
C_adj = C.adjacency_matrix_scipy(0, False, "csr", return_edge_ids)
C_adj.data = F.asnumpy(C_weights)
C_adj = F.tensor(C_adj.todense(), dtype=dtype)
c = F.tensor((a + b).todense(), dtype=dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment