#include #include #include #include #include #include "./common.h" using namespace dgl; template void _TestWithReplacement(RandomEngine *re) { Idx n_categories = 100; Idx n_rolls = 1000000; std::vector prob; DType accum = 0.; for (Idx i = 0; i < n_categories; ++i) { prob.push_back(re->Uniform()); accum += prob.back(); } for (Idx i = 0; i < n_categories; ++i) prob[i] /= accum; auto _check_given_sampler = [n_categories, n_rolls, &prob]( BaseSampler *s) { std::vector counter(n_categories, 0); for (Idx i = 0; i < n_rolls; ++i) { Idx dice = s->draw(); counter[dice]++; } for (Idx i = 0; i < n_categories; ++i) ASSERT_NEAR(static_cast(counter[i]) / n_rolls, prob[i], 1e-2); }; AliasSampler as(re, prob); CDFSampler cs(re, prob); TreeSampler ts(re, prob); _check_given_sampler(&as); _check_given_sampler(&cs); _check_given_sampler(&ts); } TEST(SampleUtilsTest, TestWithReplacement) { RandomEngine* re = RandomEngine::ThreadLocal(); re->SetSeed(42); _TestWithReplacement(re); re->SetSeed(42); _TestWithReplacement(re); re->SetSeed(42); _TestWithReplacement(re); re->SetSeed(42); _TestWithReplacement(re); }; template void _TestWithoutReplacementOrder(RandomEngine *re) { std::vector prob = {1e6, 1e-6, 1e-2, 1e2}; std::vector ground_truth = {0, 3, 2, 1}; auto _check_given_sampler = [&ground_truth]( BaseSampler *s) { for (size_t i = 0; i < ground_truth.size(); ++i) { Idx dice = s->draw(); ASSERT_EQ(dice, ground_truth[i]); } }; AliasSampler as(re, prob); CDFSampler cs(re, prob); TreeSampler ts(re, prob); _check_given_sampler(&as); _check_given_sampler(&cs); _check_given_sampler(&ts); } TEST(SampleUtilsTest, TestWithoutReplacementOrder) { RandomEngine* re = RandomEngine::ThreadLocal(); re->SetSeed(42); _TestWithoutReplacementOrder(re); re->SetSeed(42); _TestWithoutReplacementOrder(re); re->SetSeed(42); _TestWithoutReplacementOrder(re); re->SetSeed(42); _TestWithoutReplacementOrder(re); }; template void _TestWithoutReplacementUnique(RandomEngine *re) { Idx N = 1000000; std::vector likelihood; for (Idx i = 0; i < N; ++i) likelihood.push_back(re->Uniform()); auto _check_given_sampler = [N]( BaseSampler *s) { std::vector cnt(N, 0); for (Idx i = 0; i < N; ++i) { Idx dice = s->draw(); cnt[dice]++; } for (Idx i = 0; i < N; ++i) ASSERT_EQ(cnt[i], 1); }; AliasSampler as(re, likelihood); CDFSampler cs(re, likelihood); TreeSampler ts(re, likelihood); _check_given_sampler(&as); _check_given_sampler(&cs); _check_given_sampler(&ts); } TEST(SampleUtilsTest, TestWithoutReplacementUnique) { RandomEngine* re = RandomEngine::ThreadLocal(); re->SetSeed(42); _TestWithoutReplacementUnique(re); re->SetSeed(42); _TestWithoutReplacementUnique(re); re->SetSeed(42); _TestWithoutReplacementUnique(re); re->SetSeed(42); _TestWithoutReplacementUnique(re); };