#include #include #include #include "../../src/array/cpu/array_utils.h" // PairHash #include "./common.h" using namespace dgl; using namespace dgl::runtime; namespace { // Unit tests: // CSRMM(A, B) == A_mm_B // CSRSum({A, C}) == A_plus_C // CSRMask(A, C) = A_mask_C template std::unordered_map, DType, aten::PairHash> COOToMap( aten::COOMatrix coo, NDArray weights) { std::unordered_map, DType, aten::PairHash> map; for (int64_t i = 0; i < coo.row->shape[0]; ++i) { IdType irow = aten::IndexSelect(coo.row, i); IdType icol = aten::IndexSelect(coo.col, i); IdType ieid = aten::COOHasData(coo) ? aten::IndexSelect(coo.data, i) : i; DType idata = aten::IndexSelect(weights, ieid); map.insert({{irow, icol}, idata}); } return map; } template bool CSRIsClose( aten::CSRMatrix A, aten::CSRMatrix B, NDArray A_weights, NDArray B_weights, DType rtol, DType atol) { auto Amap = COOToMap(CSRToCOO(A, false), A_weights); auto Bmap = COOToMap(CSRToCOO(B, false), B_weights); if (Amap.size() != Bmap.size()) return false; for (auto itA : Amap) { auto itB = Bmap.find(itA.first); if (itB == Bmap.end()) return false; if (fabs(itA.second - itB->second) >= rtol * fabs(itA.second) + atol) return false; } return true; } template std::pair CSR_A(DLContext ctx = CTX) { // matrix([[0. , 0. , 1. , 0.7, 0. ], // [0. , 0. , 0.5, 0.+, 0. ], // [0.4, 0.7, 0. , 0.2, 0. ], // [0. , 0. , 0. , 0. , 0.2]]) // (0.+ indicates that the entry exists but the value is 0.) auto csr = aten::CSRMatrix( 4, 5, NDArray::FromVector(std::vector({0, 2, 4, 7, 8}), ctx), NDArray::FromVector(std::vector({2, 3, 2, 3, 0, 1, 3, 4}), ctx), NDArray::FromVector(std::vector({1, 0, 2, 3, 4, 5, 6, 7}), ctx)); auto weights = NDArray::FromVector( std::vector({0.7, 1.0, 0.5, 0.0, 0.4, 0.7, 0.2, 0.2}), ctx); return {csr, weights}; } template std::pair CSR_B(DLContext ctx = CTX) { // matrix([[0. , 0.9, 0. , 0.6, 0. , 0.3], // [0. , 0. , 0. , 0. , 0. , 0.4], // [0.+, 0. , 0. , 0. , 0. , 0.9], // [0.8, 0.2, 0.3, 0.2, 0. , 0. ], // [0.2, 0.4, 0. , 0. , 0. , 0. ]]) // (0.+ indicates that the entry exists but the value is 0.) auto csr = aten::CSRMatrix( 5, 6, NDArray::FromVector(std::vector({0, 3, 4, 6, 10, 12}), ctx), NDArray::FromVector(std::vector({1, 3, 5, 5, 0, 5, 0, 1, 2, 3, 0, 1}), ctx)); auto weights = NDArray::FromVector( std::vector({0.9, 0.6, 0.3, 0.4, 0.0, 0.9, 0.8, 0.2, 0.3, 0.2, 0.2, 0.4}), ctx); return {csr, weights}; } template std::pair CSR_C(DLContext ctx = CTX) { // matrix([[0. , 0. , 0. , 0.2, 0. ], // [0. , 0. , 0. , 0.5, 0.4], // [0. , 0.2, 0. , 0.9, 0.2], // [0. , 1. , 0. , 0.7, 0. ]]) auto csr = aten::CSRMatrix( 4, 5, NDArray::FromVector(std::vector({0, 1, 3, 6, 8}), ctx), NDArray::FromVector(std::vector({3, 3, 4, 1, 3, 4, 1, 3}), ctx)); auto weights = NDArray::FromVector( std::vector({0.2, 0.5, 0.4, 0.2, 0.9, 0.2, 1. , 0.7}), ctx); return {csr, weights}; } template std::pair CSR_A_mm_B(DLContext ctx = CTX) { // matrix([[0.56, 0.14, 0.21, 0.14, 0. , 0.9 ], // [0.+ , 0.+ , 0.+ , 0.+ , 0. , 0.45], // [0.16, 0.4 , 0.06, 0.28, 0. , 0.4 ], // [0.04, 0.08, 0. , 0. , 0. , 0. ]]) // (0.+ indicates that the entry exists but the value is 0.) auto csr = aten::CSRMatrix( 4, 6, NDArray::FromVector(std::vector({0, 5, 10, 15, 17}), ctx), NDArray::FromVector(std::vector( {0, 1, 2, 3, 5, 0, 1, 2, 3, 5, 0, 1, 2, 3, 5, 0, 1}), ctx)); auto weights = NDArray::FromVector( std::vector({ 0.56, 0.14, 0.21, 0.14, 0.9 , 0. , 0. , 0. , 0. , 0.45, 0.16, 0.4 , 0.06, 0.28, 0.4 , 0.04, 0.08}), ctx); return {csr, weights}; } template std::pair CSR_A_plus_C(DLContext ctx = CTX) { auto csr = aten::CSRMatrix( 4, 5, NDArray::FromVector(std::vector({0, 2, 5, 9, 12}), ctx), NDArray::FromVector(std::vector({2, 3, 2, 3, 4, 0, 1, 3, 4, 1, 3, 4}), ctx)); auto weights = NDArray::FromVector( std::vector({1. , 0.9, 0.5, 0.5, 0.4, 0.4, 0.9, 1.1, 0.2, 1. , 0.7, 0.2}), ctx); return {csr, weights}; } template NDArray CSR_A_mask_C(DLContext ctx = CTX) { return NDArray::FromVector(std::vector({0.7, 0.0, 0.0, 0.7, 0.2, 0.0, 0.0, 0.0}), ctx); } template void _TestCsrmm(DLContext ctx = CTX) { auto A = CSR_A(ctx); auto B = CSR_B(ctx); auto A_mm_B = aten::CSRMM(A.first, A.second, B.first, B.second); auto A_mm_B2 = CSR_A_mm_B(ctx); bool result = CSRIsClose(A_mm_B.first, A_mm_B2.first, A_mm_B.second, A_mm_B2.second, 1e-4, 1e-4); ASSERT_TRUE(result); } template void _TestCsrsum(DLContext ctx = CTX) { auto A = CSR_A(ctx); auto C = CSR_C(ctx); auto A_plus_C = aten::CSRSum({A.first, C.first}, {A.second, C.second}); auto A_plus_C2 = CSR_A_plus_C(ctx); bool result = CSRIsClose( A_plus_C.first, A_plus_C2.first, A_plus_C.second, A_plus_C2.second, 1e-4, 1e-4); ASSERT_TRUE(result); } template void _TestCsrmask(DLContext ctx = CTX) { auto A = CSR_A(ctx); auto C = CSR_C(ctx); auto C_coo = CSRToCOO(C.first, false); auto A_mask_C = aten::CSRGetData(A.first, C_coo.row, C_coo.col, A.second, 0); auto A_mask_C2 = CSR_A_mask_C(ctx); ASSERT_TRUE(ArrayEQ(A_mask_C, A_mask_C2)); } TEST(CsrmmTest, TestCsrmm) { _TestCsrmm(CPU); _TestCsrmm(CPU); _TestCsrmm(CPU); _TestCsrmm(CPU); #ifdef DGL_USE_CUDA _TestCsrmm(GPU); _TestCsrmm(GPU); _TestCsrmm(GPU); _TestCsrmm(GPU); #endif } TEST(CsrmmTest, TestCsrsum) { _TestCsrsum(CPU); _TestCsrsum(CPU); _TestCsrsum(CPU); _TestCsrsum(CPU); #ifdef DGL_USE_CUDA _TestCsrsum(GPU); _TestCsrsum(GPU); _TestCsrsum(GPU); _TestCsrsum(GPU); #endif } TEST(CsrmmTest, TestCsrmask) { _TestCsrmask(CPU); _TestCsrmask(CPU); _TestCsrmask(CPU); _TestCsrmask(CPU); #ifdef DGL_USE_CUDA _TestCsrmask(GPU); _TestCsrmask(GPU); _TestCsrmask(GPU); _TestCsrmask(GPU); #endif } }; // namespace