Unverified Commit dfb10db8 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Feature] Bag of samplers for efficient categorical sampling w/ and w/o replacement (#1142)

* upd

* upd

* lint

* upd

* upd

* lint

* upd

* refactor

* lint

* upd

* upd

* upd

* lint

* fix

* fix

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* complete test

* upd

* upd

* fix

* vir

* upd

* fix

* fuck

* fix numerical
parent 162dc19a
......@@ -14,8 +14,6 @@
namespace dgl {
using namespace dgl::runtime;
namespace {
inline uint32_t GetThreadId() {
......
/*!
* Copyright (c) 2019 by Contributors
* \file dgl/sample_utils.h
* \brief Sampling utilities
*/
#ifndef DGL_SAMPLE_UTILS_H_
#define DGL_SAMPLE_UTILS_H_
#include <algorithm>
#include <utility>
#include <queue>
#include <cstdlib>
#include <cmath>
#include <numeric>
#include <limits>
#include <vector>
#include "random.h"
namespace dgl {
template <
typename Idx,
typename DType,
bool replace>
class BaseSampler {
public:
virtual Idx draw() {
LOG(INFO) << "Not implemented yet.";
return 0;
}
};
/*
* AliasSampler is used to sample elements from a given discrete categorical distribution.
* Algorithm: Alias Method(https://en.wikipedia.org/wiki/Alias_method)
* Sampler building complexity: O(n)
* Sample w/ replacement complexity: O(1)
* Sample w/o replacement complexity: O(log n)
*/
template <
typename Idx,
typename DType,
bool replace>
class AliasSampler: public BaseSampler<Idx, DType, replace> {
private:
RandomEngine *re;
Idx N;
DType accum, taken; // accumulated likelihood
std::vector<Idx> K; // alias table
std::vector<DType> U; // probability table
std::vector<DType> _prob; // category distribution
std::vector<bool> used; // indicate availability, activated when replace=false;
std::vector<Idx> id_mapping; // index mapping, activated when replace=false;
inline Idx map(Idx x) const {
if (replace)
return x;
else
return id_mapping[x];
}
void rebuild(const std::vector<DType>& prob) {
N = 0;
accum = 0.;
taken = 0.;
if (!replace)
id_mapping.clear();
for (Idx i = 0; i < prob.size(); ++i)
if (!used[i]) {
N++;
accum += prob[i];
if (!replace)
id_mapping.push_back(i);
}
if (N == 0) LOG(FATAL) << "Cannot take more sample than population when 'replace=false'";
K.resize(N);
U.resize(N);
DType avg = accum / static_cast<DType>(N);
std::fill(U.begin(), U.end(), avg); // initialize U
std::queue<std::pair<Idx, DType> > under, over;
for (Idx i = 0; i < N; ++i) {
DType p = prob[map(i)];
if (p > avg)
over.push(std::make_pair(i, p));
else
under.push(std::make_pair(i, p));
K[i] = i; // initialize K
}
while (!under.empty() && !over.empty()) {
auto u_pair = under.front(), o_pair = over.front();
Idx i_u = u_pair.first, i_o = o_pair.first;
DType p_u = u_pair.second, p_o = o_pair.second;
K[i_u] = i_o;
U[i_u] = p_u;
if (p_o + p_u > 2 * avg)
over.push(std::make_pair(i_o, p_o + p_u - avg));
else if (p_o + p_u < 2 * avg)
under.push(std::make_pair(i_o, p_o + p_u - avg));
under.pop();
over.pop();
}
}
public:
void reinit_state(const std::vector<DType>& prob) {
used.resize(prob.size());
if (!replace)
_prob = prob;
std::fill(used.begin(), used.end(), false);
rebuild(prob);
}
explicit AliasSampler(RandomEngine* re, const std::vector<DType>& prob): re(re) {
reinit_state(prob);
}
~AliasSampler() {}
Idx draw() {
DType avg = accum / N;
if (!replace) {
if (2 * taken >= accum)
rebuild(_prob);
while (true) {
DType dice = re->Uniform<DType>(0, N);
Idx i = static_cast<Idx>(dice), rst;
DType p = (dice - i) * avg;
if (p <= U[map(i)]) {
rst = map(i);
} else {
rst = map(K[i]);
}
DType cap = _prob[rst];
if (!used[rst]) {
used[rst] = true;
taken += cap;
return rst;
}
}
}
DType dice = re->Uniform<DType>(0, N);
Idx i = static_cast<Idx>(dice);
DType p = (dice - i) * avg;
if (p <= U[map(i)])
return map(i);
else
return map(K[i]);
}
};
/*
* CDFSampler is used to sample elements from a given discrete categorical distribution.
* Algorithm: create a cumulative distribution function and conduct binary search for sampling.
* Reference: https://github.com/numpy/numpy/blob/d37908/numpy/random/mtrand.pyx#L804
* Sampler building complexity: O(n)
* Sample w/ and w/o replacement complexity: O(log n)
*/
template <
typename Idx,
typename DType,
bool replace>
class CDFSampler: public BaseSampler<Idx, DType, replace> {
private:
RandomEngine *re;
Idx N;
DType accum, taken;
std::vector<DType> _prob; // categorical distribution
std::vector<DType> cdf; // cumulative distribution function
std::vector<bool> used; // indicate availability, activated when replace=false;
std::vector<Idx> id_mapping; // indicate index mapping, activated when replace=false;
inline Idx map(Idx x) const {
if (replace)
return x;
else
return id_mapping[x];
}
void rebuild(const std::vector<DType>& prob) {
N = 0;
accum = 0.;
taken = 0.;
if (!replace)
id_mapping.clear();
cdf.clear();
cdf.push_back(0);
for (Idx i = 0; i < prob.size(); ++i)
if (!used[i]) {
N++;
accum += prob[i];
if (!replace)
id_mapping.push_back(i);
cdf.push_back(accum);
}
if (N == 0) LOG(FATAL) << "Cannot take more sample than population when 'replace=false'";
}
public:
void reinit_state(const std::vector<DType>& prob) {
used.resize(prob.size());
if (!replace)
_prob = prob;
std::fill(used.begin(), used.end(), false);
rebuild(prob);
}
explicit CDFSampler(RandomEngine *re, const std::vector<DType>& prob): re(re) {
reinit_state(prob);
}
~CDFSampler() {}
Idx draw() {
DType eps = std::numeric_limits<DType>::min();
if (!replace) {
if (2 * taken >= accum)
rebuild(_prob);
while (true) {
DType p = std::max(re->Uniform<DType>(0., accum), eps);
Idx rst = map(std::lower_bound(cdf.begin(), cdf.end(), p) - cdf.begin() - 1);
DType cap = _prob[rst];
if (!used[rst]) {
used[rst] = true;
taken += cap;
return rst;
}
}
}
DType p = std::max(re->Uniform<DType>(0., accum), eps);
return map(std::lower_bound(cdf.begin(), cdf.end(), p) - cdf.begin() - 1);
}
};
/*
* TreeSampler is used to sample elements from a given discrete categorical distribution.
* Algorithm: create a heap that stores accumulated likelihood of its leaf descendents.
* Reference: https://blog.smola.org/post/1016514759
* Sampler building complexity: O(n)
* Sample w/ and w/o replacement complexity: O(log n)
*/
template <
typename Idx,
typename DType,
bool replace>
class TreeSampler: public BaseSampler<Idx, DType, replace> {
private:
RandomEngine *re;
std::vector<DType> weight; // accumulated likelihood of subtrees.
int64_t N, num_leafs;
public:
void reinit_state(const std::vector<DType>& prob) {
std::fill(weight.begin(), weight.end(), 0);
for (int i = 0; i < prob.size(); ++i)
weight[num_leafs + i] = prob[i];
for (int i = num_leafs - 1; i >= 1; --i)
weight[i] = weight[i * 2] + weight[i * 2 + 1];
}
explicit TreeSampler(RandomEngine *re, const std::vector<DType>& prob): re(re) {
num_leafs = 1;
while (num_leafs < prob.size())
num_leafs *= 2;
N = num_leafs * 2;
weight.resize(N);
reinit_state(prob);
}
Idx draw() {
int64_t cur = 1;
DType p = re->Uniform<DType>(0, weight[cur]);
DType accum = 0.;
while (cur < num_leafs) {
DType w_l = weight[cur * 2], w_r = weight[cur * 2 + 1];
DType pivot = accum + w_l;
// w_r > 0 can depress some numerical problems.
Idx shift = static_cast<Idx>(p > pivot && w_r > 0);
cur = cur * 2 + shift;
if (shift == 1)
accum = pivot;
}
Idx rst = cur - num_leafs;
if (!replace) {
while (cur >= 1) {
if (cur >= num_leafs)
weight[cur] = 0.;
else
weight[cur] = weight[cur * 2] + weight[cur * 2 + 1];
cur /= 2;
}
}
return rst;
}
};
}; // namespace dgl
#endif // DGL_SAMPLE_UTILS_H_
#include <gtest/gtest.h>
#include <dgl/sample_utils.h>
#include <vector>
#include <algorithm>
#include <iostream>
#include "./common.h"
using namespace dgl;
template <typename Idx, typename DType>
void _TestWithReplacement(RandomEngine *re) {
Idx n_categories = 100;
Idx n_rolls = 1000000;
std::vector<DType> prob;
DType accum = 0.;
for (Idx i = 0; i < n_categories; ++i) {
prob.push_back(re->Uniform<DType>());
accum += prob.back();
}
for (Idx i = 0; i < n_categories; ++i)
prob[i] /= accum;
auto _check_given_sampler = [n_categories, n_rolls, &prob](
BaseSampler<Idx, DType, true> *s) {
std::vector<Idx> counter(n_categories, 0);
for (Idx i = 0; i < n_rolls; ++i) {
Idx dice = s->draw();
counter[dice]++;
}
for (Idx i = 0; i < n_categories; ++i)
ASSERT_NEAR(static_cast<DType>(counter[i]) / n_rolls, prob[i], 1e-2);
};
AliasSampler<Idx, DType, true> as(re, prob);
CDFSampler<Idx, DType, true> cs(re, prob);
TreeSampler<Idx, DType, true> ts(re, prob);
_check_given_sampler(&as);
_check_given_sampler(&cs);
_check_given_sampler(&ts);
}
TEST(SampleUtilsTest, TestWithReplacement) {
RandomEngine* re = RandomEngine::ThreadLocal();
re->SetSeed(42);
_TestWithReplacement<int32_t, float>(re);
re->SetSeed(42);
_TestWithReplacement<int32_t, double>(re);
re->SetSeed(42);
_TestWithReplacement<int64_t, float>(re);
re->SetSeed(42);
_TestWithReplacement<int64_t, double>(re);
};
template <typename Idx, typename DType>
void _TestWithoutReplacementOrder(RandomEngine *re) {
std::vector<DType> prob = {1e6, 1e-6, 1e-2, 1e2};
std::vector<Idx> ground_truth = {0, 3, 2, 1};
auto _check_given_sampler = [&ground_truth](
BaseSampler<Idx, DType, false> *s) {
for (size_t i = 0; i < ground_truth.size(); ++i) {
Idx dice = s->draw();
ASSERT_EQ(dice, ground_truth[i]);
}
};
AliasSampler<Idx, DType, false> as(re, prob);
CDFSampler<Idx, DType, false> cs(re, prob);
TreeSampler<Idx, DType, false> ts(re, prob);
_check_given_sampler(&as);
_check_given_sampler(&cs);
_check_given_sampler(&ts);
}
TEST(SampleUtilsTest, TestWithoutReplacementOrder) {
RandomEngine* re = RandomEngine::ThreadLocal();
re->SetSeed(42);
_TestWithoutReplacementOrder<int32_t, float>(re);
re->SetSeed(42);
_TestWithoutReplacementOrder<int32_t, double>(re);
re->SetSeed(42);
_TestWithoutReplacementOrder<int64_t, float>(re);
re->SetSeed(42);
_TestWithoutReplacementOrder<int64_t, double>(re);
};
template <typename Idx, typename DType>
void _TestWithoutReplacementUnique(RandomEngine *re) {
Idx N = 1000000;
std::vector<DType> likelihood;
for (Idx i = 0; i < N; ++i)
likelihood.push_back(re->Uniform<DType>());
auto _check_given_sampler = [N](
BaseSampler<Idx, DType, false> *s) {
std::vector<int> cnt(N, 0);
for (Idx i = 0; i < N; ++i) {
Idx dice = s->draw();
cnt[dice]++;
}
for (Idx i = 0; i < N; ++i)
ASSERT_EQ(cnt[i], 1);
};
AliasSampler<Idx, DType, false> as(re, likelihood);
CDFSampler<Idx, DType, false> cs(re, likelihood);
TreeSampler<Idx, DType, false> ts(re, likelihood);
_check_given_sampler(&as);
_check_given_sampler(&cs);
_check_given_sampler(&ts);
}
TEST(SampleUtilsTest, TestWithoutReplacementUnique) {
RandomEngine* re = RandomEngine::ThreadLocal();
re->SetSeed(42);
_TestWithoutReplacementUnique<int32_t, float>(re);
re->SetSeed(42);
_TestWithoutReplacementUnique<int32_t, double>(re);
re->SetSeed(42);
_TestWithoutReplacementUnique<int64_t, float>(re);
re->SetSeed(42);
_TestWithoutReplacementUnique<int64_t, double>(re);
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment