choice.cc 6.79 KB
Newer Older
1
2
3
4
5
6
/*!
 *  Copyright (c) 2019 by Contributors
 * \file random/choice.cc
 * \brief Non-uniform discrete sampling implementation
 */

7
#include <dgl/array.h>
8
#include <dgl/random.h>
9
#include <numeric>
10
#include <vector>
11
12
13
14
#include "sample_utils.h"

namespace dgl {

15
template <typename IdxType>
16
IdxType RandomEngine::Choice(FloatArray prob) {
17
  IdxType ret = 0;
18
  ATEN_FLOAT_TYPE_SWITCH(prob->dtype, ValueType, "probability", {
19
    // TODO(minjie): allow choosing different sampling algorithms
20
21
22
23
24
25
26
27
28
    utils::TreeSampler<IdxType, ValueType, true> sampler(this, prob);
    ret = sampler.Draw();
  });
  return ret;
}

template int32_t RandomEngine::Choice<int32_t>(FloatArray);
template int64_t RandomEngine::Choice<int64_t>(FloatArray);

29
30
31
template <typename IdxType, typename FloatType>
void RandomEngine::Choice(IdxType num, FloatArray prob, IdxType* out,
                          bool replace) {
32
  const IdxType N = prob->shape[0];
33
  if (!replace)
34
35
36
    CHECK_LE(num, N)
      << "Cannot take more sample than population when 'replace=false'";
  if (num == N && !replace) std::iota(out, out + num, 0);
37
38
39
40
41
42
43

  utils::BaseSampler<IdxType>* sampler = nullptr;
  if (replace) {
    sampler = new utils::TreeSampler<IdxType, FloatType, true>(this, prob);
  } else {
    sampler = new utils::TreeSampler<IdxType, FloatType, false>(this, prob);
  }
44
  for (IdxType i = 0; i < num; ++i) out[i] = sampler->Draw();
45
46
47
  delete sampler;
}

48
49
50
51
52
53
54
55
56
57
template void RandomEngine::Choice<int32_t, float>(int32_t num, FloatArray prob,
                                                   int32_t* out, bool replace);
template void RandomEngine::Choice<int64_t, float>(int64_t num, FloatArray prob,
                                                   int64_t* out, bool replace);
template void RandomEngine::Choice<int32_t, double>(int32_t num,
                                                    FloatArray prob,
                                                    int32_t* out, bool replace);
template void RandomEngine::Choice<int64_t, double>(int64_t num,
                                                    FloatArray prob,
                                                    int64_t* out, bool replace);
58
59

template <typename IdxType>
60
61
void RandomEngine::UniformChoice(IdxType num, IdxType population, IdxType* out,
                                 bool replace) {
62
  if (!replace)
63
64
    CHECK_LE(num, population)
      << "Cannot take more sample than population when 'replace=false'";
65
  if (replace) {
66
    for (IdxType i = 0; i < num; ++i) out[i] = RandInt(population);
67
  } else {
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    if (num <
        population / 10) {  // TODO(minjie): may need a better threshold here
      // if set of numbers is small (up to 128) use linear search to verify
      // uniqueness this operation is cheaper for CPU.
      if (num && num < 64) {
        *out = RandInt(population);
        auto b = out + 1;
        auto e = b + num - 1;
        while (b != e) {
          // put the new value at the end
          *b = RandInt(population);
          // Check if a new value doesn't exist in current range(out,b)
          // otherwise get a new value until we haven't unique range of
          // elements.
          auto it = std::find(out, b, *b);
          if (it != b) continue;
          ++b;
        }

      } else {
        // use hash set
        // In the best scenario, time complexity is O(num), i.e., no conflict.
        //
        // Let k be num / population, the expected number of extra sampling
        // steps is roughly k^2 / (1-k) * population, which means in the worst
        // case scenario, the time complexity is O(population^2). In practice,
        // we use 1/10 since std::unordered_set is pretty slow.
        std::unordered_set<IdxType> selected;
        while (selected.size() < num) {
          selected.insert(RandInt(population));
        }
        std::copy(selected.begin(), selected.end(), out);
100
      }
101

102
103
104
    } else {
      // reservoir algorithm
      // time: O(population), space: O(num)
105
      for (IdxType i = 0; i < num; ++i) out[i] = i;
106
      for (IdxType i = num; i < population; ++i) {
107
        const IdxType j = RandInt(i + 1);
108
        if (j < num) out[j] = i;
109
      }
110
111
112
113
    }
  }
}

114
115
116
117
118
119
template void RandomEngine::UniformChoice<int32_t>(int32_t num,
                                                   int32_t population,
                                                   int32_t* out, bool replace);
template void RandomEngine::UniformChoice<int64_t>(int64_t num,
                                                   int64_t population,
                                                   int64_t* out, bool replace);
120

121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
template <typename IdxType, typename FloatType>
void RandomEngine::BiasedChoice(
    IdxType num, const IdxType *split, FloatArray bias, IdxType* out, bool replace) {
  const int64_t num_tags = bias->shape[0];
  const FloatType *bias_data = static_cast<FloatType *>(bias->data);
  IdxType total_node_num = 0;
  FloatArray prob = NDArray::Empty({num_tags}, bias->dtype, bias->ctx);
  FloatType *prob_data = static_cast<FloatType *>(prob->data);
  for (int64_t tag = 0 ; tag < num_tags; ++tag) {
    int64_t tag_num_nodes = split[tag+1] - split[tag];
    total_node_num += tag_num_nodes;
    FloatType tag_bias = bias_data[tag];
    prob_data[tag] = tag_num_nodes * tag_bias;
  }
  if (replace) {
    auto sampler = utils::TreeSampler<IdxType, FloatType, true>(this, prob);
    for (IdxType i = 0; i < num; ++i) {
      const int64_t tag = sampler.Draw();
      const IdxType tag_num_nodes = split[tag+1] - split[tag];
      out[i] = RandInt(tag_num_nodes) + split[tag];
    }
  } else {
    utils::TreeSampler<int64_t, FloatType, false> sampler(this, prob, bias_data);
    CHECK_GE(total_node_num, num)
        << "Cannot take more sample than population when 'replace=false'";
    // we use hash set here. Maybe in the future we should support reservoir algorithm
    std::vector<std::unordered_set<IdxType>> selected(num_tags);
    for (IdxType i = 0 ; i < num ; ++i) {
      const int64_t tag = sampler.Draw();
      bool inserted = false;
      const IdxType tag_num_nodes = split[tag+1] - split[tag];
      IdxType selected_node;
      while (!inserted) {
        CHECK_LT(selected[tag].size(), tag_num_nodes)
            << "Cannot take more sample than population when 'replace=false'";
        selected_node = RandInt(tag_num_nodes);
        inserted = selected[tag].insert(selected_node).second;
      }
      out[i] = selected_node + split[tag];
    }
  }
}

template void RandomEngine::BiasedChoice<int32_t, float>(
    int32_t, const int32_t*, FloatArray, int32_t*, bool);
template void RandomEngine::BiasedChoice<int32_t, double>(
    int32_t, const int32_t*, FloatArray, int32_t*, bool);
template void RandomEngine::BiasedChoice<int64_t, float>(
    int64_t, const int64_t*, FloatArray, int64_t*, bool);
template void RandomEngine::BiasedChoice<int64_t, double>(
    int64_t, const int64_t*, FloatArray, int64_t*, bool);

173
};  // namespace dgl