"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "a80f6892003e102f56bc956e9f8707b52c5d4487"
Unverified Commit b2e35e6a authored by pawelpiotrowicz's avatar pawelpiotrowicz Committed by GitHub
Browse files

[Performance] Linear UniformChoice optimization (#2710)


Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
Co-authored-by: default avatarZihao Ye <expye@outlook.com>
parent 54c74803
...@@ -4,15 +4,15 @@ ...@@ -4,15 +4,15 @@
* \brief Non-uniform discrete sampling implementation * \brief Non-uniform discrete sampling implementation
*/ */
#include <dgl/random.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <vector> #include <dgl/random.h>
#include <numeric> #include <numeric>
#include <vector>
#include "sample_utils.h" #include "sample_utils.h"
namespace dgl { namespace dgl {
template<typename IdxType> template <typename IdxType>
IdxType RandomEngine::Choice(FloatArray prob) { IdxType RandomEngine::Choice(FloatArray prob) {
IdxType ret = 0; IdxType ret = 0;
ATEN_FLOAT_TYPE_SWITCH(prob->dtype, ValueType, "probability", { ATEN_FLOAT_TYPE_SWITCH(prob->dtype, ValueType, "probability", {
...@@ -26,14 +26,14 @@ IdxType RandomEngine::Choice(FloatArray prob) { ...@@ -26,14 +26,14 @@ IdxType RandomEngine::Choice(FloatArray prob) {
template int32_t RandomEngine::Choice<int32_t>(FloatArray); template int32_t RandomEngine::Choice<int32_t>(FloatArray);
template int64_t RandomEngine::Choice<int64_t>(FloatArray); template int64_t RandomEngine::Choice<int64_t>(FloatArray);
template <typename IdxType, typename FloatType>
template<typename IdxType, typename FloatType> void RandomEngine::Choice(IdxType num, FloatArray prob, IdxType* out,
void RandomEngine::Choice(IdxType num, FloatArray prob, IdxType* out, bool replace) { bool replace) {
const IdxType N = prob->shape[0]; const IdxType N = prob->shape[0];
if (!replace) if (!replace)
CHECK_LE(num, N) << "Cannot take more sample than population when 'replace=false'"; CHECK_LE(num, N)
if (num == N && !replace) << "Cannot take more sample than population when 'replace=false'";
std::iota(out, out + num, 0); if (num == N && !replace) std::iota(out, out + num, 0);
utils::BaseSampler<IdxType>* sampler = nullptr; utils::BaseSampler<IdxType>* sampler = nullptr;
if (replace) { if (replace) {
...@@ -41,58 +41,81 @@ void RandomEngine::Choice(IdxType num, FloatArray prob, IdxType* out, bool repla ...@@ -41,58 +41,81 @@ void RandomEngine::Choice(IdxType num, FloatArray prob, IdxType* out, bool repla
} else { } else {
sampler = new utils::TreeSampler<IdxType, FloatType, false>(this, prob); sampler = new utils::TreeSampler<IdxType, FloatType, false>(this, prob);
} }
for (IdxType i = 0; i < num; ++i) for (IdxType i = 0; i < num; ++i) out[i] = sampler->Draw();
out[i] = sampler->Draw();
delete sampler; delete sampler;
} }
template void RandomEngine::Choice<int32_t, float>( template void RandomEngine::Choice<int32_t, float>(int32_t num, FloatArray prob,
int32_t num, FloatArray prob, int32_t* out, bool replace); int32_t* out, bool replace);
template void RandomEngine::Choice<int64_t, float>( template void RandomEngine::Choice<int64_t, float>(int64_t num, FloatArray prob,
int64_t num, FloatArray prob, int64_t* out, bool replace); int64_t* out, bool replace);
template void RandomEngine::Choice<int32_t, double>( template void RandomEngine::Choice<int32_t, double>(int32_t num,
int32_t num, FloatArray prob, int32_t* out, bool replace); FloatArray prob,
template void RandomEngine::Choice<int64_t, double>( int32_t* out, bool replace);
int64_t num, FloatArray prob, int64_t* out, bool replace); template void RandomEngine::Choice<int64_t, double>(int64_t num,
FloatArray prob,
int64_t* out, bool replace);
template <typename IdxType> template <typename IdxType>
void RandomEngine::UniformChoice(IdxType num, IdxType population, IdxType* out, bool replace) { void RandomEngine::UniformChoice(IdxType num, IdxType population, IdxType* out,
bool replace) {
if (!replace) if (!replace)
CHECK_LE(num, population) << "Cannot take more sample than population when 'replace=false'"; CHECK_LE(num, population)
<< "Cannot take more sample than population when 'replace=false'";
if (replace) { if (replace) {
for (IdxType i = 0; i < num; ++i) for (IdxType i = 0; i < num; ++i) out[i] = RandInt(population);
out[i] = RandInt(population);
} else { } else {
if (num < population / 10) { // TODO(minjie): may need a better threshold here if (num <
// use hash set population / 10) { // TODO(minjie): may need a better threshold here
// In the best scenario, time complexity is O(num), i.e., no conflict. // if set of numbers is small (up to 128) use linear search to verify
// // uniqueness this operation is cheaper for CPU.
// Let k be num / population, the expected number of extra sampling steps is roughly if (num && num < 64) {
// k^2 / (1-k) * population, which means in the worst case scenario, *out = RandInt(population);
// the time complexity is O(population^2). In practice, we use 1/10 since auto b = out + 1;
// std::unordered_set is pretty slow. auto e = b + num - 1;
std::unordered_set<IdxType> selected; while (b != e) {
while (selected.size() < num) { // put the new value at the end
selected.insert(RandInt(population)); *b = RandInt(population);
// Check if a new value doesn't exist in current range(out,b)
// otherwise get a new value until we haven't unique range of
// elements.
auto it = std::find(out, b, *b);
if (it != b) continue;
++b;
}
} else {
// use hash set
// In the best scenario, time complexity is O(num), i.e., no conflict.
//
// Let k be num / population, the expected number of extra sampling
// steps is roughly k^2 / (1-k) * population, which means in the worst
// case scenario, the time complexity is O(population^2). In practice,
// we use 1/10 since std::unordered_set is pretty slow.
std::unordered_set<IdxType> selected;
while (selected.size() < num) {
selected.insert(RandInt(population));
}
std::copy(selected.begin(), selected.end(), out);
} }
std::copy(selected.begin(), selected.end(), out);
} else { } else {
// reservoir algorithm // reservoir algorithm
// time: O(population), space: O(num) // time: O(population), space: O(num)
for (IdxType i = 0; i < num; ++i) for (IdxType i = 0; i < num; ++i) out[i] = i;
out[i] = i;
for (IdxType i = num; i < population; ++i) { for (IdxType i = num; i < population; ++i) {
const IdxType j = RandInt(i + 1); const IdxType j = RandInt(i + 1);
if (j < num) if (j < num) out[j] = i;
out[j] = i;
} }
} }
} }
} }
template void RandomEngine::UniformChoice<int32_t>( template void RandomEngine::UniformChoice<int32_t>(int32_t num,
int32_t num, int32_t population, int32_t* out, bool replace); int32_t population,
template void RandomEngine::UniformChoice<int64_t>( int32_t* out, bool replace);
int64_t num, int64_t population, int64_t* out, bool replace); template void RandomEngine::UniformChoice<int64_t>(int64_t num,
int64_t population,
int64_t* out, bool replace);
}; // namespace dgl }; // namespace dgl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment