"src/vscode:/vscode.git/clone" did not exist on "e3921d5decacd10636b22e9a42ea32eebda69cb9"
rowwise_sampling.cc 24 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cpu/rowwise_sampling.cc
 * @brief rowwise sampling
5
6
 */
#include <dgl/random.h>
7

8
#include <numeric>
9

10
11
12
13
14
15
16
17
#include "./rowwise_pick.h"

namespace dgl {
namespace aten {
namespace impl {
namespace {
// Equivalent to numpy expression: array[idx[off:off + len]]
template <typename IdxType, typename FloatType>
18
19
inline FloatArray DoubleSlice(
    FloatArray array, const IdxType* idx_data, IdxType off, IdxType len) {
20
21
22
23
24
25
26
27
28
29
30
31
  const FloatType* array_data = static_cast<FloatType*>(array->data);
  FloatArray ret = FloatArray::Empty({len}, array->dtype, array->ctx);
  FloatType* ret_data = static_cast<FloatType*>(ret->data);
  for (int64_t j = 0; j < len; ++j) {
    if (idx_data)
      ret_data[j] = array_data[idx_data[off + j]];
    else
      ret_data[j] = array_data[off + j];
  }
  return ret;
}

32
33
34
template <typename IdxType, typename DType>
inline NumPicksFn<IdxType> GetSamplingNumPicksFn(
    int64_t num_samples, NDArray prob_or_mask, bool replace) {
35
36
37
38
39
40
41
42
43
44
45
  NumPicksFn<IdxType> num_picks_fn = [prob_or_mask, num_samples, replace](
                                         IdxType rowid, IdxType off,
                                         IdxType len, const IdxType* col,
                                         const IdxType* data) {
    const int64_t max_num_picks = (num_samples == -1) ? len : num_samples;
    const DType* prob_or_mask_data = prob_or_mask.Ptr<DType>();
    IdxType nnz = 0;
    for (IdxType i = off; i < off + len; ++i) {
      const IdxType eid = data ? data[i] : i;
      if (prob_or_mask_data[eid] > 0) {
        ++nnz;
46
      }
47
    }
48

49
50
51
52
53
54
    if (replace) {
      return static_cast<IdxType>(nnz == 0 ? 0 : max_num_picks);
    } else {
      return std::min(static_cast<IdxType>(max_num_picks), nnz);
    }
  };
55
56
57
58
59
60
  return num_picks_fn;
}

template <typename IdxType, typename DType>
inline PickFn<IdxType> GetSamplingPickFn(
    int64_t num_samples, NDArray prob_or_mask, bool replace) {
61
62
63
64
65
66
67
68
69
70
71
72
  PickFn<IdxType> pick_fn = [prob_or_mask, num_samples, replace](
                                IdxType rowid, IdxType off, IdxType len,
                                IdxType num_picks, const IdxType* col,
                                const IdxType* data, IdxType* out_idx) {
    NDArray prob_or_mask_selected =
        DoubleSlice<IdxType, DType>(prob_or_mask, data, off, len);
    RandomEngine::ThreadLocal()->Choice<IdxType, DType>(
        num_picks, prob_or_mask_selected, out_idx, replace);
    for (int64_t j = 0; j < num_picks; ++j) {
      out_idx[j] += off;
    }
  };
73
74
75
  return pick_fn;
}

76
template <typename IdxType, typename FloatType>
77
78
79
inline EtypeRangePickFn<IdxType> GetSamplingRangePickFn(
    const std::vector<int64_t>& num_samples,
    const std::vector<FloatArray>& prob, bool replace) {
80
81
82
83
84
85
86
87
88
89
90
91
92
93
  EtypeRangePickFn<IdxType> pick_fn =
      [prob, num_samples, replace](
          IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,
          const std::vector<IdxType>& et_idx,
          const std::vector<IdxType>& et_eid, const IdxType* eid,
          IdxType* out_idx) {
        const FloatArray& p = prob[cur_et];
        const FloatType* p_data = IsNullArray(p) ? nullptr : p.Ptr<FloatType>();
        FloatArray probs = FloatArray::Empty({et_len}, p->dtype, p->ctx);
        FloatType* probs_data = probs.Ptr<FloatType>();
        for (int64_t j = 0; j < et_len; ++j) {
          const IdxType cur_eid = et_eid[et_idx[et_offset + j]];
          probs_data[j] = p_data ? p_data[cur_eid] : static_cast<FloatType>(1.);
        }
94

95
96
97
        RandomEngine::ThreadLocal()->Choice<IdxType, FloatType>(
            num_samples[cur_et], probs, out_idx, replace);
      };
98
99
100
  return pick_fn;
}

101
102
103
template <typename IdxType>
inline NumPicksFn<IdxType> GetSamplingUniformNumPicksFn(
    int64_t num_samples, bool replace) {
104
105
106
107
108
109
110
111
112
113
114
  NumPicksFn<IdxType> num_picks_fn = [num_samples, replace](
                                         IdxType rowid, IdxType off,
                                         IdxType len, const IdxType* col,
                                         const IdxType* data) {
    const int64_t max_num_picks = (num_samples == -1) ? len : num_samples;
    if (replace) {
      return static_cast<IdxType>(len == 0 ? 0 : max_num_picks);
    } else {
      return std::min(static_cast<IdxType>(max_num_picks), len);
    }
  };
115
116
117
  return num_picks_fn;
}

118
119
120
template <typename IdxType>
inline PickFn<IdxType> GetSamplingUniformPickFn(
    int64_t num_samples, bool replace) {
121
122
123
124
125
126
127
128
129
130
  PickFn<IdxType> pick_fn = [num_samples, replace](
                                IdxType rowid, IdxType off, IdxType len,
                                IdxType num_picks, const IdxType* col,
                                const IdxType* data, IdxType* out_idx) {
    RandomEngine::ThreadLocal()->UniformChoice<IdxType>(
        num_picks, len, out_idx, replace);
    for (int64_t j = 0; j < num_picks; ++j) {
      out_idx[j] += off;
    }
  };
131
132
  return pick_fn;
}
133

134
template <typename IdxType>
135
inline EtypeRangePickFn<IdxType> GetSamplingUniformRangePickFn(
136
    const std::vector<int64_t>& num_samples, bool replace) {
137
138
139
140
141
142
143
144
145
  EtypeRangePickFn<IdxType> pick_fn =
      [num_samples, replace](
          IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,
          const std::vector<IdxType>& et_idx,
          const std::vector<IdxType>& et_eid, const IdxType* data,
          IdxType* out_idx) {
        RandomEngine::ThreadLocal()->UniformChoice<IdxType>(
            num_samples[cur_et], et_len, out_idx, replace);
      };
146
147
148
  return pick_fn;
}

149
150
151
template <typename IdxType, typename FloatType>
inline NumPicksFn<IdxType> GetSamplingBiasedNumPicksFn(
    int64_t num_samples, IdArray split, FloatArray bias, bool replace) {
152
153
154
155
156
157
158
159
160
161
162
163
  NumPicksFn<IdxType> num_picks_fn = [num_samples, split, bias, replace](
                                         IdxType rowid, IdxType off,
                                         IdxType len, const IdxType* col,
                                         const IdxType* data) {
    const int64_t max_num_picks = (num_samples == -1) ? len : num_samples;
    const int64_t num_tags = split->shape[1] - 1;
    const IdxType* tag_offset = split.Ptr<IdxType>() + rowid * split->shape[1];
    const FloatType* bias_data = bias.Ptr<FloatType>();
    IdxType nnz = 0;
    for (int64_t j = 0; j < num_tags; ++j) {
      if (bias_data[j] > 0) {
        nnz += tag_offset[j + 1] - tag_offset[j];
164
      }
165
    }
166

167
168
169
170
171
172
    if (replace) {
      return static_cast<IdxType>(nnz == 0 ? 0 : max_num_picks);
    } else {
      return std::min(static_cast<IdxType>(max_num_picks), nnz);
    }
  };
173
174
175
  return num_picks_fn;
}

176
177
178
template <typename IdxType, typename FloatType>
inline PickFn<IdxType> GetSamplingBiasedPickFn(
    int64_t num_samples, IdArray split, FloatArray bias, bool replace) {
179
180
181
182
183
  PickFn<IdxType> pick_fn = [num_samples, split, bias, replace](
                                IdxType rowid, IdxType off, IdxType len,
                                IdxType num_picks, const IdxType* col,
                                const IdxType* data, IdxType* out_idx) {
    const IdxType* tag_offset = split.Ptr<IdxType>() + rowid * split->shape[1];
184
    RandomEngine::ThreadLocal()->BiasedChoice<IdxType, FloatType>(
185
        num_picks, tag_offset, bias, out_idx, replace);
186
    for (int64_t j = 0; j < num_picks; ++j) {
187
188
189
190
191
192
      out_idx[j] += off;
    }
  };
  return pick_fn;
}

193
194
195
196
}  // namespace

/////////////////////////////// CSR ///////////////////////////////

197
template <DGLDeviceType XPU, typename IdxType, typename DType>
198
199
200
COOMatrix CSRRowWiseSampling(
    CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
    bool replace) {
201
202
203
  // If num_samples is -1, select all neighbors without replacement.
  replace = (replace && num_samples != -1);
  CHECK(prob_or_mask.defined());
204
205
206
207
  auto num_picks_fn =
      GetSamplingNumPicksFn<IdxType, DType>(num_samples, prob_or_mask, replace);
  auto pick_fn =
      GetSamplingPickFn<IdxType, DType>(num_samples, prob_or_mask, replace);
208
  return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
209
210
}

211
template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, float>(
212
    CSRMatrix, IdArray, int64_t, NDArray, bool);
213
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, float>(
214
    CSRMatrix, IdArray, int64_t, NDArray, bool);
215
template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, double>(
216
    CSRMatrix, IdArray, int64_t, NDArray, bool);
217
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, double>(
218
219
220
221
222
223
224
225
226
    CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, int8_t>(
    CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, int8_t>(
    CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, uint8_t>(
    CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, uint8_t>(
    CSRMatrix, IdArray, int64_t, NDArray, bool);
227

228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
template <
    DGLDeviceType XPU, typename IdxType, typename DType, bool map_seed_nodes>
std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused(
    CSRMatrix mat, IdArray rows, IdArray seed_mapping,
    std::vector<IdxType>* new_seed_nodes, int64_t num_samples,
    NDArray prob_or_mask, bool replace) {
  // If num_samples is -1, select all neighbors without replacement.
  replace = (replace && num_samples != -1);
  CHECK(prob_or_mask.defined());
  auto num_picks_fn =
      GetSamplingNumPicksFn<IdxType, DType>(num_samples, prob_or_mask, replace);
  auto pick_fn =
      GetSamplingPickFn<IdxType, DType>(num_samples, prob_or_mask, replace);
  return CSRRowWisePickFused<IdxType, map_seed_nodes>(
      mat, rows, seed_mapping, new_seed_nodes, num_samples, replace, pick_fn,
      num_picks_fn);
}

template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int32_t, float, true>(
    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int64_t, float, true>(
    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int32_t, double, true>(
    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int64_t, double, true>(
    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int32_t, int8_t, true>(
    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int64_t, int8_t, true>(
    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int32_t, uint8_t, true>(
    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int64_t, uint8_t, true>(
    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);

template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int32_t, float, false>(
    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int64_t, float, false>(
    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int32_t, double, false>(
    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int64_t, double, false>(
    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int32_t, int8_t, false>(
    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int64_t, int8_t, false>(
    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int32_t, uint8_t, false>(
    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingFused<kDGLCPU, int64_t, uint8_t, false>(
    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);

296
297
298
template <DGLDeviceType XPU, typename IdxType, typename DType>
COOMatrix CSRRowWisePerEtypeSampling(
    CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
299
300
301
302
303
304
305
306
307
    const std::vector<int64_t>& num_samples,
    const std::vector<NDArray>& prob_or_mask, bool replace,
    bool rowwise_etype_sorted) {
  CHECK(prob_or_mask.size() == num_samples.size())
      << "the number of probability tensors does not match the number of edge "
         "types.";
  for (auto& p : prob_or_mask) CHECK(p.defined());
  auto pick_fn = GetSamplingRangePickFn<IdxType, DType>(
      num_samples, prob_or_mask, replace);
308
  return CSRRowWisePerEtypePick<IdxType, DType>(
309
310
      mat, rows, eid2etype_offset, num_samples, replace, rowwise_etype_sorted,
      pick_fn, prob_or_mask);
311
312
}

313
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, float>(
314
315
    CSRMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
316
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, float>(
317
318
    CSRMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
319
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, double>(
320
321
    CSRMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
322
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, double>(
323
324
    CSRMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
325
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, int8_t>(
326
327
    CSRMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
328
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, int8_t>(
329
330
    CSRMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
331
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, uint8_t>(
332
333
    CSRMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
334
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, uint8_t>(
335
336
    CSRMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
337

338
template <DGLDeviceType XPU, typename IdxType>
339
340
COOMatrix CSRRowWiseSamplingUniform(
    CSRMatrix mat, IdArray rows, int64_t num_samples, bool replace) {
341
342
  // If num_samples is -1, select all neighbors without replacement.
  replace = (replace && num_samples != -1);
343
344
  auto num_picks_fn =
      GetSamplingUniformNumPicksFn<IdxType>(num_samples, replace);
345
  auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
346
  return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
347
348
}

349
template COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int32_t>(
350
    CSRMatrix, IdArray, int64_t, bool);
351
template COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int64_t>(
352
353
    CSRMatrix, IdArray, int64_t, bool);

354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
template <DGLDeviceType XPU, typename IdxType, bool map_seed_nodes>
std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingUniformFused(
    CSRMatrix mat, IdArray rows, IdArray seed_mapping,
    std::vector<IdxType>* new_seed_nodes, int64_t num_samples, bool replace) {
  // If num_samples is -1, select all neighbors without replacement.
  replace = (replace && num_samples != -1);
  auto num_picks_fn =
      GetSamplingUniformNumPicksFn<IdxType>(num_samples, replace);
  auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
  return CSRRowWisePickFused<IdxType, map_seed_nodes>(
      mat, rows, seed_mapping, new_seed_nodes, num_samples, replace, pick_fn,
      num_picks_fn);
}

template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingUniformFused<kDGLCPU, int32_t, true>(
    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingUniformFused<kDGLCPU, int64_t, true>(
    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingUniformFused<kDGLCPU, int32_t, false>(
    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, bool);
template std::pair<CSRMatrix, IdArray>
CSRRowWiseSamplingUniformFused<kDGLCPU, int64_t, false>(
    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, bool);

381
template <DGLDeviceType XPU, typename IdxType>
382
383
COOMatrix CSRRowWisePerEtypeSamplingUniform(
    CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
384
385
    const std::vector<int64_t>& num_samples, bool replace,
    bool rowwise_etype_sorted) {
386
  auto pick_fn = GetSamplingUniformRangePickFn<IdxType>(num_samples, replace);
387
  return CSRRowWisePerEtypePick<IdxType, float>(
388
389
      mat, rows, eid2etype_offset, num_samples, replace, rowwise_etype_sorted,
      pick_fn, {});
390
391
}

392
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDGLCPU, int32_t>(
393
394
    CSRMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, bool, bool);
395
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDGLCPU, int64_t>(
396
397
    CSRMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, bool, bool);
398

399
template <DGLDeviceType XPU, typename IdxType, typename FloatType>
400
COOMatrix CSRRowWiseSamplingBiased(
401
402
    CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray tag_offset,
    FloatArray bias, bool replace) {
403
404
405
406
  // If num_samples is -1, select all neighbors without replacement.
  replace = (replace && num_samples != -1);
  auto num_picks_fn = GetSamplingBiasedNumPicksFn<IdxType, FloatType>(
      num_samples, tag_offset, bias, replace);
407
408
  auto pick_fn = GetSamplingBiasedPickFn<IdxType, FloatType>(
      num_samples, tag_offset, bias, replace);
409
  return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
410
411
}

412
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int32_t, float>(
413
    CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
414

415
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int64_t, float>(
416
    CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
417

418
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int32_t, double>(
419
    CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
420

421
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int64_t, double>(
422
    CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
423

424
425
/////////////////////////////// COO ///////////////////////////////

426
template <DGLDeviceType XPU, typename IdxType, typename DType>
427
428
429
COOMatrix COORowWiseSampling(
    COOMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
    bool replace) {
430
431
432
  // If num_samples is -1, select all neighbors without replacement.
  replace = (replace && num_samples != -1);
  CHECK(prob_or_mask.defined());
433
434
435
436
  auto num_picks_fn =
      GetSamplingNumPicksFn<IdxType, DType>(num_samples, prob_or_mask, replace);
  auto pick_fn =
      GetSamplingPickFn<IdxType, DType>(num_samples, prob_or_mask, replace);
437
  return COORowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
438
439
}

440
template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, float>(
441
    COOMatrix, IdArray, int64_t, NDArray, bool);
442
template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, float>(
443
    COOMatrix, IdArray, int64_t, NDArray, bool);
444
template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, double>(
445
    COOMatrix, IdArray, int64_t, NDArray, bool);
446
template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, double>(
447
448
449
450
451
452
453
454
455
    COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, int8_t>(
    COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, int8_t>(
    COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, uint8_t>(
    COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, uint8_t>(
    COOMatrix, IdArray, int64_t, NDArray, bool);
456

457
458
459
template <DGLDeviceType XPU, typename IdxType, typename DType>
COOMatrix COORowWisePerEtypeSampling(
    COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
460
461
462
463
464
465
466
467
    const std::vector<int64_t>& num_samples,
    const std::vector<NDArray>& prob_or_mask, bool replace) {
  CHECK(prob_or_mask.size() == num_samples.size())
      << "the number of probability tensors do not match the number of edge "
         "types.";
  for (auto& p : prob_or_mask) CHECK(p.defined());
  auto pick_fn = GetSamplingRangePickFn<IdxType, DType>(
      num_samples, prob_or_mask, replace);
468
469
  return COORowWisePerEtypePick<IdxType, DType>(
      mat, rows, eid2etype_offset, num_samples, replace, pick_fn, prob_or_mask);
470
471
}

472
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, float>(
473
474
    COOMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
475
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, float>(
476
477
    COOMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
478
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, double>(
479
480
    COOMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
481
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, double>(
482
483
    COOMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
484
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, int8_t>(
485
486
    COOMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
487
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, int8_t>(
488
489
    COOMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
490
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, uint8_t>(
491
492
    COOMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
493
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, uint8_t>(
494
495
    COOMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
496

497
template <DGLDeviceType XPU, typename IdxType>
498
499
COOMatrix COORowWiseSamplingUniform(
    COOMatrix mat, IdArray rows, int64_t num_samples, bool replace) {
500
501
  // If num_samples is -1, select all neighbors without replacement.
  replace = (replace && num_samples != -1);
502
503
  auto num_picks_fn =
      GetSamplingUniformNumPicksFn<IdxType>(num_samples, replace);
504
  auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
505
  return COORowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
506
507
}

508
template COOMatrix COORowWiseSamplingUniform<kDGLCPU, int32_t>(
509
    COOMatrix, IdArray, int64_t, bool);
510
template COOMatrix COORowWiseSamplingUniform<kDGLCPU, int64_t>(
511
512
    COOMatrix, IdArray, int64_t, bool);

513
template <DGLDeviceType XPU, typename IdxType>
514
515
516
COOMatrix COORowWisePerEtypeSamplingUniform(
    COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
    const std::vector<int64_t>& num_samples, bool replace) {
517
  auto pick_fn = GetSamplingUniformRangePickFn<IdxType>(num_samples, replace);
518
519
  return COORowWisePerEtypePick<IdxType, float>(
      mat, rows, eid2etype_offset, num_samples, replace, pick_fn, {});
520
521
}

522
template COOMatrix COORowWisePerEtypeSamplingUniform<kDGLCPU, int32_t>(
523
524
    COOMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, bool);
525
template COOMatrix COORowWisePerEtypeSamplingUniform<kDGLCPU, int64_t>(
526
527
    COOMatrix, IdArray, const std::vector<int64_t>&,
    const std::vector<int64_t>&, bool);
528

529
530
531
}  // namespace impl
}  // namespace aten
}  // namespace dgl