"docs/git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "9a5341bde3400c426385ba60b697ff95f54f89c5"
rowwise_sampling.cc 11.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cpu/rowwise_sampling.cc
 * \brief rowwise sampling
 */
#include <dgl/random.h>
#include <numeric>
#include "./rowwise_pick.h"

namespace dgl {
namespace aten {
namespace impl {
namespace {
// Equivalent to numpy expression: array[idx[off:off + len]]
template <typename IdxType, typename FloatType>
inline FloatArray DoubleSlice(FloatArray array, const IdxType* idx_data,
                              IdxType off, IdxType len) {
  const FloatType* array_data = static_cast<FloatType*>(array->data);
  FloatArray ret = FloatArray::Empty({len}, array->dtype, array->ctx);
  FloatType* ret_data = static_cast<FloatType*>(ret->data);
  for (int64_t j = 0; j < len; ++j) {
    if (idx_data)
      ret_data[j] = array_data[idx_data[off + j]];
    else
      ret_data[j] = array_data[off + j];
  }
  return ret;
}

template <typename IdxType, typename FloatType>
inline PickFn<IdxType> GetSamplingPickFn(
    int64_t num_samples, FloatArray prob, bool replace) {
  PickFn<IdxType> pick_fn = [prob, num_samples, replace]
    (IdxType rowid, IdxType off, IdxType len,
     const IdxType* col, const IdxType* data,
     IdxType* out_idx) {
      FloatArray prob_selected = DoubleSlice<IdxType, FloatType>(prob, data, off, len);
38
39
      RandomEngine::ThreadLocal()->Choice<IdxType, FloatType>(
          num_samples, prob_selected, out_idx, replace);
40
      for (int64_t j = 0; j < num_samples; ++j) {
41
        out_idx[j] += off;
42
43
44
45
46
      }
    };
  return pick_fn;
}

47
48
template <typename IdxType, typename FloatType>
inline RangePickFn<IdxType> GetSamplingRangePickFn(
49
    const std::vector<int64_t>& num_samples, FloatArray prob, bool replace) {
50
  RangePickFn<IdxType> pick_fn = [prob, num_samples, replace]
51
    (IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,
52
53
54
55
56
57
58
59
60
61
62
63
64
    const std::vector<IdxType> &et_idx,
    const IdxType* data, IdxType* out_idx) {
      const FloatType* p_data = static_cast<FloatType*>(prob->data);
      FloatArray probs = FloatArray::Empty({et_len}, prob->dtype, prob->ctx);
      FloatType* probs_data = static_cast<FloatType*>(probs->data);
      for (int64_t j = 0; j < et_len; ++j) {
        if (data)
          probs_data[j] = p_data[data[off+et_idx[et_offset+j]]];
        else
          probs_data[j] = p_data[off+et_idx[et_offset+j]];
      }

      RandomEngine::ThreadLocal()->Choice<IdxType, FloatType>(
65
          num_samples[cur_et], probs, out_idx, replace);
66
67
68
69
    };
  return pick_fn;
}

70
71
72
73
74
75
76
template <typename IdxType>
inline PickFn<IdxType> GetSamplingUniformPickFn(
    int64_t num_samples, bool replace) {
  PickFn<IdxType> pick_fn = [num_samples, replace]
    (IdxType rowid, IdxType off, IdxType len,
     const IdxType* col, const IdxType* data,
     IdxType* out_idx) {
77
78
      RandomEngine::ThreadLocal()->UniformChoice<IdxType>(
          num_samples, len, out_idx, replace);
79
      for (int64_t j = 0; j < num_samples; ++j) {
80
        out_idx[j] += off;
81
82
83
84
      }
    };
  return pick_fn;
}
85

86
87
template <typename IdxType>
inline RangePickFn<IdxType> GetSamplingUniformRangePickFn(
88
    const std::vector<int64_t>& num_samples, bool replace) {
89
  RangePickFn<IdxType> pick_fn = [num_samples, replace]
90
    (IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,
91
92
93
    const std::vector<IdxType> &et_idx,
    const IdxType* data, IdxType* out_idx) {
      RandomEngine::ThreadLocal()->UniformChoice<IdxType>(
94
          num_samples[cur_et], et_len, out_idx, replace);
95
96
97
98
    };
  return pick_fn;
}

99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
template <typename IdxType, typename FloatType>
inline PickFn<IdxType> GetSamplingBiasedPickFn(
    int64_t num_samples, IdArray split, FloatArray bias, bool replace) {
  PickFn<IdxType> pick_fn = [num_samples, split, bias, replace]
    (IdxType rowid, IdxType off, IdxType len,
     const IdxType* col, const IdxType* data,
     IdxType* out_idx) {
    const IdxType *tag_offset = static_cast<IdxType *>(split->data) + rowid * split->shape[1];
    RandomEngine::ThreadLocal()->BiasedChoice<IdxType, FloatType>(
            num_samples, tag_offset, bias, out_idx, replace);
    for (int64_t j = 0; j < num_samples; ++j) {
      out_idx[j] += off;
    }
  };
  return pick_fn;
}

116
117
118
119
}  // namespace

/////////////////////////////// CSR ///////////////////////////////

120
template <DGLDeviceType XPU, typename IdxType, typename FloatType>
121
122
COOMatrix CSRRowWiseSampling(CSRMatrix mat, IdArray rows, int64_t num_samples,
                             FloatArray prob, bool replace) {
123
  CHECK(prob.defined());
124
125
126
127
  auto pick_fn = GetSamplingPickFn<IdxType, FloatType>(num_samples, prob, replace);
  return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn);
}

128
template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, float>(
129
    CSRMatrix, IdArray, int64_t, FloatArray, bool);
130
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, float>(
131
    CSRMatrix, IdArray, int64_t, FloatArray, bool);
132
template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, double>(
133
    CSRMatrix, IdArray, int64_t, FloatArray, bool);
134
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, double>(
135
136
    CSRMatrix, IdArray, int64_t, FloatArray, bool);

137
template <DGLDeviceType XPU, typename IdxType, typename FloatType>
138
COOMatrix CSRRowWisePerEtypeSampling(CSRMatrix mat, IdArray rows, IdArray etypes,
139
140
                                     const std::vector<int64_t>& num_samples,
                                     FloatArray prob, bool replace, bool etype_sorted) {
141
142
  CHECK(prob.defined());
  auto pick_fn = GetSamplingRangePickFn<IdxType, FloatType>(num_samples, prob, replace);
143
  return CSRRowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
144
145
}

146
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, float>(
147
    CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
148
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, float>(
149
    CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
150
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, double>(
151
    CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
152
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, double>(
153
    CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
154

155
template <DGLDeviceType XPU, typename IdxType>
156
157
158
159
160
161
COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat, IdArray rows,
                                    int64_t num_samples, bool replace) {
  auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
  return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn);
}

162
template COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int32_t>(
163
    CSRMatrix, IdArray, int64_t, bool);
164
template COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int64_t>(
165
166
    CSRMatrix, IdArray, int64_t, bool);

167
template <DGLDeviceType XPU, typename IdxType>
168
COOMatrix CSRRowWisePerEtypeSamplingUniform(CSRMatrix mat, IdArray rows, IdArray etypes,
169
                                            const std::vector<int64_t>& num_samples,
170
                                            bool replace, bool etype_sorted) {
171
  auto pick_fn = GetSamplingUniformRangePickFn<IdxType>(num_samples, replace);
172
  return CSRRowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
173
174
}

175
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDGLCPU, int32_t>(
176
    CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool);
177
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDGLCPU, int64_t>(
178
    CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool);
179

180
template <DGLDeviceType XPU, typename IdxType, typename FloatType>
181
182
183
184
185
186
187
188
189
190
191
192
193
COOMatrix CSRRowWiseSamplingBiased(
    CSRMatrix mat,
    IdArray rows,
    int64_t num_samples,
    NDArray tag_offset,
    FloatArray bias,
    bool replace
) {
  auto pick_fn = GetSamplingBiasedPickFn<IdxType, FloatType>(
      num_samples, tag_offset, bias, replace);
  return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn);
}

194
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int32_t, float>(
195
196
  CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);

197
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int64_t, float>(
198
199
  CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);

200
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int32_t, double>(
201
202
  CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);

203
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int64_t, double>(
204
205
206
  CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);


207
208
/////////////////////////////// COO ///////////////////////////////

209
template <DGLDeviceType XPU, typename IdxType, typename FloatType>
210
211
COOMatrix COORowWiseSampling(COOMatrix mat, IdArray rows, int64_t num_samples,
                             FloatArray prob, bool replace) {
212
  CHECK(prob.defined());
213
214
215
216
  auto pick_fn = GetSamplingPickFn<IdxType, FloatType>(num_samples, prob, replace);
  return COORowWisePick(mat, rows, num_samples, replace, pick_fn);
}

217
template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, float>(
218
    COOMatrix, IdArray, int64_t, FloatArray, bool);
219
template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, float>(
220
    COOMatrix, IdArray, int64_t, FloatArray, bool);
221
template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, double>(
222
    COOMatrix, IdArray, int64_t, FloatArray, bool);
223
template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, double>(
224
225
    COOMatrix, IdArray, int64_t, FloatArray, bool);

226
template <DGLDeviceType XPU, typename IdxType, typename FloatType>
227
COOMatrix COORowWisePerEtypeSampling(COOMatrix mat, IdArray rows, IdArray etypes,
228
229
                                     const std::vector<int64_t>& num_samples,
                                     FloatArray prob, bool replace, bool etype_sorted) {
230
231
  CHECK(prob.defined());
  auto pick_fn = GetSamplingRangePickFn<IdxType, FloatType>(num_samples, prob, replace);
232
  return COORowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
233
234
}

235
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, float>(
236
    COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
237
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, float>(
238
    COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
239
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, double>(
240
    COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
241
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, double>(
242
    COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
243

244
template <DGLDeviceType XPU, typename IdxType>
245
246
247
248
249
250
COOMatrix COORowWiseSamplingUniform(COOMatrix mat, IdArray rows,
                                    int64_t num_samples, bool replace) {
  auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
  return COORowWisePick(mat, rows, num_samples, replace, pick_fn);
}

251
template COOMatrix COORowWiseSamplingUniform<kDGLCPU, int32_t>(
252
    COOMatrix, IdArray, int64_t, bool);
253
template COOMatrix COORowWiseSamplingUniform<kDGLCPU, int64_t>(
254
255
    COOMatrix, IdArray, int64_t, bool);

256
template <DGLDeviceType XPU, typename IdxType>
257
COOMatrix COORowWisePerEtypeSamplingUniform(COOMatrix mat, IdArray rows, IdArray etypes,
258
259
                                    const std::vector<int64_t>& num_samples,
                                    bool replace, bool etype_sorted) {
260
  auto pick_fn = GetSamplingUniformRangePickFn<IdxType>(num_samples, replace);
261
  return COORowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
262
263
}

264
template COOMatrix COORowWisePerEtypeSamplingUniform<kDGLCPU, int32_t>(
265
    COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool);
266
template COOMatrix COORowWisePerEtypeSamplingUniform<kDGLCPU, int64_t>(
267
    COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool);
268

269
270
271
}  // namespace impl
}  // namespace aten
}  // namespace dgl