sample_utils.h 8.53 KB
Newer Older
1
2
3
4
5
/*!
 *  Copyright (c) 2019 by Contributors
 * \file dgl/sample_utils.h
 * \brief Sampling utilities
 */
6
7
#ifndef DGL_RANDOM_CPU_SAMPLE_UTILS_H_
#define DGL_RANDOM_CPU_SAMPLE_UTILS_H_
8

9
10
#include <dgl/random.h>
#include <dgl/array.h>
11
12
13
14
15
16
17
18
19
20
#include <algorithm>
#include <utility>
#include <queue>
#include <cstdlib>
#include <cmath>
#include <numeric>
#include <limits>
#include <vector>

namespace dgl {
21
namespace utils {
22
23
24
25
26
27
28

template <
  typename Idx,
  typename DType,
  bool replace>
class BaseSampler {
 public:
29
  virtual Idx Draw() {
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    LOG(INFO) << "Not implemented yet.";
    return 0;
  }
};

/*
 * AliasSampler is used to sample elements from a given discrete categorical distribution.
 * Algorithm: Alias Method(https://en.wikipedia.org/wiki/Alias_method)
 * Sampler building complexity: O(n)
 * Sample w/ replacement complexity: O(1)
 * Sample w/o replacement complexity: O(log n)
 */
template <
  typename Idx,
  typename DType,
  bool replace>
class AliasSampler: public BaseSampler<Idx, DType, replace> {
 private:
  RandomEngine *re;
  Idx N;
  DType accum, taken;             // accumulated likelihood
  std::vector<Idx> K;             // alias table
  std::vector<DType> U;           // probability table
53
  FloatArray _prob;               // category distribution
54
55
56
  std::vector<bool> used;         // indicate availability, activated when replace=false;
  std::vector<Idx> id_mapping;    // index mapping, activated when replace=false;

57
  inline Idx Map(Idx x) const {   // Map consecutive indices to unused elements
58
59
60
61
62
63
    if (replace)
      return x;
    else
      return id_mapping[x];
  }

64
65
66
  void Reconstruct(FloatArray prob) {  // Reconstruct alias table
    const int64_t prob_size = prob->shape[0];
    const DType *prob_data = static_cast<DType *>(prob->data);
67
68
69
70
71
    N = 0;
    accum = 0.;
    taken = 0.;
    if (!replace)
      id_mapping.clear();
72
    for (Idx i = 0; i < prob_size; ++i)
73
74
      if (!used[i]) {
        N++;
75
        accum += prob_data[i];
76
77
78
79
80
81
82
83
84
85
        if (!replace)
          id_mapping.push_back(i);
      }
    if (N == 0) LOG(FATAL) << "Cannot take more sample than population when 'replace=false'";
    K.resize(N);
    U.resize(N);
    DType avg = accum / static_cast<DType>(N);
    std::fill(U.begin(), U.end(), avg);     // initialize U
    std::queue<std::pair<Idx, DType> > under, over;
    for (Idx i = 0; i < N; ++i) {
86
      DType p = prob_data[Map(i)];
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
      if (p > avg)
        over.push(std::make_pair(i, p));
      else
        under.push(std::make_pair(i, p));
      K[i] = i;                             // initialize K
    }
    while (!under.empty() && !over.empty()) {
      auto u_pair = under.front(), o_pair = over.front();
      Idx i_u = u_pair.first, i_o = o_pair.first;
      DType p_u = u_pair.second, p_o = o_pair.second;
      K[i_u] = i_o;
      U[i_u] = p_u;
      if (p_o + p_u > 2 * avg)
        over.push(std::make_pair(i_o, p_o + p_u - avg));
      else if (p_o + p_u < 2 * avg)
        under.push(std::make_pair(i_o, p_o + p_u - avg));
      under.pop();
      over.pop();
    }
  }

 public:
109
110
  void ResetState(FloatArray prob) {
    used.resize(prob->shape[0]);
111
112
113
    if (!replace)
      _prob = prob;
    std::fill(used.begin(), used.end(), false);
114
    Reconstruct(prob);
115
116
  }

117
  explicit AliasSampler(RandomEngine* re, FloatArray prob): re(re) {
118
    ResetState(prob);
119
120
121
122
  }

  ~AliasSampler() {}

123
  Idx Draw() {
124
    const DType *_prob_data = static_cast<DType *>(_prob->data);
125
126
127
    DType avg = accum / N;
    if (!replace) {
      if (2 * taken >= accum)
128
        Reconstruct(_prob);
129
130
131
132
      while (true) {
        DType dice = re->Uniform<DType>(0, N);
        Idx i = static_cast<Idx>(dice), rst;
        DType p = (dice - i) * avg;
133
134
        if (p <= U[Map(i)]) {
          rst = Map(i);
135
        } else {
136
          rst = Map(K[i]);
137
        }
138
        DType cap = _prob_data[rst];
139
140
141
142
143
144
145
146
147
148
        if (!used[rst]) {
          used[rst] = true;
          taken += cap;
          return rst;
        }
      }
    }
    DType dice = re->Uniform<DType>(0, N);
    Idx i = static_cast<Idx>(dice);
    DType p = (dice - i) * avg;
149
150
    if (p <= U[Map(i)])
        return Map(i);
151
    else
152
        return Map(K[i]);
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
  }
};


/*
 * CDFSampler is used to sample elements from a given discrete categorical distribution.
 * Algorithm: create a cumulative distribution function and conduct binary search for sampling.
 * Reference: https://github.com/numpy/numpy/blob/d37908/numpy/random/mtrand.pyx#L804
 * Sampler building complexity: O(n)
 * Sample w/ and w/o replacement complexity: O(log n)
 */
template <
  typename Idx,
  typename DType,
  bool replace>
class CDFSampler: public BaseSampler<Idx, DType, replace> {
 private:
  RandomEngine *re;
  Idx N;
  DType accum, taken;
173
  FloatArray _prob;             // categorical distribution
174
175
176
177
  std::vector<DType> cdf;       // cumulative distribution function
  std::vector<bool> used;       // indicate availability, activated when replace=false;
  std::vector<Idx> id_mapping;  // indicate index mapping, activated when replace=false;

178
  inline Idx Map(Idx x) const {   // Map consecutive indices to unused elements
179
180
181
182
183
184
    if (replace)
      return x;
    else
      return id_mapping[x];
  }

185
186
187
  void Reconstruct(FloatArray prob) {  // Reconstruct CDF
    int64_t prob_size = prob->shape[0];
    const DType *prob_data = static_cast<DType *>(prob->data);
188
189
190
191
192
193
194
    N = 0;
    accum = 0.;
    taken = 0.;
    if (!replace)
      id_mapping.clear();
    cdf.clear();
    cdf.push_back(0);
195
    for (Idx i = 0; i < prob_size; ++i)
196
197
      if (!used[i]) {
        N++;
198
        accum += prob_data[i];
199
200
201
202
203
204
205
206
        if (!replace)
          id_mapping.push_back(i);
        cdf.push_back(accum);
      }
    if (N == 0) LOG(FATAL) << "Cannot take more sample than population when 'replace=false'";
  }

 public:
207
208
  void ResetState(FloatArray prob) {
    used.resize(prob->shape[0]);
209
210
211
    if (!replace)
      _prob = prob;
    std::fill(used.begin(), used.end(), false);
212
    Reconstruct(prob);
213
214
  }

215
  explicit CDFSampler(RandomEngine *re, FloatArray prob): re(re) {
216
    ResetState(prob);
217
218
219
220
  }

  ~CDFSampler() {}

221
  Idx Draw() {
222
    const DType *_prob_data = static_cast<DType *>(_prob->data);
223
224
225
    DType eps = std::numeric_limits<DType>::min();
    if (!replace) {
      if (2 * taken >= accum)
226
        Reconstruct(_prob);
227
228
      while (true) {
        DType p = std::max(re->Uniform<DType>(0., accum), eps);
229
        Idx rst = Map(std::lower_bound(cdf.begin(), cdf.end(), p) - cdf.begin() - 1);
230
        DType cap = _prob_data[rst];
231
232
233
234
235
236
237
238
        if (!used[rst]) {
          used[rst] = true;
          taken += cap;
          return rst;
        }
      }
    }
    DType p = std::max(re->Uniform<DType>(0., accum), eps);
239
    return Map(std::lower_bound(cdf.begin(), cdf.end(), p) - cdf.begin() - 1);
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
  }
};


/*
 * TreeSampler is used to sample elements from a given discrete categorical distribution.
 * Algorithm: create a heap that stores accumulated likelihood of its leaf descendents.
 * Reference: https://blog.smola.org/post/1016514759
 * Sampler building complexity: O(n)
 * Sample w/ and w/o replacement complexity: O(log n)
 */
template <
  typename Idx,
  typename DType,
  bool replace>
class TreeSampler: public BaseSampler<Idx, DType, replace> {
 private:
  RandomEngine *re;
  std::vector<DType> weight;    // accumulated likelihood of subtrees.
259
260
  int64_t N;
  int64_t num_leafs;
261
262

 public:
263
264
265
  void ResetState(FloatArray prob) {
    int64_t prob_size = prob->shape[0];
    const DType *prob_data = static_cast<DType *>(prob->data);
266
    std::fill(weight.begin(), weight.end(), 0);
267
268
269
    for (int64_t i = 0; i < prob_size; ++i)
      weight[num_leafs + i] = prob_data[i];
    for (int64_t i = num_leafs - 1; i >= 1; --i)
270
271
272
      weight[i] = weight[i * 2] + weight[i * 2 + 1];
  }

273
  explicit TreeSampler(RandomEngine *re, FloatArray prob): re(re) {
274
    num_leafs = 1;
275
    while (num_leafs < prob->shape[0])
276
277
278
      num_leafs *= 2;
    N = num_leafs * 2;
    weight.resize(N);
279
    ResetState(prob);
280
281
  }

282
  Idx Draw() {
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
    int64_t cur = 1;
    DType p = re->Uniform<DType>(0, weight[cur]);
    DType accum = 0.;
    while (cur < num_leafs) {
      DType w_l = weight[cur * 2], w_r = weight[cur * 2 + 1];
      DType pivot = accum + w_l;
      // w_r > 0 can depress some numerical problems.
      Idx shift = static_cast<Idx>(p > pivot && w_r > 0);
      cur = cur * 2 + shift;
      if (shift == 1)
        accum = pivot;
    }
    Idx rst = cur - num_leafs;
    if (!replace) {
      while (cur >= 1) {
        if (cur >= num_leafs)
          weight[cur] = 0.;
        else
          weight[cur] = weight[cur * 2] + weight[cur * 2 + 1];
        cur /= 2;
      }
    }
    return rst;
  }
};

309
};  // namespace utils
310
311
};  // namespace dgl

312
#endif  // DGL_RANDOM_CPU_SAMPLE_UTILS_H_