rowwise_sampling.cc 24 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
/**
3
 *  Copyright (c) 2020 by Contributors
4
5
 * @file array/cpu/rowwise_sampling.cc
 * @brief rowwise sampling
6
7
 */
#include <dgl/random.h>
8

9
#include <numeric>
10

sangwzh's avatar
sangwzh committed
11
#include "rowwise_pick.h"
12
13
14
15
16
17
18

namespace dgl {
namespace aten {
namespace impl {
namespace {
// Equivalent to numpy expression: array[idx[off:off + len]]
template <typename IdxType, typename FloatType>
19
20
inline FloatArray DoubleSlice(
    FloatArray array, const IdxType* idx_data, IdxType off, IdxType len) {
21
22
23
24
25
26
27
28
29
30
31
32
  const FloatType* array_data = static_cast<FloatType*>(array->data);
  FloatArray ret = FloatArray::Empty({len}, array->dtype, array->ctx);
  FloatType* ret_data = static_cast<FloatType*>(ret->data);
  for (int64_t j = 0; j < len; ++j) {
    if (idx_data)
      ret_data[j] = array_data[idx_data[off + j]];
    else
      ret_data[j] = array_data[off + j];
  }
  return ret;
}

33
34
35
template <typename IdxType, typename DType>
inline NumPicksFn<IdxType> GetSamplingNumPicksFn(
    int64_t num_samples, NDArray prob_or_mask, bool replace) {
36
37
38
39
40
41
42
43
44
45
46
  NumPicksFn<IdxType> num_picks_fn = [prob_or_mask, num_samples, replace](
                                         IdxType rowid, IdxType off,
                                         IdxType len, const IdxType* col,
                                         const IdxType* data) {
    const int64_t max_num_picks = (num_samples == -1) ? len : num_samples;
    const DType* prob_or_mask_data = prob_or_mask.Ptr<DType>();
    IdxType nnz = 0;
    for (IdxType i = off; i < off + len; ++i) {
      const IdxType eid = data ? data[i] : i;
      if (prob_or_mask_data[eid] > 0) {
        ++nnz;
47
      }
48
    }
49

50
51
52
53
54
55
    if (replace) {
      return static_cast<IdxType>(nnz == 0 ? 0 : max_num_picks);
    } else {
      return std::min(static_cast<IdxType>(max_num_picks), nnz);
    }
  };
56
57
58
59
60
61
  return num_picks_fn;
}

template <typename IdxType, typename DType>
inline PickFn<IdxType> GetSamplingPickFn(
    int64_t num_samples, NDArray prob_or_mask, bool replace) {
62
63
64
65
66
67
68
69
70
71
72
73
  PickFn<IdxType> pick_fn = [prob_or_mask, num_samples, replace](
                                IdxType rowid, IdxType off, IdxType len,
                                IdxType num_picks, const IdxType* col,
                                const IdxType* data, IdxType* out_idx) {
    NDArray prob_or_mask_selected =
        DoubleSlice<IdxType, DType>(prob_or_mask, data, off, len);
    RandomEngine::ThreadLocal()->Choice<IdxType, DType>(
        num_picks, prob_or_mask_selected, out_idx, replace);
    for (int64_t j = 0; j < num_picks; ++j) {
      out_idx[j] += off;
    }
  };
74
75
76
  return pick_fn;
}

77
template <typename IdxType, typename FloatType>
78
79
80
inline EtypeRangePickFn<IdxType> GetSamplingRangePickFn(
    const std::vector<int64_t>& num_samples,
    const std::vector<FloatArray>& prob, bool replace) {
81
82
83
84
85
86
87
88
89
90
91
92
93
94
  EtypeRangePickFn<IdxType> pick_fn =
      [prob, num_samples, replace](
          IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,
          const std::vector<IdxType>& et_idx,
          const std::vector<IdxType>& et_eid, const IdxType* eid,
          IdxType* out_idx) {
        const FloatArray& p = prob[cur_et];
        const FloatType* p_data = IsNullArray(p) ? nullptr : p.Ptr<FloatType>();
        FloatArray probs = FloatArray::Empty({et_len}, p->dtype, p->ctx);
        FloatType* probs_data = probs.Ptr<FloatType>();
        for (int64_t j = 0; j < et_len; ++j) {
          const IdxType cur_eid = et_eid[et_idx[et_offset + j]];
          probs_data[j] = p_data ? p_data[cur_eid] : static_cast<FloatType>(1.);
        }
95

96
97
98
        RandomEngine::ThreadLocal()->Choice<IdxType, FloatType>(
            num_samples[cur_et], probs, out_idx, replace);
      };
99
100
101
  return pick_fn;
}

102
103
104
template <typename IdxType>
inline NumPicksFn<IdxType> GetSamplingUniformNumPicksFn(
    int64_t num_samples, bool replace) {
105
106
107
108
109
110
111
112
113
114
115
  NumPicksFn<IdxType> num_picks_fn = [num_samples, replace](
                                         IdxType rowid, IdxType off,
                                         IdxType len, const IdxType* col,
                                         const IdxType* data) {
    const int64_t max_num_picks = (num_samples == -1) ? len : num_samples;
    if (replace) {
      return static_cast<IdxType>(len == 0 ? 0 : max_num_picks);
    } else {
      return std::min(static_cast<IdxType>(max_num_picks), len);
    }
  };
116
117
118
  return num_picks_fn;
}

119
120
121
template <typename IdxType>
inline PickFn<IdxType> GetSamplingUniformPickFn(
    int64_t num_samples, bool replace) {
122
123
124
125
126
127
128
129
130
131
  PickFn<IdxType> pick_fn = [num_samples, replace](
                                IdxType rowid, IdxType off, IdxType len,
                                IdxType num_picks, const IdxType* col,
                                const IdxType* data, IdxType* out_idx) {
    RandomEngine::ThreadLocal()->UniformChoice<IdxType>(
        num_picks, len, out_idx, replace);
    for (int64_t j = 0; j < num_picks; ++j) {
      out_idx[j] += off;
    }
  };
132
133
  return pick_fn;
}
134

135
template <typename IdxType>
136
inline EtypeRangePickFn<IdxType> GetSamplingUniformRangePickFn(
137
    const std::vector<int64_t>& num_samples, bool replace) {
138
139
140
141
142
143
144
145
146
  EtypeRangePickFn<IdxType> pick_fn =
      [num_samples, replace](
          IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,
          const std::vector<IdxType>& et_idx,
          const std::vector<IdxType>& et_eid, const IdxType* data,
          IdxType* out_idx) {
        RandomEngine::ThreadLocal()->UniformChoice<IdxType>(
            num_samples[cur_et], et_len, out_idx, replace);
      };
147
148
149
  return pick_fn;
}

150
151
152
template <typename IdxType, typename FloatType>
inline NumPicksFn<IdxType> GetSamplingBiasedNumPicksFn(
    int64_t num_samples, IdArray split, FloatArray bias, bool replace) {
153
154
155
156
157
158
159
160
161
162
163
164
  NumPicksFn<IdxType> num_picks_fn = [num_samples, split, bias, replace](
                                         IdxType rowid, IdxType off,
                                         IdxType len, const IdxType* col,
                                         const IdxType* data) {
    const int64_t max_num_picks = (num_samples == -1) ? len : num_samples;
    const int64_t num_tags = split->shape[1] - 1;
    const IdxType* tag_offset = split.Ptr<IdxType>() + rowid * split->shape[1];
    const FloatType* bias_data = bias.Ptr<FloatType>();
    IdxType nnz = 0;
    for (int64_t j = 0; j < num_tags; ++j) {
      if (bias_data[j] > 0) {
        nnz += tag_offset[j + 1] - tag_offset[j];
165
      }
166
    }
167

168
169
170
171
172
173
    if (replace) {
      return static_cast<IdxType>(nnz == 0 ? 0 : max_num_picks);
    } else {
      return std::min(static_cast<IdxType>(max_num_picks), nnz);
    }
  };
174
175
176
  return num_picks_fn;
}

177
178
179
template <typename IdxType, typename FloatType>
inline PickFn<IdxType> GetSamplingBiasedPickFn(
    int64_t num_samples, IdArray split, FloatArray bias, bool replace) {
180
181
182
183
184
  PickFn<IdxType> pick_fn = [num_samples, split, bias, replace](
                                IdxType rowid, IdxType off, IdxType len,
                                IdxType num_picks, const IdxType* col,
                                const IdxType* data, IdxType* out_idx) {
    const IdxType* tag_offset = split.Ptr<IdxType>() + rowid * split->shape[1];
185
    RandomEngine::ThreadLocal()->BiasedChoice<IdxType, FloatType>(
186
        num_picks, tag_offset, bias, out_idx, replace);
187
    for (int64_t j = 0; j < num_picks; ++j) {
188
189
190
191
192
193
      out_idx[j] += off;
    }
  };
  return pick_fn;
}

194
195
196
197
}  // namespace

/////////////////////////////// CSR ///////////////////////////////

198
template <DGLDeviceType XPU, typename IdxType, typename DType>
199
200
201
COOMatrix CSRRowWiseSampling(
    CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
    bool replace) {
202
203
204
  // If num_samples is -1, select all neighbors without replacement.
  replace = (replace && num_samples != -1);
  CHECK(prob_or_mask.defined());
205
206
207
208
  auto num_picks_fn =
      GetSamplingNumPicksFn<IdxType, DType>(num_samples, prob_or_mask, replace);
  auto pick_fn =
      GetSamplingPickFn<IdxType, DType>(num_samples, prob_or_mask, replace);
209
  return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
210
211
}

212
template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, float>(
213
    CSRMatrix, IdArray, int64_t, NDArray, bool);
214
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, float>(
215
    CSRMatrix, IdArray, int64_t, NDArray, bool);
216
template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, double>(
217
    CSRMatrix, IdArray, int64_t, NDArray, bool);
218
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, double>(
219
220
221
222
223
224
225
226
227
    CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, int8_t>(
    CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, int8_t>(
    CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, uint8_t>(
    CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, uint8_t>(
    CSRMatrix, IdArray, int64_t, NDArray, bool);
228

229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
template <
    DGLDeviceType XPU, typename IdxType, typename DType, bool map_seed_nodes>
std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused(
    CSRMatrix mat, IdArray rows, IdArray seed_mapping,
    std::vector<IdxType>* new_seed_nodes, int64_t num_samples,
    NDArray prob_or_mask, bool replace) {
  // If num_samples is -1, select all neighbors without replacement.
  replace = (replace && num_samples != -1);
  CHECK(prob_or_mask.defined());
  auto num_picks_fn =
      GetSamplingNumPicksFn<IdxType, DType>(num_samples, prob_or_mask, replace);
  auto pick_fn =
      GetSamplingPickFn<IdxType, DType>(num_samples, prob_or_mask, replace);
  return CSRRowWisePickFused<IdxType, map_seed_nodes>(
      mat, rows, seed_mapping, new_seed_nodes, num_samples, replace, pick_fn,
      num_picks_fn);
}

template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int32_t, float, true>(
    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int64_t, float, true>(
    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int32_t, double, true>(
    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int64_t, double, true>(
    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int32_t, int8_t, true>(
    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int64_t, int8_t, true>(
    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int32_t, uint8_t, true>(
    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int64_t, uint8_t, true>(
    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);

template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int32_t, float, false>(
    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int64_t, float, false>(
    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int32_t, double, false>(
    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int64_t, double, false>(
    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int32_t, int8_t, false>(
    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int64_t, int8_t, false>(
    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int32_t, uint8_t, false>(
    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int64_t, uint8_t, false>(
    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);

297
298
299
template <DGLDeviceType XPU, typename IdxType, typename DType>
COOMatrix CSRRowWisePerEtypeSampling(
    CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
300
301
302
303
304
305
306
307
308
    const std::vector<int64_t>& num_samples,
    const std::vector<NDArray>& prob_or_mask, bool replace,
    bool rowwise_etype_sorted) {
  CHECK(prob_or_mask.size() == num_samples.size())
      << "the number of probability tensors does not match the number of edge "
         "types.";
  for (auto& p : prob_or_mask) CHECK(p.defined());
  auto pick_fn = GetSamplingRangePickFn<IdxType, DType>(
      num_samples, prob_or_mask, replace);
309
  return CSRRowWisePerEtypePick<IdxType, DType>(
310
311
      mat, rows, eid2etype_offset, num_samples, replace, rowwise_etype_sorted,
      pick_fn, prob_or_mask);
312
313
}

314
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, float>(
315
316
    CSRMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
317
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, float>(
318
319
    CSRMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
320
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, double>(
321
322
    CSRMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
323
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, double>(
324
325
    CSRMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
326
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, int8_t>(
327
328
    CSRMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
329
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, int8_t>(
330
331
    CSRMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
332
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, uint8_t>(
333
334
    CSRMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
335
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, uint8_t>(
336
337
    CSRMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
338

339
template <DGLDeviceType XPU, typename IdxType>
340
341
COOMatrix CSRRowWiseSamplingUniform(
    CSRMatrix mat, IdArray rows, int64_t num_samples, bool replace) {
342
343
  // If num_samples is -1, select all neighbors without replacement.
  replace = (replace && num_samples != -1);
344
345
  auto num_picks_fn =
      GetSamplingUniformNumPicksFn<IdxType>(num_samples, replace);
346
  auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
347
  return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
348
349
}

350
template COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int32_t>(
351
    CSRMatrix, IdArray, int64_t, bool);
352
template COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int64_t>(
353
354
    CSRMatrix, IdArray, int64_t, bool);

355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
template <DGLDeviceType XPU, typename IdxType, bool map_seed_nodes>
std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingUniformFused(
    CSRMatrix mat, IdArray rows, IdArray seed_mapping,
    std::vector<IdxType>* new_seed_nodes, int64_t num_samples, bool replace) {
  // If num_samples is -1, select all neighbors without replacement.
  replace = (replace && num_samples != -1);
  auto num_picks_fn =
      GetSamplingUniformNumPicksFn<IdxType>(num_samples, replace);
  auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
  return CSRRowWisePickFused<IdxType, map_seed_nodes>(
      mat, rows, seed_mapping, new_seed_nodes, num_samples, replace, pick_fn,
      num_picks_fn);
}

template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingUniformFused<kDGLCPU, int32_t, true>(
    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingUniformFused<kDGLCPU, int64_t, true>(
    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingUniformFused<kDGLCPU, int32_t, false>(
    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingUniformFused<kDGLCPU, int64_t, false>(
    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, bool);

382
template <DGLDeviceType XPU, typename IdxType>
383
384
COOMatrix CSRRowWisePerEtypeSamplingUniform(
    CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
385
386
    const std::vector<int64_t>& num_samples, bool replace,
    bool rowwise_etype_sorted) {
387
  auto pick_fn = GetSamplingUniformRangePickFn<IdxType>(num_samples, replace);
388
  return CSRRowWisePerEtypePick<IdxType, float>(
389
390
      mat, rows, eid2etype_offset, num_samples, replace, rowwise_etype_sorted,
      pick_fn, {});
391
392
}

393
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDGLCPU, int32_t>(
394
395
    CSRMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, bool, bool);
396
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDGLCPU, int64_t>(
397
398
    CSRMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, bool, bool);
399

400
template <DGLDeviceType XPU, typename IdxType, typename FloatType>
401
COOMatrix CSRRowWiseSamplingBiased(
402
403
    CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray tag_offset,
    FloatArray bias, bool replace) {
404
405
406
407
  // If num_samples is -1, select all neighbors without replacement.
  replace = (replace && num_samples != -1);
  auto num_picks_fn = GetSamplingBiasedNumPicksFn<IdxType, FloatType>(
      num_samples, tag_offset, bias, replace);
408
409
  auto pick_fn = GetSamplingBiasedPickFn<IdxType, FloatType>(
      num_samples, tag_offset, bias, replace);
410
  return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
411
412
}

413
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int32_t, float>(
414
    CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
415

416
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int64_t, float>(
417
    CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
418

419
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int32_t, double>(
420
    CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
421

422
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int64_t, double>(
423
    CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
424

425
426
/////////////////////////////// COO ///////////////////////////////

427
template <DGLDeviceType XPU, typename IdxType, typename DType>
428
429
430
COOMatrix COORowWiseSampling(
    COOMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
    bool replace) {
431
432
433
  // If num_samples is -1, select all neighbors without replacement.
  replace = (replace && num_samples != -1);
  CHECK(prob_or_mask.defined());
434
435
436
437
  auto num_picks_fn =
      GetSamplingNumPicksFn<IdxType, DType>(num_samples, prob_or_mask, replace);
  auto pick_fn =
      GetSamplingPickFn<IdxType, DType>(num_samples, prob_or_mask, replace);
438
  return COORowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
439
440
}

441
template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, float>(
442
    COOMatrix, IdArray, int64_t, NDArray, bool);
443
template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, float>(
444
    COOMatrix, IdArray, int64_t, NDArray, bool);
445
template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, double>(
446
    COOMatrix, IdArray, int64_t, NDArray, bool);
447
template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, double>(
448
449
450
451
452
453
454
455
456
    COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, int8_t>(
    COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, int8_t>(
    COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, uint8_t>(
    COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, uint8_t>(
    COOMatrix, IdArray, int64_t, NDArray, bool);
457

458
459
460
template <DGLDeviceType XPU, typename IdxType, typename DType>
COOMatrix COORowWisePerEtypeSampling(
    COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
461
462
463
464
465
466
467
468
    const std::vector<int64_t>& num_samples,
    const std::vector<NDArray>& prob_or_mask, bool replace) {
  CHECK(prob_or_mask.size() == num_samples.size())
      << "the number of probability tensors do not match the number of edge "
         "types.";
  for (auto& p : prob_or_mask) CHECK(p.defined());
  auto pick_fn = GetSamplingRangePickFn<IdxType, DType>(
      num_samples, prob_or_mask, replace);
469
470
  return COORowWisePerEtypePick<IdxType, DType>(
      mat, rows, eid2etype_offset, num_samples, replace, pick_fn, prob_or_mask);
471
472
}

473
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, float>(
474
475
    COOMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
476
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, float>(
477
478
    COOMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
479
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, double>(
480
481
    COOMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
482
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, double>(
483
484
    COOMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
485
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, int8_t>(
486
487
    COOMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
488
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, int8_t>(
489
490
    COOMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
491
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, uint8_t>(
492
493
    COOMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
494
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, uint8_t>(
495
496
    COOMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
497

498
template <DGLDeviceType XPU, typename IdxType>
499
500
COOMatrix COORowWiseSamplingUniform(
    COOMatrix mat, IdArray rows, int64_t num_samples, bool replace) {
501
502
  // If num_samples is -1, select all neighbors without replacement.
  replace = (replace && num_samples != -1);
503
504
  auto num_picks_fn =
      GetSamplingUniformNumPicksFn<IdxType>(num_samples, replace);
505
  auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
506
  return COORowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
507
508
}

509
template COOMatrix COORowWiseSamplingUniform<kDGLCPU, int32_t>(
510
    COOMatrix, IdArray, int64_t, bool);
511
template COOMatrix COORowWiseSamplingUniform<kDGLCPU, int64_t>(
512
513
    COOMatrix, IdArray, int64_t, bool);

514
template <DGLDeviceType XPU, typename IdxType>
515
516
517
COOMatrix COORowWisePerEtypeSamplingUniform(
    COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
    const std::vector<int64_t>& num_samples, bool replace) {
518
  auto pick_fn = GetSamplingUniformRangePickFn<IdxType>(num_samples, replace);
519
520
  return COORowWisePerEtypePick<IdxType, float>(
      mat, rows, eid2etype_offset, num_samples, replace, pick_fn, {});
521
522
}

523
template COOMatrix COORowWisePerEtypeSamplingUniform<kDGLCPU, int32_t>(
524
525
    COOMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, bool);
526
template COOMatrix COORowWisePerEtypeSamplingUniform<kDGLCPU, int64_t>(
527
528
    COOMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, bool);
529

530
531
532
}  // namespace impl
}  // namespace aten
}  // namespace dgl