sample_utils.h 9.89 KB
Newer Older
1
2
3
4
5
/*!
 *  Copyright (c) 2019 by Contributors
 * \file dgl/sample_utils.h
 * \brief Sampling utilities
 */
6
7
#ifndef DGL_RANDOM_CPU_SAMPLE_UTILS_H_
#define DGL_RANDOM_CPU_SAMPLE_UTILS_H_
8

9
10
#include <dgl/random.h>
#include <dgl/array.h>
11
12
13
14
15
16
17
18
19
20
#include <algorithm>
#include <utility>
#include <queue>
#include <cstdlib>
#include <cmath>
#include <numeric>
#include <limits>
#include <vector>

namespace dgl {
21
namespace utils {
22

23
24
/*! \brief Base sampler class */
template <typename Idx>
25
26
class BaseSampler {
 public:
27
28
  virtual ~BaseSampler() = default;
  /*! \brief Draw one integer sample */
29
  virtual Idx Draw() {
30
31
32
33
34
    LOG(INFO) << "Not implemented yet.";
    return 0;
  }
};

35
36
37
38
39
// (BarclayII 2022.9.20) Changing the internal data type of probabilities to double since
// we are using non-uniform sampling to sample on boolean masks, where False represents
// probability 0.  DType could be uint8 in this case, which will give incorrect arithmetic
// results due to overflowing and/or integer division.

40
41
42
43
44
45
46
47
48
49
50
/*
 * AliasSampler is used to sample elements from a given discrete categorical distribution.
 * Algorithm: Alias Method(https://en.wikipedia.org/wiki/Alias_method)
 * Sampler building complexity: O(n)
 * Sample w/ replacement complexity: O(1)
 * Sample w/o replacement complexity: O(log n)
 */
template <
  typename Idx,
  typename DType,
  bool replace>
51
class AliasSampler: public BaseSampler<Idx> {
52
53
54
 private:
  RandomEngine *re;
  Idx N;
55
  double accum, taken;            // accumulated likelihood
56
  std::vector<Idx> K;             // alias table
57
  std::vector<double> U;          // probability table
58
  FloatArray _prob;               // category distribution
59
60
61
  std::vector<bool> used;         // indicate availability, activated when replace=false;
  std::vector<Idx> id_mapping;    // index mapping, activated when replace=false;

62
  inline Idx Map(Idx x) const {   // Map consecutive indices to unused elements
63
64
65
66
67
68
    if (replace)
      return x;
    else
      return id_mapping[x];
  }

69
70
  void Reconstruct(FloatArray prob) {  // Reconstruct alias table
    const int64_t prob_size = prob->shape[0];
71
    const DType *prob_data = prob.Ptr<DType>();
72
73
74
75
76
    N = 0;
    accum = 0.;
    taken = 0.;
    if (!replace)
      id_mapping.clear();
77
    for (Idx i = 0; i < prob_size; ++i)
78
79
      if (!used[i]) {
        N++;
80
        accum += prob_data[i];
81
82
83
84
85
86
        if (!replace)
          id_mapping.push_back(i);
      }
    if (N == 0) LOG(FATAL) << "Cannot take more sample than population when 'replace=false'";
    K.resize(N);
    U.resize(N);
87
    double avg = accum / static_cast<double>(N);
88
    std::fill(U.begin(), U.end(), avg);     // initialize U
89
    std::queue<std::pair<Idx, double> > under, over;
90
    for (Idx i = 0; i < N; ++i) {
91
      double p = prob_data[Map(i)];
92
93
94
95
96
97
98
99
100
      if (p > avg)
        over.push(std::make_pair(i, p));
      else
        under.push(std::make_pair(i, p));
      K[i] = i;                             // initialize K
    }
    while (!under.empty() && !over.empty()) {
      auto u_pair = under.front(), o_pair = over.front();
      Idx i_u = u_pair.first, i_o = o_pair.first;
101
      double p_u = u_pair.second, p_o = o_pair.second;
102
103
104
105
106
107
108
109
110
111
112
113
      K[i_u] = i_o;
      U[i_u] = p_u;
      if (p_o + p_u > 2 * avg)
        over.push(std::make_pair(i_o, p_o + p_u - avg));
      else if (p_o + p_u < 2 * avg)
        under.push(std::make_pair(i_o, p_o + p_u - avg));
      under.pop();
      over.pop();
    }
  }

 public:
114
115
  void ResetState(FloatArray prob) {
    used.resize(prob->shape[0]);
116
117
118
    if (!replace)
      _prob = prob;
    std::fill(used.begin(), used.end(), false);
119
    Reconstruct(prob);
120
121
  }

122
  explicit AliasSampler(RandomEngine* re, FloatArray prob): re(re) {
123
    ResetState(prob);
124
125
126
127
  }

  ~AliasSampler() {}

128
  Idx Draw() {
129
    if (!replace) {
130
      const DType *_prob_data = _prob.Ptr<DType>();
131
      if (2 * taken >= accum)
132
        Reconstruct(_prob);
133
134
135
136
      if (accum <= 0)
        return -1;
      // accum changes after Reconstruct(), so avg should be computed after that.
      double avg = accum / N;
137
      while (true) {
138
        double dice = re->Uniform<double>(0, N);
139
        Idx i = static_cast<Idx>(dice), rst;
140
141
        double p = (dice - i) * avg;
        if (p <= U[i]) {
142
          rst = Map(i);
143
        } else {
144
          rst = Map(K[i]);
145
        }
146
        double cap = _prob_data[rst];
147
148
149
150
151
152
153
        if (!used[rst]) {
          used[rst] = true;
          taken += cap;
          return rst;
        }
      }
    }
154
155
156
157
    if (accum <= 0)
      return -1;
    double avg = accum / N;
    double dice = re->Uniform<double>(0, N);
158
    Idx i = static_cast<Idx>(dice);
159
160
    double p = (dice - i) * avg;
    if (p <= U[i])
161
        return Map(i);
162
    else
163
        return Map(K[i]);
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
  }
};


/*
 * CDFSampler is used to sample elements from a given discrete categorical distribution.
 * Algorithm: create a cumulative distribution function and conduct binary search for sampling.
 * Reference: https://github.com/numpy/numpy/blob/d37908/numpy/random/mtrand.pyx#L804
 * Sampler building complexity: O(n)
 * Sample w/ and w/o replacement complexity: O(log n)
 */
template <
  typename Idx,
  typename DType,
  bool replace>
179
class CDFSampler: public BaseSampler<Idx> {
180
181
182
 private:
  RandomEngine *re;
  Idx N;
183
  double accum, taken;
184
  FloatArray _prob;             // categorical distribution
185
  std::vector<double> cdf;      // cumulative distribution function
186
187
188
  std::vector<bool> used;       // indicate availability, activated when replace=false;
  std::vector<Idx> id_mapping;  // indicate index mapping, activated when replace=false;

189
  inline Idx Map(Idx x) const {   // Map consecutive indices to unused elements
190
191
192
193
194
195
    if (replace)
      return x;
    else
      return id_mapping[x];
  }

196
197
  void Reconstruct(FloatArray prob) {  // Reconstruct CDF
    int64_t prob_size = prob->shape[0];
198
    const DType *prob_data = prob.Ptr<DType>();
199
200
201
202
203
204
205
    N = 0;
    accum = 0.;
    taken = 0.;
    if (!replace)
      id_mapping.clear();
    cdf.clear();
    cdf.push_back(0);
206
    for (Idx i = 0; i < prob_size; ++i)
207
208
      if (!used[i]) {
        N++;
209
        accum += prob_data[i];
210
211
212
213
214
215
216
217
        if (!replace)
          id_mapping.push_back(i);
        cdf.push_back(accum);
      }
    if (N == 0) LOG(FATAL) << "Cannot take more sample than population when 'replace=false'";
  }

 public:
218
219
  void ResetState(FloatArray prob) {
    used.resize(prob->shape[0]);
220
221
222
    if (!replace)
      _prob = prob;
    std::fill(used.begin(), used.end(), false);
223
    Reconstruct(prob);
224
225
  }

226
  explicit CDFSampler(RandomEngine *re, FloatArray prob): re(re) {
227
    ResetState(prob);
228
229
230
231
  }

  ~CDFSampler() {}

232
  Idx Draw() {
233
    double eps = std::numeric_limits<double>::min();
234
    if (!replace) {
235
      const DType *_prob_data = _prob.Ptr<DType>();
236
      if (2 * taken >= accum)
237
        Reconstruct(_prob);
238
239
      if (accum <= 0)
        return -1;
240
      while (true) {
241
        double p = std::max(re->Uniform<double>(0., accum), eps);
242
        Idx rst = Map(std::lower_bound(cdf.begin(), cdf.end(), p) - cdf.begin() - 1);
243
        double cap = static_cast<double>(_prob_data[rst]);
244
245
246
247
248
249
250
        if (!used[rst]) {
          used[rst] = true;
          taken += cap;
          return rst;
        }
      }
    }
251
252
253
    if (accum <= 0)
      return -1;
    double p = std::max(re->Uniform<double>(0., accum), eps);
254
    return Map(std::lower_bound(cdf.begin(), cdf.end(), p) - cdf.begin() - 1);
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
  }
};


/*
 * TreeSampler is used to sample elements from a given discrete categorical distribution.
 * Algorithm: create a heap that stores accumulated likelihood of its leaf descendents.
 * Reference: https://blog.smola.org/post/1016514759
 * Sampler building complexity: O(n)
 * Sample w/ and w/o replacement complexity: O(log n)
 */
template <
  typename Idx,
  typename DType,
  bool replace>
270
class TreeSampler: public BaseSampler<Idx> {
271
272
 private:
  RandomEngine *re;
273
  std::vector<double> weight;    // accumulated likelihood of subtrees.
274
275
  int64_t N;
  int64_t num_leafs;
276
  const DType *decrease;
277
278

 public:
279
280
  void ResetState(FloatArray prob) {
    int64_t prob_size = prob->shape[0];
281
    const DType *prob_data = prob.Ptr<DType>();
282
    std::fill(weight.begin(), weight.end(), 0);
283
284
285
    for (int64_t i = 0; i < prob_size; ++i)
      weight[num_leafs + i] = prob_data[i];
    for (int64_t i = num_leafs - 1; i >= 1; --i)
286
287
288
      weight[i] = weight[i * 2] + weight[i * 2 + 1];
  }

289
290
  explicit TreeSampler(RandomEngine *re, FloatArray prob, const DType* decrease = nullptr)
    : re(re), decrease(decrease) {
291
    num_leafs = 1;
292
    while (num_leafs < prob->shape[0])
293
294
295
      num_leafs *= 2;
    N = num_leafs * 2;
    weight.resize(N);
296
    ResetState(prob);
297
298
  }

299
300
301
302
303
304
305
306
307
308
309
  /* Pick an element from the given distribution and update the tree.
   *
   * The parameter decrease is an array of which the length is the number of categories.
   * Every time an element in the category x is picked, the weight of this category is subtracted
   * by decrease[x]. It is used to support the case where a category might contains multiple
   * candidates and decrease[x] is the weight of one candidate of the category x.
   *
   * When decrease == nullptr, it means there is only one candidate in each category and will
   * directly set the weight of the chosen category as 0.
   *
   */
310
  Idx Draw() {
311
312
    if (weight[1] <= 0)
      return -1;
313
    int64_t cur = 1;
314
315
    double p = re->Uniform<double>(0, weight[cur]);
    double accum = 0.;
316
    while (cur < num_leafs) {
317
318
319
      double w_l = weight[cur * 2], w_r = weight[cur * 2 + 1];
      double pivot = accum + w_l;
      // w_r > 0 can suppress some numerical problems.
320
321
322
323
324
325
326
327
328
      Idx shift = static_cast<Idx>(p > pivot && w_r > 0);
      cur = cur * 2 + shift;
      if (shift == 1)
        accum = pivot;
    }
    Idx rst = cur - num_leafs;
    if (!replace) {
      while (cur >= 1) {
        if (cur >= num_leafs)
329
330
          weight[cur] = this->decrease ?
            weight[cur] - static_cast<double>(this->decrease[rst]) : 0.;
331
332
333
334
335
336
337
338
339
        else
          weight[cur] = weight[cur * 2] + weight[cur * 2 + 1];
        cur /= 2;
      }
    }
    return rst;
  }
};

340
};  // namespace utils
341
342
};  // namespace dgl

343
#endif  // DGL_RANDOM_CPU_SAMPLE_UTILS_H_