#include #include #include #include #include "./common.h" using namespace dgl; using namespace dgl::runtime; using namespace dgl::aten; template using ETuple = std::tuple; template std::set> AllEdgeSet(bool has_data) { if (has_data) { std::set> eset; eset.insert(ETuple{0, 0, 2}); eset.insert(ETuple{0, 1, 3}); eset.insert(ETuple{1, 1, 0}); eset.insert(ETuple{3, 2, 1}); eset.insert(ETuple{3, 3, 4}); return eset; } else { std::set> eset; eset.insert(ETuple{0, 0, 0}); eset.insert(ETuple{0, 1, 1}); eset.insert(ETuple{1, 1, 2}); eset.insert(ETuple{3, 2, 3}); eset.insert(ETuple{3, 3, 4}); return eset; } } template std::set> AllEdgePerEtypeSet(bool has_data) { if (has_data) { std::set> eset; eset.insert(ETuple{0, 0, 0}); eset.insert(ETuple{0, 1, 1}); eset.insert(ETuple{0, 2, 4}); eset.insert(ETuple{0, 3, 6}); eset.insert(ETuple{3, 2, 5}); eset.insert(ETuple{3, 3, 3}); return eset; } else { std::set> eset; eset.insert(ETuple{0, 0, 0}); eset.insert(ETuple{0, 1, 1}); eset.insert(ETuple{0, 2, 2}); eset.insert(ETuple{0, 3, 3}); eset.insert(ETuple{3, 3, 5}); eset.insert(ETuple{3, 2, 6}); return eset; } } template std::set> ToEdgeSet(COOMatrix mat) { std::set> eset; Idx* row = static_cast(mat.row->data); Idx* col = static_cast(mat.col->data); Idx* data = static_cast(mat.data->data); for (int64_t i = 0; i < mat.row->shape[0]; ++i) { //std::cout << row[i] << " " << col[i] << " " << data[i] << std::endl; eset.emplace(row[i], col[i], data[i]); } return eset; } template void CheckSampledResult(COOMatrix mat, IdArray rows, bool has_data) { ASSERT_EQ(mat.num_rows, 4); ASSERT_EQ(mat.num_cols, 4); Idx* row = static_cast(mat.row->data); Idx* col = static_cast(mat.col->data); Idx* data = static_cast(mat.data->data); const auto& gt = AllEdgeSet(has_data); for (int64_t i = 0; i < mat.row->shape[0]; ++i) { ASSERT_TRUE(gt.count(std::make_tuple(row[i], col[i], data[i]))); ASSERT_TRUE(IsInArray(rows, row[i])); } } template void CheckSampledPerEtypeResult(COOMatrix mat, IdArray rows, bool has_data) { ASSERT_EQ(mat.num_rows, 4); ASSERT_EQ(mat.num_cols, 4); Idx* row = static_cast(mat.row->data); Idx* col = static_cast(mat.col->data); Idx* data = static_cast(mat.data->data); const auto& gt = AllEdgePerEtypeSet(has_data); for (int64_t i = 0; i < mat.row->shape[0]; ++i) { int64_t count = gt.count(std::make_tuple(row[i], col[i], data[i])); ASSERT_TRUE(count); ASSERT_TRUE(IsInArray(rows, row[i])); } } template CSRMatrix CSR(bool has_data) { IdArray indptr = NDArray::FromVector(std::vector({0, 2, 3, 3, 5})); IdArray indices = NDArray::FromVector(std::vector({0, 1, 1, 2, 3})); IdArray data = NDArray::FromVector(std::vector({2, 3, 0, 1, 4})); if (has_data) return CSRMatrix(4, 4, indptr, indices, data); else return CSRMatrix(4, 4, indptr, indices); } template COOMatrix COO(bool has_data) { IdArray row = NDArray::FromVector(std::vector({0, 0, 1, 3, 3})); IdArray col = NDArray::FromVector(std::vector({0, 1, 1, 2, 3})); IdArray data = NDArray::FromVector(std::vector({2, 3, 0, 1, 4})); if (has_data) return COOMatrix(4, 4, row, col, data); else return COOMatrix(4, 4, row, col); } template std::pair> CSREtypes(bool has_data) { IdArray indptr = NDArray::FromVector(std::vector({0, 4, 5, 5, 7})); IdArray indices = NDArray::FromVector(std::vector({0, 1, 2, 3, 1, 3, 2})); IdArray data = NDArray::FromVector(std::vector({0, 1, 4, 6, 2, 3, 5})); auto eid2etype_offsets = std::vector({0, 4, 5, 6, 7}); if (has_data) return {CSRMatrix(4, 4, indptr, indices, data), eid2etype_offsets}; else return {CSRMatrix(4, 4, indptr, indices), eid2etype_offsets}; } template std::pair> COOEtypes(bool has_data) { IdArray row = NDArray::FromVector(std::vector({0, 0, 0, 0, 1, 3, 3})); IdArray col = NDArray::FromVector(std::vector({0, 1, 2, 3, 1, 3, 2})); IdArray data = NDArray::FromVector(std::vector({0, 1, 4, 6, 2, 3, 5})); auto eid2etype_offsets = std::vector({0, 4, 5, 6, 7}); if (has_data) return {COOMatrix(4, 4, row, col, data), eid2etype_offsets}; else return {COOMatrix(4, 4, row, col), eid2etype_offsets}; } template void _TestCSRSampling(bool has_data) { auto mat = CSR(has_data); FloatArray prob = NDArray::FromVector( std::vector({.5, .5, .5, .5, .5})); IdArray rows = NDArray::FromVector(std::vector({0, 3})); for (int k = 0; k < 10; ++k) { auto rst = CSRRowWiseSampling(mat, rows, 2, prob, true); CheckSampledResult(rst, rows, has_data); } for (int k = 0; k < 10; ++k) { auto rst = CSRRowWiseSampling(mat, rows, 2, prob, false); CheckSampledResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); ASSERT_EQ(eset.size(), 4); if (has_data) { ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 2))); ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3))); ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 1))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); } else { ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 0))); ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1))); ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 3))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); } } prob = NDArray::FromVector( std::vector({.0, .5, .5, .0, .5})); for (int k = 0; k < 100; ++k) { auto rst = CSRRowWiseSampling(mat, rows, 2, prob, true); CheckSampledResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); if (has_data) { ASSERT_FALSE(eset.count(std::make_tuple(0, 1, 3))); } else { ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 0))); ASSERT_FALSE(eset.count(std::make_tuple(3, 2, 3))); } } } TEST(RowwiseTest, TestCSRSampling) { _TestCSRSampling(true); _TestCSRSampling(true); _TestCSRSampling(true); _TestCSRSampling(true); _TestCSRSampling(false); _TestCSRSampling(false); _TestCSRSampling(false); _TestCSRSampling(false); } template void _TestCSRSamplingUniform(bool has_data) { auto mat = CSR(has_data); FloatArray prob = aten::NullArray(); IdArray rows = NDArray::FromVector(std::vector({0, 3})); for (int k = 0; k < 10; ++k) { auto rst = CSRRowWiseSampling(mat, rows, 2, prob, true); CheckSampledResult(rst, rows, has_data); } for (int k = 0; k < 10; ++k) { auto rst = CSRRowWiseSampling(mat, rows, 2, prob, false); CheckSampledResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); if (has_data) { ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 2))); ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3))); ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 1))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); } else { ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 0))); ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1))); ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 3))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); } } } TEST(RowwiseTest, TestCSRSamplingUniform) { _TestCSRSamplingUniform(true); _TestCSRSamplingUniform(true); _TestCSRSamplingUniform(true); _TestCSRSamplingUniform(true); _TestCSRSamplingUniform(false); _TestCSRSamplingUniform(false); _TestCSRSamplingUniform(false); _TestCSRSamplingUniform(false); } template void _TestCSRPerEtypeSampling(bool has_data) { auto pair = CSREtypes(has_data); auto mat = pair.first; auto eid2etype_offset = pair.second; std::vector prob = { NDArray::FromVector(std::vector({.5, .5, .5, .5})), NDArray::FromVector(std::vector({.5})), NDArray::FromVector(std::vector({.5})), NDArray::FromVector(std::vector({.5})) }; IdArray rows = NDArray::FromVector(std::vector({0, 3})); for (int k = 0; k < 10; ++k) { auto rst = CSRRowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true); CheckSampledPerEtypeResult(rst, rows, has_data); } for (int k = 0; k < 10; ++k) { auto rst = CSRRowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false); CheckSampledPerEtypeResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); if (has_data) { int counts = 0; counts += eset.count(std::make_tuple(0, 0, 0)); counts += eset.count(std::make_tuple(0, 1, 1)); ASSERT_EQ(counts, 2); counts = 0; counts += eset.count(std::make_tuple(0, 2, 4)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(0, 3, 6)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(1, 1, 2)); ASSERT_EQ(counts, 0); counts = 0; counts += eset.count(std::make_tuple(3, 2, 5)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(3, 3, 3)); ASSERT_EQ(counts, 1); } else { int counts = 0; counts += eset.count(std::make_tuple(0, 0, 0)); counts += eset.count(std::make_tuple(0, 1, 1)); counts += eset.count(std::make_tuple(0, 2, 2)); counts += eset.count(std::make_tuple(0, 3, 3)); ASSERT_EQ(counts, 2); counts = 0; counts += eset.count(std::make_tuple(1, 1, 4)); ASSERT_EQ(counts, 0); counts = 0; counts += eset.count(std::make_tuple(3, 3, 5)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(3, 2, 6)); ASSERT_EQ(counts, 1); } } prob = { NDArray::FromVector(std::vector({.0, .5, .0, .0})), NDArray::FromVector(std::vector({.5})), NDArray::FromVector(std::vector({.5})), NDArray::FromVector(std::vector({.5})) }; for (int k = 0; k < 10; ++k) { auto rst = CSRRowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true); CheckSampledPerEtypeResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); if (has_data) { ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 0))); } else { ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 0))); ASSERT_FALSE(eset.count(std::make_tuple(0, 2, 2))); ASSERT_FALSE(eset.count(std::make_tuple(0, 3, 3))); } } } template void _TestCSRPerEtypeSamplingSorted() { auto pair = CSREtypes(true); auto mat = pair.first; auto eid2etype_offset = pair.second; std::vector prob = { NDArray::FromVector(std::vector({.5, .5, .5, .5})), NDArray::FromVector(std::vector({.5})), NDArray::FromVector(std::vector({.5})), NDArray::FromVector(std::vector({.5})) }; IdArray rows = NDArray::FromVector(std::vector({0, 3})); for (int k = 0; k < 10; ++k) { auto rst = CSRRowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true, true); CheckSampledPerEtypeResult(rst, rows, true); } for (int k = 0; k < 10; ++k) { auto rst = CSRRowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false, true); CheckSampledPerEtypeResult(rst, rows, true); auto eset = ToEdgeSet(rst); int counts = 0; counts += eset.count(std::make_tuple(0, 0, 0)); counts += eset.count(std::make_tuple(0, 1, 1)); ASSERT_EQ(counts, 2); counts = 0; counts += eset.count(std::make_tuple(0, 2, 4)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(0, 3, 6)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(1, 1, 2)); ASSERT_EQ(counts, 0); counts = 0; counts += eset.count(std::make_tuple(3, 2, 5)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(3, 3, 3)); ASSERT_EQ(counts, 1); } prob = { NDArray::FromVector(std::vector({.0, .5, .0, .0})), NDArray::FromVector(std::vector({.5})), NDArray::FromVector(std::vector({.5})), NDArray::FromVector(std::vector({.5})) }; for (int k = 0; k < 10; ++k) { auto rst = CSRRowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true, true); CheckSampledPerEtypeResult(rst, rows, true); auto eset = ToEdgeSet(rst); ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 0))); } } TEST(RowwiseTest, TestCSRPerEtypeSampling) { _TestCSRPerEtypeSampling(true); _TestCSRPerEtypeSampling(true); _TestCSRPerEtypeSampling(true); _TestCSRPerEtypeSampling(true); _TestCSRPerEtypeSampling(false); _TestCSRPerEtypeSampling(false); _TestCSRPerEtypeSampling(false); _TestCSRPerEtypeSampling(false); _TestCSRPerEtypeSamplingSorted(); _TestCSRPerEtypeSamplingSorted(); _TestCSRPerEtypeSamplingSorted(); _TestCSRPerEtypeSamplingSorted(); } template void _TestCSRPerEtypeSamplingUniform(bool has_data) { auto pair = CSREtypes(has_data); auto mat = pair.first; auto eid2etype_offset = pair.second; std::vector prob = { aten::NullArray(), aten::NullArray(), aten::NullArray(), aten::NullArray() }; IdArray rows = NDArray::FromVector(std::vector({0, 3})); for (int k = 0; k < 10; ++k) { auto rst = CSRRowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true); CheckSampledPerEtypeResult(rst, rows, has_data); } for (int k = 0; k < 10; ++k) { auto rst = CSRRowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false); CheckSampledPerEtypeResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); if (has_data) { int counts = 0; counts += eset.count(std::make_tuple(0, 0, 0)); counts += eset.count(std::make_tuple(0, 1, 1)); ASSERT_EQ(counts, 2); counts = 0; counts += eset.count(std::make_tuple(0, 2, 4)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(0, 3, 6)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(1, 1, 2)); ASSERT_EQ(counts, 0); counts = 0; counts += eset.count(std::make_tuple(3, 2, 5)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(3, 3, 3)); ASSERT_EQ(counts, 1); } else { int counts = 0; counts += eset.count(std::make_tuple(0, 0, 0)); counts += eset.count(std::make_tuple(0, 1, 1)); counts += eset.count(std::make_tuple(0, 2, 2)); counts += eset.count(std::make_tuple(0, 3, 3)); ASSERT_EQ(counts, 2); counts = 0; counts += eset.count(std::make_tuple(1, 1, 4)); ASSERT_EQ(counts, 0); counts = 0; counts += eset.count(std::make_tuple(3, 3, 5)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(3, 2, 6)); ASSERT_EQ(counts, 1); } } } template void _TestCSRPerEtypeSamplingUniformSorted() { auto pair = CSREtypes(true); auto mat = pair.first; auto eid2etype_offset = pair.second; std::vector prob = { aten::NullArray(), aten::NullArray(), aten::NullArray(), aten::NullArray() }; IdArray rows = NDArray::FromVector(std::vector({0, 3})); for (int k = 0; k < 10; ++k) { auto rst = CSRRowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true, true); CheckSampledPerEtypeResult(rst, rows, true); } for (int k = 0; k < 10; ++k) { auto rst = CSRRowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false, true); CheckSampledPerEtypeResult(rst, rows, true); auto eset = ToEdgeSet(rst); int counts = 0; counts += eset.count(std::make_tuple(0, 0, 0)); counts += eset.count(std::make_tuple(0, 1, 1)); ASSERT_EQ(counts, 2); counts = 0; counts += eset.count(std::make_tuple(0, 2, 4)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(0, 3, 6)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(1, 1, 2)); ASSERT_EQ(counts, 0); counts = 0; counts += eset.count(std::make_tuple(3, 2, 5)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(3, 3, 3)); ASSERT_EQ(counts, 1); } } TEST(RowwiseTest, TestCSRPerEtypeSamplingUniform) { _TestCSRPerEtypeSamplingUniform(true); _TestCSRPerEtypeSamplingUniform(true); _TestCSRPerEtypeSamplingUniform(true); _TestCSRPerEtypeSamplingUniform(true); _TestCSRPerEtypeSamplingUniform(false); _TestCSRPerEtypeSamplingUniform(false); _TestCSRPerEtypeSamplingUniform(false); _TestCSRPerEtypeSamplingUniform(false); _TestCSRPerEtypeSamplingUniformSorted(); _TestCSRPerEtypeSamplingUniformSorted(); _TestCSRPerEtypeSamplingUniformSorted(); _TestCSRPerEtypeSamplingUniformSorted(); } template void _TestCOOSampling(bool has_data) { auto mat = COO(has_data); FloatArray prob = NDArray::FromVector( std::vector({.5, .5, .5, .5, .5})); IdArray rows = NDArray::FromVector(std::vector({0, 3})); for (int k = 0; k < 10; ++k) { auto rst = COORowWiseSampling(mat, rows, 2, prob, true); CheckSampledResult(rst, rows, has_data); } for (int k = 0; k < 10; ++k) { auto rst = COORowWiseSampling(mat, rows, 2, prob, false); CheckSampledResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); ASSERT_EQ(eset.size(), 4); if (has_data) { ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 2))); ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3))); ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 1))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); } else { ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 0))); ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1))); ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 3))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); } } prob = NDArray::FromVector( std::vector({.0, .5, .5, .0, .5})); for (int k = 0; k < 100; ++k) { auto rst = COORowWiseSampling(mat, rows, 2, prob, true); CheckSampledResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); if (has_data) { ASSERT_FALSE(eset.count(std::make_tuple(0, 1, 3))); } else { ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 0))); ASSERT_FALSE(eset.count(std::make_tuple(3, 2, 3))); } } } TEST(RowwiseTest, TestCOOSampling) { _TestCOOSampling(true); _TestCOOSampling(true); _TestCOOSampling(true); _TestCOOSampling(true); _TestCOOSampling(false); _TestCOOSampling(false); _TestCOOSampling(false); _TestCOOSampling(false); } template void _TestCOOSamplingUniform(bool has_data) { auto mat = COO(has_data); FloatArray prob = aten::NullArray(); IdArray rows = NDArray::FromVector(std::vector({0, 3})); for (int k = 0; k < 10; ++k) { auto rst = COORowWiseSampling(mat, rows, 2, prob, true); CheckSampledResult(rst, rows, has_data); } for (int k = 0; k < 10; ++k) { auto rst = COORowWiseSampling(mat, rows, 2, prob, false); CheckSampledResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); if (has_data) { ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 2))); ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3))); ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 1))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); } else { ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 0))); ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1))); ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 3))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); } } } TEST(RowwiseTest, TestCOOSamplingUniform) { _TestCOOSamplingUniform(true); _TestCOOSamplingUniform(true); _TestCOOSamplingUniform(true); _TestCOOSamplingUniform(true); _TestCOOSamplingUniform(false); _TestCOOSamplingUniform(false); _TestCOOSamplingUniform(false); _TestCOOSamplingUniform(false); } // COOPerEtypeSampling with rowwise_etype_sorted == true is not meaningful as // it's never used in practice. template void _TestCOOPerEtypeSampling(bool has_data) { auto pair = COOEtypes(has_data); auto mat = pair.first; auto eid2etype_offset = pair.second; std::vector prob = { NDArray::FromVector(std::vector({.5, .5, .5, .5})), NDArray::FromVector(std::vector({.5})), NDArray::FromVector(std::vector({.5})), NDArray::FromVector(std::vector({.5})) }; IdArray rows = NDArray::FromVector(std::vector({0, 3})); for (int k = 0; k < 10; ++k) { auto rst = COORowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true); CheckSampledPerEtypeResult(rst, rows, has_data); } for (int k = 0; k < 10; ++k) { auto rst = COORowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false); CheckSampledPerEtypeResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); if (has_data) { int counts = 0; counts += eset.count(std::make_tuple(0, 0, 0)); counts += eset.count(std::make_tuple(0, 1, 1)); ASSERT_EQ(counts, 2); counts = 0; counts += eset.count(std::make_tuple(0, 2, 4)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(0, 3, 6)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(1, 1, 2)); ASSERT_EQ(counts, 0); counts = 0; counts += eset.count(std::make_tuple(3, 2, 5)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(3, 3, 3)); ASSERT_EQ(counts, 1); } else { int counts = 0; counts += eset.count(std::make_tuple(0, 0, 0)); counts += eset.count(std::make_tuple(0, 1, 1)); counts += eset.count(std::make_tuple(0, 2, 2)); counts += eset.count(std::make_tuple(0, 3, 3)); ASSERT_EQ(counts, 2); counts = 0; counts += eset.count(std::make_tuple(1, 1, 4)); ASSERT_EQ(counts, 0); counts = 0; counts += eset.count(std::make_tuple(3, 3, 5)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(3, 2, 6)); ASSERT_EQ(counts, 1); } } prob = { NDArray::FromVector(std::vector({.0, .5, .0, .0})), NDArray::FromVector(std::vector({.5})), NDArray::FromVector(std::vector({.5})), NDArray::FromVector(std::vector({.5})) }; for (int k = 0; k < 10; ++k) { auto rst = COORowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true); CheckSampledPerEtypeResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); if (has_data) { ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 0))); } else { ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 0))); ASSERT_FALSE(eset.count(std::make_tuple(0, 2, 2))); ASSERT_FALSE(eset.count(std::make_tuple(0, 3, 3))); } } } TEST(RowwiseTest, TestCOOPerEtypeSampling) { _TestCOOPerEtypeSampling(true); _TestCOOPerEtypeSampling(true); _TestCOOPerEtypeSampling(true); _TestCOOPerEtypeSampling(true); _TestCOOPerEtypeSampling(false); _TestCOOPerEtypeSampling(false); _TestCOOPerEtypeSampling(false); _TestCOOPerEtypeSampling(false); } template void _TestCOOPerEtypeSamplingUniform(bool has_data) { auto pair = COOEtypes(has_data); auto mat = pair.first; auto eid2etype_offset = pair.second; std::vector prob = { aten::NullArray(), aten::NullArray(), aten::NullArray(), aten::NullArray() }; IdArray rows = NDArray::FromVector(std::vector({0, 3})); for (int k = 0; k < 10; ++k) { auto rst = COORowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true); CheckSampledPerEtypeResult(rst, rows, has_data); } for (int k = 0; k < 10; ++k) { auto rst = COORowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false); CheckSampledPerEtypeResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); if (has_data) { int counts = 0; counts += eset.count(std::make_tuple(0, 0, 0)); counts += eset.count(std::make_tuple(0, 1, 1)); ASSERT_EQ(counts, 2); counts = 0; counts += eset.count(std::make_tuple(0, 2, 4)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(0, 3, 6)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(1, 1, 2)); ASSERT_EQ(counts, 0); counts = 0; counts += eset.count(std::make_tuple(3, 2, 5)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(3, 3, 3)); ASSERT_EQ(counts, 1); } else { int counts = 0; counts += eset.count(std::make_tuple(0, 0, 0)); counts += eset.count(std::make_tuple(0, 1, 1)); counts += eset.count(std::make_tuple(0, 2, 2)); counts += eset.count(std::make_tuple(0, 3, 3)); ASSERT_EQ(counts, 2); counts = 0; counts += eset.count(std::make_tuple(1, 1, 4)); ASSERT_EQ(counts, 0); counts = 0; counts += eset.count(std::make_tuple(3, 3, 5)); ASSERT_EQ(counts, 1); counts = 0; counts += eset.count(std::make_tuple(3, 2, 6)); ASSERT_EQ(counts, 1); } } } TEST(RowwiseTest, TestCOOPerEtypeSamplingUniform) { _TestCOOPerEtypeSamplingUniform(true); _TestCOOPerEtypeSamplingUniform(true); _TestCOOPerEtypeSamplingUniform(true); _TestCOOPerEtypeSamplingUniform(true); _TestCOOPerEtypeSamplingUniform(false); _TestCOOPerEtypeSamplingUniform(false); _TestCOOPerEtypeSamplingUniform(false); _TestCOOPerEtypeSamplingUniform(false); } template void _TestCSRTopk(bool has_data) { auto mat = CSR(has_data); FloatArray weight = NDArray::FromVector( std::vector({.1f, .0f, -.1f, .2f, .5f})); // -.1, .2, .1, .0, .5 IdArray rows = NDArray::FromVector(std::vector({0, 3})); { auto rst = CSRRowWiseTopk(mat, rows, 1, weight, true); auto eset = ToEdgeSet(rst); ASSERT_EQ(eset.size(), 2); if (has_data) { ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 2))); ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 1))); } else { ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1))); ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 3))); } } { auto rst = CSRRowWiseTopk(mat, rows, 1, weight, false); auto eset = ToEdgeSet(rst); ASSERT_EQ(eset.size(), 2); if (has_data) { ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); } else { ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 0))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); } } } TEST(RowwiseTest, TestCSRTopk) { _TestCSRTopk(true); _TestCSRTopk(true); _TestCSRTopk(true); _TestCSRTopk(true); _TestCSRTopk(false); _TestCSRTopk(false); _TestCSRTopk(false); _TestCSRTopk(false); } template void _TestCOOTopk(bool has_data) { auto mat = COO(has_data); FloatArray weight = NDArray::FromVector( std::vector({.1f, .0f, -.1f, .2f, .5f})); // -.1, .2, .1, .0, .5 IdArray rows = NDArray::FromVector(std::vector({0, 3})); { auto rst = COORowWiseTopk(mat, rows, 1, weight, true); auto eset = ToEdgeSet(rst); ASSERT_EQ(eset.size(), 2); if (has_data) { ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 2))); ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 1))); } else { ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1))); ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 3))); } } { auto rst = COORowWiseTopk(mat, rows, 1, weight, false); auto eset = ToEdgeSet(rst); ASSERT_EQ(eset.size(), 2); if (has_data) { ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); } else { ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 0))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); } } } TEST(RowwiseTest, TestCOOTopk) { _TestCOOTopk(true); _TestCOOTopk(true); _TestCOOTopk(true); _TestCOOTopk(true); _TestCOOTopk(false); _TestCOOTopk(false); _TestCOOTopk(false); _TestCOOTopk(false); } template void _TestCSRSamplingBiased(bool has_data) { auto mat = CSR(has_data); // 0 - 0,1 // 1 - 1 // 3 - 2,3 NDArray tag_offset = NDArray::FromVector( std::vector({0, 1, 2, 0, 0, 1, 0, 0, 0, 0, 1, 2})); tag_offset = tag_offset.CreateView({4, 3}, tag_offset->dtype); IdArray rows = NDArray::FromVector(std::vector({0, 1, 3})); FloatArray bias = NDArray::FromVector( std::vector({0, 0.5}) ); for (int k = 0 ; k < 10 ; ++k) { auto rst = CSRRowWiseSamplingBiased(mat, rows, 1, tag_offset, bias, false); CheckSampledResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); if (has_data) { ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3))); ASSERT_TRUE(eset.count(std::make_tuple(1, 1, 0))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); } else { ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1))); ASSERT_TRUE(eset.count(std::make_tuple(1, 1, 2))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); } } for (int k = 0 ; k < 10 ; ++k) { auto rst = CSRRowWiseSamplingBiased(mat, rows, 3, tag_offset, bias, true); CheckSampledResult(rst, rows, has_data); auto eset = ToEdgeSet(rst); if (has_data) { ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3))); ASSERT_TRUE(eset.count(std::make_tuple(1, 1, 0))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 2))); ASSERT_FALSE(eset.count(std::make_tuple(3, 2, 1))); } else { ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1))); ASSERT_TRUE(eset.count(std::make_tuple(1, 1, 2))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 0))); ASSERT_FALSE(eset.count(std::make_tuple(3, 2, 3))); } } } TEST(RowwiseTest, TestCSRSamplingBiased) { _TestCSRSamplingBiased(true); _TestCSRSamplingBiased(false); _TestCSRSamplingBiased(true); _TestCSRSamplingBiased(false); _TestCSRSamplingBiased(true); _TestCSRSamplingBiased(false); _TestCSRSamplingBiased(true); _TestCSRSamplingBiased(false); }