rowwise_sampling_prob.cu 24.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
/*!
 *  Copyright (c) 2022 by Contributors
 * \file array/cuda/rowwise_sampling_prob.cu
 * \brief weighted rowwise sampling. The degree computing kernels and
 * host-side functions are partially borrowed from the uniform rowwise
 * sampling code rowwise_sampling.cu.
 * \author pengqirong (OPPO), dlasalle and Xin from Nvidia.
 */
#include <dgl/random.h>
#include <dgl/runtime/device_api.h>
#include <curand_kernel.h>
#include <numeric>

#include "./dgl_cub.cuh"
#include "../../array/cuda/atomic.cuh"
#include "../../runtime/cuda/cuda_common.h"

// require CUB 1.17 to use DeviceSegmentedSort
static_assert(CUB_VERSION >= 101700);

using namespace dgl::aten::cuda;

namespace dgl {
namespace aten {
namespace impl {

namespace {

constexpr int BLOCK_SIZE = 128;

/**
* @brief Compute the size of each row in the sampled CSR, without replacement.
* temp_deg is calculated for rows with deg > num_picks.
* For these rows, we will calculate their A-Res values and sort them to get top-num_picks.
*
* @tparam IdType The type of node and edge indexes.
* @param num_picks The number of non-zero entries to pick per row.
* @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick.
* @param in_ptr The index where each row's edges start.
* @param out_deg The size of each row in the sampled matrix, as indexed by `in_rows` (output).
* @param temp_deg The size of each row in the input matrix, as indexed by `in_rows` (output).
*/
template<typename IdType>
__global__ void _CSRRowWiseSampleDegreeKernel(
    const int64_t num_picks,
    const int64_t num_rows,
    const IdType * const in_rows,
    const IdType * const in_ptr,
    IdType * const out_deg,
    IdType * const temp_deg) {
  const int64_t tIdx = threadIdx.x + blockIdx.x * blockDim.x;

  if (tIdx < num_rows) {
    const int64_t in_row = in_rows[tIdx];
    const int64_t out_row = tIdx;
    const IdType deg = in_ptr[in_row + 1] - in_ptr[in_row];
    // temp_deg is used to generate ares_ptr
    temp_deg[out_row] = deg > static_cast<IdType>(num_picks) ? deg : 0;
    out_deg[out_row] = min(static_cast<IdType>(num_picks), deg);

    if (out_row == num_rows - 1) {
      // make the prefixsum work
      out_deg[num_rows] = 0;
      temp_deg[num_rows] = 0;
    }
  }
}

/**
* @brief Compute the size of each row in the sampled CSR, with replacement.
* We need the actual in degree of each row to store CDF values.
*
* @tparam IdType The type of node and edge indexes.
* @param num_picks The number of non-zero entries to pick per row.
* @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick.
* @param in_ptr The index where each row's edges start.
* @param out_deg The size of each row in the sampled matrix, as indexed by `in_rows` (output).
* @param temp_deg The size of each row in the input matrix, as indexed by `in_rows` (output).
*/
template<typename IdType>
__global__ void _CSRRowWiseSampleDegreeReplaceKernel(
    const int64_t num_picks,
    const int64_t num_rows,
    const IdType * const in_rows,
    const IdType * const in_ptr,
    IdType * const out_deg,
    IdType * const temp_deg) {
  const int64_t tIdx = threadIdx.x + blockIdx.x * blockDim.x;

  if (tIdx < num_rows) {
    const int64_t in_row = in_rows[tIdx];
    const int64_t out_row = tIdx;
    const IdType deg = in_ptr[in_row + 1] - in_ptr[in_row];
    temp_deg[out_row] = deg;
    out_deg[out_row] = deg == 0 ? 0 : static_cast<IdType>(num_picks);

    if (out_row == num_rows - 1) {
      // make the prefixsum work
      out_deg[num_rows] = 0;
      temp_deg[num_rows] = 0;
    }
  }
}

/**
* @brief Equivalent to numpy expression: array[idx[off:off + len]]
*
* @tparam IdType The ID type used for indices.
* @tparam FloatType The float type used for array values.
* @param array The array to be selected.
* @param idx_data The index mapping array.
* @param index The index of value to be selected.
* @param offset The offset to start.
* @param out The selected value (output).
*/
template<typename IdType, typename FloatType>
__device__ void _DoubleSlice(
    const FloatType * const array,
    const IdType * const idx_data,
    const IdType idx,
    const IdType offset,
    FloatType* const out) {
  if (idx_data) {
    *out = array[idx_data[offset + idx]];
  } else {
    *out = array[offset + idx];
  }
}

/**
* @brief Compute A-Res value. A-Res value needs to be calculated only if deg 
* is greater than num_picks in weighted rowwise sampling without replacement.
*
* @tparam IdType The ID type used for matrices.
* @tparam FloatType The Float type used for matrices.
* @tparam TILE_SIZE The number of rows covered by each threadblock.
* @param rand_seed The random seed to use.
* @param num_picks The number of non-zeros to pick per row.
* @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick.
* @param in_ptr The indptr array of the input CSR.
* @param data The data array of the input CSR.
* @param prob The probability array of the input CSR.
* @param ares_ptr The offset to write each row to in the A-res array.
* @param ares_idxs The A-Res value corresponding index array, the index of input CSR (output).
* @param ares The A-Res value array (output).
* @author pengqirong (OPPO)
*/
template<typename IdType, typename FloatType, int TILE_SIZE>
__global__ void _CSRAResValueKernel(
    const uint64_t rand_seed,
    const int64_t num_picks,
    const int64_t num_rows,
    const IdType * const in_rows,
    const IdType * const in_ptr,
    const IdType * const data,
    const FloatType * const prob,
    const IdType * const ares_ptr,
    IdType * const ares_idxs,
    FloatType * const ares) {

  int64_t out_row = blockIdx.x * TILE_SIZE;
  const int64_t last_row = min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);

  curandStatePhilox4_32_10_t rng;
  curand_init(rand_seed * gridDim.x + blockIdx.x, threadIdx.x, 0, &rng);

  while (out_row < last_row) {
    const int64_t row = in_rows[out_row];
    const int64_t in_row_start = in_ptr[row];
    const int64_t deg = in_ptr[row + 1] - in_row_start;
    // A-Res value needs to be calculated only if deg is greater than num_picks
    // in weighted rowwise sampling without replacement
    if (deg > num_picks) {
      const int64_t ares_row_start = ares_ptr[out_row];

      for (int64_t idx = threadIdx.x; idx < deg; idx += BLOCK_SIZE) {
        const int64_t in_idx = in_row_start + idx;
        const int64_t ares_idx = ares_row_start + idx;
        FloatType item_prob;
        _DoubleSlice<IdType, FloatType>(prob, data, idx, in_row_start, &item_prob);
        // compute A-Res value
        ares[ares_idx] = static_cast<FloatType>(__powf(curand_uniform(&rng), 1.0f / item_prob));
        ares_idxs[ares_idx] = static_cast<IdType>(in_idx);
      }
    }
    out_row += 1;
  }
}


/**
* @brief Perform weighted row-wise sampling on a CSR matrix, and generate a COO matrix,
* without replacement. After sorting, we select top-num_picks items.
*
* @tparam IdType The ID type used for matrices.
* @tparam FloatType The Float type used for matrices.
* @tparam TILE_SIZE The number of rows covered by each threadblock.
* @param num_picks The number of non-zeros to pick per row.
* @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick.
* @param in_ptr The indptr array of the input CSR.
* @param in_cols The columns array of the input CSR.
* @param data The data array of the input CSR.
* @param out_ptr The offset to write each row to in the output COO.
* @param ares_ptr The offset to write each row to in the ares array.
* @param sort_ares_idxs The sorted A-Res value corresponding index array, the index of input CSR.
* @param out_rows The rows of the output COO (output).
* @param out_cols The columns of the output COO (output).
* @param out_idxs The data array of the output COO (output).
* @author pengqirong (OPPO)
*/
template<typename IdType, typename FloatType, int TILE_SIZE>
__global__ void _CSRRowWiseSampleKernel(
    const int64_t num_picks,
    const int64_t num_rows,
    const IdType * const in_rows,
    const IdType * const in_ptr,
    const IdType * const in_cols,
    const IdType * const data,
    const IdType * const out_ptr,
    const IdType * const ares_ptr,
    const IdType * const sort_ares_idxs,
    IdType * const out_rows,
    IdType * const out_cols,
    IdType * const out_idxs) {
  // we assign one warp per row
  assert(blockDim.x == BLOCK_SIZE);

  int64_t out_row = blockIdx.x * TILE_SIZE;
  const int64_t last_row = min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);

  while (out_row < last_row) {
    const int64_t row = in_rows[out_row];
    const int64_t in_row_start = in_ptr[row];
    const int64_t out_row_start = out_ptr[out_row];
    const int64_t deg = in_ptr[row + 1] - in_row_start;

    if (deg > num_picks) {
      const int64_t ares_row_start = ares_ptr[out_row];
      for (int64_t idx = threadIdx.x; idx < num_picks; idx += BLOCK_SIZE) {
        // get in and out index, the in_idx is one of top num_picks A-Res value
        // corresponding index in input CSR.
        const int64_t out_idx = out_row_start + idx;
        const int64_t ares_idx = ares_row_start + idx;
        const int64_t in_idx = sort_ares_idxs[ares_idx];
        // copy permutation over
        out_rows[out_idx] = static_cast<IdType>(row);
        out_cols[out_idx] = in_cols[in_idx];
        out_idxs[out_idx] = static_cast<IdType>(data ? data[in_idx] : in_idx);
      }
    } else {
      for (int64_t idx = threadIdx.x; idx < deg; idx += BLOCK_SIZE) {
        // get in and out index
        const int64_t out_idx = out_row_start + idx;
        const int64_t in_idx = in_row_start + idx;
        // copy permutation over
        out_rows[out_idx] = static_cast<IdType>(row);
        out_cols[out_idx] = in_cols[in_idx];
        out_idxs[out_idx] = static_cast<IdType>(data ? data[in_idx] : in_idx);
      }
    }
    out_row += 1;
  }
}


// A stateful callback functor that maintains a running prefix to be applied
// during consecutive scan operations.
template<typename FloatType>
struct BlockPrefixCallbackOp {
    // Running prefix
    FloatType running_total;
    // Constructor
    __device__ BlockPrefixCallbackOp(FloatType running_total) : running_total(running_total) {}
    // Callback operator to be entered by the first warp of threads in the block.
    // Thread-0 is responsible for returning a value for seeding the block-wide scan.
    __device__ FloatType operator()(FloatType block_aggregate) {
      FloatType old_prefix = running_total;
      running_total += block_aggregate;
      return old_prefix;
    }
};

/**
* @brief Perform weighted row-wise sampling on a CSR matrix, and generate a COO matrix,
* with replacement. We store the CDF (unnormalized) of all neighbors of a row
* in global memory and use binary search to find inverse indices as selected items.
*
* @tparam IdType The ID type used for matrices.
* @tparam FloatType The Float type used for matrices.
* @tparam TILE_SIZE The number of rows covered by each threadblock.
* @param rand_seed The random seed to use.
* @param num_picks The number of non-zeros to pick per row.
* @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick.
* @param in_ptr The indptr array of the input CSR.
* @param in_cols The columns array of the input CSR.
* @param data The data array of the input CSR.
* @param prob The probability array of the input CSR.
* @param out_ptr The offset to write each row to in the output COO.
* @param cdf_ptr The offset of each cdf segment.
* @param cdf The global buffer to store cdf segments.
* @param out_rows The rows of the output COO (output).
* @param out_cols The columns of the output COO (output).
* @param out_idxs The data array of the output COO (output).
* @author pengqirong (OPPO)
*/
template<typename IdType, typename FloatType, int TILE_SIZE>
__global__ void _CSRRowWiseSampleReplaceKernel(
    const uint64_t rand_seed,
    const int64_t num_picks,
    const int64_t num_rows,
    const IdType * const in_rows,
    const IdType * const in_ptr,
    const IdType * const in_cols,
    const IdType * const data,
    const FloatType * const prob,
    const IdType * const out_ptr,
    const IdType * const cdf_ptr,
    FloatType * const cdf,
    IdType * const out_rows,
    IdType * const out_cols,
    IdType * const out_idxs
) {
  // we assign one warp per row
  assert(blockDim.x == BLOCK_SIZE);

  int64_t out_row = blockIdx.x * TILE_SIZE;
  const int64_t last_row = min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);

  curandStatePhilox4_32_10_t rng;
  curand_init(rand_seed * gridDim.x + blockIdx.x, threadIdx.x, 0, &rng);

  while (out_row < last_row) {
    const int64_t row = in_rows[out_row];
    const int64_t in_row_start = in_ptr[row];
    const int64_t out_row_start = out_ptr[out_row];
    const int64_t cdf_row_start = cdf_ptr[out_row];
    const int64_t deg = in_ptr[row + 1] - in_row_start;
    const FloatType MIN_THREAD_DATA = static_cast<FloatType>(0.0f);

    if (deg > 0) {
      // Specialize BlockScan for a 1D block of BLOCK_SIZE threads
      typedef cub::BlockScan<FloatType, BLOCK_SIZE> BlockScan;
      // Allocate shared memory for BlockScan
      __shared__ typename BlockScan::TempStorage temp_storage;
      // Initialize running total
      BlockPrefixCallbackOp<FloatType> prefix_op(MIN_THREAD_DATA);

      int64_t max_iter = (1 + (deg - 1) / BLOCK_SIZE) * BLOCK_SIZE;
      // Have the block iterate over segments of items
      for (int64_t idx = threadIdx.x; idx < max_iter; idx += BLOCK_SIZE) {
        // Load a segment of consecutive items that are blocked across threads
        FloatType thread_data;
        if (idx < deg)
          _DoubleSlice<IdType, FloatType>(prob, data, idx, in_row_start, &thread_data);
        else
          thread_data = MIN_THREAD_DATA;
        thread_data = max(thread_data, MIN_THREAD_DATA);
        // Collectively compute the block-wide inclusive prefix sum
        BlockScan(temp_storage).InclusiveSum(thread_data, thread_data, prefix_op);
        __syncthreads();

        // Store scanned items to cdf array
        if (idx < deg) {
          cdf[cdf_row_start + idx] = thread_data;
        }
      }
      __syncthreads();

      for (int64_t idx = threadIdx.x; idx < num_picks; idx += BLOCK_SIZE) {
        // get random value
        FloatType sum = cdf[cdf_row_start + deg - 1];
        FloatType rand = static_cast<FloatType>(curand_uniform(&rng) * sum);
        // get the offset of the first value within cdf array which is greater than random value.
        int64_t item = cub::UpperBound<FloatType*, int64_t, FloatType>(
            &cdf[cdf_row_start], deg, rand);
        item = min(item, deg - 1);
        // get in and out index
        const int64_t in_idx = in_row_start + item;
        const int64_t out_idx = out_row_start + idx;
        // copy permutation over
        out_rows[out_idx] = static_cast<IdType>(row);
        out_cols[out_idx] = in_cols[in_idx];
        out_idxs[out_idx] = static_cast<IdType>(data ? data[in_idx] : in_idx);
      }
    }
    out_row += 1;
  }
}

}  // namespace

/////////////////////////////// CSR ///////////////////////////////

/**
* @brief Perform weighted row-wise sampling on a CSR matrix, and generate a COO matrix.
* Use CDF sampling algorithm for with replacement:
*   1) Calculate the CDF of all neighbor's prob.
*   2) For each [0, num_picks), generate a rand ~ U(0, 1).
*      Use binary search to find its index in the CDF array as a chosen item.
* Use A-Res sampling algorithm for without replacement:
*   1) For rows with deg > num_picks, calculate A-Res values for all neighbors.
*   2) Sort the A-Res array and select top-num_picks as chosen items.
*
* @tparam XPU The device type used for matrices.
* @tparam IdType The ID type used for matrices.
* @tparam FloatType The Float type used for matrices.
* @param mat The CSR matrix.
* @param rows The set of rows to pick.
* @param num_picks The number of non-zeros to pick per row.
* @param prob The probability array of the input CSR.
* @param replace Is replacement sampling?
* @author pengqirong (OPPO), dlasalle and Xin from Nvidia.
*/
template <DLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix CSRRowWiseSampling(CSRMatrix mat,
                             IdArray rows,
                             int64_t num_picks,
                             FloatArray prob,
                             bool replace) {
  const auto& ctx = rows->ctx;
  auto device = runtime::DeviceAPI::Get(ctx);

  // TODO(dlasalle): Once the device api supports getting the stream from the
  // context, that should be used instead of the default stream here.
  cudaStream_t stream = 0;

  const int64_t num_rows = rows->shape[0];
  const IdType * const slice_rows = static_cast<const IdType*>(rows->data);

  IdArray picked_row = NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
  IdArray picked_col = NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
  IdArray picked_idx = NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
  const IdType * const in_ptr = static_cast<const IdType*>(mat.indptr->data);
  const IdType * const in_cols = static_cast<const IdType*>(mat.indices->data);
  IdType* const out_rows = static_cast<IdType*>(picked_row->data);
  IdType* const out_cols = static_cast<IdType*>(picked_col->data);
  IdType* const out_idxs = static_cast<IdType*>(picked_idx->data);

  const IdType* const data = CSRHasData(mat) ?
      static_cast<IdType*>(mat.data->data) : nullptr;
  const FloatType* const prob_data = static_cast<const FloatType*>(prob->data);

  // compute degree
  // out_deg: the size of each row in the sampled matrix
  // temp_deg: the size of each row we will manipulate in sampling
  //    1) for w/o replacement: in degree if it's greater than num_picks else 0
  //    2) for w/ replacement: in degree
  IdType * out_deg = static_cast<IdType*>(
      device->AllocWorkspace(ctx, (num_rows + 1) * sizeof(IdType)));
  IdType * temp_deg = static_cast<IdType*>(
      device->AllocWorkspace(ctx, (num_rows + 1) * sizeof(IdType)));
  if (replace) {
    const dim3 block(512);
    const dim3 grid((num_rows + block.x - 1) / block.x);
    CUDA_KERNEL_CALL(
        _CSRRowWiseSampleDegreeReplaceKernel,
        grid, block, 0, stream,
        num_picks, num_rows, slice_rows, in_ptr, out_deg, temp_deg);
  } else {
    const dim3 block(512);
    const dim3 grid((num_rows + block.x - 1) / block.x);
    CUDA_KERNEL_CALL(
        _CSRRowWiseSampleDegreeKernel,
        grid, block, 0, stream,
        num_picks, num_rows, slice_rows, in_ptr, out_deg, temp_deg);
  }

  // fill temp_ptr
  IdType * temp_ptr = static_cast<IdType*>(
    device->AllocWorkspace(ctx, (num_rows + 1)*sizeof(IdType)));
  size_t prefix_temp_size = 0;
  CUDA_CALL(cub::DeviceScan::ExclusiveSum(nullptr, prefix_temp_size,
      temp_deg,
      temp_ptr,
      num_rows + 1,
      stream));
  void * prefix_temp = device->AllocWorkspace(ctx, prefix_temp_size);
  CUDA_CALL(cub::DeviceScan::ExclusiveSum(prefix_temp, prefix_temp_size,
      temp_deg,
      temp_ptr,
      num_rows + 1,
      stream));
  device->FreeWorkspace(ctx, prefix_temp);
  device->FreeWorkspace(ctx, temp_deg);

  // TODO(Xin): The copy here is too small, and the overhead of creating
  // cuda events cannot be ignored. Just use synchronized copy.
  IdType temp_len;
  device->CopyDataFromTo(temp_ptr, num_rows * sizeof(temp_len), &temp_len, 0,
      sizeof(temp_len),
      ctx,
      DGLContext{kDLCPU, 0},
      mat.indptr->dtype,
      stream);
  device->StreamSync(ctx, stream);

  // fill out_ptr
  IdType * out_ptr = static_cast<IdType*>(
      device->AllocWorkspace(ctx, (num_rows+1)*sizeof(IdType)));
  prefix_temp_size = 0;
  CUDA_CALL(cub::DeviceScan::ExclusiveSum(nullptr, prefix_temp_size,
      out_deg,
      out_ptr,
      num_rows+1,
      stream));
  prefix_temp = device->AllocWorkspace(ctx, prefix_temp_size);
  CUDA_CALL(cub::DeviceScan::ExclusiveSum(prefix_temp, prefix_temp_size,
      out_deg,
      out_ptr,
      num_rows+1,
      stream));
  device->FreeWorkspace(ctx, prefix_temp);
  device->FreeWorkspace(ctx, out_deg);

  cudaEvent_t copyEvent;
  CUDA_CALL(cudaEventCreate(&copyEvent));
  // TODO(dlasalle): use pinned memory to overlap with the actual sampling, and wait on
  // a cudaevent
  IdType new_len;
  device->CopyDataFromTo(out_ptr, num_rows * sizeof(new_len), &new_len, 0,
      sizeof(new_len),
      ctx,
      DGLContext{kDLCPU, 0},
      mat.indptr->dtype,
      stream);
  CUDA_CALL(cudaEventRecord(copyEvent, stream));

  // allocate workspace
  // 1) for w/ replacement, it's a global buffer to store cdf segments (one segment for each row).
  // 2) for w/o replacement, it's used to store a-res segments (one segment for
  //    each row with degree > num_picks)
  FloatType * temp = static_cast<FloatType*>(
      device->AllocWorkspace(ctx, temp_len * sizeof(FloatType)));

  const uint64_t rand_seed = RandomEngine::ThreadLocal()->RandInt(1000000000);

  // select edges
  // the number of rows each thread block will cover
  constexpr int TILE_SIZE = 128 / BLOCK_SIZE;
  if (replace) {  // with replacement.
    const dim3 block(BLOCK_SIZE);
    const dim3 grid((num_rows + TILE_SIZE - 1) / TILE_SIZE);
    CUDA_KERNEL_CALL(
        (_CSRRowWiseSampleReplaceKernel<IdType, FloatType, TILE_SIZE>),
        grid, block, 0, stream,
        rand_seed,
        num_picks,
        num_rows,
        slice_rows,
        in_ptr,
        in_cols,
        data,
        prob_data,
        out_ptr,
        temp_ptr,
        temp,
        out_rows,
        out_cols,
        out_idxs);
    device->FreeWorkspace(ctx, temp);
  } else {  // without replacement
    IdType* temp_idxs = static_cast<IdType*>(
        device->AllocWorkspace(ctx, (temp_len) * sizeof(IdType)));

    // Compute A-Res value. A-Res value needs to be calculated only if deg
    // is greater than num_picks in weighted rowwise sampling without replacement.
    const dim3 block(BLOCK_SIZE);
    const dim3 grid((num_rows + TILE_SIZE - 1) / TILE_SIZE);
    CUDA_KERNEL_CALL(
        (_CSRAResValueKernel<IdType, FloatType, TILE_SIZE>),
        grid, block, 0, stream,
        rand_seed,
        num_picks,
        num_rows,
        slice_rows,
        in_ptr,
        data,
        prob_data,
        temp_ptr,
        temp_idxs,
        temp);

    // sort A-Res value array.
    FloatType* sort_temp = static_cast<FloatType*>(
      device->AllocWorkspace(ctx, temp_len * sizeof(FloatType)));
    IdType* sort_temp_idxs = static_cast<IdType*>(
      device->AllocWorkspace(ctx, temp_len * sizeof(IdType)));

    cub::DoubleBuffer<FloatType> sort_keys(temp, sort_temp);
    cub::DoubleBuffer<IdType> sort_values(temp_idxs, sort_temp_idxs);

    void *d_temp_storage = nullptr;
    size_t temp_storage_bytes = 0;
    CUDA_CALL(cub::DeviceSegmentedSort::SortPairsDescending(
        d_temp_storage,
        temp_storage_bytes,
        sort_keys,
        sort_values,
        temp_len,
        num_rows,
        temp_ptr,
        temp_ptr + 1));
    d_temp_storage = device->AllocWorkspace(ctx, temp_storage_bytes);
    CUDA_CALL(cub::DeviceSegmentedSort::SortPairsDescending(
        d_temp_storage,
        temp_storage_bytes,
        sort_keys,
        sort_values,
        temp_len,
        num_rows,
        temp_ptr,
        temp_ptr + 1));
    device->FreeWorkspace(ctx, d_temp_storage);
    device->FreeWorkspace(ctx, temp);
    device->FreeWorkspace(ctx, temp_idxs);
    device->FreeWorkspace(ctx, sort_temp);
    device->FreeWorkspace(ctx, sort_temp_idxs);

    // select tok-num_picks as results
    CUDA_KERNEL_CALL(
        (_CSRRowWiseSampleKernel<IdType, FloatType, TILE_SIZE>),
        grid, block, 0, stream,
        num_picks,
        num_rows,
        slice_rows,
        in_ptr,
        in_cols,
        data,
        out_ptr,
        temp_ptr,
        sort_values.Current(),
        out_rows,
        out_cols,
        out_idxs);
  }

  device->FreeWorkspace(ctx, temp_ptr);
  device->FreeWorkspace(ctx, out_ptr);

  // wait for copying `new_len` to finish
  CUDA_CALL(cudaEventSynchronize(copyEvent));
  CUDA_CALL(cudaEventDestroy(copyEvent));

  picked_row = picked_row.CreateView({new_len}, picked_row->dtype);
  picked_col = picked_col.CreateView({new_len}, picked_col->dtype);
  picked_idx = picked_idx.CreateView({new_len}, picked_idx->dtype);

  return COOMatrix(mat.num_rows, mat.num_cols, picked_row,
      picked_col, picked_idx);
}

template COOMatrix CSRRowWiseSampling<kDLGPU, int32_t, float>(
  CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDLGPU, int64_t, float>(
  CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDLGPU, int32_t, double>(
  CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDLGPU, int64_t, double>(
  CSRMatrix, IdArray, int64_t, FloatArray, bool);

}  // namespace impl
}  // namespace aten
}  // namespace dgl