rowwise_sampling.cc 16.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cpu/rowwise_sampling.cc
 * \brief rowwise sampling
 */
#include <dgl/random.h>
#include <numeric>
#include "./rowwise_pick.h"

namespace dgl {
namespace aten {
namespace impl {
namespace {
// Equivalent to numpy expression: array[idx[off:off + len]]
template <typename IdxType, typename FloatType>
inline FloatArray DoubleSlice(FloatArray array, const IdxType* idx_data,
                              IdxType off, IdxType len) {
  const FloatType* array_data = static_cast<FloatType*>(array->data);
  FloatArray ret = FloatArray::Empty({len}, array->dtype, array->ctx);
  FloatType* ret_data = static_cast<FloatType*>(ret->data);
  for (int64_t j = 0; j < len; ++j) {
    if (idx_data)
      ret_data[j] = array_data[idx_data[off + j]];
    else
      ret_data[j] = array_data[off + j];
  }
  return ret;
}

30
31
32
33
template <typename IdxType, typename DType>
inline NumPicksFn<IdxType> GetSamplingNumPicksFn(
    int64_t num_samples, NDArray prob_or_mask, bool replace) {
  NumPicksFn<IdxType> num_picks_fn = [prob_or_mask, num_samples, replace]
34
    (IdxType rowid, IdxType off, IdxType len,
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
     const IdxType* col, const IdxType* data) {
      const int64_t max_num_picks = (num_samples == -1) ? len : num_samples;
      const DType* prob_or_mask_data = prob_or_mask.Ptr<DType>();
      IdxType nnz = 0;
      for (IdxType i = off; i < off + len; ++i) {
        if (prob_or_mask_data[i] > 0) {
          ++nnz;
        }
      }

      if (replace) {
        return static_cast<IdxType>(nnz == 0 ? 0 : max_num_picks);
      } else {
        return std::min(static_cast<IdxType>(max_num_picks), nnz);
      }
    };
  return num_picks_fn;
}

template <typename IdxType, typename DType>
inline PickFn<IdxType> GetSamplingPickFn(
    int64_t num_samples, NDArray prob_or_mask, bool replace) {
  PickFn<IdxType> pick_fn = [prob_or_mask, num_samples, replace]
    (IdxType rowid, IdxType off, IdxType len, IdxType num_picks,
59
60
     const IdxType* col, const IdxType* data,
     IdxType* out_idx) {
61
62
63
64
      NDArray prob_or_mask_selected = DoubleSlice<IdxType, DType>(prob_or_mask, data, off, len);
      RandomEngine::ThreadLocal()->Choice<IdxType, DType>(
          num_picks, prob_or_mask_selected, out_idx, replace);
      for (int64_t j = 0; j < num_picks; ++j) {
65
        out_idx[j] += off;
66
67
68
69
70
      }
    };
  return pick_fn;
}

71
72
template <typename IdxType, typename FloatType>
inline RangePickFn<IdxType> GetSamplingRangePickFn(
73
    const std::vector<int64_t>& num_samples, FloatArray prob, bool replace) {
74
  RangePickFn<IdxType> pick_fn = [prob, num_samples, replace]
75
    (IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,
76
77
78
79
80
81
82
83
84
85
86
87
88
    const std::vector<IdxType> &et_idx,
    const IdxType* data, IdxType* out_idx) {
      const FloatType* p_data = static_cast<FloatType*>(prob->data);
      FloatArray probs = FloatArray::Empty({et_len}, prob->dtype, prob->ctx);
      FloatType* probs_data = static_cast<FloatType*>(probs->data);
      for (int64_t j = 0; j < et_len; ++j) {
        if (data)
          probs_data[j] = p_data[data[off+et_idx[et_offset+j]]];
        else
          probs_data[j] = p_data[off+et_idx[et_offset+j]];
      }

      RandomEngine::ThreadLocal()->Choice<IdxType, FloatType>(
89
          num_samples[cur_et], probs, out_idx, replace);
90
91
92
93
    };
  return pick_fn;
}

94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
template <typename IdxType>
inline NumPicksFn<IdxType> GetSamplingUniformNumPicksFn(
    int64_t num_samples, bool replace) {
  NumPicksFn<IdxType> num_picks_fn = [num_samples, replace]
    (IdxType rowid, IdxType off, IdxType len,
     const IdxType* col, const IdxType* data) {
      const int64_t max_num_picks = (num_samples == -1) ? len : num_samples;
      if (replace) {
        return static_cast<IdxType>(len == 0 ? 0 : max_num_picks);
      } else {
        return std::min(static_cast<IdxType>(max_num_picks), len);
      }
    };
  return num_picks_fn;
}

110
111
112
113
template <typename IdxType>
inline PickFn<IdxType> GetSamplingUniformPickFn(
    int64_t num_samples, bool replace) {
  PickFn<IdxType> pick_fn = [num_samples, replace]
114
    (IdxType rowid, IdxType off, IdxType len, IdxType num_picks,
115
116
     const IdxType* col, const IdxType* data,
     IdxType* out_idx) {
117
      RandomEngine::ThreadLocal()->UniformChoice<IdxType>(
118
119
          num_picks, len, out_idx, replace);
      for (int64_t j = 0; j < num_picks; ++j) {
120
        out_idx[j] += off;
121
122
123
124
      }
    };
  return pick_fn;
}
125

126
127
template <typename IdxType>
inline RangePickFn<IdxType> GetSamplingUniformRangePickFn(
128
    const std::vector<int64_t>& num_samples, bool replace) {
129
  RangePickFn<IdxType> pick_fn = [num_samples, replace]
130
    (IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,
131
132
133
    const std::vector<IdxType> &et_idx,
    const IdxType* data, IdxType* out_idx) {
      RandomEngine::ThreadLocal()->UniformChoice<IdxType>(
134
          num_samples[cur_et], et_len, out_idx, replace);
135
136
137
138
    };
  return pick_fn;
}

139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
template <typename IdxType, typename FloatType>
inline NumPicksFn<IdxType> GetSamplingBiasedNumPicksFn(
    int64_t num_samples, IdArray split, FloatArray bias, bool replace) {
  NumPicksFn<IdxType> num_picks_fn = [num_samples, split, bias, replace]
    (IdxType rowid, IdxType off, IdxType len,
     const IdxType* col, const IdxType* data) {
      const int64_t max_num_picks = (num_samples == -1) ? len : num_samples;
      const int64_t num_tags = split->shape[1] - 1;
      const IdxType* tag_offset = split.Ptr<IdxType>() + rowid * split->shape[1];
      const FloatType* bias_data = bias.Ptr<FloatType>();
      IdxType nnz = 0;
      for (int64_t j = 0; j < num_tags; ++j) {
        if (bias_data[j] > 0) {
          nnz += tag_offset[j + 1] - tag_offset[j];
        }
      }

      if (replace) {
        return static_cast<IdxType>(nnz == 0 ? 0 : max_num_picks);
      } else {
        return std::min(static_cast<IdxType>(max_num_picks), nnz);
      }
    };
  return num_picks_fn;
}

165
166
167
168
template <typename IdxType, typename FloatType>
inline PickFn<IdxType> GetSamplingBiasedPickFn(
    int64_t num_samples, IdArray split, FloatArray bias, bool replace) {
  PickFn<IdxType> pick_fn = [num_samples, split, bias, replace]
169
    (IdxType rowid, IdxType off, IdxType len, IdxType num_picks,
170
171
     const IdxType* col, const IdxType* data,
     IdxType* out_idx) {
172
    const IdxType *tag_offset = split.Ptr<IdxType>() + rowid * split->shape[1];
173
    RandomEngine::ThreadLocal()->BiasedChoice<IdxType, FloatType>(
174
175
            num_picks, tag_offset, bias, out_idx, replace);
    for (int64_t j = 0; j < num_picks; ++j) {
176
177
178
179
180
181
      out_idx[j] += off;
    }
  };
  return pick_fn;
}

182
183
184
185
}  // namespace

/////////////////////////////// CSR ///////////////////////////////

186
template <DGLDeviceType XPU, typename IdxType, typename DType>
187
COOMatrix CSRRowWiseSampling(CSRMatrix mat, IdArray rows, int64_t num_samples,
188
189
190
191
192
193
194
195
196
                             NDArray prob_or_mask, bool replace) {
  // If num_samples is -1, select all neighbors without replacement.
  replace = (replace && num_samples != -1);
  CHECK(prob_or_mask.defined());
  auto num_picks_fn = GetSamplingNumPicksFn<IdxType, DType>(
      num_samples, prob_or_mask, replace);
  auto pick_fn = GetSamplingPickFn<IdxType, DType>(
      num_samples, prob_or_mask, replace);
  return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
197
198
}

199
template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, float>(
200
    CSRMatrix, IdArray, int64_t, NDArray, bool);
201
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, float>(
202
    CSRMatrix, IdArray, int64_t, NDArray, bool);
203
template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, double>(
204
    CSRMatrix, IdArray, int64_t, NDArray, bool);
205
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, double>(
206
207
208
209
210
211
212
213
214
    CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, int8_t>(
    CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, int8_t>(
    CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, uint8_t>(
    CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, uint8_t>(
    CSRMatrix, IdArray, int64_t, NDArray, bool);
215

216
template <DGLDeviceType XPU, typename IdxType, typename FloatType>
217
COOMatrix CSRRowWisePerEtypeSampling(CSRMatrix mat, IdArray rows, IdArray etypes,
218
219
                                     const std::vector<int64_t>& num_samples,
                                     FloatArray prob, bool replace, bool etype_sorted) {
220
221
  CHECK(prob.defined());
  auto pick_fn = GetSamplingRangePickFn<IdxType, FloatType>(num_samples, prob, replace);
222
  return CSRRowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
223
224
}

225
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, float>(
226
    CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
227
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, float>(
228
    CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
229
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, double>(
230
    CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
231
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, double>(
232
    CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
233

234
template <DGLDeviceType XPU, typename IdxType>
235
236
COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat, IdArray rows,
                                    int64_t num_samples, bool replace) {
237
238
239
  // If num_samples is -1, select all neighbors without replacement.
  replace = (replace && num_samples != -1);
  auto num_picks_fn = GetSamplingUniformNumPicksFn<IdxType>(num_samples, replace);
240
  auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
241
  return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
242
243
}

244
template COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int32_t>(
245
    CSRMatrix, IdArray, int64_t, bool);
246
template COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int64_t>(
247
248
    CSRMatrix, IdArray, int64_t, bool);

249
template <DGLDeviceType XPU, typename IdxType>
250
COOMatrix CSRRowWisePerEtypeSamplingUniform(CSRMatrix mat, IdArray rows, IdArray etypes,
251
                                            const std::vector<int64_t>& num_samples,
252
                                            bool replace, bool etype_sorted) {
253
  auto pick_fn = GetSamplingUniformRangePickFn<IdxType>(num_samples, replace);
254
  return CSRRowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
255
256
}

257
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDGLCPU, int32_t>(
258
    CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool);
259
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDGLCPU, int64_t>(
260
    CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool);
261

262
template <DGLDeviceType XPU, typename IdxType, typename FloatType>
263
264
265
266
267
268
269
270
COOMatrix CSRRowWiseSamplingBiased(
    CSRMatrix mat,
    IdArray rows,
    int64_t num_samples,
    NDArray tag_offset,
    FloatArray bias,
    bool replace
) {
271
272
273
274
  // If num_samples is -1, select all neighbors without replacement.
  replace = (replace && num_samples != -1);
  auto num_picks_fn = GetSamplingBiasedNumPicksFn<IdxType, FloatType>(
      num_samples, tag_offset, bias, replace);
275
276
  auto pick_fn = GetSamplingBiasedPickFn<IdxType, FloatType>(
      num_samples, tag_offset, bias, replace);
277
  return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
278
279
}

280
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int32_t, float>(
281
282
  CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);

283
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int64_t, float>(
284
285
  CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);

286
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int32_t, double>(
287
288
  CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);

289
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int64_t, double>(
290
291
292
  CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);


293
294
/////////////////////////////// COO ///////////////////////////////

295
template <DGLDeviceType XPU, typename IdxType, typename DType>
296
COOMatrix COORowWiseSampling(COOMatrix mat, IdArray rows, int64_t num_samples,
297
298
299
300
301
302
303
304
305
                             NDArray prob_or_mask, bool replace) {
  // If num_samples is -1, select all neighbors without replacement.
  replace = (replace && num_samples != -1);
  CHECK(prob_or_mask.defined());
  auto num_picks_fn = GetSamplingNumPicksFn<IdxType, DType>(
      num_samples, prob_or_mask, replace);
  auto pick_fn = GetSamplingPickFn<IdxType, DType>(
      num_samples, prob_or_mask, replace);
  return COORowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
306
307
}

308
template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, float>(
309
    COOMatrix, IdArray, int64_t, NDArray, bool);
310
template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, float>(
311
    COOMatrix, IdArray, int64_t, NDArray, bool);
312
template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, double>(
313
    COOMatrix, IdArray, int64_t, NDArray, bool);
314
template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, double>(
315
316
317
318
319
320
321
322
323
    COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, int8_t>(
    COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, int8_t>(
    COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, uint8_t>(
    COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, uint8_t>(
    COOMatrix, IdArray, int64_t, NDArray, bool);
324

325
template <DGLDeviceType XPU, typename IdxType, typename FloatType>
326
COOMatrix COORowWisePerEtypeSampling(COOMatrix mat, IdArray rows, IdArray etypes,
327
328
                                     const std::vector<int64_t>& num_samples,
                                     FloatArray prob, bool replace, bool etype_sorted) {
329
330
  CHECK(prob.defined());
  auto pick_fn = GetSamplingRangePickFn<IdxType, FloatType>(num_samples, prob, replace);
331
  return COORowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
332
333
}

334
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, float>(
335
    COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
336
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, float>(
337
    COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
338
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, double>(
339
    COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
340
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, double>(
341
    COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
342

343
template <DGLDeviceType XPU, typename IdxType>
344
345
COOMatrix COORowWiseSamplingUniform(COOMatrix mat, IdArray rows,
                                    int64_t num_samples, bool replace) {
346
347
348
  // If num_samples is -1, select all neighbors without replacement.
  replace = (replace && num_samples != -1);
  auto num_picks_fn = GetSamplingUniformNumPicksFn<IdxType>(num_samples, replace);
349
  auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
350
  return COORowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
351
352
}

353
template COOMatrix COORowWiseSamplingUniform<kDGLCPU, int32_t>(
354
    COOMatrix, IdArray, int64_t, bool);
355
template COOMatrix COORowWiseSamplingUniform<kDGLCPU, int64_t>(
356
357
    COOMatrix, IdArray, int64_t, bool);

358
template <DGLDeviceType XPU, typename IdxType>
359
COOMatrix COORowWisePerEtypeSamplingUniform(COOMatrix mat, IdArray rows, IdArray etypes,
360
361
                                    const std::vector<int64_t>& num_samples,
                                    bool replace, bool etype_sorted) {
362
  auto pick_fn = GetSamplingUniformRangePickFn<IdxType>(num_samples, replace);
363
  return COORowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
364
365
}

366
template COOMatrix COORowWisePerEtypeSamplingUniform<kDGLCPU, int32_t>(
367
    COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool);
368
template COOMatrix COORowWisePerEtypeSamplingUniform<kDGLCPU, int64_t>(
369
    COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool);
370

371
372
373
}  // namespace impl
}  // namespace aten
}  // namespace dgl