choice.cc 7.67 KB
Newer Older
1
/**
2
 *  Copyright (c) 2019 by Contributors
3
4
 * @file random/choice.cc
 * @brief Non-uniform discrete sampling implementation
5
6
 */

7
#include <dgl/array.h>
8
#include <dgl/random.h>
9

10
#include <numeric>
11
#include <vector>
12

13
14
15
16
#include "sample_utils.h"

namespace dgl {

17
template <typename IdxType>
18
IdxType RandomEngine::Choice(FloatArray prob) {
19
  IdxType ret = 0;
20
  ATEN_FLOAT_TYPE_SWITCH(prob->dtype, ValueType, "probability", {
21
    // TODO(minjie): allow choosing different sampling algorithms
22
23
24
25
26
27
28
29
30
    utils::TreeSampler<IdxType, ValueType, true> sampler(this, prob);
    ret = sampler.Draw();
  });
  return ret;
}

template int32_t RandomEngine::Choice<int32_t>(FloatArray);
template int64_t RandomEngine::Choice<int64_t>(FloatArray);

31
template <typename IdxType, typename FloatType>
32
33
void RandomEngine::Choice(
    IdxType num, FloatArray prob, IdxType* out, bool replace) {
34
  const IdxType N = prob->shape[0];
35
  if (!replace)
36
    CHECK_LE(num, N)
37
        << "Cannot take more sample than population when 'replace=false'";
38
  if (num == N && !replace) std::iota(out, out + num, 0);
39
40
41
42
43
44
45

  utils::BaseSampler<IdxType>* sampler = nullptr;
  if (replace) {
    sampler = new utils::TreeSampler<IdxType, FloatType, true>(this, prob);
  } else {
    sampler = new utils::TreeSampler<IdxType, FloatType, false>(this, prob);
  }
46
  for (IdxType i = 0; i < num; ++i) out[i] = sampler->Draw();
47
48
49
  delete sampler;
}

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
template void RandomEngine::Choice<int32_t, float>(
    int32_t num, FloatArray prob, int32_t* out, bool replace);
template void RandomEngine::Choice<int64_t, float>(
    int64_t num, FloatArray prob, int64_t* out, bool replace);
template void RandomEngine::Choice<int32_t, double>(
    int32_t num, FloatArray prob, int32_t* out, bool replace);
template void RandomEngine::Choice<int64_t, double>(
    int64_t num, FloatArray prob, int64_t* out, bool replace);
template void RandomEngine::Choice<int32_t, int8_t>(
    int32_t num, FloatArray prob, int32_t* out, bool replace);
template void RandomEngine::Choice<int64_t, int8_t>(
    int64_t num, FloatArray prob, int64_t* out, bool replace);
template void RandomEngine::Choice<int32_t, uint8_t>(
    int32_t num, FloatArray prob, int32_t* out, bool replace);
template void RandomEngine::Choice<int64_t, uint8_t>(
    int64_t num, FloatArray prob, int64_t* out, bool replace);
66
67

template <typename IdxType>
68
69
void RandomEngine::UniformChoice(
    IdxType num, IdxType population, IdxType* out, bool replace) {
70
71
  CHECK_GE(num, 0) << "The numbers to sample should be non-negative.";
  CHECK_GE(population, 0) << "The population size should be non-negative.";
72
  if (!replace)
73
    CHECK_LE(num, population)
74
        << "Cannot take more sample than population when 'replace=false'";
75
  if (replace) {
76
    for (IdxType i = 0; i < num; ++i) out[i] = RandInt(population);
77
  } else {
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    if (num <
        population / 10) {  // TODO(minjie): may need a better threshold here
      // if set of numbers is small (up to 128) use linear search to verify
      // uniqueness this operation is cheaper for CPU.
      if (num && num < 64) {
        *out = RandInt(population);
        auto b = out + 1;
        auto e = b + num - 1;
        while (b != e) {
          // put the new value at the end
          *b = RandInt(population);
          // Check if a new value doesn't exist in current range(out,b)
          // otherwise get a new value until we haven't unique range of
          // elements.
          auto it = std::find(out, b, *b);
          if (it != b) continue;
          ++b;
        }

      } else {
        // use hash set
        // In the best scenario, time complexity is O(num), i.e., no conflict.
        //
        // Let k be num / population, the expected number of extra sampling
        // steps is roughly k^2 / (1-k) * population, which means in the worst
        // case scenario, the time complexity is O(population^2). In practice,
        // we use 1/10 since std::unordered_set is pretty slow.
        std::unordered_set<IdxType> selected;
106
        while (static_cast<IdxType>(selected.size()) < num) {
107
108
109
          selected.insert(RandInt(population));
        }
        std::copy(selected.begin(), selected.end(), out);
110
      }
111

112
    } else {
113
114
115
116
117
118
119
120
121
      // In this case, `num >= population / 10`. To reduce the computation
      // overhead, we should reduce the number of random number generations.
      // Even though reservior algorithm is more memory effficient (it has
      // O(num) memory complexity), it generates O(population) random numbers,
      // which is computationally expensive. This algorithm has memory
      // complexity of O(population) but generates much fewer random numbers
      // O(num). In the case of `num >= population/10`, we don't need to worry
      // about memory complexity because `num` is usually small. So is
      // `population`. Allocating a small piece of memory is very efficient.
122
123
124
125
126
127
128
129
130
      std::vector<IdxType> seq(population);
      for (size_t i = 0; i < seq.size(); i++) seq[i] = i;
      for (IdxType i = 0; i < num; i++) {
        IdxType j = RandInt(i, population);
        std::swap(seq[i], seq[j]);
      }
      // Save the randomly sampled numbers.
      for (IdxType i = 0; i < num; i++) {
        out[i] = seq[i];
131
      }
132
133
134
135
    }
  }
}

136
137
138
139
template void RandomEngine::UniformChoice<int32_t>(
    int32_t num, int32_t population, int32_t* out, bool replace);
template void RandomEngine::UniformChoice<int64_t>(
    int64_t num, int64_t population, int64_t* out, bool replace);
140

141
142
template <typename IdxType, typename FloatType>
void RandomEngine::BiasedChoice(
143
144
    IdxType num, const IdxType* split, FloatArray bias, IdxType* out,
    bool replace) {
145
  const int64_t num_tags = bias->shape[0];
146
  const FloatType* bias_data = static_cast<FloatType*>(bias->data);
147
148
  IdxType total_node_num = 0;
  FloatArray prob = NDArray::Empty({num_tags}, bias->dtype, bias->ctx);
149
150
151
  FloatType* prob_data = static_cast<FloatType*>(prob->data);
  for (int64_t tag = 0; tag < num_tags; ++tag) {
    int64_t tag_num_nodes = split[tag + 1] - split[tag];
152
153
154
155
156
157
158
159
    total_node_num += tag_num_nodes;
    FloatType tag_bias = bias_data[tag];
    prob_data[tag] = tag_num_nodes * tag_bias;
  }
  if (replace) {
    auto sampler = utils::TreeSampler<IdxType, FloatType, true>(this, prob);
    for (IdxType i = 0; i < num; ++i) {
      const int64_t tag = sampler.Draw();
160
      const IdxType tag_num_nodes = split[tag + 1] - split[tag];
161
162
163
      out[i] = RandInt(tag_num_nodes) + split[tag];
    }
  } else {
164
165
    utils::TreeSampler<int64_t, FloatType, false> sampler(
        this, prob, bias_data);
166
167
    CHECK_GE(total_node_num, num)
        << "Cannot take more sample than population when 'replace=false'";
168
169
    // we use hash set here. Maybe in the future we should support reservoir
    // algorithm
170
    std::vector<std::unordered_set<IdxType>> selected(num_tags);
171
    for (IdxType i = 0; i < num; ++i) {
172
173
      const int64_t tag = sampler.Draw();
      bool inserted = false;
174
      const IdxType tag_num_nodes = split[tag + 1] - split[tag];
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
      IdxType selected_node;
      while (!inserted) {
        CHECK_LT(selected[tag].size(), tag_num_nodes)
            << "Cannot take more sample than population when 'replace=false'";
        selected_node = RandInt(tag_num_nodes);
        inserted = selected[tag].insert(selected_node).second;
      }
      out[i] = selected_node + split[tag];
    }
  }
}

template void RandomEngine::BiasedChoice<int32_t, float>(
    int32_t, const int32_t*, FloatArray, int32_t*, bool);
template void RandomEngine::BiasedChoice<int32_t, double>(
    int32_t, const int32_t*, FloatArray, int32_t*, bool);
template void RandomEngine::BiasedChoice<int64_t, float>(
    int64_t, const int64_t*, FloatArray, int64_t*, bool);
template void RandomEngine::BiasedChoice<int64_t, double>(
    int64_t, const int64_t*, FloatArray, int64_t*, bool);

196
};  // namespace dgl