#include #include #include #include #include "./common.h" #include "../../src/random/cpu/sample_utils.h" using namespace dgl; using namespace dgl::aten; // TODO: adapt this to Random::Choice template void _TestWithReplacement(RandomEngine *re) { Idx n_categories = 100; Idx n_rolls = 1000000; std::vector _prob; DType accum = 0.; for (Idx i = 0; i < n_categories; ++i) { _prob.push_back(re->Uniform()); accum += _prob.back(); } for (Idx i = 0; i < n_categories; ++i) _prob[i] /= accum; FloatArray prob = NDArray::FromVector(_prob); auto _check_given_sampler = [n_categories, n_rolls, &_prob]( utils::BaseSampler *s) { std::vector counter(n_categories, 0); for (Idx i = 0; i < n_rolls; ++i) { Idx dice = s->Draw(); counter[dice]++; } for (Idx i = 0; i < n_categories; ++i) ASSERT_NEAR(static_cast(counter[i]) / n_rolls, _prob[i], 1e-2); }; auto _check_random_choice = [n_categories, n_rolls, &_prob, prob]() { std::vector counter(n_categories, 0); for (Idx i = 0; i < n_rolls; ++i) { Idx dice = RandomEngine::ThreadLocal()->Choice(prob); counter[dice]++; } for (Idx i = 0; i < n_categories; ++i) ASSERT_NEAR(static_cast(counter[i]) / n_rolls, _prob[i], 1e-2); }; utils::AliasSampler as(re, prob); utils::CDFSampler cs(re, prob); utils::TreeSampler ts(re, prob); _check_given_sampler(&as); _check_given_sampler(&cs); _check_given_sampler(&ts); _check_random_choice(); } TEST(SampleUtilsTest, TestWithReplacement) { RandomEngine* re = RandomEngine::ThreadLocal(); re->SetSeed(42); _TestWithReplacement(re); re->SetSeed(42); _TestWithReplacement(re); re->SetSeed(42); _TestWithReplacement(re); re->SetSeed(42); _TestWithReplacement(re); }; template void _TestWithoutReplacementOrder(RandomEngine *re) { // TODO(BarclayII): is there a reliable way to do this test? std::vector _prob = {1e6, 1e-6, 1e-2, 1e2}; FloatArray prob = NDArray::FromVector(_prob); std::vector ground_truth = {0, 3, 2, 1}; auto _check_given_sampler = [&ground_truth]( utils::BaseSampler *s) { for (size_t i = 0; i < ground_truth.size(); ++i) { Idx dice = s->Draw(); ASSERT_EQ(dice, ground_truth[i]); } }; utils::AliasSampler as(re, prob); utils::CDFSampler cs(re, prob); utils::TreeSampler ts(re, prob); _check_given_sampler(&as); _check_given_sampler(&cs); _check_given_sampler(&ts); } TEST(SampleUtilsTest, TestWithoutReplacementOrder) { RandomEngine* re = RandomEngine::ThreadLocal(); re->SetSeed(42); _TestWithoutReplacementOrder(re); re->SetSeed(42); _TestWithoutReplacementOrder(re); re->SetSeed(42); _TestWithoutReplacementOrder(re); re->SetSeed(42); _TestWithoutReplacementOrder(re); }; template void _TestWithoutReplacementUnique(RandomEngine *re) { Idx N = 1000000; std::vector _likelihood; for (Idx i = 0; i < N; ++i) _likelihood.push_back(re->Uniform()); FloatArray likelihood = NDArray::FromVector(_likelihood); auto _check_given_sampler = [N]( utils::BaseSampler *s) { std::vector cnt(N, 0); for (Idx i = 0; i < N; ++i) { Idx dice = s->Draw(); cnt[dice]++; } for (Idx i = 0; i < N; ++i) ASSERT_EQ(cnt[i], 1); }; utils::AliasSampler as(re, likelihood); utils::CDFSampler cs(re, likelihood); utils::TreeSampler ts(re, likelihood); _check_given_sampler(&as); _check_given_sampler(&cs); _check_given_sampler(&ts); } TEST(SampleUtilsTest, TestWithoutReplacementUnique) { RandomEngine* re = RandomEngine::ThreadLocal(); re->SetSeed(42); _TestWithoutReplacementUnique(re); re->SetSeed(42); _TestWithoutReplacementUnique(re); re->SetSeed(42); _TestWithoutReplacementUnique(re); re->SetSeed(42); _TestWithoutReplacementUnique(re); };