Unverified Commit bacc9047 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[BugFix] initialize data if null when converting from row sorted coo to csr (#3360)

parent 2647afc9
......@@ -323,7 +323,7 @@ template <class IdType> CSRMatrix SortedCOOToCSR(const COOMatrix &coo) {
Bp[0] = 0;
IdType *const fill_data =
data ? nullptr : static_cast<IdType *>(coo.data->data);
data ? nullptr : static_cast<IdType *>(ret_data->data);
if (NNZ > 0) {
auto num_threads = omp_get_max_threads();
......
......@@ -847,6 +847,16 @@ def test_to_simple(idtype):
assert 'h' not in sg.nodes['user'].data
assert 'hh' not in sg.nodes['user'].data
# verify DGLGraph.edge_ids() after dgl.to_simple()
# in case ids are not initialized in underlying coo2csr()
u = F.tensor([0, 1, 2])
v = F.tensor([1, 2, 3])
eids = F.tensor([0, 1, 2])
g = dgl.graph((u, v))
assert F.array_equal(g.edge_ids(u, v), eids)
sg = dgl.to_simple(g)
assert F.array_equal(sg.edge_ids(u, v), eids)
@parametrize_dtype
def test_to_block(idtype):
def check(g, bg, ntype, etype, dst_nodes, include_dst_in_src=True):
......
......@@ -148,6 +148,39 @@ bool isSparseCOO(const int64_t &num_threads, const int64_t &num_nodes,
// refer to COOToCSR<>() in ~dgl/src/array/cpu/spmat_op_impl_coo for details.
return num_threads * num_nodes > 4 * num_edges;
}
template <typename IDX>
aten::COOMatrix RowSorted_NullData_COO(DLContext ctx = CTX) {
// [[0, 1, 1, 0, 0],
// [1, 0, 0, 0, 0],
// [0, 0, 1, 1, 0],
// [0, 0, 0, 0, 0]]
// row : [0, 0, 1, 2, 2]
// col : [1, 2, 0, 2, 3]
return aten::COOMatrix(4, 5,
aten::VecToIdArray(std::vector<IDX>({0, 0, 1, 2, 2}),
sizeof(IDX) * 8, ctx),
aten::VecToIdArray(std::vector<IDX>({1, 2, 0, 2, 3}),
sizeof(IDX) * 8, ctx),
aten::NullArray(), true, false);
}
template <typename IDX>
aten::CSRMatrix RowSorted_NullData_CSR(DLContext ctx = CTX) {
// [[0, 1, 1, 0, 0],
// [1, 0, 0, 0, 0],
// [0, 0, 1, 1, 0],
// [0, 0, 0, 0, 0]]
// data: [0, 1, 2, 3, 4]
return aten::CSRMatrix(4, 5,
aten::VecToIdArray(std::vector<IDX>({0, 2, 3, 5, 5}),
sizeof(IDX) * 8, ctx),
aten::VecToIdArray(std::vector<IDX>({1, 2, 0, 2, 3}),
sizeof(IDX) * 8, ctx),
aten::VecToIdArray(std::vector<IDX>({0, 1, 2, 3, 4}),
sizeof(IDX) * 8, ctx),
false);
}
} // namespace
template <typename IDX>
......@@ -192,6 +225,20 @@ void _TestCOOToCSR(DLContext ctx) {
ASSERT_TRUE(ArrayEQ<IDX>(rs_tcsr.indices, rs_coo.col));
ASSERT_TRUE(ArrayEQ<IDX>(rs_tcsr.data, rs_coo.data));
rs_coo = RowSorted_NullData_COO<IDX>(ctx);
ASSERT_TRUE(rs_coo.row_sorted);
rs_csr = RowSorted_NullData_CSR<IDX>(ctx);
rs_tcsr = aten::COOToCSR(rs_coo);
ASSERT_EQ(coo.num_rows, rs_tcsr.num_rows);
ASSERT_EQ(rs_csr.num_rows, rs_tcsr.num_rows);
ASSERT_EQ(coo.num_cols, rs_tcsr.num_cols);
ASSERT_EQ(rs_csr.num_cols, rs_tcsr.num_cols);
ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.indptr, rs_tcsr.indptr));
ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.indices, rs_tcsr.indices));
ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.data, rs_tcsr.data));
ASSERT_TRUE(ArrayEQ<IDX>(rs_coo.col, rs_tcsr.indices));
ASSERT_FALSE(ArrayEQ<IDX>(rs_coo.data, rs_tcsr.data));
// Convert from col sorted coo
coo = COO1<IDX>(ctx);
auto src_coo = aten::COOSort(coo, true);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment