Unverified Commit f1420d19 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Refactor] Renaming class methods of sampler utilities to improve readability (#1180)

* upd

* upd
parent 31911834
......@@ -17,6 +17,7 @@
#include "random.h"
namespace dgl {
namespace utils {
template <
typename Idx,
......@@ -24,7 +25,7 @@ template <
bool replace>
class BaseSampler {
public:
virtual Idx draw() {
virtual Idx Draw() {
LOG(INFO) << "Not implemented yet.";
return 0;
}
......@@ -52,14 +53,14 @@ class AliasSampler: public BaseSampler<Idx, DType, replace> {
std::vector<bool> used; // indicate availability, activated when replace=false;
std::vector<Idx> id_mapping; // index mapping, activated when replace=false;
inline Idx map(Idx x) const {
inline Idx Map(Idx x) const { // Map consecutive indices to unused elements
if (replace)
return x;
else
return id_mapping[x];
}
void rebuild(const std::vector<DType>& prob) {
void Reconstruct(const std::vector<DType>& prob) { // Reconstruct alias table
N = 0;
accum = 0.;
taken = 0.;
......@@ -79,7 +80,7 @@ class AliasSampler: public BaseSampler<Idx, DType, replace> {
std::fill(U.begin(), U.end(), avg); // initialize U
std::queue<std::pair<Idx, DType> > under, over;
for (Idx i = 0; i < N; ++i) {
DType p = prob[map(i)];
DType p = prob[Map(i)];
if (p > avg)
over.push(std::make_pair(i, p));
else
......@@ -102,33 +103,33 @@ class AliasSampler: public BaseSampler<Idx, DType, replace> {
}
public:
void reinit_state(const std::vector<DType>& prob) {
void ResetState(const std::vector<DType>& prob) {
used.resize(prob.size());
if (!replace)
_prob = prob;
std::fill(used.begin(), used.end(), false);
rebuild(prob);
Reconstruct(prob);
}
explicit AliasSampler(RandomEngine* re, const std::vector<DType>& prob): re(re) {
reinit_state(prob);
ResetState(prob);
}
~AliasSampler() {}
Idx draw() {
Idx Draw() {
DType avg = accum / N;
if (!replace) {
if (2 * taken >= accum)
rebuild(_prob);
Reconstruct(_prob);
while (true) {
DType dice = re->Uniform<DType>(0, N);
Idx i = static_cast<Idx>(dice), rst;
DType p = (dice - i) * avg;
if (p <= U[map(i)]) {
rst = map(i);
if (p <= U[Map(i)]) {
rst = Map(i);
} else {
rst = map(K[i]);
rst = Map(K[i]);
}
DType cap = _prob[rst];
if (!used[rst]) {
......@@ -141,10 +142,10 @@ class AliasSampler: public BaseSampler<Idx, DType, replace> {
DType dice = re->Uniform<DType>(0, N);
Idx i = static_cast<Idx>(dice);
DType p = (dice - i) * avg;
if (p <= U[map(i)])
return map(i);
if (p <= U[Map(i)])
return Map(i);
else
return map(K[i]);
return Map(K[i]);
}
};
......@@ -170,14 +171,14 @@ class CDFSampler: public BaseSampler<Idx, DType, replace> {
std::vector<bool> used; // indicate availability, activated when replace=false;
std::vector<Idx> id_mapping; // indicate index mapping, activated when replace=false;
inline Idx map(Idx x) const {
inline Idx Map(Idx x) const { // Map consecutive indices to unused elements
if (replace)
return x;
else
return id_mapping[x];
}
void rebuild(const std::vector<DType>& prob) {
void Reconstruct(const std::vector<DType>& prob) { // Reconstruct CDF
N = 0;
accum = 0.;
taken = 0.;
......@@ -197,28 +198,28 @@ class CDFSampler: public BaseSampler<Idx, DType, replace> {
}
public:
void reinit_state(const std::vector<DType>& prob) {
void ResetState(const std::vector<DType>& prob) {
used.resize(prob.size());
if (!replace)
_prob = prob;
std::fill(used.begin(), used.end(), false);
rebuild(prob);
Reconstruct(prob);
}
explicit CDFSampler(RandomEngine *re, const std::vector<DType>& prob): re(re) {
reinit_state(prob);
ResetState(prob);
}
~CDFSampler() {}
Idx draw() {
Idx Draw() {
DType eps = std::numeric_limits<DType>::min();
if (!replace) {
if (2 * taken >= accum)
rebuild(_prob);
Reconstruct(_prob);
while (true) {
DType p = std::max(re->Uniform<DType>(0., accum), eps);
Idx rst = map(std::lower_bound(cdf.begin(), cdf.end(), p) - cdf.begin() - 1);
Idx rst = Map(std::lower_bound(cdf.begin(), cdf.end(), p) - cdf.begin() - 1);
DType cap = _prob[rst];
if (!used[rst]) {
used[rst] = true;
......@@ -228,7 +229,7 @@ class CDFSampler: public BaseSampler<Idx, DType, replace> {
}
}
DType p = std::max(re->Uniform<DType>(0., accum), eps);
return map(std::lower_bound(cdf.begin(), cdf.end(), p) - cdf.begin() - 1);
return Map(std::lower_bound(cdf.begin(), cdf.end(), p) - cdf.begin() - 1);
}
};
......@@ -251,7 +252,7 @@ class TreeSampler: public BaseSampler<Idx, DType, replace> {
int64_t N, num_leafs;
public:
void reinit_state(const std::vector<DType>& prob) {
void ResetState(const std::vector<DType>& prob) {
std::fill(weight.begin(), weight.end(), 0);
for (int i = 0; i < prob.size(); ++i)
weight[num_leafs + i] = prob[i];
......@@ -265,10 +266,10 @@ class TreeSampler: public BaseSampler<Idx, DType, replace> {
num_leafs *= 2;
N = num_leafs * 2;
weight.resize(N);
reinit_state(prob);
ResetState(prob);
}
Idx draw() {
Idx Draw() {
int64_t cur = 1;
DType p = re->Uniform<DType>(0, weight[cur]);
DType accum = 0.;
......@@ -295,6 +296,7 @@ class TreeSampler: public BaseSampler<Idx, DType, replace> {
}
};
}; // namespace utils
}; // namespace dgl
#endif // DGL_SAMPLE_UTILS_H_
......@@ -21,19 +21,19 @@ void _TestWithReplacement(RandomEngine *re) {
prob[i] /= accum;
auto _check_given_sampler = [n_categories, n_rolls, &prob](
BaseSampler<Idx, DType, true> *s) {
utils::BaseSampler<Idx, DType, true> *s) {
std::vector<Idx> counter(n_categories, 0);
for (Idx i = 0; i < n_rolls; ++i) {
Idx dice = s->draw();
Idx dice = s->Draw();
counter[dice]++;
}
for (Idx i = 0; i < n_categories; ++i)
ASSERT_NEAR(static_cast<DType>(counter[i]) / n_rolls, prob[i], 1e-2);
};
AliasSampler<Idx, DType, true> as(re, prob);
CDFSampler<Idx, DType, true> cs(re, prob);
TreeSampler<Idx, DType, true> ts(re, prob);
utils::AliasSampler<Idx, DType, true> as(re, prob);
utils::CDFSampler<Idx, DType, true> cs(re, prob);
utils::TreeSampler<Idx, DType, true> ts(re, prob);
_check_given_sampler(&as);
_check_given_sampler(&cs);
_check_given_sampler(&ts);
......@@ -57,16 +57,16 @@ void _TestWithoutReplacementOrder(RandomEngine *re) {
std::vector<Idx> ground_truth = {0, 3, 2, 1};
auto _check_given_sampler = [&ground_truth](
BaseSampler<Idx, DType, false> *s) {
utils::BaseSampler<Idx, DType, false> *s) {
for (size_t i = 0; i < ground_truth.size(); ++i) {
Idx dice = s->draw();
Idx dice = s->Draw();
ASSERT_EQ(dice, ground_truth[i]);
}
};
AliasSampler<Idx, DType, false> as(re, prob);
CDFSampler<Idx, DType, false> cs(re, prob);
TreeSampler<Idx, DType, false> ts(re, prob);
utils::AliasSampler<Idx, DType, false> as(re, prob);
utils::CDFSampler<Idx, DType, false> cs(re, prob);
utils::TreeSampler<Idx, DType, false> ts(re, prob);
_check_given_sampler(&as);
_check_given_sampler(&cs);
_check_given_sampler(&ts);
......@@ -92,19 +92,19 @@ void _TestWithoutReplacementUnique(RandomEngine *re) {
likelihood.push_back(re->Uniform<DType>());
auto _check_given_sampler = [N](
BaseSampler<Idx, DType, false> *s) {
utils::BaseSampler<Idx, DType, false> *s) {
std::vector<int> cnt(N, 0);
for (Idx i = 0; i < N; ++i) {
Idx dice = s->draw();
Idx dice = s->Draw();
cnt[dice]++;
}
for (Idx i = 0; i < N; ++i)
ASSERT_EQ(cnt[i], 1);
};
AliasSampler<Idx, DType, false> as(re, likelihood);
CDFSampler<Idx, DType, false> cs(re, likelihood);
TreeSampler<Idx, DType, false> ts(re, likelihood);
utils::AliasSampler<Idx, DType, false> as(re, likelihood);
utils::CDFSampler<Idx, DType, false> cs(re, likelihood);
utils::TreeSampler<Idx, DType, false> ts(re, likelihood);
_check_given_sampler(&as);
_check_given_sampler(&cs);
_check_given_sampler(&ts);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment