"vscode:/vscode.git/clone" did not exist on "9c0944581a386736bc808e68d7dfb52d8cf1790e"
choice.cc 3.33 KB
Newer Older
1
2
3
4
5
6
7
/*!
 *  Copyright (c) 2019 by Contributors
 * \file random/choice.cc
 * \brief Non-uniform discrete sampling implementation
 */

#include <dgl/random.h>
8
#include <dgl/array.h>
9
#include <vector>
10
#include <numeric>
11
12
13
14
15
16
#include "sample_utils.h"

namespace dgl {

template<typename IdxType>
IdxType RandomEngine::Choice(FloatArray prob) {
17
  IdxType ret = 0;
18
  ATEN_FLOAT_TYPE_SWITCH(prob->dtype, ValueType, "probability", {
19
    // TODO(minjie): allow choosing different sampling algorithms
20
21
22
23
24
25
26
27
28
    utils::TreeSampler<IdxType, ValueType, true> sampler(this, prob);
    ret = sampler.Draw();
  });
  return ret;
}

template int32_t RandomEngine::Choice<int32_t>(FloatArray);
template int64_t RandomEngine::Choice<int64_t>(FloatArray);

29
30

template<typename IdxType, typename FloatType>
31
32
void RandomEngine::Choice(IdxType num, FloatArray prob, IdxType* out, bool replace) {
  const IdxType N = prob->shape[0];
33
34
35
  if (!replace)
    CHECK_LE(num, N) << "Cannot take more sample than population when 'replace=false'";
  if (num == N && !replace)
36
    std::iota(out, out + num, 0);
37
38
39
40
41
42
43

  utils::BaseSampler<IdxType>* sampler = nullptr;
  if (replace) {
    sampler = new utils::TreeSampler<IdxType, FloatType, true>(this, prob);
  } else {
    sampler = new utils::TreeSampler<IdxType, FloatType, false>(this, prob);
  }
44
45
  for (IdxType i = 0; i < num; ++i)
    out[i] = sampler->Draw();
46
47
48
  delete sampler;
}

49
50
51
52
53
54
55
56
template void RandomEngine::Choice<int32_t, float>(
    int32_t num, FloatArray prob, int32_t* out, bool replace);
template void RandomEngine::Choice<int64_t, float>(
    int64_t num, FloatArray prob, int64_t* out, bool replace);
template void RandomEngine::Choice<int32_t, double>(
    int32_t num, FloatArray prob, int32_t* out, bool replace);
template void RandomEngine::Choice<int64_t, double>(
    int64_t num, FloatArray prob, int64_t* out, bool replace);
57
58

template <typename IdxType>
59
void RandomEngine::UniformChoice(IdxType num, IdxType population, IdxType* out, bool replace) {
60
61
62
  if (!replace)
    CHECK_LE(num, population) << "Cannot take more sample than population when 'replace=false'";
  if (replace) {
63
64
    for (IdxType i = 0; i < num; ++i)
      out[i] = RandInt(population);
65
  } else {
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    if (num < population / 10) {  // TODO(minjie): may need a better threshold here
      // use hash set
      // In the best scenario, time complexity is O(num), i.e., no conflict.
      //
      // Let k be num / population, the expected number of extra sampling steps is roughly
      // k^2 / (1-k) * population, which means in the worst case scenario,
      // the time complexity is O(population^2). In practice, we use 1/10 since
      // std::unordered_set is pretty slow.
      std::unordered_set<IdxType> selected;
      while (selected.size() < num) {
        selected.insert(RandInt(population));
      }
      std::copy(selected.begin(), selected.end(), out);
    } else {
      // reservoir algorithm
      // time: O(population), space: O(num)
      for (IdxType i = 0; i < num; ++i)
        out[i] = i;
      for (IdxType i = num; i < population; ++i) {
        const IdxType j = RandInt(i);
        if (j < num)
          out[j] = i;
      }
89
90
91
92
    }
  }
}

93
94
95
96
template void RandomEngine::UniformChoice<int32_t>(
    int32_t num, int32_t population, int32_t* out, bool replace);
template void RandomEngine::UniformChoice<int64_t>(
    int64_t num, int64_t population, int64_t* out, bool replace);
97

98
};  // namespace dgl