"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "a5b75c8a5e9f31e3a199a17eb7e1dcbe64baad99"
Unverified Commit 6b022d2f authored by Qidong Su's avatar Qidong Su Committed by GitHub
Browse files

[Sampler] BiasedChoice sampler (#1665)



* update

* update

* update

* update

* update

* update

* update

* fix

* fix

* update

* doc

* doc

* fix

* fix
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent 4579bbf7
......@@ -178,6 +178,57 @@ class RandomEngine {
return ret;
}
/*!
* \brief Pick random integers with different probability for different segments.
*
* For example, if split=[0, 4, 10] and bias=[1.5, 1], it means to pick some integers
* from 0 to 9, which is divided into two segments. 0-3 are in the first segment and the rest
* belongs to the second. The weight(bias) of each candidate in the first segment is upweighted
* to 1.5.
*
* candidate | 0 1 2 3 | 4 5 6 7 8 9 |
* split ^ ^ ^
* bias | 1.5 | 1 |
*
*
* The complexity of this operator is O(k * log(T)) where k is the number of integers we want
* to pick, and T is the number of segments. It is much faster compared with assigning
* probability for each candidate, of which the complexity is O(k * log(N)) where N is the
* number of all candidates.
*
* If replace is false, num must not be larger than population.
*
* \tparam IdxType Return integer type
* \param num Number of integers to choose
* \param split Array of T+1 split positions of different segments(including start and end)
* \param bias Array of T weight of each segments
* \param out The output buffer to write selected indices.
* \param replace If true, choose with replacement.
*/
template <typename IdxType, typename FloatType>
void BiasedChoice(
IdxType num, const IdxType *split, FloatArray bias, IdxType* out, bool replace = true);
/*!
* \brief Pick random integers with different probability for different segments.
*
* If replace is false, num must not be larger than population.
*
* \tparam IdxType Return integer type
* \param num Number of integers to choose
* \param split Split positions of different segments
* \param bias Weights of different segments
* \param replace If true, choose with replacement.
*/
template <typename IdxType, typename FloatType>
IdArray BiasedChoice(
IdxType num, const IdxType *split, FloatArray bias, bool replace = true) {
const DLDataType dtype{kDLInt, sizeof(IdxType) * 8, 1};
IdArray ret = IdArray::Empty({num}, dtype, DLContext{kDLCPU, 0});
BiasedChoice<IdxType, FloatType>(num, split, bias, static_cast<IdxType*>(ret->data), replace);
return ret;
}
private:
std::default_random_engine rng_;
};
......
......@@ -118,4 +118,56 @@ template void RandomEngine::UniformChoice<int64_t>(int64_t num,
int64_t population,
int64_t* out, bool replace);
template <typename IdxType, typename FloatType>
void RandomEngine::BiasedChoice(
IdxType num, const IdxType *split, FloatArray bias, IdxType* out, bool replace) {
const int64_t num_tags = bias->shape[0];
const FloatType *bias_data = static_cast<FloatType *>(bias->data);
IdxType total_node_num = 0;
FloatArray prob = NDArray::Empty({num_tags}, bias->dtype, bias->ctx);
FloatType *prob_data = static_cast<FloatType *>(prob->data);
for (int64_t tag = 0 ; tag < num_tags; ++tag) {
int64_t tag_num_nodes = split[tag+1] - split[tag];
total_node_num += tag_num_nodes;
FloatType tag_bias = bias_data[tag];
prob_data[tag] = tag_num_nodes * tag_bias;
}
if (replace) {
auto sampler = utils::TreeSampler<IdxType, FloatType, true>(this, prob);
for (IdxType i = 0; i < num; ++i) {
const int64_t tag = sampler.Draw();
const IdxType tag_num_nodes = split[tag+1] - split[tag];
out[i] = RandInt(tag_num_nodes) + split[tag];
}
} else {
utils::TreeSampler<int64_t, FloatType, false> sampler(this, prob, bias_data);
CHECK_GE(total_node_num, num)
<< "Cannot take more sample than population when 'replace=false'";
// we use hash set here. Maybe in the future we should support reservoir algorithm
std::vector<std::unordered_set<IdxType>> selected(num_tags);
for (IdxType i = 0 ; i < num ; ++i) {
const int64_t tag = sampler.Draw();
bool inserted = false;
const IdxType tag_num_nodes = split[tag+1] - split[tag];
IdxType selected_node;
while (!inserted) {
CHECK_LT(selected[tag].size(), tag_num_nodes)
<< "Cannot take more sample than population when 'replace=false'";
selected_node = RandInt(tag_num_nodes);
inserted = selected[tag].insert(selected_node).second;
}
out[i] = selected_node + split[tag];
}
}
}
template void RandomEngine::BiasedChoice<int32_t, float>(
int32_t, const int32_t*, FloatArray, int32_t*, bool);
template void RandomEngine::BiasedChoice<int32_t, double>(
int32_t, const int32_t*, FloatArray, int32_t*, bool);
template void RandomEngine::BiasedChoice<int64_t, float>(
int64_t, const int64_t*, FloatArray, int64_t*, bool);
template void RandomEngine::BiasedChoice<int64_t, double>(
int64_t, const int64_t*, FloatArray, int64_t*, bool);
}; // namespace dgl
......@@ -258,6 +258,7 @@ class TreeSampler: public BaseSampler<Idx> {
std::vector<DType> weight; // accumulated likelihood of subtrees.
int64_t N;
int64_t num_leafs;
const DType *decrease;
public:
void ResetState(FloatArray prob) {
......@@ -270,7 +271,8 @@ class TreeSampler: public BaseSampler<Idx> {
weight[i] = weight[i * 2] + weight[i * 2 + 1];
}
explicit TreeSampler(RandomEngine *re, FloatArray prob): re(re) {
explicit TreeSampler(RandomEngine *re, FloatArray prob, const DType* decrease = nullptr)
: re(re), decrease(decrease) {
num_leafs = 1;
while (num_leafs < prob->shape[0])
num_leafs *= 2;
......@@ -279,6 +281,17 @@ class TreeSampler: public BaseSampler<Idx> {
ResetState(prob);
}
/* Pick an element from the given distribution and update the tree.
*
* The parameter decrease is an array of which the length is the number of categories.
* Every time an element in the category x is picked, the weight of this category is subtracted
* by decrease[x]. It is used to support the case where a category might contains multiple
* candidates and decrease[x] is the weight of one candidate of the category x.
*
* When decrease == nullptr, it means there is only one candidate in each category and will
* directly set the weight of the chosen category as 0.
*
*/
Idx Draw() {
int64_t cur = 1;
DType p = re->Uniform<DType>(0, weight[cur]);
......@@ -296,7 +309,7 @@ class TreeSampler: public BaseSampler<Idx> {
if (!replace) {
while (cur >= 1) {
if (cur >= num_leafs)
weight[cur] = 0.;
weight[cur] = this->decrease ? weight[cur] - this->decrease[rst] : 0.;
else
weight[cur] = weight[cur * 2] + weight[cur * 2 + 1];
cur /= 2;
......
......@@ -227,3 +227,56 @@ TEST(RandomTest, TestUniformChoice) {
_TestUniformChoice<int32_t>(re);
_TestUniformChoice<int64_t>(re);
}
template <typename Idx, typename FloatType>
void _TestBiasedChoice(RandomEngine* re) {
re->SetSeed(42);
// num == 0
{
Idx split[] = {0, 1, 2};
FloatArray bias = NDArray::FromVector(std::vector<FloatType>({1, 3}));
IdArray rst = re->BiasedChoice<Idx, FloatType>(0, split, bias, true);
ASSERT_EQ(rst->shape[0], 0);
}
// basic test
{
Idx sample_num = 100000;
Idx population = 1000000;
Idx split[] = {0, population/2, population};
FloatArray bias = NDArray::FromVector(std::vector<FloatType>({1, 3}));
IdArray rst = re->BiasedChoice<Idx, FloatType>(sample_num, split, bias, true);
auto rst_data = static_cast<Idx *>(rst->data);
Idx larger = 0;
for (Idx i = 0 ; i < sample_num ; ++i)
if (rst_data[i] >= population / 2)
larger++;
ASSERT_LE(fabs((double)larger / sample_num - 0.75), 1e-2);
}
// without replacement
{
Idx sample_num = 500;
Idx population = 1000;
Idx split[] = {0, sample_num, population};
FloatArray bias = NDArray::FromVector(std::vector<FloatType>({1, 0}));
IdArray rst = re->BiasedChoice<Idx, FloatType>(sample_num, split, bias, false);
auto rst_data = static_cast<Idx *>(rst->data);
std::set<Idx> idxset;
for (int64_t i = 0; i < sample_num; ++i) {
Idx x = rst_data[i];
ASSERT_LT(x, sample_num);
idxset.insert(x);
}
ASSERT_EQ(idxset.size(), sample_num);
}
}
TEST(RandomTest, TestBiasedChoice) {
RandomEngine* re = RandomEngine::ThreadLocal();
_TestBiasedChoice<int32_t, float>(re);
_TestBiasedChoice<int64_t, float>(re);
_TestBiasedChoice<int32_t, double>(re);
_TestBiasedChoice<int64_t, double>(re);
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment