choice.cc 4.51 KB
Newer Older
1
2
3
4
5
6
/*!
 *  Copyright (c) 2019 by Contributors
 * \file random/choice.cc
 * \brief Non-uniform discrete sampling implementation
 */

7
#include <dgl/array.h>
8
#include <dgl/random.h>
9
#include <numeric>
10
#include <vector>
11
12
13
14
#include "sample_utils.h"

namespace dgl {

15
template <typename IdxType>
16
IdxType RandomEngine::Choice(FloatArray prob) {
17
  IdxType ret = 0;
18
  ATEN_FLOAT_TYPE_SWITCH(prob->dtype, ValueType, "probability", {
19
    // TODO(minjie): allow choosing different sampling algorithms
20
21
22
23
24
25
26
27
28
    utils::TreeSampler<IdxType, ValueType, true> sampler(this, prob);
    ret = sampler.Draw();
  });
  return ret;
}

template int32_t RandomEngine::Choice<int32_t>(FloatArray);
template int64_t RandomEngine::Choice<int64_t>(FloatArray);

29
30
31
template <typename IdxType, typename FloatType>
void RandomEngine::Choice(IdxType num, FloatArray prob, IdxType* out,
                          bool replace) {
32
  const IdxType N = prob->shape[0];
33
  if (!replace)
34
35
36
    CHECK_LE(num, N)
      << "Cannot take more sample than population when 'replace=false'";
  if (num == N && !replace) std::iota(out, out + num, 0);
37
38
39
40
41
42
43

  utils::BaseSampler<IdxType>* sampler = nullptr;
  if (replace) {
    sampler = new utils::TreeSampler<IdxType, FloatType, true>(this, prob);
  } else {
    sampler = new utils::TreeSampler<IdxType, FloatType, false>(this, prob);
  }
44
  for (IdxType i = 0; i < num; ++i) out[i] = sampler->Draw();
45
46
47
  delete sampler;
}

48
49
50
51
52
53
54
55
56
57
template void RandomEngine::Choice<int32_t, float>(int32_t num, FloatArray prob,
                                                   int32_t* out, bool replace);
template void RandomEngine::Choice<int64_t, float>(int64_t num, FloatArray prob,
                                                   int64_t* out, bool replace);
template void RandomEngine::Choice<int32_t, double>(int32_t num,
                                                    FloatArray prob,
                                                    int32_t* out, bool replace);
template void RandomEngine::Choice<int64_t, double>(int64_t num,
                                                    FloatArray prob,
                                                    int64_t* out, bool replace);
58
59

template <typename IdxType>
60
61
void RandomEngine::UniformChoice(IdxType num, IdxType population, IdxType* out,
                                 bool replace) {
62
  if (!replace)
63
64
    CHECK_LE(num, population)
      << "Cannot take more sample than population when 'replace=false'";
65
  if (replace) {
66
    for (IdxType i = 0; i < num; ++i) out[i] = RandInt(population);
67
  } else {
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    if (num <
        population / 10) {  // TODO(minjie): may need a better threshold here
      // if set of numbers is small (up to 128) use linear search to verify
      // uniqueness this operation is cheaper for CPU.
      if (num && num < 64) {
        *out = RandInt(population);
        auto b = out + 1;
        auto e = b + num - 1;
        while (b != e) {
          // put the new value at the end
          *b = RandInt(population);
          // Check if a new value doesn't exist in current range(out,b)
          // otherwise get a new value until we haven't unique range of
          // elements.
          auto it = std::find(out, b, *b);
          if (it != b) continue;
          ++b;
        }

      } else {
        // use hash set
        // In the best scenario, time complexity is O(num), i.e., no conflict.
        //
        // Let k be num / population, the expected number of extra sampling
        // steps is roughly k^2 / (1-k) * population, which means in the worst
        // case scenario, the time complexity is O(population^2). In practice,
        // we use 1/10 since std::unordered_set is pretty slow.
        std::unordered_set<IdxType> selected;
        while (selected.size() < num) {
          selected.insert(RandInt(population));
        }
        std::copy(selected.begin(), selected.end(), out);
100
      }
101

102
103
104
    } else {
      // reservoir algorithm
      // time: O(population), space: O(num)
105
      for (IdxType i = 0; i < num; ++i) out[i] = i;
106
      for (IdxType i = num; i < population; ++i) {
107
        const IdxType j = RandInt(i + 1);
108
        if (j < num) out[j] = i;
109
      }
110
111
112
113
    }
  }
}

114
115
116
117
118
119
template void RandomEngine::UniformChoice<int32_t>(int32_t num,
                                                   int32_t population,
                                                   int32_t* out, bool replace);
template void RandomEngine::UniformChoice<int64_t>(int64_t num,
                                                   int64_t population,
                                                   int64_t* out, bool replace);
120

121
};  // namespace dgl