choice.cc 8.45 KB
Newer Older
1
2
3
4
5
6
/*!
 *  Copyright (c) 2019 by Contributors
 * \file random/choice.cc
 * \brief Non-uniform discrete sampling implementation
 */

7
#include <dgl/array.h>
8
#include <dgl/random.h>
9
#include <numeric>
10
#include <vector>
11
12
13
14
#include "sample_utils.h"

namespace dgl {

15
template <typename IdxType>
16
IdxType RandomEngine::Choice(FloatArray prob) {
17
  IdxType ret = 0;
18
  ATEN_FLOAT_TYPE_SWITCH(prob->dtype, ValueType, "probability", {
19
    // TODO(minjie): allow choosing different sampling algorithms
20
21
22
23
24
25
26
27
28
    utils::TreeSampler<IdxType, ValueType, true> sampler(this, prob);
    ret = sampler.Draw();
  });
  return ret;
}

template int32_t RandomEngine::Choice<int32_t>(FloatArray);
template int64_t RandomEngine::Choice<int64_t>(FloatArray);

29
30
31
template <typename IdxType, typename FloatType>
void RandomEngine::Choice(IdxType num, FloatArray prob, IdxType* out,
                          bool replace) {
32
  const IdxType N = prob->shape[0];
33
  if (!replace)
34
35
36
    CHECK_LE(num, N)
      << "Cannot take more sample than population when 'replace=false'";
  if (num == N && !replace) std::iota(out, out + num, 0);
37
38
39
40
41
42
43

  utils::BaseSampler<IdxType>* sampler = nullptr;
  if (replace) {
    sampler = new utils::TreeSampler<IdxType, FloatType, true>(this, prob);
  } else {
    sampler = new utils::TreeSampler<IdxType, FloatType, false>(this, prob);
  }
44
  for (IdxType i = 0; i < num; ++i) out[i] = sampler->Draw();
45
46
47
  delete sampler;
}

48
49
50
51
52
53
54
55
56
57
template void RandomEngine::Choice<int32_t, float>(int32_t num, FloatArray prob,
                                                   int32_t* out, bool replace);
template void RandomEngine::Choice<int64_t, float>(int64_t num, FloatArray prob,
                                                   int64_t* out, bool replace);
template void RandomEngine::Choice<int32_t, double>(int32_t num,
                                                    FloatArray prob,
                                                    int32_t* out, bool replace);
template void RandomEngine::Choice<int64_t, double>(int64_t num,
                                                    FloatArray prob,
                                                    int64_t* out, bool replace);
58
59
60
61
62
63
64
65
66
67
template void RandomEngine::Choice<int32_t, int8_t>(int32_t num, FloatArray prob,
                                                    int32_t* out, bool replace);
template void RandomEngine::Choice<int64_t, int8_t>(int64_t num, FloatArray prob,
                                                    int64_t* out, bool replace);
template void RandomEngine::Choice<int32_t, uint8_t>(int32_t num,
                                                     FloatArray prob,
                                                     int32_t* out, bool replace);
template void RandomEngine::Choice<int64_t, uint8_t>(int64_t num,
                                                     FloatArray prob,
                                                     int64_t* out, bool replace);
68
69

template <typename IdxType>
70
71
void RandomEngine::UniformChoice(IdxType num, IdxType population, IdxType* out,
                                 bool replace) {
72
73
  CHECK_GE(num, 0) << "The numbers to sample should be non-negative.";
  CHECK_GE(population, 0) << "The population size should be non-negative.";
74
  if (!replace)
75
76
    CHECK_LE(num, population)
      << "Cannot take more sample than population when 'replace=false'";
77
  if (replace) {
78
    for (IdxType i = 0; i < num; ++i) out[i] = RandInt(population);
79
  } else {
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    if (num <
        population / 10) {  // TODO(minjie): may need a better threshold here
      // if set of numbers is small (up to 128) use linear search to verify
      // uniqueness this operation is cheaper for CPU.
      if (num && num < 64) {
        *out = RandInt(population);
        auto b = out + 1;
        auto e = b + num - 1;
        while (b != e) {
          // put the new value at the end
          *b = RandInt(population);
          // Check if a new value doesn't exist in current range(out,b)
          // otherwise get a new value until we haven't unique range of
          // elements.
          auto it = std::find(out, b, *b);
          if (it != b) continue;
          ++b;
        }

      } else {
        // use hash set
        // In the best scenario, time complexity is O(num), i.e., no conflict.
        //
        // Let k be num / population, the expected number of extra sampling
        // steps is roughly k^2 / (1-k) * population, which means in the worst
        // case scenario, the time complexity is O(population^2). In practice,
        // we use 1/10 since std::unordered_set is pretty slow.
        std::unordered_set<IdxType> selected;
108
        while (static_cast<IdxType>(selected.size()) < num) {
109
110
111
          selected.insert(RandInt(population));
        }
        std::copy(selected.begin(), selected.end(), out);
112
      }
113

114
    } else {
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
      // In this case, `num >= population / 10`. To reduce the computation overhead,
      // we should reduce the number of random number generations. Even though
      // reservior algorithm is more memory effficient (it has O(num) memory complexity),
      // it generates O(population) random numbers, which is computationally expensive.
      // This algorithm has memory complexity of O(population) but generates much fewer random
      // numbers O(num). In the case of `num >= population/10`, we don't need to worry about
      // memory complexity because `num` is usually small. So is `population`. Allocating a small
      // piece of memory is very efficient.
      std::vector<IdxType> seq(population);
      for (size_t i = 0; i < seq.size(); i++) seq[i] = i;
      for (IdxType i = 0; i < num; i++) {
        IdxType j = RandInt(i, population);
        std::swap(seq[i], seq[j]);
      }
      // Save the randomly sampled numbers.
      for (IdxType i = 0; i < num; i++) {
        out[i] = seq[i];
132
      }
133
134
135
136
    }
  }
}

137
138
139
140
141
142
template void RandomEngine::UniformChoice<int32_t>(int32_t num,
                                                   int32_t population,
                                                   int32_t* out, bool replace);
template void RandomEngine::UniformChoice<int64_t>(int64_t num,
                                                   int64_t population,
                                                   int64_t* out, bool replace);
143

144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
template <typename IdxType, typename FloatType>
void RandomEngine::BiasedChoice(
    IdxType num, const IdxType *split, FloatArray bias, IdxType* out, bool replace) {
  const int64_t num_tags = bias->shape[0];
  const FloatType *bias_data = static_cast<FloatType *>(bias->data);
  IdxType total_node_num = 0;
  FloatArray prob = NDArray::Empty({num_tags}, bias->dtype, bias->ctx);
  FloatType *prob_data = static_cast<FloatType *>(prob->data);
  for (int64_t tag = 0 ; tag < num_tags; ++tag) {
    int64_t tag_num_nodes = split[tag+1] - split[tag];
    total_node_num += tag_num_nodes;
    FloatType tag_bias = bias_data[tag];
    prob_data[tag] = tag_num_nodes * tag_bias;
  }
  if (replace) {
    auto sampler = utils::TreeSampler<IdxType, FloatType, true>(this, prob);
    for (IdxType i = 0; i < num; ++i) {
      const int64_t tag = sampler.Draw();
      const IdxType tag_num_nodes = split[tag+1] - split[tag];
      out[i] = RandInt(tag_num_nodes) + split[tag];
    }
  } else {
    utils::TreeSampler<int64_t, FloatType, false> sampler(this, prob, bias_data);
    CHECK_GE(total_node_num, num)
        << "Cannot take more sample than population when 'replace=false'";
    // we use hash set here. Maybe in the future we should support reservoir algorithm
    std::vector<std::unordered_set<IdxType>> selected(num_tags);
    for (IdxType i = 0 ; i < num ; ++i) {
      const int64_t tag = sampler.Draw();
      bool inserted = false;
      const IdxType tag_num_nodes = split[tag+1] - split[tag];
      IdxType selected_node;
      while (!inserted) {
        CHECK_LT(selected[tag].size(), tag_num_nodes)
            << "Cannot take more sample than population when 'replace=false'";
        selected_node = RandInt(tag_num_nodes);
        inserted = selected[tag].insert(selected_node).second;
      }
      out[i] = selected_node + split[tag];
    }
  }
}

template void RandomEngine::BiasedChoice<int32_t, float>(
    int32_t, const int32_t*, FloatArray, int32_t*, bool);
template void RandomEngine::BiasedChoice<int32_t, double>(
    int32_t, const int32_t*, FloatArray, int32_t*, bool);
template void RandomEngine::BiasedChoice<int64_t, float>(
    int64_t, const int64_t*, FloatArray, int64_t*, bool);
template void RandomEngine::BiasedChoice<int64_t, double>(
    int64_t, const int64_t*, FloatArray, int64_t*, bool);

196
};  // namespace dgl