Unverified Commit 27cad329 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Kernel] Matrix Union (#1752)



* Matrix union

* Pass test

* Fix lint

* return map for unionCOO/unionCSR

* Revert "return map for unionCOO/unionCSR"

This reverts commit 28e96c40f0659f02b33d88bcf528af6a6267726d.

* Update

* lint

* lint

* Fix doc
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
parent 29e6c93f
...@@ -373,7 +373,39 @@ COOMatrix COORowWiseTopk( ...@@ -373,7 +373,39 @@ COOMatrix COORowWiseTopk(
bool ascending = false); bool ascending = false);
/*! /*!
* \brief Union a list COOMatrix into one COOMatrix. * \brief Union two COOMatrix into one COOMatrix.
*
* Two Matrix must have the same shape.
*
* Example:
*
* A = [[0, 0, 1, 0],
* [1, 0, 1, 1],
* [0, 1, 0, 0]]
*
* B = [[0, 1, 1, 0],
* [0, 0, 0, 1],
* [0, 0, 1, 0]]
*
* COOMatrix_A.num_rows : 3
* COOMatrix_A.num_cols : 4
* COOMatrix_B.num_rows : 3
* COOMatrix_B.num_cols : 4
*
* C = UnionCoo({A, B});
*
* C = [[0, 1, 2, 0],
* [1, 0, 1, 2],
* [0, 1, 1, 0]]
*
* COOMatrix_C.num_rows : 3
* COOMatrix_C.num_cols : 4
*/
COOMatrix UnionCoo(
const std::vector<COOMatrix>& coos);
/*!
* \brief DisjointUnion a list COOMatrix into one COOMatrix.
* *
* Examples: * Examples:
* *
......
...@@ -367,6 +367,38 @@ COOMatrix CSRRowWiseTopk( ...@@ -367,6 +367,38 @@ COOMatrix CSRRowWiseTopk(
FloatArray weight, FloatArray weight,
bool ascending = false); bool ascending = false);
/*!
* \brief Union two CSRMatrix into one CSRMatrix.
*
* Two Matrix must have the same shape.
*
* Example:
*
* A = [[0, 0, 1, 0],
* [1, 0, 1, 1],
* [0, 1, 0, 0]]
*
* B = [[0, 1, 1, 0],
* [0, 0, 0, 1],
* [0, 0, 1, 0]]
*
* CSRMatrix_A.num_rows : 3
* CSRMatrix_A.num_cols : 4
* CSRMatrix_B.num_rows : 3
* CSRMatrix_B.num_cols : 4
*
* C = UnionCsr({A, B});
*
* C = [[0, 1, 2, 0],
* [1, 0, 1, 2],
* [0, 1, 1, 0]]
*
* CSRMatrix_C.num_rows : 3
* CSRMatrix_C.num_cols : 4
*/
CSRMatrix UnionCsr(
const std::vector<CSRMatrix>& csrs);
/*! /*!
* \brief Union a list CSRMatrix into one CSRMatrix. * \brief Union a list CSRMatrix into one CSRMatrix.
* *
......
...@@ -465,14 +465,6 @@ CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArr ...@@ -465,14 +465,6 @@ CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArr
return ret; return ret;
} }
COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids) {
COOMatrix ret;
ATEN_COO_SWITCH(coo, XPU, IdType, "COOReorder", {
ret = impl::COOReorder<XPU, IdType>(coo, new_row_ids, new_col_ids);
});
return ret;
}
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) { CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) {
CSRMatrix ret; CSRMatrix ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRRemove", { ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRRemove", {
...@@ -509,6 +501,27 @@ COOMatrix CSRRowWiseTopk( ...@@ -509,6 +501,27 @@ COOMatrix CSRRowWiseTopk(
return ret; return ret;
} }
CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs) {
CSRMatrix ret;
CHECK_GT(csrs.size(), 1) << "UnionCsr creates a union of multiple CSRMatrixes";
// sanity check
for (size_t i = 1; i < csrs.size(); ++i) {
CHECK_EQ(csrs[0].num_rows, csrs[i].num_rows) <<
"UnionCsr requires both CSRMatrix have same number of rows";
CHECK_EQ(csrs[0].num_cols, csrs[i].num_cols) <<
"UnionCsr requires both CSRMatrix have same number of cols";
CHECK_SAME_CONTEXT(csrs[0].indptr, csrs[i].indptr);
CHECK_SAME_DTYPE(csrs[0].indptr, csrs[i].indptr);
}
ATEN_CSR_SWITCH(csrs[0], XPU, IdType, "UnionCsr", {
ret = impl::UnionCsr<XPU, IdType>(csrs);
});
return ret;
}
std::tuple<CSRMatrix, IdArray, IdArray> std::tuple<CSRMatrix, IdArray, IdArray>
CSRToSimple(const CSRMatrix& csr) { CSRToSimple(const CSRMatrix& csr) {
std::tuple<CSRMatrix, IdArray, IdArray> ret; std::tuple<CSRMatrix, IdArray, IdArray> ret;
...@@ -645,6 +658,14 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo) { ...@@ -645,6 +658,14 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
return ret; return ret;
} }
COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids) {
COOMatrix ret;
ATEN_COO_SWITCH(coo, XPU, IdType, "COOReorder", {
ret = impl::COOReorder<XPU, IdType>(coo, new_row_ids, new_col_ids);
});
return ret;
}
COOMatrix COORemove(COOMatrix coo, IdArray entries) { COOMatrix COORemove(COOMatrix coo, IdArray entries) {
COOMatrix ret; COOMatrix ret;
ATEN_COO_SWITCH(coo, XPU, IdType, "COORemove", { ATEN_COO_SWITCH(coo, XPU, IdType, "COORemove", {
...@@ -689,6 +710,68 @@ std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo) { ...@@ -689,6 +710,68 @@ std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo) {
return ret; return ret;
} }
COOMatrix UnionCoo(const std::vector<COOMatrix>& coos) {
COOMatrix ret;
CHECK_GT(coos.size(), 1) << "UnionCoo creates a union of multiple COOMatrixes";
// sanity check
for (size_t i = 1; i < coos.size(); ++i) {
CHECK_EQ(coos[0].num_rows, coos[i].num_rows) <<
"UnionCoo requires both COOMatrix have same number of rows";
CHECK_EQ(coos[0].num_cols, coos[i].num_cols) <<
"UnionCoo requires both COOMatrix have same number of cols";
CHECK_SAME_CONTEXT(coos[0].row, coos[i].row);
CHECK_SAME_DTYPE(coos[0].row, coos[i].row);
}
// we assume the number of coos is not large in common cases
std::vector<IdArray> coo_row;
std::vector<IdArray> coo_col;
bool has_data = false;
for (size_t i = 0; i < coos.size(); ++i) {
coo_row.push_back(coos[i].row);
coo_col.push_back(coos[i].col);
has_data |= COOHasData(coos[i]);
}
IdArray row = Concat(coo_row);
IdArray col = Concat(coo_col);
IdArray data = NullArray();
if (has_data) {
std::vector<IdArray> eid_data;
eid_data.push_back(COOHasData(coos[0]) ?
coos[0].data :
Range(0,
coos[0].row->shape[0],
coos[0].row->dtype.bits,
coos[0].row->ctx));
int64_t num_edges = coos[0].row->shape[0];
for (size_t i = 1; i < coos.size(); ++i) {
eid_data.push_back(COOHasData(coos[i]) ?
coos[i].data + num_edges :
Range(num_edges,
num_edges + coos[i].row->shape[0],
coos[i].row->dtype.bits,
coos[i].row->ctx));
num_edges += coos[i].row->shape[0];
}
data = Concat(eid_data);
}
return COOMatrix(
coos[0].num_rows,
coos[0].num_cols,
row,
col,
data,
false,
false);
}
std::tuple<COOMatrix, IdArray, IdArray> std::tuple<COOMatrix, IdArray, IdArray>
COOToSimple(const COOMatrix& coo) { COOToSimple(const COOMatrix& coo) {
// coo column sorted // coo column sorted
......
...@@ -152,6 +152,10 @@ template <DLDeviceType XPU, typename IdType, typename DType> ...@@ -152,6 +152,10 @@ template <DLDeviceType XPU, typename IdType, typename DType>
COOMatrix CSRRowWiseTopk( COOMatrix CSRRowWiseTopk(
CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending); CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending);
// Union CSRMatrixes
template <DLDeviceType XPU, typename IdType>
CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs);
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(CSRMatrix csr); std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(CSRMatrix csr);
...@@ -224,6 +228,8 @@ template <DLDeviceType XPU, typename IdType, typename FloatType> ...@@ -224,6 +228,8 @@ template <DLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix COORowWiseTopk( COOMatrix COORowWiseTopk(
COOMatrix mat, IdArray rows, int64_t k, FloatArray weight, bool ascending); COOMatrix mat, IdArray rows, int64_t k, FloatArray weight, bool ascending);
///////////////////////// Graph Traverse routines //////////////////////////
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source); Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source);
......
/*!
* Copyright (c) 2020 by Contributors
* \file array/cpu/coo_sort.cc
* \brief COO sorting
*/
#include <dgl/array.h>
#include <numeric>
#include <algorithm>
#include <vector>
#include <iterator>
namespace dgl {
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs) {
std::vector<IdType> res_indptr;
std::vector<IdType> res_indices;
std::vector<IdType> res_data;
// some preprocess
// we assume the number of csrs is not large in common cases
std::vector<IdArray> data;
std::vector<IdType *> data_data;
std::vector<IdType *> indptr_data;
std::vector<IdType *> indices_data;
int64_t num_edges = 0;
bool sorted = true;
for (size_t i = 0; i < csrs.size(); ++i) {
// eids of csrs[0] remains unchanged
// eids of csrs[1] will be increased by number of edges of csrs[0], etc.
data.push_back(CSRHasData(csrs[i]) ?
csrs[i].data + num_edges:
Range(num_edges,
num_edges + csrs[i].indices->shape[0],
csrs[i].indptr->dtype.bits,
csrs[i].indptr->ctx));
data_data.push_back(data[i].Ptr<IdType>());
indptr_data.push_back(csrs[i].indptr.Ptr<IdType>());
indices_data.push_back(csrs[i].indices.Ptr<IdType>());
num_edges += csrs[i].indices->shape[0];
sorted &= csrs[i].sorted;
}
res_indptr.resize(csrs[0].num_rows + 1);
res_indices.resize(num_edges);
res_data.resize(num_edges);
res_indptr[0] = 0;
if (sorted) { // all csrs are sorted
#pragma omp for
for (int64_t i = 1; i <= csrs[0].num_rows; ++i) {
std::vector<int64_t> indices_off;
res_indptr[i] = indptr_data[0][i];
indices_off.push_back(indptr_data[0][i-1]);
for (size_t j = 1; j < csrs.size(); ++j) {
res_indptr[i] += indptr_data[j][i];
indices_off.push_back(indptr_data[j][i-1]);
}
IdType off = res_indptr[i-1];
while (off < res_indptr[i]) {
IdType min = csrs[0].num_cols + 1;
int64_t min_idx = -1;
for (size_t j = 0; j < csrs.size(); ++j) {
if (indices_off[j] < indptr_data[j][i]) {
if (min <= indices_data[j][indices_off[j]]) {
continue;
} else {
min = indices_data[j][indices_off[j]];
min_idx = j;
}
} // for check out of bound
} // for
res_indices[off] = min;
res_data[off] = data_data[min_idx][indices_off[min_idx]];
indices_off[min_idx] += 1;
++off;
} // while
} // omp for
} else { // some csrs are not sorted
#pragma omp for
for (int64_t i = 1; i <= csrs[0].num_rows; ++i) {
IdType off = res_indptr[i-1];
res_indptr[i] = 0;
for (size_t j = 0; j < csrs.size(); ++j) {
std::memcpy(&res_indices[off],
&indices_data[j][indptr_data[j][i-1]],
sizeof(IdType) * (indptr_data[j][i] - indptr_data[j][i-1]));
std::memcpy(&res_data[off],
&data_data[j][indptr_data[j][i-1]],
sizeof(IdType) * (indptr_data[j][i] - indptr_data[j][i-1]));
off += indptr_data[j][i] - indptr_data[j][i-1];
}
res_indptr[i] = off;
} // omp for
}
return CSRMatrix(
csrs[0].num_rows,
csrs[0].num_cols,
IdArray::FromVector(res_indptr),
IdArray::FromVector(res_indices),
IdArray::FromVector(res_data),
sorted);
}
template CSRMatrix UnionCsr<kDLCPU, int64_t>(const std::vector<CSRMatrix>&);
template CSRMatrix UnionCsr<kDLCPU, int32_t>(const std::vector<CSRMatrix>&);
} // namespace impl
} // namespace aten
} // namespace dgl
...@@ -764,6 +764,442 @@ TEST(DisjointUnionTest, TestDisjointUnionPartitionCsr) { ...@@ -764,6 +764,442 @@ TEST(DisjointUnionTest, TestDisjointUnionPartitionCsr) {
#endif #endif
} }
template <typename IdType>
void _TestMatrixUnionCsr(DLContext ctx) {
/*
* A = [[0, 0, 0, 0],
* [0, 0, 0, 0],
* [0, 1, 0, 0],
* [1, 1, 1, 1],
* [0, 1, 1, 0],
* [1, 0, 0, 1]]
*
* B = [[0, 0, 0, 0],
* [1, 0, 0, 1],
* [0, 0, 1, 0],
* [1, 0, 0, 1],
* [1, 0, 0, 1]]
* [1, 0, 0, 1]]
*
* C = UnionCsr({A, B})
*
* C = [[0, 0, 0, 0],
* [1, 0, 0, 1],
* [0, 1, 1, 0],
* [2, 1, 1, 2],
* [1, 1, 1, 1]]
* [2, 0, 0, 2]]
*
* D = [[1, 0, 0, 0],
* [0, 0, 0, 0],
* [0, 0, 0, 0],
* [0, 0, 0, 0],
* [0, 0, 0, 0],
* [1, 0, 0, 1]]
*
* C = UnionCsr({A, B, D})
*
* C = [[1, 0, 0, 0],
* [1, 0, 0, 1],
* [0, 1, 1, 0],
* [2, 1, 1, 2],
* [1, 1, 1, 1]]
* [3, 0, 0, 3]]
*/
IdArray a_indptr =
aten::VecToIdArray(std::vector<IdType>({0, 0, 0, 1, 5, 7, 9}),
sizeof(IdType)*8, CTX);
IdArray a_indices =
aten::VecToIdArray(std::vector<IdType>({1, 0, 1, 2, 3, 1, 2, 0, 3}),
sizeof(IdType)*8, CTX);
IdArray b_indptr =
aten::VecToIdArray(std::vector<IdType>({0, 0, 2, 3, 5, 7, 9}),
sizeof(IdType)*8, CTX);
IdArray b_indices =
aten::VecToIdArray(std::vector<IdType>({0, 3, 2, 0, 3, 0, 3, 0, 3}),
sizeof(IdType)*8, CTX);
IdArray c_indptr =
aten::VecToIdArray(std::vector<IdType>({0, 0, 2, 4, 10, 14, 18}),
sizeof(IdType)*8, CTX);
IdArray c_indices =
aten::VecToIdArray(std::vector<IdType>({0, 3, 1, 2, 0, 0, 1, 2, 3, 3, 0, 1, 2, 3, 0, 0, 3, 3}),
sizeof(IdType)*8, CTX);
IdArray c_data =
aten::VecToIdArray(std::vector<IdType>({9, 10, 0, 11, 1, 12, 2, 3, 4,
13, 14, 5, 6, 15, 7, 16, 8, 17}),
sizeof(IdType)*8, CTX);
const aten::CSRMatrix &csr_a = aten::CSRMatrix(
6,
4,
a_indptr,
a_indices,
aten::NullArray(),
true);
const aten::CSRMatrix &csr_b = aten::CSRMatrix(
6,
4,
b_indptr,
b_indices,
aten::NullArray(),
true);
const aten::CSRMatrix &csr_aUb = aten::UnionCsr({csr_a, csr_b});
ASSERT_EQ(csr_aUb.num_rows, 6);
ASSERT_EQ(csr_aUb.num_cols, 4);
ASSERT_TRUE(ArrayEQ<IdType>(csr_aUb.indptr, c_indptr));
ASSERT_TRUE(ArrayEQ<IdType>(csr_aUb.indices, c_indices));
ASSERT_TRUE(ArrayEQ<IdType>(csr_aUb.data, c_data));
ASSERT_TRUE(csr_aUb.sorted);
IdArray a_data =
aten::VecToIdArray(std::vector<IdType>({8, 7, 6, 5, 4, 3, 2, 1, 0}),
sizeof(IdType)*8, CTX);
c_data =
aten::VecToIdArray(std::vector<IdType>({9, 10, 8, 11, 7, 12, 6, 5, 4,
13, 14, 3, 2, 15, 1, 16, 0, 17}),
sizeof(IdType)*8, CTX);
const aten::CSRMatrix &csr_ad = aten::CSRMatrix(
6,
4,
a_indptr,
a_indices,
a_data,
true);
const aten::CSRMatrix &csr_adUb = aten::UnionCsr({csr_ad, csr_b});
ASSERT_EQ(csr_adUb.num_rows, 6);
ASSERT_EQ(csr_adUb.num_cols, 4);
ASSERT_TRUE(ArrayEQ<IdType>(csr_adUb.indptr, c_indptr));
ASSERT_TRUE(ArrayEQ<IdType>(csr_adUb.indices, c_indices));
ASSERT_TRUE(ArrayEQ<IdType>(csr_adUb.data, c_data));
ASSERT_TRUE(csr_adUb.sorted);
IdArray b_indices2 =
aten::VecToIdArray(std::vector<IdType>({0, 3, 2, 0, 3, 3, 0, 0, 3}),
sizeof(IdType)*8, CTX);
c_indices =
aten::VecToIdArray(std::vector<IdType>({0, 3, 1, 2, 0, 1, 2, 3, 0, 3, 1, 2, 3, 0, 0, 3, 0, 3}),
sizeof(IdType)*8, CTX);
c_data =
aten::VecToIdArray(std::vector<IdType>({9, 10, 0, 11, 1, 2, 3, 4, 12,
13, 5, 6, 14, 15, 7, 8, 16, 17}),
sizeof(IdType)*8, CTX);
const aten::CSRMatrix &csr_b2 = aten::CSRMatrix(
6,
4,
b_indptr,
b_indices2,
aten::NullArray(),
false);
const aten::CSRMatrix &csr_aUb2 = aten::UnionCsr({csr_a, csr_b2});
ASSERT_EQ(csr_aUb2.num_rows, 6);
ASSERT_EQ(csr_aUb2.num_cols, 4);
ASSERT_TRUE(ArrayEQ<IdType>(csr_aUb2.indptr, c_indptr));
ASSERT_TRUE(ArrayEQ<IdType>(csr_aUb2.indices, c_indices));
ASSERT_TRUE(ArrayEQ<IdType>(csr_aUb2.data, c_data));
ASSERT_FALSE(csr_aUb2.sorted);
IdArray a_indices2 =
aten::VecToIdArray(std::vector<IdType>({1, 3, 2, 1, 0, 1, 2, 0, 3}),
sizeof(IdType)*8, CTX);
c_indices =
aten::VecToIdArray(std::vector<IdType>({0, 3, 1, 2, 3, 2, 1, 0, 0, 3, 1, 2, 0, 3, 0, 3, 0, 3}),
sizeof(IdType)*8, CTX);
const aten::CSRMatrix &csr_a2 = aten::CSRMatrix(
6,
4,
a_indptr,
a_indices2,
aten::NullArray(),
false);
const aten::CSRMatrix &csr_aUb3 = aten::UnionCsr({csr_a2, csr_b});
ASSERT_EQ(csr_aUb3.num_rows, 6);
ASSERT_EQ(csr_aUb3.num_cols, 4);
ASSERT_TRUE(ArrayEQ<IdType>(csr_aUb3.indptr, c_indptr));
ASSERT_TRUE(ArrayEQ<IdType>(csr_aUb3.indices, c_indices));
ASSERT_TRUE(ArrayEQ<IdType>(csr_aUb3.data, c_data));
ASSERT_FALSE(csr_aUb3.sorted);
c_indices =
aten::VecToIdArray(std::vector<IdType>({0, 3, 1, 2, 3, 2, 1, 0, 0, 3, 1, 2, 3, 0, 0, 3, 0, 3}),
sizeof(IdType)*8, CTX);
const aten::CSRMatrix &csr_aUb4 = aten::UnionCsr({csr_a2, csr_b2});
ASSERT_EQ(csr_aUb4.num_rows, 6);
ASSERT_EQ(csr_aUb4.num_cols, 4);
ASSERT_TRUE(ArrayEQ<IdType>(csr_aUb4.indptr, c_indptr));
ASSERT_TRUE(ArrayEQ<IdType>(csr_aUb4.indices, c_indices));
ASSERT_TRUE(ArrayEQ<IdType>(csr_aUb4.data, c_data));
ASSERT_FALSE(csr_aUb4.sorted);
IdArray d_indptr =
aten::VecToIdArray(std::vector<IdType>({0, 1, 1, 1, 1, 1, 3}),
sizeof(IdType)*8, CTX);
IdArray d_indices =
aten::VecToIdArray(std::vector<IdType>({0, 0, 3}),
sizeof(IdType)*8, CTX);
c_indptr =
aten::VecToIdArray(std::vector<IdType>({0, 1, 3, 5, 11, 15, 21}),
sizeof(IdType)*8, CTX);
c_indices =
aten::VecToIdArray(std::vector<IdType>({0, 0, 3, 1, 2, 0, 0, 1, 2, 3,
3, 0, 1, 2, 3, 0, 0, 0, 3, 3, 3}),
sizeof(IdType)*8, CTX);
c_data =
aten::VecToIdArray(std::vector<IdType>({18, 9, 10, 8, 11, 7, 12, 6, 5, 4,
13, 14, 3, 2, 15, 1, 16, 19, 0, 17, 20}),
sizeof(IdType)*8, CTX);
const aten::CSRMatrix &csr_d = aten::CSRMatrix(
6,
4,
d_indptr,
d_indices,
aten::NullArray(),
true);
const aten::CSRMatrix &csr_aUbUd = aten::UnionCsr({csr_ad, csr_b, csr_d});
ASSERT_EQ(csr_aUbUd.num_rows, 6);
ASSERT_EQ(csr_aUbUd.num_cols, 4);
ASSERT_TRUE(ArrayEQ<IdType>(csr_aUbUd.indptr, c_indptr));
ASSERT_TRUE(ArrayEQ<IdType>(csr_aUbUd.indices, c_indices));
ASSERT_TRUE(ArrayEQ<IdType>(csr_aUbUd.data, c_data));
ASSERT_TRUE(csr_aUbUd.sorted);
c_indices =
aten::VecToIdArray(std::vector<IdType>({0, 0, 3, 1, 2, 3, 2, 1, 0, 0,
3, 1, 2, 3, 0, 0, 3, 0, 3, 0, 3}),
sizeof(IdType)*8, CTX);
c_data =
aten::VecToIdArray(std::vector<IdType>({18, 9, 10, 0, 11, 1, 2, 3, 4, 12,
13, 5, 6, 14, 15, 7, 8, 16, 17, 19, 20}),
sizeof(IdType)*8, CTX);
const aten::CSRMatrix &csr_aUbUd2 = aten::UnionCsr({csr_a2, csr_b2, csr_d});
ASSERT_EQ(csr_aUbUd2.num_rows, 6);
ASSERT_EQ(csr_aUbUd2.num_cols, 4);
ASSERT_TRUE(ArrayEQ<IdType>(csr_aUbUd2.indptr, c_indptr));
ASSERT_TRUE(ArrayEQ<IdType>(csr_aUbUd2.indices, c_indices));
ASSERT_TRUE(ArrayEQ<IdType>(csr_aUbUd2.data, c_data));
ASSERT_FALSE(csr_aUbUd2.sorted);
}
TEST(MatrixUnionTest, TestMatrixUnionCsr) {
_TestMatrixUnionCsr<int32_t>(CPU);
_TestMatrixUnionCsr<int64_t>(CPU);
}
template <typename IdType>
void _TestMatrixUnionCoo(DLContext ctx) {
/*
* A = [[0, 0, 0, 0],
* [0, 0, 0, 0],
* [0, 1, 0, 0],
* [1, 1, 1, 1],
* [0, 1, 1, 0],
* [1, 0, 0, 1]]
*
* B = [[0, 0, 0, 0],
* [1, 0, 0, 1],
* [0, 0, 1, 0],
* [1, 0, 0, 1],
* [1, 0, 0, 1]]
* [1, 0, 0, 1]]
*
* C = UnionCsr({A, B})
*
* C = [[0, 0, 0, 0],
* [1, 0, 0, 1],
* [0, 1, 1, 0],
* [2, 1, 1, 2],
* [1, 1, 1, 1]]
* [2, 0, 0, 2]]
*
* D = [[1, 0, 0, 0],
* [0, 0, 0, 0],
* [0, 0, 0, 0],
* [0, 0, 0, 0],
* [0, 0, 0, 0],
* [1, 0, 0, 1]]
*
* C = UnionCsr({A, B, D})
*
* C = [[1, 0, 0, 0],
* [1, 0, 0, 1],
* [0, 1, 1, 0],
* [2, 1, 1, 2],
* [1, 1, 1, 1]]
* [3, 0, 0, 3]]
*/
IdArray a_row =
aten::VecToIdArray(std::vector<IdType>({2, 3, 3, 3, 3, 4, 4, 5, 5}),
sizeof(IdType)*8, CTX);
IdArray a_col =
aten::VecToIdArray(std::vector<IdType>({1, 0, 1, 2, 3, 1, 2, 0, 3}),
sizeof(IdType)*8, CTX);
IdArray b_row =
aten::VecToIdArray(std::vector<IdType>({1, 1, 2, 3, 3, 4, 4, 5, 5}),
sizeof(IdType)*8, CTX);
IdArray b_col =
aten::VecToIdArray(std::vector<IdType>({0, 3, 2, 0, 3, 0, 3, 0, 3}),
sizeof(IdType)*8, CTX);
IdArray c_row =
aten::VecToIdArray(std::vector<IdType>({2, 3, 3, 3, 3, 4, 4, 5, 5,
1, 1, 2, 3, 3, 4, 4, 5, 5}),
sizeof(IdType)*8, CTX);
IdArray c_col =
aten::VecToIdArray(std::vector<IdType>({1, 0, 1, 2, 3, 1, 2, 0, 3,
0, 3, 2, 0, 3, 0, 3, 0, 3}),
sizeof(IdType)*8, CTX);
const aten::COOMatrix &coo_a = aten::COOMatrix(
6,
4,
a_row,
a_col,
aten::NullArray(),
true,
true);
const aten::COOMatrix &coo_b = aten::COOMatrix(
6,
4,
b_row,
b_col,
aten::NullArray(),
true,
true);
const std::vector<aten::COOMatrix> coos_ab({coo_a, coo_b});
const aten::COOMatrix &coo_ab = aten::UnionCoo(coos_ab);
ASSERT_EQ(coo_ab.num_rows, 6);
ASSERT_EQ(coo_ab.num_cols, 4);
ASSERT_TRUE(ArrayEQ<IdType>(coo_ab.row, c_row));
ASSERT_TRUE(ArrayEQ<IdType>(coo_ab.col, c_col));
ASSERT_FALSE(COOHasData(coo_ab));
ASSERT_FALSE(coo_ab.row_sorted);
ASSERT_FALSE(coo_ab.col_sorted);
IdArray a_data =
aten::VecToIdArray(std::vector<IdType>({2, 1, 0, 3, 4, 5, 6, 7, 8}),
sizeof(IdType)*8, CTX);
IdArray c_data =
aten::VecToIdArray(std::vector<IdType>({2, 1, 0, 3, 4, 5, 6, 7, 8,
9 ,10, 11, 12, 13, 14, 15, 16, 17}),
sizeof(IdType)*8, CTX);
const aten::COOMatrix &coo_a2 = aten::COOMatrix(
6,
4,
a_row,
a_col,
a_data,
true,
true);
const std::vector<aten::COOMatrix> coos_ab2({coo_a2, coo_b});
const aten::COOMatrix &coo_ab2 = aten::UnionCoo(coos_ab2);
ASSERT_EQ(coo_ab2.num_rows, 6);
ASSERT_EQ(coo_ab2.num_cols, 4);
ASSERT_TRUE(ArrayEQ<IdType>(coo_ab2.row, c_row));
ASSERT_TRUE(ArrayEQ<IdType>(coo_ab2.col, c_col));
ASSERT_TRUE(COOHasData(coo_ab2));
ASSERT_TRUE(ArrayEQ<IdType>(coo_ab2.data, c_data));
ASSERT_FALSE(coo_ab2.row_sorted);
ASSERT_FALSE(coo_ab2.col_sorted);
IdArray b_data =
aten::VecToIdArray(std::vector<IdType>({0, 1, 2, 3, 4, 5, 6, 8, 7}),
sizeof(IdType)*8, CTX);
c_data =
aten::VecToIdArray(std::vector<IdType>({2, 1, 0, 3, 4, 5, 6, 7, 8,
9 ,10, 11, 12, 13, 14, 15, 17, 16}),
sizeof(IdType)*8, CTX);
const aten::COOMatrix &coo_b2 = aten::COOMatrix(
6,
4,
b_row,
b_col,
b_data,
true,
true);
const std::vector<aten::COOMatrix> coos_ab3({coo_a2, coo_b2});
const aten::COOMatrix &coo_ab3 = aten::UnionCoo(coos_ab3);
ASSERT_EQ(coo_ab3.num_rows, 6);
ASSERT_EQ(coo_ab3.num_cols, 4);
ASSERT_TRUE(ArrayEQ<IdType>(coo_ab3.row, c_row));
ASSERT_TRUE(ArrayEQ<IdType>(coo_ab3.col, c_col));
ASSERT_TRUE(COOHasData(coo_ab3));
ASSERT_TRUE(ArrayEQ<IdType>(coo_ab3.data, c_data));
ASSERT_FALSE(coo_ab3.row_sorted);
ASSERT_FALSE(coo_ab3.col_sorted);
c_data =
aten::VecToIdArray(std::vector<IdType>({2, 1, 0, 3, 4, 5, 6, 7, 8,
9 ,10, 11, 12, 13, 14, 15, 17, 16}),
sizeof(IdType)*8, CTX);
const std::vector<aten::COOMatrix> coos_ab4({coo_a2, coo_b2});
const aten::COOMatrix &coo_ab4 = aten::UnionCoo(coos_ab4);
ASSERT_EQ(coo_ab4.num_rows, 6);
ASSERT_EQ(coo_ab4.num_cols, 4);
ASSERT_TRUE(ArrayEQ<IdType>(coo_ab4.row, c_row));
ASSERT_TRUE(ArrayEQ<IdType>(coo_ab4.col, c_col));
ASSERT_TRUE(COOHasData(coo_ab4));
ASSERT_TRUE(ArrayEQ<IdType>(coo_ab4.data, c_data));
ASSERT_FALSE(coo_ab4.row_sorted);
ASSERT_FALSE(coo_ab4.col_sorted);
IdArray d_row =
aten::VecToIdArray(std::vector<IdType>({0, 5, 5}),
sizeof(IdType)*8, CTX);
IdArray d_col =
aten::VecToIdArray(std::vector<IdType>({0, 0, 3}),
sizeof(IdType)*8, CTX);
c_row =
aten::VecToIdArray(std::vector<IdType>({2, 3, 3, 3, 3, 4, 4, 5, 5,
1, 1, 2, 3, 3, 4, 4, 5, 5,
0, 5, 5}),
sizeof(IdType)*8, CTX);
c_col =
aten::VecToIdArray(std::vector<IdType>({1, 0, 1, 2, 3, 1, 2, 0, 3,
0, 3, 2, 0, 3, 0, 3, 0, 3,
0, 0, 3}),
sizeof(IdType)*8, CTX);
const aten::COOMatrix &coo_d = aten::COOMatrix(
6,
4,
d_row,
d_col,
aten::NullArray(),
true,
true);
const aten::COOMatrix &csr_aUbUd = aten::UnionCoo({coo_a, coo_b, coo_d});
ASSERT_EQ(csr_aUbUd.num_rows, 6);
ASSERT_EQ(csr_aUbUd.num_cols, 4);
ASSERT_TRUE(ArrayEQ<IdType>(csr_aUbUd.row, c_row));
ASSERT_TRUE(ArrayEQ<IdType>(csr_aUbUd.col, c_col));
ASSERT_FALSE(COOHasData(csr_aUbUd));
ASSERT_FALSE(csr_aUbUd.row_sorted);
ASSERT_FALSE(csr_aUbUd.col_sorted);
c_data =
aten::VecToIdArray(std::vector<IdType>({2, 1, 0, 3, 4, 5, 6, 7, 8,
9 ,10, 11, 12, 13, 14, 15, 17, 16,
18, 19, 20}),
sizeof(IdType)*8, CTX);
const aten::COOMatrix &csr_aUbUd2 = aten::UnionCoo({coo_a2, coo_b2, coo_d});
ASSERT_EQ(csr_aUbUd2.num_rows, 6);
ASSERT_EQ(csr_aUbUd2.num_cols, 4);
ASSERT_TRUE(ArrayEQ<IdType>(csr_aUbUd2.row, c_row));
ASSERT_TRUE(ArrayEQ<IdType>(csr_aUbUd2.col, c_col));
ASSERT_TRUE(COOHasData(csr_aUbUd2));
ASSERT_TRUE(ArrayEQ<IdType>(csr_aUbUd2.data, c_data));
ASSERT_FALSE(csr_aUbUd2.row_sorted);
ASSERT_FALSE(csr_aUbUd2.col_sorted);
}
TEST(MatrixUnionTest, TestMatrixUnionCoo) {
_TestMatrixUnionCoo<int32_t>(CPU);
_TestMatrixUnionCoo<int64_t>(CPU);
}
template <typename IDX> template <typename IDX>
void _TestCumSum(DLContext ctx) { void _TestCumSum(DLContext ctx) {
IdArray a = aten::VecToIdArray(std::vector<IDX>({8, 6, 7, 5, 3, 0, 9}), IdArray a = aten::VecToIdArray(std::vector<IDX>({8, 6, 7, 5, 3, 0, 9}),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment