sample_utils.h 9.81 KB
Newer Older
1
/**
2
 *  Copyright (c) 2019 by Contributors
3
4
 * @file dgl/sample_utils.h
 * @brief Sampling utilities
5
 */
6
7
#ifndef DGL_RANDOM_CPU_SAMPLE_UTILS_H_
#define DGL_RANDOM_CPU_SAMPLE_UTILS_H_
8

9
#include <dgl/array.h>
10
11
#include <dgl/random.h>

12
13
#include <algorithm>
#include <cmath>
14
#include <cstdlib>
15
#include <limits>
16
17
18
#include <numeric>
#include <queue>
#include <utility>
19
20
21
#include <vector>

namespace dgl {
22
namespace utils {
23

24
/** @brief Base sampler class */
25
template <typename Idx>
26
27
class BaseSampler {
 public:
28
  virtual ~BaseSampler() = default;
29
  /** @brief Draw one integer sample */
30
  virtual Idx Draw() {
31
32
33
34
35
    LOG(INFO) << "Not implemented yet.";
    return 0;
  }
};

36
37
38
39
40
// (BarclayII 2022.9.20) Changing the internal data type of probabilities to
// double since we are using non-uniform sampling to sample on boolean masks,
// where False represents probability 0.  DType could be uint8 in this case,
// which will give incorrect arithmetic results due to overflowing and/or
// integer division.
41

42
/**
43
44
45
46
47
 * AliasSampler is used to sample elements from a given discrete categorical
 * distribution. Algorithm: Alias
 * Method(https://en.wikipedia.org/wiki/Alias_method) Sampler building
 * complexity: O(n) Sample w/ replacement complexity: O(1) Sample w/o
 * replacement complexity: O(log n)
48
 */
49
50
template <typename Idx, typename DType, bool replace>
class AliasSampler : public BaseSampler<Idx> {
51
52
53
 private:
  RandomEngine *re;
  Idx N;
54
55
56
57
58
59
60
  double accum, taken;    // accumulated likelihood
  std::vector<Idx> K;     // alias table
  std::vector<double> U;  // probability table
  FloatArray _prob;       // category distribution
  std::vector<bool>
      used;  // indicate availability, activated when replace=false;
  std::vector<Idx> id_mapping;  // index mapping, activated when replace=false;
61

62
  inline Idx Map(Idx x) const {  // Map consecutive indices to unused elements
63
64
65
66
67
68
    if (replace)
      return x;
    else
      return id_mapping[x];
  }

69
70
  void Reconstruct(FloatArray prob) {  // Reconstruct alias table
    const int64_t prob_size = prob->shape[0];
71
    const DType *prob_data = prob.Ptr<DType>();
72
73
74
    N = 0;
    accum = 0.;
    taken = 0.;
75
    if (!replace) id_mapping.clear();
76
    for (Idx i = 0; i < prob_size; ++i)
77
78
      if (!used[i]) {
        N++;
79
        accum += prob_data[i];
80
        if (!replace) id_mapping.push_back(i);
81
      }
82
83
84
    if (N == 0)
      LOG(FATAL)
          << "Cannot take more sample than population when 'replace=false'";
85
86
    K.resize(N);
    U.resize(N);
87
    double avg = accum / static_cast<double>(N);
88
    std::fill(U.begin(), U.end(), avg);  // initialize U
89
    std::queue<std::pair<Idx, double> > under, over;
90
    for (Idx i = 0; i < N; ++i) {
91
      double p = prob_data[Map(i)];
92
93
94
95
      if (p > avg)
        over.push(std::make_pair(i, p));
      else
        under.push(std::make_pair(i, p));
96
      K[i] = i;  // initialize K
97
98
99
100
    }
    while (!under.empty() && !over.empty()) {
      auto u_pair = under.front(), o_pair = over.front();
      Idx i_u = u_pair.first, i_o = o_pair.first;
101
      double p_u = u_pair.second, p_o = o_pair.second;
102
103
104
105
106
107
108
109
110
111
112
113
      K[i_u] = i_o;
      U[i_u] = p_u;
      if (p_o + p_u > 2 * avg)
        over.push(std::make_pair(i_o, p_o + p_u - avg));
      else if (p_o + p_u < 2 * avg)
        under.push(std::make_pair(i_o, p_o + p_u - avg));
      under.pop();
      over.pop();
    }
  }

 public:
114
115
  void ResetState(FloatArray prob) {
    used.resize(prob->shape[0]);
116
    if (!replace) _prob = prob;
117
    std::fill(used.begin(), used.end(), false);
118
    Reconstruct(prob);
119
120
  }

121
  explicit AliasSampler(RandomEngine *re, FloatArray prob) : re(re) {
122
    ResetState(prob);
123
124
125
126
  }

  ~AliasSampler() {}

127
  Idx Draw() {
128
    if (!replace) {
129
      const DType *_prob_data = _prob.Ptr<DType>();
130
131
132
133
      if (2 * taken >= accum) Reconstruct(_prob);
      if (accum <= 0) return -1;
      // accum changes after Reconstruct(), so avg should be computed after
      // that.
134
      double avg = accum / N;
135
      while (true) {
136
        double dice = re->Uniform<double>(0, N);
137
        Idx i = static_cast<Idx>(dice), rst;
138
139
        double p = (dice - i) * avg;
        if (p <= U[i]) {
140
          rst = Map(i);
141
        } else {
142
          rst = Map(K[i]);
143
        }
144
        double cap = _prob_data[rst];
145
146
147
148
149
150
151
        if (!used[rst]) {
          used[rst] = true;
          taken += cap;
          return rst;
        }
      }
    }
152
    if (accum <= 0) return -1;
153
154
    double avg = accum / N;
    double dice = re->Uniform<double>(0, N);
155
    Idx i = static_cast<Idx>(dice);
156
157
    double p = (dice - i) * avg;
    if (p <= U[i])
158
      return Map(i);
159
    else
160
      return Map(K[i]);
161
162
163
  }
};

164
/**
165
166
167
168
 * CDFSampler is used to sample elements from a given discrete categorical
 * distribution. Algorithm: create a cumulative distribution function and
 * conduct binary search for sampling. Reference:
 * https://github.com/numpy/numpy/blob/d37908/numpy/random/mtrand.pyx#L804
169
170
171
 * Sampler building complexity: O(n)
 * Sample w/ and w/o replacement complexity: O(log n)
 */
172
173
template <typename Idx, typename DType, bool replace>
class CDFSampler : public BaseSampler<Idx> {
174
175
176
 private:
  RandomEngine *re;
  Idx N;
177
  double accum, taken;
178
179
180
181
182
183
  FloatArray _prob;         // categorical distribution
  std::vector<double> cdf;  // cumulative distribution function
  std::vector<bool>
      used;  // indicate availability, activated when replace=false;
  std::vector<Idx>
      id_mapping;  // indicate index mapping, activated when replace=false;
184

185
  inline Idx Map(Idx x) const {  // Map consecutive indices to unused elements
186
187
188
189
190
191
    if (replace)
      return x;
    else
      return id_mapping[x];
  }

192
193
  void Reconstruct(FloatArray prob) {  // Reconstruct CDF
    int64_t prob_size = prob->shape[0];
194
    const DType *prob_data = prob.Ptr<DType>();
195
196
197
    N = 0;
    accum = 0.;
    taken = 0.;
198
    if (!replace) id_mapping.clear();
199
200
    cdf.clear();
    cdf.push_back(0);
201
    for (Idx i = 0; i < prob_size; ++i)
202
203
      if (!used[i]) {
        N++;
204
        accum += prob_data[i];
205
        if (!replace) id_mapping.push_back(i);
206
207
        cdf.push_back(accum);
      }
208
209
210
    if (N == 0)
      LOG(FATAL)
          << "Cannot take more sample than population when 'replace=false'";
211
212
213
  }

 public:
214
215
  void ResetState(FloatArray prob) {
    used.resize(prob->shape[0]);
216
    if (!replace) _prob = prob;
217
    std::fill(used.begin(), used.end(), false);
218
    Reconstruct(prob);
219
220
  }

221
  explicit CDFSampler(RandomEngine *re, FloatArray prob) : re(re) {
222
    ResetState(prob);
223
224
225
226
  }

  ~CDFSampler() {}

227
  Idx Draw() {
228
    double eps = std::numeric_limits<double>::min();
229
    if (!replace) {
230
      const DType *_prob_data = _prob.Ptr<DType>();
231
232
      if (2 * taken >= accum) Reconstruct(_prob);
      if (accum <= 0) return -1;
233
      while (true) {
234
        double p = std::max(re->Uniform<double>(0., accum), eps);
235
236
        Idx rst =
            Map(std::lower_bound(cdf.begin(), cdf.end(), p) - cdf.begin() - 1);
237
        double cap = static_cast<double>(_prob_data[rst]);
238
239
240
241
242
243
244
        if (!used[rst]) {
          used[rst] = true;
          taken += cap;
          return rst;
        }
      }
    }
245
    if (accum <= 0) return -1;
246
    double p = std::max(re->Uniform<double>(0., accum), eps);
247
    return Map(std::lower_bound(cdf.begin(), cdf.end(), p) - cdf.begin() - 1);
248
249
250
  }
};

251
/**
252
253
254
 * TreeSampler is used to sample elements from a given discrete categorical
 * distribution. Algorithm: create a heap that stores accumulated likelihood of
 * its leaf descendents. Reference: https://blog.smola.org/post/1016514759
255
256
257
 * Sampler building complexity: O(n)
 * Sample w/ and w/o replacement complexity: O(log n)
 */
258
259
template <typename Idx, typename DType, bool replace>
class TreeSampler : public BaseSampler<Idx> {
260
261
 private:
  RandomEngine *re;
262
  std::vector<double> weight;  // accumulated likelihood of subtrees.
263
264
  int64_t N;
  int64_t num_leafs;
265
  const DType *decrease;
266
267

 public:
268
269
  void ResetState(FloatArray prob) {
    int64_t prob_size = prob->shape[0];
270
    const DType *prob_data = prob.Ptr<DType>();
271
    std::fill(weight.begin(), weight.end(), 0);
272
273
274
    for (int64_t i = 0; i < prob_size; ++i)
      weight[num_leafs + i] = prob_data[i];
    for (int64_t i = num_leafs - 1; i >= 1; --i)
275
276
277
      weight[i] = weight[i * 2] + weight[i * 2 + 1];
  }

278
279
280
  explicit TreeSampler(
      RandomEngine *re, FloatArray prob, const DType *decrease = nullptr)
      : re(re), decrease(decrease) {
281
    num_leafs = 1;
282
    while (num_leafs < prob->shape[0]) num_leafs *= 2;
283
284
    N = num_leafs * 2;
    weight.resize(N);
285
    ResetState(prob);
286
287
  }

288
289
  /* Pick an element from the given distribution and update the tree.
   *
290
291
292
293
294
   * The parameter decrease is an array of which the length is the number of
   * categories. Every time an element in the category x is picked, the weight
   * of this category is subtracted by decrease[x]. It is used to support the
   * case where a category might contains multiple candidates and decrease[x] is
   * the weight of one candidate of the category x.
295
   *
296
297
   * When decrease == nullptr, it means there is only one candidate in each
   * category and will directly set the weight of the chosen category as 0.
298
299
   *
   */
300
  Idx Draw() {
301
    if (weight[1] <= 0) return -1;
302
    int64_t cur = 1;
303
304
    double p = re->Uniform<double>(0, weight[cur]);
    double accum = 0.;
305
    while (cur < num_leafs) {
306
307
308
      double w_l = weight[cur * 2], w_r = weight[cur * 2 + 1];
      double pivot = accum + w_l;
      // w_r > 0 can suppress some numerical problems.
309
310
      Idx shift = static_cast<Idx>(p > pivot && w_r > 0);
      cur = cur * 2 + shift;
311
      if (shift == 1) accum = pivot;
312
313
314
315
316
    }
    Idx rst = cur - num_leafs;
    if (!replace) {
      while (cur >= 1) {
        if (cur >= num_leafs)
317
318
319
320
          weight[cur] =
              this->decrease
                  ? weight[cur] - static_cast<double>(this->decrease[rst])
                  : 0.;
321
322
323
324
325
326
327
328
329
        else
          weight[cur] = weight[cur * 2] + weight[cur * 2 + 1];
        cur /= 2;
      }
    }
    return rst;
  }
};

330
};  // namespace utils
331
332
};  // namespace dgl

333
#endif  // DGL_RANDOM_CPU_SAMPLE_UTILS_H_