Unverified Commit f1420d19 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Refactor] Renaming class methods of sampler utilities to improve readability (#1180)

* upd

* upd
parent 31911834
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "random.h" #include "random.h"
namespace dgl { namespace dgl {
namespace utils {
template < template <
typename Idx, typename Idx,
...@@ -24,7 +25,7 @@ template < ...@@ -24,7 +25,7 @@ template <
bool replace> bool replace>
class BaseSampler { class BaseSampler {
public: public:
virtual Idx draw() { virtual Idx Draw() {
LOG(INFO) << "Not implemented yet."; LOG(INFO) << "Not implemented yet.";
return 0; return 0;
} }
...@@ -52,14 +53,14 @@ class AliasSampler: public BaseSampler<Idx, DType, replace> { ...@@ -52,14 +53,14 @@ class AliasSampler: public BaseSampler<Idx, DType, replace> {
std::vector<bool> used; // indicate availability, activated when replace=false; std::vector<bool> used; // indicate availability, activated when replace=false;
std::vector<Idx> id_mapping; // index mapping, activated when replace=false; std::vector<Idx> id_mapping; // index mapping, activated when replace=false;
inline Idx map(Idx x) const { inline Idx Map(Idx x) const { // Map consecutive indices to unused elements
if (replace) if (replace)
return x; return x;
else else
return id_mapping[x]; return id_mapping[x];
} }
void rebuild(const std::vector<DType>& prob) { void Reconstruct(const std::vector<DType>& prob) { // Reconstruct alias table
N = 0; N = 0;
accum = 0.; accum = 0.;
taken = 0.; taken = 0.;
...@@ -79,7 +80,7 @@ class AliasSampler: public BaseSampler<Idx, DType, replace> { ...@@ -79,7 +80,7 @@ class AliasSampler: public BaseSampler<Idx, DType, replace> {
std::fill(U.begin(), U.end(), avg); // initialize U std::fill(U.begin(), U.end(), avg); // initialize U
std::queue<std::pair<Idx, DType> > under, over; std::queue<std::pair<Idx, DType> > under, over;
for (Idx i = 0; i < N; ++i) { for (Idx i = 0; i < N; ++i) {
DType p = prob[map(i)]; DType p = prob[Map(i)];
if (p > avg) if (p > avg)
over.push(std::make_pair(i, p)); over.push(std::make_pair(i, p));
else else
...@@ -102,33 +103,33 @@ class AliasSampler: public BaseSampler<Idx, DType, replace> { ...@@ -102,33 +103,33 @@ class AliasSampler: public BaseSampler<Idx, DType, replace> {
} }
public: public:
void reinit_state(const std::vector<DType>& prob) { void ResetState(const std::vector<DType>& prob) {
used.resize(prob.size()); used.resize(prob.size());
if (!replace) if (!replace)
_prob = prob; _prob = prob;
std::fill(used.begin(), used.end(), false); std::fill(used.begin(), used.end(), false);
rebuild(prob); Reconstruct(prob);
} }
explicit AliasSampler(RandomEngine* re, const std::vector<DType>& prob): re(re) { explicit AliasSampler(RandomEngine* re, const std::vector<DType>& prob): re(re) {
reinit_state(prob); ResetState(prob);
} }
~AliasSampler() {} ~AliasSampler() {}
Idx draw() { Idx Draw() {
DType avg = accum / N; DType avg = accum / N;
if (!replace) { if (!replace) {
if (2 * taken >= accum) if (2 * taken >= accum)
rebuild(_prob); Reconstruct(_prob);
while (true) { while (true) {
DType dice = re->Uniform<DType>(0, N); DType dice = re->Uniform<DType>(0, N);
Idx i = static_cast<Idx>(dice), rst; Idx i = static_cast<Idx>(dice), rst;
DType p = (dice - i) * avg; DType p = (dice - i) * avg;
if (p <= U[map(i)]) { if (p <= U[Map(i)]) {
rst = map(i); rst = Map(i);
} else { } else {
rst = map(K[i]); rst = Map(K[i]);
} }
DType cap = _prob[rst]; DType cap = _prob[rst];
if (!used[rst]) { if (!used[rst]) {
...@@ -141,10 +142,10 @@ class AliasSampler: public BaseSampler<Idx, DType, replace> { ...@@ -141,10 +142,10 @@ class AliasSampler: public BaseSampler<Idx, DType, replace> {
DType dice = re->Uniform<DType>(0, N); DType dice = re->Uniform<DType>(0, N);
Idx i = static_cast<Idx>(dice); Idx i = static_cast<Idx>(dice);
DType p = (dice - i) * avg; DType p = (dice - i) * avg;
if (p <= U[map(i)]) if (p <= U[Map(i)])
return map(i); return Map(i);
else else
return map(K[i]); return Map(K[i]);
} }
}; };
...@@ -170,14 +171,14 @@ class CDFSampler: public BaseSampler<Idx, DType, replace> { ...@@ -170,14 +171,14 @@ class CDFSampler: public BaseSampler<Idx, DType, replace> {
std::vector<bool> used; // indicate availability, activated when replace=false; std::vector<bool> used; // indicate availability, activated when replace=false;
std::vector<Idx> id_mapping; // indicate index mapping, activated when replace=false; std::vector<Idx> id_mapping; // indicate index mapping, activated when replace=false;
inline Idx map(Idx x) const { inline Idx Map(Idx x) const { // Map consecutive indices to unused elements
if (replace) if (replace)
return x; return x;
else else
return id_mapping[x]; return id_mapping[x];
} }
void rebuild(const std::vector<DType>& prob) { void Reconstruct(const std::vector<DType>& prob) { // Reconstruct CDF
N = 0; N = 0;
accum = 0.; accum = 0.;
taken = 0.; taken = 0.;
...@@ -197,28 +198,28 @@ class CDFSampler: public BaseSampler<Idx, DType, replace> { ...@@ -197,28 +198,28 @@ class CDFSampler: public BaseSampler<Idx, DType, replace> {
} }
public: public:
void reinit_state(const std::vector<DType>& prob) { void ResetState(const std::vector<DType>& prob) {
used.resize(prob.size()); used.resize(prob.size());
if (!replace) if (!replace)
_prob = prob; _prob = prob;
std::fill(used.begin(), used.end(), false); std::fill(used.begin(), used.end(), false);
rebuild(prob); Reconstruct(prob);
} }
explicit CDFSampler(RandomEngine *re, const std::vector<DType>& prob): re(re) { explicit CDFSampler(RandomEngine *re, const std::vector<DType>& prob): re(re) {
reinit_state(prob); ResetState(prob);
} }
~CDFSampler() {} ~CDFSampler() {}
Idx draw() { Idx Draw() {
DType eps = std::numeric_limits<DType>::min(); DType eps = std::numeric_limits<DType>::min();
if (!replace) { if (!replace) {
if (2 * taken >= accum) if (2 * taken >= accum)
rebuild(_prob); Reconstruct(_prob);
while (true) { while (true) {
DType p = std::max(re->Uniform<DType>(0., accum), eps); DType p = std::max(re->Uniform<DType>(0., accum), eps);
Idx rst = map(std::lower_bound(cdf.begin(), cdf.end(), p) - cdf.begin() - 1); Idx rst = Map(std::lower_bound(cdf.begin(), cdf.end(), p) - cdf.begin() - 1);
DType cap = _prob[rst]; DType cap = _prob[rst];
if (!used[rst]) { if (!used[rst]) {
used[rst] = true; used[rst] = true;
...@@ -228,7 +229,7 @@ class CDFSampler: public BaseSampler<Idx, DType, replace> { ...@@ -228,7 +229,7 @@ class CDFSampler: public BaseSampler<Idx, DType, replace> {
} }
} }
DType p = std::max(re->Uniform<DType>(0., accum), eps); DType p = std::max(re->Uniform<DType>(0., accum), eps);
return map(std::lower_bound(cdf.begin(), cdf.end(), p) - cdf.begin() - 1); return Map(std::lower_bound(cdf.begin(), cdf.end(), p) - cdf.begin() - 1);
} }
}; };
...@@ -251,7 +252,7 @@ class TreeSampler: public BaseSampler<Idx, DType, replace> { ...@@ -251,7 +252,7 @@ class TreeSampler: public BaseSampler<Idx, DType, replace> {
int64_t N, num_leafs; int64_t N, num_leafs;
public: public:
void reinit_state(const std::vector<DType>& prob) { void ResetState(const std::vector<DType>& prob) {
std::fill(weight.begin(), weight.end(), 0); std::fill(weight.begin(), weight.end(), 0);
for (int i = 0; i < prob.size(); ++i) for (int i = 0; i < prob.size(); ++i)
weight[num_leafs + i] = prob[i]; weight[num_leafs + i] = prob[i];
...@@ -265,10 +266,10 @@ class TreeSampler: public BaseSampler<Idx, DType, replace> { ...@@ -265,10 +266,10 @@ class TreeSampler: public BaseSampler<Idx, DType, replace> {
num_leafs *= 2; num_leafs *= 2;
N = num_leafs * 2; N = num_leafs * 2;
weight.resize(N); weight.resize(N);
reinit_state(prob); ResetState(prob);
} }
Idx draw() { Idx Draw() {
int64_t cur = 1; int64_t cur = 1;
DType p = re->Uniform<DType>(0, weight[cur]); DType p = re->Uniform<DType>(0, weight[cur]);
DType accum = 0.; DType accum = 0.;
...@@ -295,6 +296,7 @@ class TreeSampler: public BaseSampler<Idx, DType, replace> { ...@@ -295,6 +296,7 @@ class TreeSampler: public BaseSampler<Idx, DType, replace> {
} }
}; };
}; // namespace utils
}; // namespace dgl }; // namespace dgl
#endif // DGL_SAMPLE_UTILS_H_ #endif // DGL_SAMPLE_UTILS_H_
...@@ -21,19 +21,19 @@ void _TestWithReplacement(RandomEngine *re) { ...@@ -21,19 +21,19 @@ void _TestWithReplacement(RandomEngine *re) {
prob[i] /= accum; prob[i] /= accum;
auto _check_given_sampler = [n_categories, n_rolls, &prob]( auto _check_given_sampler = [n_categories, n_rolls, &prob](
BaseSampler<Idx, DType, true> *s) { utils::BaseSampler<Idx, DType, true> *s) {
std::vector<Idx> counter(n_categories, 0); std::vector<Idx> counter(n_categories, 0);
for (Idx i = 0; i < n_rolls; ++i) { for (Idx i = 0; i < n_rolls; ++i) {
Idx dice = s->draw(); Idx dice = s->Draw();
counter[dice]++; counter[dice]++;
} }
for (Idx i = 0; i < n_categories; ++i) for (Idx i = 0; i < n_categories; ++i)
ASSERT_NEAR(static_cast<DType>(counter[i]) / n_rolls, prob[i], 1e-2); ASSERT_NEAR(static_cast<DType>(counter[i]) / n_rolls, prob[i], 1e-2);
}; };
AliasSampler<Idx, DType, true> as(re, prob); utils::AliasSampler<Idx, DType, true> as(re, prob);
CDFSampler<Idx, DType, true> cs(re, prob); utils::CDFSampler<Idx, DType, true> cs(re, prob);
TreeSampler<Idx, DType, true> ts(re, prob); utils::TreeSampler<Idx, DType, true> ts(re, prob);
_check_given_sampler(&as); _check_given_sampler(&as);
_check_given_sampler(&cs); _check_given_sampler(&cs);
_check_given_sampler(&ts); _check_given_sampler(&ts);
...@@ -57,16 +57,16 @@ void _TestWithoutReplacementOrder(RandomEngine *re) { ...@@ -57,16 +57,16 @@ void _TestWithoutReplacementOrder(RandomEngine *re) {
std::vector<Idx> ground_truth = {0, 3, 2, 1}; std::vector<Idx> ground_truth = {0, 3, 2, 1};
auto _check_given_sampler = [&ground_truth]( auto _check_given_sampler = [&ground_truth](
BaseSampler<Idx, DType, false> *s) { utils::BaseSampler<Idx, DType, false> *s) {
for (size_t i = 0; i < ground_truth.size(); ++i) { for (size_t i = 0; i < ground_truth.size(); ++i) {
Idx dice = s->draw(); Idx dice = s->Draw();
ASSERT_EQ(dice, ground_truth[i]); ASSERT_EQ(dice, ground_truth[i]);
} }
}; };
AliasSampler<Idx, DType, false> as(re, prob); utils::AliasSampler<Idx, DType, false> as(re, prob);
CDFSampler<Idx, DType, false> cs(re, prob); utils::CDFSampler<Idx, DType, false> cs(re, prob);
TreeSampler<Idx, DType, false> ts(re, prob); utils::TreeSampler<Idx, DType, false> ts(re, prob);
_check_given_sampler(&as); _check_given_sampler(&as);
_check_given_sampler(&cs); _check_given_sampler(&cs);
_check_given_sampler(&ts); _check_given_sampler(&ts);
...@@ -92,19 +92,19 @@ void _TestWithoutReplacementUnique(RandomEngine *re) { ...@@ -92,19 +92,19 @@ void _TestWithoutReplacementUnique(RandomEngine *re) {
likelihood.push_back(re->Uniform<DType>()); likelihood.push_back(re->Uniform<DType>());
auto _check_given_sampler = [N]( auto _check_given_sampler = [N](
BaseSampler<Idx, DType, false> *s) { utils::BaseSampler<Idx, DType, false> *s) {
std::vector<int> cnt(N, 0); std::vector<int> cnt(N, 0);
for (Idx i = 0; i < N; ++i) { for (Idx i = 0; i < N; ++i) {
Idx dice = s->draw(); Idx dice = s->Draw();
cnt[dice]++; cnt[dice]++;
} }
for (Idx i = 0; i < N; ++i) for (Idx i = 0; i < N; ++i)
ASSERT_EQ(cnt[i], 1); ASSERT_EQ(cnt[i], 1);
}; };
AliasSampler<Idx, DType, false> as(re, likelihood); utils::AliasSampler<Idx, DType, false> as(re, likelihood);
CDFSampler<Idx, DType, false> cs(re, likelihood); utils::CDFSampler<Idx, DType, false> cs(re, likelihood);
TreeSampler<Idx, DType, false> ts(re, likelihood); utils::TreeSampler<Idx, DType, false> ts(re, likelihood);
_check_given_sampler(&as); _check_given_sampler(&as);
_check_given_sampler(&cs); _check_given_sampler(&cs);
_check_given_sampler(&ts); _check_given_sampler(&ts);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment