"examples/multigpu/vscode:/vscode.git/clone" did not exist on "8f1b5782bb78a5285ea491626630d5a5cbe1512f"
rowwise_pick.h 22.6 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cpu/rowwise_pick.h
 * @brief Template implementation for rowwise pick operators.
5
6
7
8
9
 */
#ifndef DGL_ARRAY_CPU_ROWWISE_PICK_H_
#define DGL_ARRAY_CPU_ROWWISE_PICK_H_

#include <dgl/array.h>
10
#include <dgl/runtime/parallel_for.h>
11
12
#include <dmlc/omp.h>

13
#include <algorithm>
14
15
#include <functional>
#include <memory>
16
#include <string>
17
#include <utility>
18
#include <vector>
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

namespace dgl {
namespace aten {
namespace impl {

// User-defined function for picking elements from one row.
//
// The column indices of the given row are stored in
//   [col + off, col + off + len)
//
// Similarly, the data indices are stored in
//   [data + off, data + off + len)
// Data index pointer could be NULL, which means data[i] == i
//
// *ATTENTION*: This function will be invoked concurrently. Please make sure
// it is thread-safe.
//
36
37
38
39
40
41
42
// @param rowid The row to pick from.
// @param off Starting offset of this row.
// @param len NNZ of the row.
// @param num_picks Number of picks on the row.
// @param col Pointer of the column indices.
// @param data Pointer of the data indices.
// @param out_idx Picked indices in [off, off + len).
43
44
template <typename IdxType>
using PickFn = std::function<void(
45
    IdxType rowid, IdxType off, IdxType len, IdxType num_picks,
46
    const IdxType* col, const IdxType* data, IdxType* out_idx)>;
47

48
49
// User-defined function for determining the number of elements to pick from one
// row.
50
51
52
53
54
55
56
57
58
59
60
//
// The column indices of the given row are stored in
//   [col + off, col + off + len)
//
// Similarly, the data indices are stored in
//   [data + off, data + off + len)
// Data index pointer could be NULL, which means data[i] == i
//
// *ATTENTION*: This function will be invoked concurrently. Please make sure
// it is thread-safe.
//
61
62
63
64
65
// @param rowid The row to pick from.
// @param off Starting offset of this row.
// @param len NNZ of the row.
// @param col Pointer of the column indices.
// @param data Pointer of the data indices.
66
67
template <typename IdxType>
using NumPicksFn = std::function<IdxType(
68
69
    IdxType rowid, IdxType off, IdxType len, const IdxType* col,
    const IdxType* data)>;
70

71
72
73
74
75
76
77
// User-defined function for picking elements from a range within a row.
//
// The column indices of each element is in
//   off + et_idx[et_offset+i]), where i is in [et_offset, et_offset+et_len)
//
// Similarly, the data indices are stored in
//   data[off+et_idx[et_offset+i])]
78
79
// Data index pointer could be NULL, which means data[i] ==
// off+et_idx[et_offset+i])
80
81
82
83
//
// *ATTENTION*: This function will be invoked concurrently. Please make sure
// it is thread-safe.
//
84
85
86
87
88
89
90
91
// @param off Starting offset of this row.
// @param et_offset Starting offset of this range.
// @param cur_et The edge type.
// @param et_len Length of the range.
// @param et_idx A map from local idx to column id.
// @param et_eid Edge-type-specific id array.
// @param eid Pointer of the homogenized edge id array.
// @param out_idx Picked indices in [et_offset, et_offset + et_len).
92
template <typename IdxType>
93
using EtypeRangePickFn = std::function<void(
94
    IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,
95
96
    const std::vector<IdxType>& et_idx, const std::vector<IdxType>& et_eid,
    const IdxType* eid, IdxType* out_idx)>;
97

98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
template <typename IdxType, bool map_seed_nodes>
std::pair<CSRMatrix, IdArray> CSRRowWisePickFused(
    CSRMatrix mat, IdArray rows, IdArray seed_mapping,
    std::vector<IdxType>* new_seed_nodes, int64_t num_picks, bool replace,
    PickFn<IdxType> pick_fn, NumPicksFn<IdxType> num_picks_fn) {
  using namespace aten;

  const IdxType* indptr = static_cast<IdxType*>(mat.indptr->data);
  const IdxType* indices = static_cast<IdxType*>(mat.indices->data);
  const IdxType* data =
      CSRHasData(mat) ? static_cast<IdxType*>(mat.data->data) : nullptr;
  const IdxType* rows_data = static_cast<IdxType*>(rows->data);
  const int64_t num_rows = rows->shape[0];
  const auto& ctx = mat.indptr->ctx;
  const auto& idtype = mat.indptr->dtype;
  IdxType* seed_mapping_data = nullptr;
  if (map_seed_nodes) seed_mapping_data = seed_mapping.Ptr<IdxType>();

  const int num_threads = runtime::compute_num_threads(0, num_rows, 1);
  std::vector<int64_t> global_prefix(num_threads + 1, 0);

  IdArray picked_col, picked_idx, picked_coo_rows;

  IdArray block_csr_indptr = IdArray::Empty({num_rows + 1}, idtype, ctx);
  IdxType* block_csr_indptr_data = block_csr_indptr.Ptr<IdxType>();

#pragma omp parallel num_threads(num_threads)
  {
    const int thread_id = omp_get_thread_num();

    const int64_t start_i =
        thread_id * (num_rows / num_threads) +
        std::min(static_cast<int64_t>(thread_id), num_rows % num_threads);
    const int64_t end_i =
        (thread_id + 1) * (num_rows / num_threads) +
        std::min(static_cast<int64_t>(thread_id + 1), num_rows % num_threads);
    assert(thread_id + 1 < num_threads || end_i == num_rows);

    const int64_t num_local = end_i - start_i;

    std::unique_ptr<int64_t[]> local_prefix(new int64_t[num_local + 1]);
    local_prefix[0] = 0;
    for (int64_t i = start_i; i < end_i; ++i) {
      // build prefix-sum
      const int64_t local_i = i - start_i;
      const IdxType rid = rows_data[i];
      if (map_seed_nodes) seed_mapping_data[rid] = i;

      IdxType len = num_picks_fn(
          rid, indptr[rid], indptr[rid + 1] - indptr[rid], indices, data);
      local_prefix[local_i + 1] = local_prefix[local_i] + len;
    }
    global_prefix[thread_id + 1] = local_prefix[num_local];

#pragma omp barrier
#pragma omp master
    {
      for (int t = 0; t < num_threads; ++t) {
        global_prefix[t + 1] += global_prefix[t];
      }
      picked_col = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);
      picked_idx = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);
      picked_coo_rows =
          IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);
    }

#pragma omp barrier
    IdxType* picked_cdata = picked_col.Ptr<IdxType>();
    IdxType* picked_idata = picked_idx.Ptr<IdxType>();
    IdxType* picked_rows = picked_coo_rows.Ptr<IdxType>();

    const IdxType thread_offset = global_prefix[thread_id];

    for (int64_t i = start_i; i < end_i; ++i) {
      const IdxType rid = rows_data[i];
      const int64_t local_i = i - start_i;
      block_csr_indptr_data[i] = local_prefix[local_i] + thread_offset;

      const IdxType off = indptr[rid];
      const IdxType len = indptr[rid + 1] - off;
      if (len == 0) continue;

      const int64_t row_offset = local_prefix[local_i] + thread_offset;
      const int64_t num_picks =
          local_prefix[local_i + 1] + thread_offset - row_offset;

      pick_fn(
          rid, off, len, num_picks, indices, data, picked_idata + row_offset);
      for (int64_t j = 0; j < num_picks; ++j) {
        const IdxType picked = picked_idata[row_offset + j];
        picked_cdata[row_offset + j] = indices[picked];
        picked_idata[row_offset + j] = data ? data[picked] : picked;
        picked_rows[row_offset + j] = i;
      }
    }
  }
  block_csr_indptr_data[num_rows] = global_prefix.back();

  const IdxType num_cols = picked_col->shape[0];
  if (map_seed_nodes) {
    (*new_seed_nodes).resize(num_rows);
    memcpy((*new_seed_nodes).data(), rows_data, sizeof(IdxType) * num_rows);
  }

  return std::make_pair(
      CSRMatrix(num_rows, num_cols, block_csr_indptr, picked_col, picked_idx),
      picked_coo_rows);
}

207
// Template for picking non-zero values row-wise. The implementation utilizes
208
209
// OpenMP parallelization on rows because each row performs computation
// independently.
210
template <typename IdxType>
211
212
213
COOMatrix CSRRowWisePick(
    CSRMatrix mat, IdArray rows, int64_t num_picks, bool replace,
    PickFn<IdxType> pick_fn, NumPicksFn<IdxType> num_picks_fn) {
214
215
216
  using namespace aten;
  const IdxType* indptr = static_cast<IdxType*>(mat.indptr->data);
  const IdxType* indices = static_cast<IdxType*>(mat.indices->data);
217
218
  const IdxType* data =
      CSRHasData(mat) ? static_cast<IdxType*>(mat.data->data) : nullptr;
219
220
221
  const IdxType* rows_data = static_cast<IdxType*>(rows->data);
  const int64_t num_rows = rows->shape[0];
  const auto& ctx = mat.indptr->ctx;
222
  const auto& idtype = mat.indptr->dtype;
223
224
225
226
227
228
229
230
231

  // To leverage OMP parallelization, we create two arrays to store
  // picked src and dst indices. Each array is of length num_rows * num_picks.
  // For rows whose nnz < num_picks, the indices are padded with -1.
  //
  // We check whether all the given rows
  // have at least num_picks number of nnz when replace is false.
  //
  // If the check holds, remove -1 elements by remove_if operation, which simply
232
233
234
  // moves valid elements to the head of arrays and create a view of the
  // original array. The implementation consumes a little extra memory than the
  // actual requirement.
235
  //
236
237
  // Otherwise, directly use the row and col arrays to construct the result COO
  // matrix.
238
  //
239
240
  // [02/29/2020 update]: OMP is disabled for now since batch-wise parallelism
  // is more
241
  //   significant. (minjie)
242

243
244
  // Do not use omp_get_max_threads() since that doesn't work for compiling
  // without OpenMP.
245
246
247
  const int num_threads = runtime::compute_num_threads(0, num_rows, 1);
  std::vector<int64_t> global_prefix(num_threads + 1, 0);

248
249
250
251
  // TODO(BarclayII) Using OMP parallel directly instead of using
  // runtime::parallel_for does not handle exceptions well (directly aborts when
  // an exception pops up). It runs faster though because there is less
  // scheduling.  Need to handle exceptions better.
252
  IdArray picked_row, picked_col, picked_idx;
253
254
255
256
#pragma omp parallel num_threads(num_threads)
  {
    const int thread_id = omp_get_thread_num();

257
258
    const int64_t start_i =
        thread_id * (num_rows / num_threads) +
259
        std::min(static_cast<int64_t>(thread_id), num_rows % num_threads);
260
261
    const int64_t end_i =
        (thread_id + 1) * (num_rows / num_threads) +
262
263
264
265
        std::min(static_cast<int64_t>(thread_id + 1), num_rows % num_threads);
    assert(thread_id + 1 < num_threads || end_i == num_rows);

    const int64_t num_local = end_i - start_i;
266

267
268
269
270
271
    // make sure we don't have to pay initialization cost
    std::unique_ptr<int64_t[]> local_prefix(new int64_t[num_local + 1]);
    local_prefix[0] = 0;
    for (int64_t i = start_i; i < end_i; ++i) {
      // build prefix-sum
272
      const int64_t local_i = i - start_i;
273
      const IdxType rid = rows_data[i];
274
275
      IdxType len = num_picks_fn(
          rid, indptr[rid], indptr[rid + 1] - indptr[rid], indices, data);
276
277
278
279
      local_prefix[local_i + 1] = local_prefix[local_i] + len;
    }
    global_prefix[thread_id + 1] = local_prefix[num_local];

280
281
#pragma omp barrier
#pragma omp master
282
283
    {
      for (int t = 0; t < num_threads; ++t) {
284
        global_prefix[t + 1] += global_prefix[t];
285
      }
286
287
288
      picked_row = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);
      picked_col = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);
      picked_idx = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);
289
    }
290

291
#pragma omp barrier
292
293
294
295
    IdxType* picked_rdata = picked_row.Ptr<IdxType>();
    IdxType* picked_cdata = picked_col.Ptr<IdxType>();
    IdxType* picked_idata = picked_idx.Ptr<IdxType>();

296
    const IdxType thread_offset = global_prefix[thread_id];
297

298
299
300
301
302
    for (int64_t i = start_i; i < end_i; ++i) {
      const IdxType rid = rows_data[i];

      const IdxType off = indptr[rid];
      const IdxType len = indptr[rid + 1] - off;
303
      if (len == 0) continue;
304
305
306

      const int64_t local_i = i - start_i;
      const int64_t row_offset = thread_offset + local_prefix[local_i];
307
308
      const int64_t num_picks =
          thread_offset + local_prefix[local_i + 1] - row_offset;
309

310
311
      pick_fn(
          rid, off, len, num_picks, indices, data, picked_idata + row_offset);
312
313
314
315
316
      for (int64_t j = 0; j < num_picks; ++j) {
        const IdxType picked = picked_idata[row_offset + j];
        picked_rdata[row_offset + j] = rid;
        picked_cdata[row_offset + j] = indices[picked];
        picked_idata[row_offset + j] = data ? data[picked] : picked;
317
318
      }
    }
319
320
  }

321
322
  const int64_t new_len = global_prefix.back();

323
  return COOMatrix(
324
      mat.num_rows, mat.num_cols,
325
326
327
      picked_row.CreateView({new_len}, picked_row->dtype),
      picked_col.CreateView({new_len}, picked_row->dtype),
      picked_idx.CreateView({new_len}, picked_row->dtype));
328
329
}

330
// Template for picking non-zero values row-wise. The implementation utilizes
331
332
// OpenMP parallelization on rows because each row performs computation
// independently.
333
template <typename IdxType, typename DType>
334
335
336
337
338
COOMatrix CSRRowWisePerEtypePick(
    CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
    const std::vector<int64_t>& num_picks, bool replace,
    bool rowwise_etype_sorted, EtypeRangePickFn<IdxType> pick_fn,
    const std::vector<NDArray>& prob_or_mask) {
339
  using namespace aten;
340
341
  const IdxType* indptr = mat.indptr.Ptr<IdxType>();
  const IdxType* indices = mat.indices.Ptr<IdxType>();
342
  const IdxType* eid = CSRHasData(mat) ? mat.data.Ptr<IdxType>() : nullptr;
343
  const IdxType* rows_data = rows.Ptr<IdxType>();
344
345
  const int64_t num_rows = rows->shape[0];
  const auto& ctx = mat.indptr->ctx;
346
  const int64_t num_etypes = num_picks.size();
347
  const bool has_probs = (prob_or_mask.size() > 0);
348
349
350
351
  std::vector<IdArray> picked_rows(rows->shape[0]);
  std::vector<IdArray> picked_cols(rows->shape[0]);
  std::vector<IdArray> picked_idxs(rows->shape[0]);

352
  // Check if the number of picks have the same value.
353
354
  // If so, we can potentially speed up if we have a node with total number of
  // neighbors less than the given number of picks with replace=False.
355
356
357
358
359
360
  bool same_num_pick = true;
  int64_t num_pick_value = num_picks[0];
  for (int64_t num_pick : num_picks) {
    if (num_pick_value != num_pick) {
      same_num_pick = false;
      break;
361
    }
362
  }
363

364
  runtime::parallel_for(0, num_rows, [&](size_t b, size_t e) {
365
    for (size_t i = b; i < e; ++i) {
366
367
368
369
370
371
372
373
374
375
376
      const IdxType rid = rows_data[i];
      CHECK_LT(rid, mat.num_rows);
      const IdxType off = indptr[rid];
      const IdxType len = indptr[rid + 1] - off;

      // do something here
      if (len == 0) {
        picked_rows[i] = NewIdArray(0, ctx, sizeof(IdxType) * 8);
        picked_cols[i] = NewIdArray(0, ctx, sizeof(IdxType) * 8);
        picked_idxs[i] = NewIdArray(0, ctx, sizeof(IdxType) * 8);
        continue;
377
      }
378
379
380
381
382
383
384
385

      // fast path
      if (same_num_pick && len <= num_pick_value && !replace) {
        IdArray rows = Full(rid, len, sizeof(IdxType) * 8, ctx);
        IdArray cols = Full(-1, len, sizeof(IdxType) * 8, ctx);
        IdArray idx = Full(-1, len, sizeof(IdxType) * 8, ctx);
        IdxType* cdata = cols.Ptr<IdxType>();
        IdxType* idata = idx.Ptr<IdxType>();
386
387

        int64_t k = 0;
388
        for (int64_t j = 0; j < len; ++j) {
389
390
          const IdxType homogenized_eid = eid ? eid[off + j] : off + j;
          auto it = std::upper_bound(
391
392
              eid2etype_offset.begin(), eid2etype_offset.end(),
              homogenized_eid);
393
          const IdxType heterogenized_etype = it - eid2etype_offset.begin() - 1;
394
          const IdxType heterogenized_eid =
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
              homogenized_eid - eid2etype_offset[heterogenized_etype];

          if (!has_probs || IsNullArray(prob_or_mask[heterogenized_etype])) {
            // No probability array, select all
            cdata[k] = indices[off + j];
            idata[k] = homogenized_eid;
            ++k;
          } else {
            // Select the entries with non-zero probability
            const NDArray& p = prob_or_mask[heterogenized_etype];
            const DType* pdata = p.Ptr<DType>();
            if (pdata[heterogenized_eid] > 0) {
              cdata[k] = indices[off + j];
              idata[k] = homogenized_eid;
              ++k;
            }
          }
412
        }
413
414
415
416

        picked_rows[i] = rows.CreateView({k}, rows->dtype);
        picked_cols[i] = cols.CreateView({k}, cols->dtype);
        picked_idxs[i] = idx.CreateView({k}, idx->dtype);
417
418
419
420
421
422
423
424
      } else {
        // need to do per edge type sample
        std::vector<IdxType> rows;
        std::vector<IdxType> cols;
        std::vector<IdxType> idx;

        std::vector<IdxType> et(len);
        std::vector<IdxType> et_idx(len);
425
        std::vector<IdxType> et_eid(len);
426
427
        std::iota(et_idx.begin(), et_idx.end(), 0);
        for (int64_t j = 0; j < len; ++j) {
428
429
          const IdxType homogenized_eid = eid ? eid[off + j] : off + j;
          auto it = std::upper_bound(
430
431
              eid2etype_offset.begin(), eid2etype_offset.end(),
              homogenized_eid);
432
          const IdxType heterogenized_etype = it - eid2etype_offset.begin() - 1;
433
          const IdxType heterogenized_eid =
434
435
436
              homogenized_eid - eid2etype_offset[heterogenized_etype];
          et[j] = heterogenized_etype;
          et_eid[j] = heterogenized_eid;
437
        }
438
439
440
441
442
443
444
        if (!rowwise_etype_sorted)  // the edge type is sorted, not need to sort
                                    // it
          std::sort(
              et_idx.begin(), et_idx.end(),
              [&et](IdxType i1, IdxType i2) { return et[i1] < et[i2]; });
        CHECK_LT(et[et_idx[len - 1]], num_etypes)
            << "etype values exceed the number of fanouts";
445
446
447
448
449

        IdxType cur_et = et[et_idx[0]];
        int64_t et_offset = 0;
        int64_t et_len = 1;
        for (int64_t j = 0; j < len; ++j) {
450
451
          CHECK((j + 1 == len) || (et[et_idx[j]] <= et[et_idx[j + 1]]))
              << "Edge type is not sorted. Please sort in advance or specify "
452
                 "'rowwise_etype_sorted' as false.";
453
          if ((j + 1 == len) || cur_et != et[et_idx[j + 1]]) {
454
455
456
            // 1 end of the current etype
            // 2 end of the row
            // random pick for current etype
457
458
            if ((num_picks[cur_et] == -1) ||
                (et_len <= num_picks[cur_et] && !replace)) {
459
460
              // fast path, select all
              for (int64_t k = 0; k < et_len; ++k) {
461
                const IdxType eid_offset = off + et_idx[et_offset + k];
462
463
                const IdxType homogenized_eid =
                    eid ? eid[eid_offset] : eid_offset;
464
                auto it = std::upper_bound(
465
466
467
468
469
                    eid2etype_offset.begin(), eid2etype_offset.end(),
                    homogenized_eid);
                const IdxType heterogenized_etype =
                    it - eid2etype_offset.begin() - 1;
                const IdxType heterogenized_eid =
470
471
                    homogenized_eid - eid2etype_offset[heterogenized_etype];

472
473
                if (!has_probs ||
                    IsNullArray(prob_or_mask[heterogenized_etype])) {
474
475
476
477
478
479
480
481
482
483
484
485
486
487
                  // No probability, select all
                  rows.push_back(rid);
                  cols.push_back(indices[eid_offset]);
                  idx.push_back(homogenized_eid);
                } else {
                  // Select the entries with non-zero probability
                  const NDArray& p = prob_or_mask[heterogenized_etype];
                  const DType* pdata = p.Ptr<DType>();
                  if (pdata[heterogenized_eid] > 0) {
                    rows.push_back(rid);
                    cols.push_back(indices[eid_offset]);
                    idx.push_back(homogenized_eid);
                  }
                }
488
489
              }
            } else {
490
491
              IdArray picked_idx =
                  Full(-1, num_picks[cur_et], sizeof(IdxType) * 8, ctx);
492
              IdxType* picked_idata = picked_idx.Ptr<IdxType>();
493
494

              // need call random pick
495
496
497
              pick_fn(
                  off, et_offset, cur_et, et_len, et_idx, et_eid, eid,
                  picked_idata);
498
499
              for (int64_t k = 0; k < num_picks[cur_et]; ++k) {
                const IdxType picked = picked_idata[k];
500
                if (picked == -1) continue;
501
                rows.push_back(rid);
502
                cols.push_back(indices[off + et_idx[et_offset + picked]]);
503
                if (eid) {
504
                  idx.push_back(eid[off + et_idx[et_offset + picked]]);
505
                } else {
506
                  idx.push_back(off + et_idx[et_offset + picked]);
507
                }
508
              }
509
            }
510

511
            if (j + 1 == len) break;
512
            // next etype
513
514
            cur_et = et[et_idx[j + 1]];
            et_offset = j + 1;
515
            et_len = 1;
516
          } else {
517
            et_len++;
518
519
520
          }
        }

521
522
523
524
        picked_rows[i] = VecToIdArray(rows, sizeof(IdxType) * 8, ctx);
        picked_cols[i] = VecToIdArray(cols, sizeof(IdxType) * 8, ctx);
        picked_idxs[i] = VecToIdArray(idx, sizeof(IdxType) * 8, ctx);
      }  // end processing one row
525

526
527
528
529
      CHECK_EQ(picked_rows[i]->shape[0], picked_cols[i]->shape[0]);
      CHECK_EQ(picked_rows[i]->shape[0], picked_idxs[i]->shape[0]);
    }  // end processing all rows
  });
530
531
532
533

  IdArray picked_row = Concat(picked_rows);
  IdArray picked_col = Concat(picked_cols);
  IdArray picked_idx = Concat(picked_idxs);
534
535
  return COOMatrix(
      mat.num_rows, mat.num_cols, picked_row, picked_col, picked_idx);
536
537
}

538
539
540
// Template for picking non-zero values row-wise. The implementation first
// slices out the corresponding rows and then converts it to CSR format. It then
// performs row-wise pick on the CSR matrix and rectifies the returned results.
541
template <typename IdxType>
542
543
544
COOMatrix COORowWisePick(
    COOMatrix mat, IdArray rows, int64_t num_picks, bool replace,
    PickFn<IdxType> pick_fn, NumPicksFn<IdxType> num_picks_fn) {
545
546
  using namespace aten;
  const auto& csr = COOToCSR(COOSliceRows(mat, rows));
547
548
  const IdArray new_rows =
      Range(0, rows->shape[0], rows->dtype.bits, rows->ctx);
549
550
  const auto& picked = CSRRowWisePick<IdxType>(
      csr, new_rows, num_picks, replace, pick_fn, num_picks_fn);
551
552
553
554
  return COOMatrix(
      mat.num_rows, mat.num_cols,
      IndexSelect(rows, picked.row),  // map the row index to the correct one
      picked.col, picked.data);
555
556
}

557
558
559
// Template for picking non-zero values row-wise. The implementation first
// slices out the corresponding rows and then converts it to CSR format. It then
// performs row-wise pick on the CSR matrix and rectifies the returned results.
560
561
562
563
564
565
template <typename IdxType, typename DType>
COOMatrix COORowWisePerEtypePick(
    COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
    const std::vector<int64_t>& num_picks, bool replace,
    EtypeRangePickFn<IdxType> pick_fn,
    const std::vector<NDArray>& prob_or_mask) {
566
567
  using namespace aten;
  const auto& csr = COOToCSR(COOSliceRows(mat, rows));
568
569
  const IdArray new_rows =
      Range(0, rows->shape[0], rows->dtype.bits, rows->ctx);
570
  const auto& picked = CSRRowWisePerEtypePick<IdxType, DType>(
571
572
573
574
575
576
      csr, new_rows, eid2etype_offset, num_picks, replace, false, pick_fn,
      prob_or_mask);
  return COOMatrix(
      mat.num_rows, mat.num_cols,
      IndexSelect(rows, picked.row),  // map the row index to the correct one
      picked.col, picked.data);
577
578
}

579
580
581
582
583
}  // namespace impl
}  // namespace aten
}  // namespace dgl

#endif  // DGL_ARRAY_CPU_ROWWISE_PICK_H_