#include #include #include #include #include "./common.h" using namespace dgl; using namespace dgl::runtime; using namespace dgl::aten; template using ETuple = std::tuple; template std::set> AllEdgeSet(bool has_data) { if (has_data) { std::set> eset; eset.insert(ETuple{0, 0, 2}); eset.insert(ETuple{0, 1, 3}); eset.insert(ETuple{1, 1, 0}); eset.insert(ETuple{3, 2, 1}); eset.insert(ETuple{3, 3, 4}); return eset; } else { std::set> eset; eset.insert(ETuple{0, 0, 0}); eset.insert(ETuple{0, 1, 1}); eset.insert(ETuple{1, 1, 2}); eset.insert(ETuple{3, 2, 3}); eset.insert(ETuple{3, 3, 4}); return eset; } } template std::set> AllEdgePerEtypeSet(bool has_data) { if (has_data) { std::set> eset; eset.insert(ETuple{0, 0, 2}); eset.insert(ETuple{0, 1, 3}); eset.insert(ETuple{0, 2, 5}); eset.insert(ETuple{0, 3, 6}); eset.insert(ETuple{3, 2, 1}); eset.insert(ETuple{3, 3, 4}); return eset; } else { std::set> eset; eset.insert(ETuple{0, 0, 0}); eset.insert(ETuple{0, 1, 1}); eset.insert(ETuple{0, 2, 2}); eset.insert(ETuple{0, 3, 3}); eset.insert(ETuple{3, 2, 5}); eset.insert(ETuple{3, 3, 6}); return eset; } } template std::set> ToEdgeSet(COOMatrix mat) { std::set> eset; Idx* row = static_cast(mat.row->data); Idx* col = static_cast(mat.col->data); Idx* data = static_cast(mat.data->data); for (int64_t i = 0; i < mat.row->shape[0]; ++i) { //std::cout << row[i] << " " << col[i] << " " << data[i] << std::endl; eset.emplace(row[i], col[i], data[i]); } return eset; } template void CheckSampledResult(COOMatrix mat, IdArray rows, bool has_data) { ASSERT_EQ(mat.num_rows, 4); ASSERT_EQ(mat.num_cols, 4); Idx* row = static_cast(mat.row->data); Idx* col = static_cast(mat.col->data); Idx* data = static_cast(mat.data->data); const auto& gt = AllEdgeSet(has_data); for (int64_t i = 0; i < mat.row->shape[0]; ++i) { ASSERT_TRUE(gt.count(std::make_tuple(row[i], col[i], data[i]))); ASSERT_TRUE(IsInArray(rows, row[i])); } } template void CheckSampledPerEtypeReplaceResult(COOMatrix mat, IdArray rows, bool has_data) { ASSERT_EQ(mat.num_rows, 4); ASSERT_EQ(mat.num_cols, 4); Idx* row = static_cast(mat.row->data); Idx* col = static_cast(mat.col->data); Idx* data = static_cast(mat.data->data); const auto& gt = AllEdgePerEtypeSet(has_data); for (int64_t i = 0; i < mat.row->shape[0]; ++i) { ASSERT_TRUE(gt.count(std::make_tuple(row[i], col[i], data[i]))); ASSERT_TRUE(IsInArray(rows, row[i])); } } template void CheckSampledPerEtypeResult(COOMatrix mat, IdArray rows, bool has_data) { ASSERT_EQ(mat.num_rows, 4); ASSERT_EQ(mat.num_cols, 4); Idx* row = static_cast(mat.row->data); Idx* col = static_cast(mat.col->data); Idx* data = static_cast(mat.data->data); const auto& gt = AllEdgePerEtypeSet(has_data); int cnt_0 = 0; int cnt_3 = 0; for (int64_t i = 0; i < mat.row->shape[0]; ++i) { ASSERT_TRUE(gt.count(std::make_tuple(row[i], col[i], data[i]))); ASSERT_TRUE(IsInArray(rows, row[i])); if (row[i] == 0) cnt_0 += 1; if (row[i] == 3) cnt_3 += 1; } ASSERT_EQ(cnt_0, 3); ASSERT_EQ(cnt_3, 2); } template CSRMatrix CSR(bool has_data) { IdArray indptr = NDArray::FromVector(std::vector({0, 2, 3, 3, 5})); IdArray indices = NDArray::FromVector(std::vector({0, 1, 1, 2, 3})); IdArray data = NDArray::FromVector(std::vector({2, 3, 0, 1, 4})); if (has_data) return CSRMatrix(4, 4, indptr, indices, data); else return CSRMatrix(4, 4, indptr, indices); } template COOMatrix COO(bool has_data) { IdArray row = NDArray::FromVector(std::vector({0, 0, 1, 3, 3})); IdArray col = NDArray::FromVector(std::vector({0, 1, 1, 2, 3})); IdArray data = NDArray::FromVector(std::vector({2, 3, 0, 1, 4})); if (has_data) return COOMatrix(4, 4, row, col, data); else return COOMatrix(4, 4, row, col); } template CSRMatrix CSREtypes(bool has_data) { IdArray indptr = NDArray::FromVector(std::vector({0, 4, 5, 5, 7})); IdArray indices = NDArray::FromVector(std::vector({0, 1, 2, 3, 1, 2, 3})); IdArray data = NDArray::FromVector(std::vector({2, 3, 5, 6, 0, 1, 4})); if (has_data) return CSRMatrix(4, 4, indptr, indices, data); else return CSRMatrix(4, 4, indptr, indices); } template COOMatrix COOEtypes(bool has_data) { IdArray row = NDArray::FromVector(std::vector({0, 0, 0, 0, 1, 3, 3})); IdArray col = NDArray::FromVector(std::vector({0, 1, 2, 3, 1, 2, 3})); IdArray data = NDArray::FromVector(std::vector({2, 3, 5, 6, 0, 1, 4})); if (has_data) return COOMatrix(4, 4, row, col, data); else return COOMatrix(4, 4, row, col); } template void _TestCSRSampling(bool has_data) { auto mat = CSR(has_data); FloatArray prob = NDArray::FromVector( std::vector({.5, .5, .5, .5, .5})); IdArray rows = NDArray::FromVector(std::vector({0, 3})); for (int k = 0; k < 10; ++k) { auto rst = CSRRowWiseSampling(mat, rows, 2, prob, true); CheckSampledResult(rst, rows, has_data); } for (int k = 0; k < 10; ++k) { auto rst = CSRRowWiseSampling(mat, rows, 2, prob, false); CheckSampledResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); ASSERT_EQ(eset.size(), 4); if (has_data) { ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 2))); ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3))); ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 1))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); } else { ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 0))); ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1))); ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 3))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); } } prob = NDArray::FromVector( std::vector({.0, .5, .5, .0, .5})); for (int k = 0; k < 100; ++k) { auto rst = CSRRowWiseSampling(mat, rows, 2, prob, true); CheckSampledResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); if (has_data) { ASSERT_FALSE(eset.count(std::make_tuple(0, 1, 3))); } else { ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 0))); ASSERT_FALSE(eset.count(std::make_tuple(3, 2, 3))); } } } TEST(RowwiseTest, TestCSRSampling) { _TestCSRSampling(true); _TestCSRSampling(true); _TestCSRSampling(true); _TestCSRSampling(true); _TestCSRSampling(false); _TestCSRSampling(false); _TestCSRSampling(false); _TestCSRSampling(false); } template void _TestCSRSamplingUniform(bool has_data) { auto mat = CSR(has_data); FloatArray prob = aten::NullArray(); IdArray rows = NDArray::FromVector(std::vector({0, 3})); for (int k = 0; k < 10; ++k) { auto rst = CSRRowWiseSampling(mat, rows, 2, prob, true); CheckSampledResult(rst, rows, has_data); } for (int k = 0; k < 10; ++k) { auto rst = CSRRowWiseSampling(mat, rows, 2, prob, false); CheckSampledResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); if (has_data) { ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 2))); ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3))); ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 1))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); } else { ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 0))); ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1))); ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 3))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); } } } TEST(RowwiseTest, TestCSRSamplingUniform) { _TestCSRSamplingUniform(true); _TestCSRSamplingUniform(true); _TestCSRSamplingUniform(true); _TestCSRSamplingUniform(true); _TestCSRSamplingUniform(false); _TestCSRSamplingUniform(false); _TestCSRSamplingUniform(false); _TestCSRSamplingUniform(false); } template void _TestCSRPerEtypeSampling(bool has_data) { auto mat = CSREtypes(has_data); FloatArray prob = NDArray::FromVector( std::vector({.5, .5, .5, .5, .5, .5, .5})); IdArray rows = NDArray::FromVector(std::vector({0, 3})); IdArray etypes = has_data ? NDArray::FromVector(std::vector({3, 1, 3, 3, 2, 3, 0})) : NDArray::FromVector(std::vector({3, 3, 3, 0, 3, 1, 2})); for (int k = 0; k < 10; ++k) { auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true); CheckSampledPerEtypeReplaceResult(rst, rows, has_data); } for (int k = 0; k < 10; ++k) { auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, false); CheckSampledPerEtypeResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); if (has_data) { int counts = 0; counts += eset.count(std::make_tuple(0, 0, 2)); counts += eset.count(std::make_tuple(0, 1, 3)); counts += eset.count(std::make_tuple(0, 2, 5)); ASSERT_EQ(counts, 2); counts = 0; counts += eset.count(std::make_tuple(0, 3, 6)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(1, 1, 0)); ASSERT_EQ(counts, 0); counts = 0; counts += eset.count(std::make_tuple(3, 2, 1)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(3, 3, 4)); ASSERT_EQ(counts, 1); } else { int counts = 0; counts += eset.count(std::make_tuple(0, 0, 0)); counts += eset.count(std::make_tuple(0, 1, 1)); counts += eset.count(std::make_tuple(0, 2, 2)); ASSERT_EQ(counts, 2); counts = 0; counts += eset.count(std::make_tuple(0, 3, 3)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(1, 1, 4)); ASSERT_EQ(counts, 0); counts = 0; counts += eset.count(std::make_tuple(3, 2, 5)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(3, 3, 6)); ASSERT_EQ(counts, 1); } } prob = has_data ? NDArray::FromVector( std::vector({.0, .5, .0, .5, .5, .0, .5})) : NDArray::FromVector( std::vector({.0, .5, .0, .5, .0, .5, .5})); for (int k = 0; k < 10; ++k) { auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true); CheckSampledPerEtypeReplaceResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); if (has_data) { ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 2))); ASSERT_FALSE(eset.count(std::make_tuple(0, 2, 5))); } else { ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 0))); ASSERT_FALSE(eset.count(std::make_tuple(0, 2, 2))); } } } template void _TestCSRPerEtypeSamplingSorted(bool has_data, bool etype_sorted) { auto mat = CSREtypes(has_data); FloatArray prob = NDArray::FromVector( std::vector({.5, .5, .5, .5, .5, .5, .5})); IdArray rows = NDArray::FromVector(std::vector({0, 3})); IdArray etypes = has_data ? NDArray::FromVector(std::vector({0, 1, 0, 0, 2, 0, 3})) : NDArray::FromVector(std::vector({0, 0, 0, 3, 0, 1, 2})); for (int k = 0; k < 10; ++k) { auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true, etype_sorted); CheckSampledPerEtypeReplaceResult(rst, rows, has_data); } for (int k = 0; k < 10; ++k) { auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, false, etype_sorted); CheckSampledPerEtypeResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); if (has_data) { int counts = 0; counts += eset.count(std::make_tuple(0, 0, 2)); counts += eset.count(std::make_tuple(0, 1, 3)); counts += eset.count(std::make_tuple(0, 2, 5)); ASSERT_EQ(counts, 2); counts = 0; counts += eset.count(std::make_tuple(0, 3, 6)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(1, 1, 0)); ASSERT_EQ(counts, 0); counts = 0; counts += eset.count(std::make_tuple(3, 2, 1)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(3, 3, 4)); ASSERT_EQ(counts, 1); } else { int counts = 0; counts += eset.count(std::make_tuple(0, 0, 0)); counts += eset.count(std::make_tuple(0, 1, 1)); counts += eset.count(std::make_tuple(0, 2, 2)); ASSERT_EQ(counts, 2); counts = 0; counts += eset.count(std::make_tuple(0, 3, 3)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(1, 1, 4)); ASSERT_EQ(counts, 0); counts = 0; counts += eset.count(std::make_tuple(3, 2, 5)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(3, 3, 6)); ASSERT_EQ(counts, 1); } } prob = has_data ? NDArray::FromVector( std::vector({.0, .5, .0, .5, .5, .0, .5})) : NDArray::FromVector( std::vector({.0, .5, .0, .5, .0, .5, .5})); for (int k = 0; k < 10; ++k) { auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true, etype_sorted); CheckSampledPerEtypeReplaceResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); if (has_data) { ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 2))); ASSERT_FALSE(eset.count(std::make_tuple(0, 2, 5))); } else { ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 0))); ASSERT_FALSE(eset.count(std::make_tuple(0, 2, 2))); } } } TEST(RowwiseTest, TestCSRPerEtypeSampling) { _TestCSRPerEtypeSampling(true); _TestCSRPerEtypeSampling(true); _TestCSRPerEtypeSampling(true); _TestCSRPerEtypeSampling(true); _TestCSRPerEtypeSampling(false); _TestCSRPerEtypeSampling(false); _TestCSRPerEtypeSampling(false); _TestCSRPerEtypeSampling(false); _TestCSRPerEtypeSamplingSorted(true, true); _TestCSRPerEtypeSamplingSorted(true, true); _TestCSRPerEtypeSamplingSorted(true, true); _TestCSRPerEtypeSamplingSorted(true, true); _TestCSRPerEtypeSamplingSorted(false, true); _TestCSRPerEtypeSamplingSorted(false, true); _TestCSRPerEtypeSamplingSorted(false, true); _TestCSRPerEtypeSamplingSorted(false, true); _TestCSRPerEtypeSamplingSorted(true, false); _TestCSRPerEtypeSamplingSorted(true, false); _TestCSRPerEtypeSamplingSorted(true, false); _TestCSRPerEtypeSamplingSorted(true, false); _TestCSRPerEtypeSamplingSorted(false, false); _TestCSRPerEtypeSamplingSorted(false, false); _TestCSRPerEtypeSamplingSorted(false, false); _TestCSRPerEtypeSamplingSorted(false, false); } template void _TestCSRPerEtypeSamplingUniform(bool has_data) { auto mat = CSREtypes(has_data); FloatArray prob = aten::NullArray(); IdArray rows = NDArray::FromVector(std::vector({0, 3})); IdArray etypes = has_data ? NDArray::FromVector(std::vector({3, 1, 3, 3, 2, 3, 0})) : NDArray::FromVector(std::vector({3, 3, 3, 0, 3, 1, 2})); for (int k = 0; k < 10; ++k) { auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true); CheckSampledPerEtypeReplaceResult(rst, rows, has_data); } for (int k = 0; k < 10; ++k) { auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, false); CheckSampledPerEtypeResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); if (has_data) { int counts = 0; counts += eset.count(std::make_tuple(0, 0, 2)); counts += eset.count(std::make_tuple(0, 1, 3)); counts += eset.count(std::make_tuple(0, 2, 5)); ASSERT_EQ(counts, 2); counts = 0; counts += eset.count(std::make_tuple(0, 3, 6)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(1, 1, 0)); ASSERT_EQ(counts, 0); counts = 0; counts += eset.count(std::make_tuple(3, 2, 1)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(3, 3, 4)); ASSERT_EQ(counts, 1); } else { int counts = 0; counts += eset.count(std::make_tuple(0, 0, 0)); counts += eset.count(std::make_tuple(0, 1, 1)); counts += eset.count(std::make_tuple(0, 2, 2)); ASSERT_EQ(counts, 2); counts = 0; counts += eset.count(std::make_tuple(0, 3, 3)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(1, 1, 4)); ASSERT_EQ(counts, 0); counts = 0; counts += eset.count(std::make_tuple(3, 2, 5)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(3, 3, 6)); ASSERT_EQ(counts, 1); } } } template void _TestCSRPerEtypeSamplingUniformSorted(bool has_data, bool etype_sorted) { auto mat = CSREtypes(has_data); FloatArray prob = aten::NullArray(); IdArray rows = NDArray::FromVector(std::vector({0, 3})); IdArray etypes = has_data ? NDArray::FromVector(std::vector({0, 1, 0, 0, 2, 0, 3})) : NDArray::FromVector(std::vector({0, 0, 0, 3, 0, 1, 2})); for (int k = 0; k < 10; ++k) { auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true, etype_sorted); CheckSampledPerEtypeReplaceResult(rst, rows, has_data); } for (int k = 0; k < 10; ++k) { auto rst = CSRRowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, false, etype_sorted); CheckSampledPerEtypeResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); if (has_data) { int counts = 0; counts += eset.count(std::make_tuple(0, 0, 2)); counts += eset.count(std::make_tuple(0, 1, 3)); counts += eset.count(std::make_tuple(0, 2, 5)); ASSERT_EQ(counts, 2); counts = 0; counts += eset.count(std::make_tuple(0, 3, 6)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(1, 1, 0)); ASSERT_EQ(counts, 0); counts = 0; counts += eset.count(std::make_tuple(3, 2, 1)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(3, 3, 4)); ASSERT_EQ(counts, 1); } else { int counts = 0; counts += eset.count(std::make_tuple(0, 0, 0)); counts += eset.count(std::make_tuple(0, 1, 1)); counts += eset.count(std::make_tuple(0, 2, 2)); ASSERT_EQ(counts, 2); counts = 0; counts += eset.count(std::make_tuple(0, 3, 3)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(1, 1, 4)); ASSERT_EQ(counts, 0); counts = 0; counts += eset.count(std::make_tuple(3, 2, 5)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(3, 3, 6)); ASSERT_EQ(counts, 1); } } } TEST(RowwiseTest, TestCSRPerEtypeSamplingUniform) { _TestCSRPerEtypeSamplingUniform(true); _TestCSRPerEtypeSamplingUniform(true); _TestCSRPerEtypeSamplingUniform(true); _TestCSRPerEtypeSamplingUniform(true); _TestCSRPerEtypeSamplingUniform(false); _TestCSRPerEtypeSamplingUniform(false); _TestCSRPerEtypeSamplingUniform(false); _TestCSRPerEtypeSamplingUniform(false); _TestCSRPerEtypeSamplingUniformSorted(true, true); _TestCSRPerEtypeSamplingUniformSorted(true, true); _TestCSRPerEtypeSamplingUniformSorted(true, true); _TestCSRPerEtypeSamplingUniformSorted(true, true); _TestCSRPerEtypeSamplingUniformSorted(false, true); _TestCSRPerEtypeSamplingUniformSorted(false, true); _TestCSRPerEtypeSamplingUniformSorted(false, true); _TestCSRPerEtypeSamplingUniformSorted(false, true); _TestCSRPerEtypeSamplingUniformSorted(true, false); _TestCSRPerEtypeSamplingUniformSorted(true, false); _TestCSRPerEtypeSamplingUniformSorted(true, false); _TestCSRPerEtypeSamplingUniformSorted(true, false); _TestCSRPerEtypeSamplingUniformSorted(false, false); _TestCSRPerEtypeSamplingUniformSorted(false, false); _TestCSRPerEtypeSamplingUniformSorted(false, false); _TestCSRPerEtypeSamplingUniformSorted(false, false); } template void _TestCOOSampling(bool has_data) { auto mat = COO(has_data); FloatArray prob = NDArray::FromVector( std::vector({.5, .5, .5, .5, .5})); IdArray rows = NDArray::FromVector(std::vector({0, 3})); for (int k = 0; k < 10; ++k) { auto rst = COORowWiseSampling(mat, rows, 2, prob, true); CheckSampledResult(rst, rows, has_data); } for (int k = 0; k < 10; ++k) { auto rst = COORowWiseSampling(mat, rows, 2, prob, false); CheckSampledResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); ASSERT_EQ(eset.size(), 4); if (has_data) { ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 2))); ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3))); ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 1))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); } else { ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 0))); ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1))); ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 3))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); } } prob = NDArray::FromVector( std::vector({.0, .5, .5, .0, .5})); for (int k = 0; k < 100; ++k) { auto rst = COORowWiseSampling(mat, rows, 2, prob, true); CheckSampledResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); if (has_data) { ASSERT_FALSE(eset.count(std::make_tuple(0, 1, 3))); } else { ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 0))); ASSERT_FALSE(eset.count(std::make_tuple(3, 2, 3))); } } } TEST(RowwiseTest, TestCOOSampling) { _TestCOOSampling(true); _TestCOOSampling(true); _TestCOOSampling(true); _TestCOOSampling(true); _TestCOOSampling(false); _TestCOOSampling(false); _TestCOOSampling(false); _TestCOOSampling(false); } template void _TestCOOSamplingUniform(bool has_data) { auto mat = COO(has_data); FloatArray prob = aten::NullArray(); IdArray rows = NDArray::FromVector(std::vector({0, 3})); for (int k = 0; k < 10; ++k) { auto rst = COORowWiseSampling(mat, rows, 2, prob, true); CheckSampledResult(rst, rows, has_data); } for (int k = 0; k < 10; ++k) { auto rst = COORowWiseSampling(mat, rows, 2, prob, false); CheckSampledResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); if (has_data) { ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 2))); ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3))); ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 1))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); } else { ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 0))); ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1))); ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 3))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); } } } TEST(RowwiseTest, TestCOOSamplingUniform) { _TestCOOSamplingUniform(true); _TestCOOSamplingUniform(true); _TestCOOSamplingUniform(true); _TestCOOSamplingUniform(true); _TestCOOSamplingUniform(false); _TestCOOSamplingUniform(false); _TestCOOSamplingUniform(false); _TestCOOSamplingUniform(false); } template void _TestCOOerEtypeSampling(bool has_data) { auto mat = COOEtypes(has_data); FloatArray prob = NDArray::FromVector( std::vector({.5, .5, .5, .5, .5, .5, .5})); IdArray rows = NDArray::FromVector(std::vector({0, 3})); IdArray etypes = has_data ? NDArray::FromVector(std::vector({3, 1, 3, 3, 2, 3, 0})) : NDArray::FromVector(std::vector({3, 3, 3, 0, 3, 1, 2})); for (int k = 0; k < 10; ++k) { auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true); CheckSampledPerEtypeReplaceResult(rst, rows, has_data); } for (int k = 0; k < 10; ++k) { auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, false); CheckSampledPerEtypeResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); if (has_data) { int counts = 0; counts += eset.count(std::make_tuple(0, 0, 2)); counts += eset.count(std::make_tuple(0, 1, 3)); counts += eset.count(std::make_tuple(0, 2, 5)); ASSERT_EQ(counts, 2); counts = 0; counts += eset.count(std::make_tuple(0, 3, 6)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(1, 1, 0)); ASSERT_EQ(counts, 0); counts = 0; counts += eset.count(std::make_tuple(3, 2, 1)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(3, 3, 4)); ASSERT_EQ(counts, 1); } else { int counts = 0; counts += eset.count(std::make_tuple(0, 0, 0)); counts += eset.count(std::make_tuple(0, 1, 1)); counts += eset.count(std::make_tuple(0, 2, 2)); ASSERT_EQ(counts, 2); counts = 0; counts += eset.count(std::make_tuple(0, 3, 3)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(1, 1, 4)); ASSERT_EQ(counts, 0); counts = 0; counts += eset.count(std::make_tuple(3, 2, 5)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(3, 3, 6)); ASSERT_EQ(counts, 1); } } prob = has_data ? NDArray::FromVector( std::vector({.0, .5, .0, .5, .5, .0, .5})) : NDArray::FromVector( std::vector({.0, .5, .0, .5, .0, .5, .5})); for (int k = 0; k < 10; ++k) { auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true); CheckSampledPerEtypeReplaceResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); if (has_data) { ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 2))); ASSERT_FALSE(eset.count(std::make_tuple(0, 2, 5))); } else { ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 0))); ASSERT_FALSE(eset.count(std::make_tuple(0, 2, 2))); } } } template void _TestCOOerEtypeSamplingSorted(bool has_data, bool etype_sorted) { auto mat = COOEtypes(has_data); FloatArray prob = NDArray::FromVector( std::vector({.5, .5, .5, .5, .5, .5, .5})); IdArray rows = NDArray::FromVector(std::vector({0, 3})); IdArray etypes = has_data ? NDArray::FromVector(std::vector({0, 1, 0, 0, 2, 0, 3})) : NDArray::FromVector(std::vector({0, 0, 0, 3, 0, 1, 2})); for (int k = 0; k < 10; ++k) { auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true, etype_sorted); CheckSampledPerEtypeReplaceResult(rst, rows, has_data); } for (int k = 0; k < 10; ++k) { auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, false, etype_sorted); CheckSampledPerEtypeResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); if (has_data) { int counts = 0; counts += eset.count(std::make_tuple(0, 0, 2)); counts += eset.count(std::make_tuple(0, 1, 3)); counts += eset.count(std::make_tuple(0, 2, 5)); ASSERT_EQ(counts, 2); counts = 0; counts += eset.count(std::make_tuple(0, 3, 6)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(1, 1, 0)); ASSERT_EQ(counts, 0); counts = 0; counts += eset.count(std::make_tuple(3, 2, 1)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(3, 3, 4)); ASSERT_EQ(counts, 1); } else { int counts = 0; counts += eset.count(std::make_tuple(0, 0, 0)); counts += eset.count(std::make_tuple(0, 1, 1)); counts += eset.count(std::make_tuple(0, 2, 2)); ASSERT_EQ(counts, 2); counts = 0; counts += eset.count(std::make_tuple(0, 3, 3)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(1, 1, 4)); ASSERT_EQ(counts, 0); counts = 0; counts += eset.count(std::make_tuple(3, 2, 5)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(3, 3, 6)); ASSERT_EQ(counts, 1); } } prob = has_data ? NDArray::FromVector( std::vector({.0, .5, .0, .5, .5, .0, .5})) : NDArray::FromVector( std::vector({.0, .5, .0, .5, .0, .5, .5})); for (int k = 0; k < 10; ++k) { auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true, etype_sorted); CheckSampledPerEtypeReplaceResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); if (has_data) { ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 2))); ASSERT_FALSE(eset.count(std::make_tuple(0, 2, 5))); } else { ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 0))); ASSERT_FALSE(eset.count(std::make_tuple(0, 2, 2))); } } } TEST(RowwiseTest, TestCOOerEtypeSampling) { _TestCOOerEtypeSampling(true); _TestCOOerEtypeSampling(true); _TestCOOerEtypeSampling(true); _TestCOOerEtypeSampling(true); _TestCOOerEtypeSampling(false); _TestCOOerEtypeSampling(false); _TestCOOerEtypeSampling(false); _TestCOOerEtypeSampling(false); _TestCOOerEtypeSamplingSorted(true, true); _TestCOOerEtypeSamplingSorted(true, true); _TestCOOerEtypeSamplingSorted(true, true); _TestCOOerEtypeSamplingSorted(true, true); _TestCOOerEtypeSamplingSorted(false, true); _TestCOOerEtypeSamplingSorted(false, true); _TestCOOerEtypeSamplingSorted(false, true); _TestCOOerEtypeSamplingSorted(false, true); _TestCOOerEtypeSamplingSorted(true, false); _TestCOOerEtypeSamplingSorted(true, false); _TestCOOerEtypeSamplingSorted(true, false); _TestCOOerEtypeSamplingSorted(true, false); _TestCOOerEtypeSamplingSorted(false, false); _TestCOOerEtypeSamplingSorted(false, false); _TestCOOerEtypeSamplingSorted(false, false); _TestCOOerEtypeSamplingSorted(false, false); } template void _TestCOOPerEtypeSamplingUniform(bool has_data) { auto mat = COOEtypes(has_data); FloatArray prob = aten::NullArray(); IdArray rows = NDArray::FromVector(std::vector({0, 3})); IdArray etypes = has_data ? NDArray::FromVector(std::vector({3, 1, 3, 3, 2, 3, 0})) : NDArray::FromVector(std::vector({3, 3, 3, 0, 3, 1, 2})); for (int k = 0; k < 10; ++k) { auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true); CheckSampledPerEtypeReplaceResult(rst, rows, has_data); } for (int k = 0; k < 10; ++k) { auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, false); CheckSampledPerEtypeResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); if (has_data) { int counts = 0; counts += eset.count(std::make_tuple(0, 0, 2)); counts += eset.count(std::make_tuple(0, 1, 3)); counts += eset.count(std::make_tuple(0, 2, 5)); ASSERT_EQ(counts, 2); counts = 0; counts += eset.count(std::make_tuple(0, 3, 6)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(1, 1, 0)); ASSERT_EQ(counts, 0); counts = 0; counts += eset.count(std::make_tuple(3, 2, 1)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(3, 3, 4)); ASSERT_EQ(counts, 1); } else { int counts = 0; counts += eset.count(std::make_tuple(0, 0, 0)); counts += eset.count(std::make_tuple(0, 1, 1)); counts += eset.count(std::make_tuple(0, 2, 2)); ASSERT_EQ(counts, 2); counts = 0; counts += eset.count(std::make_tuple(0, 3, 3)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(1, 1, 4)); ASSERT_EQ(counts, 0); counts = 0; counts += eset.count(std::make_tuple(3, 2, 5)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(3, 3, 6)); ASSERT_EQ(counts, 1); } } } template void _TestCOOPerEtypeSamplingUniformSorted(bool has_data, bool etype_sorted) { auto mat = COOEtypes(has_data); FloatArray prob = aten::NullArray(); IdArray rows = NDArray::FromVector(std::vector({0, 3})); IdArray etypes = has_data ? NDArray::FromVector(std::vector({0, 1, 0, 0, 2, 0, 3})) : NDArray::FromVector(std::vector({0, 0, 0, 3, 0, 1, 2})); for (int k = 0; k < 10; ++k) { auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, true, etype_sorted); CheckSampledPerEtypeReplaceResult(rst, rows, has_data); } for (int k = 0; k < 10; ++k) { auto rst = COORowWisePerEtypeSampling(mat, rows, etypes, {2, 2, 2, 2}, prob, false, etype_sorted); CheckSampledPerEtypeResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); if (has_data) { int counts = 0; counts += eset.count(std::make_tuple(0, 0, 2)); counts += eset.count(std::make_tuple(0, 1, 3)); counts += eset.count(std::make_tuple(0, 2, 5)); ASSERT_EQ(counts, 2); counts = 0; counts += eset.count(std::make_tuple(0, 3, 6)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(1, 1, 0)); ASSERT_EQ(counts, 0); counts = 0; counts += eset.count(std::make_tuple(3, 2, 1)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(3, 3, 4)); ASSERT_EQ(counts, 1); } else { int counts = 0; counts += eset.count(std::make_tuple(0, 0, 0)); counts += eset.count(std::make_tuple(0, 1, 1)); counts += eset.count(std::make_tuple(0, 2, 2)); ASSERT_EQ(counts, 2); counts = 0; counts += eset.count(std::make_tuple(0, 3, 3)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(1, 1, 4)); ASSERT_EQ(counts, 0); counts = 0; counts += eset.count(std::make_tuple(3, 2, 5)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(3, 3, 6)); ASSERT_EQ(counts, 1); } } } TEST(RowwiseTest, TestCOOPerEtypeSamplingUniform) { _TestCOOPerEtypeSamplingUniform(true); _TestCOOPerEtypeSamplingUniform(true); _TestCOOPerEtypeSamplingUniform(true); _TestCOOPerEtypeSamplingUniform(true); _TestCOOPerEtypeSamplingUniform(false); _TestCOOPerEtypeSamplingUniform(false); _TestCOOPerEtypeSamplingUniform(false); _TestCOOPerEtypeSamplingUniform(false); _TestCOOPerEtypeSamplingUniformSorted(true, true); _TestCOOPerEtypeSamplingUniformSorted(true, true); _TestCOOPerEtypeSamplingUniformSorted(true, true); _TestCOOPerEtypeSamplingUniformSorted(true, true); _TestCOOPerEtypeSamplingUniformSorted(false, true); _TestCOOPerEtypeSamplingUniformSorted(false, true); _TestCOOPerEtypeSamplingUniformSorted(false, true); _TestCOOPerEtypeSamplingUniformSorted(false, true); _TestCOOPerEtypeSamplingUniformSorted(true, false); _TestCOOPerEtypeSamplingUniformSorted(true, false); _TestCOOPerEtypeSamplingUniformSorted(true, false); _TestCOOPerEtypeSamplingUniformSorted(true, false); _TestCOOPerEtypeSamplingUniformSorted(false, false); _TestCOOPerEtypeSamplingUniformSorted(false, false); _TestCOOPerEtypeSamplingUniformSorted(false, false); _TestCOOPerEtypeSamplingUniformSorted(false, false); } template void _TestCSRTopk(bool has_data) { auto mat = CSR(has_data); FloatArray weight = NDArray::FromVector( std::vector({.1f, .0f, -.1f, .2f, .5f})); // -.1, .2, .1, .0, .5 IdArray rows = NDArray::FromVector(std::vector({0, 3})); { auto rst = CSRRowWiseTopk(mat, rows, 1, weight, true); auto eset = ToEdgeSet(rst); ASSERT_EQ(eset.size(), 2); if (has_data) { ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 2))); ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 1))); } else { ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1))); ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 3))); } } { auto rst = CSRRowWiseTopk(mat, rows, 1, weight, false); auto eset = ToEdgeSet(rst); ASSERT_EQ(eset.size(), 2); if (has_data) { ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); } else { ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 0))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); } } } TEST(RowwiseTest, TestCSRTopk) { _TestCSRTopk(true); _TestCSRTopk(true); _TestCSRTopk(true); _TestCSRTopk(true); _TestCSRTopk(false); _TestCSRTopk(false); _TestCSRTopk(false); _TestCSRTopk(false); } template void _TestCOOTopk(bool has_data) { auto mat = COO(has_data); FloatArray weight = NDArray::FromVector( std::vector({.1f, .0f, -.1f, .2f, .5f})); // -.1, .2, .1, .0, .5 IdArray rows = NDArray::FromVector(std::vector({0, 3})); { auto rst = COORowWiseTopk(mat, rows, 1, weight, true); auto eset = ToEdgeSet(rst); ASSERT_EQ(eset.size(), 2); if (has_data) { ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 2))); ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 1))); } else { ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1))); ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 3))); } } { auto rst = COORowWiseTopk(mat, rows, 1, weight, false); auto eset = ToEdgeSet(rst); ASSERT_EQ(eset.size(), 2); if (has_data) { ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); } else { ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 0))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); } } } TEST(RowwiseTest, TestCOOTopk) { _TestCOOTopk(true); _TestCOOTopk(true); _TestCOOTopk(true); _TestCOOTopk(true); _TestCOOTopk(false); _TestCOOTopk(false); _TestCOOTopk(false); _TestCOOTopk(false); } template void _TestCSRSamplingBiased(bool has_data) { auto mat = CSR(has_data); // 0 - 0,1 // 1 - 1 // 3 - 2,3 NDArray tag_offset = NDArray::FromVector( std::vector({0, 1, 2, 0, 0, 1, 0, 0, 0, 0, 1, 2})); tag_offset = tag_offset.CreateView({4, 3}, tag_offset->dtype); IdArray rows = NDArray::FromVector(std::vector({0, 1, 3})); FloatArray bias = NDArray::FromVector( std::vector({0, 0.5}) ); for (int k = 0 ; k < 10 ; ++k) { auto rst = CSRRowWiseSamplingBiased(mat, rows, 1, tag_offset, bias, false); CheckSampledResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); if (has_data) { ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3))); ASSERT_TRUE(eset.count(std::make_tuple(1, 1, 0))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); } else { ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1))); ASSERT_TRUE(eset.count(std::make_tuple(1, 1, 2))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); } } for (int k = 0 ; k < 10 ; ++k) { auto rst = CSRRowWiseSamplingBiased(mat, rows, 3, tag_offset, bias, true); CheckSampledResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); if (has_data) { ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3))); ASSERT_TRUE(eset.count(std::make_tuple(1, 1, 0))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 2))); ASSERT_FALSE(eset.count(std::make_tuple(3, 2, 1))); } else { ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1))); ASSERT_TRUE(eset.count(std::make_tuple(1, 1, 2))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 0))); ASSERT_FALSE(eset.count(std::make_tuple(3, 2, 3))); } } } TEST(RowwiseTest, TestCSRSamplingBiased) { _TestCSRSamplingBiased(true); _TestCSRSamplingBiased(false); _TestCSRSamplingBiased(true); _TestCSRSamplingBiased(false); _TestCSRSamplingBiased(true); _TestCSRSamplingBiased(false); _TestCSRSamplingBiased(true); _TestCSRSamplingBiased(false); }