labor_sampling.hip 31.8 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
2
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
/*!
 *   Copyright (c) 2022, NVIDIA Corporation
 *   Copyright (c) 2022, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
 *   All rights reserved.
 *
 *   Licensed under the Apache License, Version 2.0 (the "License");
 *   you may not use this file except in compliance with the License.
 *   You may obtain a copy of the License at
 *
 *       http://www.apache.org/licenses/LICENSE-2.0
 *
 *   Unless required by applicable law or agreed to in writing, software
 *   distributed under the License is distributed on an "AS IS" BASIS,
 *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *   See the License for the specific language governing permissions and
 *   limitations under the License.
 *
20
21
 * @file array/cuda/labor_sampling.cu
 * @brief labor sampling
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
 */

#include <dgl/aten/coo.h>
#include <dgl/random.h>
#include <dgl/runtime/device_api.h>
#include <thrust/binary_search.h>
#include <thrust/copy.h>
#include <thrust/execution_policy.h>
#include <thrust/gather.h>
#include <thrust/iterator/constant_iterator.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/shuffle.h>
#include <thrust/transform.h>
#include <thrust/zip_function.h>

#include <algorithm>
sangwzh's avatar
sangwzh committed
39
#include <hipcub/hipcub.hpp>  // NOLINT
40
41
42
43
44
#include <limits>
#include <numeric>
#include <type_traits>
#include <utility>

sangwzh's avatar
sangwzh committed
45
46
#include "atomic.cuh"
#include "utils.h"
47
#include "../../graph/transform/cuda/cuda_map_edges.cuh"
48
#include "../../random/continuous_seed.h"
49
#include "../../runtime/cuda/cuda_common.h"
sangwzh's avatar
sangwzh committed
50
51
#include "functor.cuh"
#include "spmm.cuh"
52
53
54
55
56

namespace dgl {
namespace aten {
namespace impl {

57
58
using dgl::random::continuous_seed;

59
60
61
62
63
64
65
66
67
68
69
70
71
72
constexpr int BLOCK_SIZE = 128;
constexpr int CTA_SIZE = 128;
constexpr double eps = 0.0001;

namespace {

template <typename IdType>
struct TransformOp {
  const IdType* idx_coo;
  const IdType* rows;
  const IdType* indptr;
  const IdType* subindptr;
  const IdType* indices;
  const IdType* data_arr;
73
  bool is_pinned;
74
75
76
  __host__ __device__ auto operator()(IdType idx) {
    const auto in_row = idx_coo[idx];
    const auto row = rows[in_row];
77
78
    const auto in_idx = indptr[in_row] + idx - subindptr[in_row];
    const auto u = indices[is_pinned ? idx : in_idx];
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    const auto data = data_arr ? data_arr[in_idx] : in_idx;
    return thrust::make_tuple(row, u, data);
  }
};

template <
    typename IdType, typename FloatType, typename probs_t, typename A_t,
    typename B_t>
struct TransformOpImp {
  probs_t probs;
  A_t A;
  B_t B;
  const IdType* idx_coo;
  const IdType* rows;
  const FloatType* cs;
  const IdType* indptr;
  const IdType* subindptr;
  const IdType* indices;
  const IdType* data_arr;
98
  bool is_pinned;
99
100
101
102
103
  __host__ __device__ auto operator()(IdType idx) {
    const auto ps = probs[idx];
    const auto in_row = idx_coo[idx];
    const auto c = cs[in_row];
    const auto row = rows[in_row];
104
105
    const auto in_idx = indptr[in_row] + idx - subindptr[in_row];
    const auto u = indices[is_pinned ? idx : in_idx];
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    const auto w = A[in_idx];
    const auto w2 = B[in_idx];
    const auto data = data_arr ? data_arr[in_idx] : in_idx;
    return thrust::make_tuple(
        in_row, row, u, data, w / min((FloatType)1, c * w2 * ps));
  }
};

template <typename FloatType>
struct StencilOp {
  const FloatType* cs;
  template <typename IdType>
  __host__ __device__ auto operator()(
      IdType in_row, FloatType ps, FloatType rnd) {
    return rnd <= cs[in_row] * ps;
  }
};

template <typename IdType, typename FloatType, typename ps_t, typename A_t>
struct StencilOpFused {
126
  const continuous_seed seed;
127
128
129
130
131
132
133
134
  const IdType* idx_coo;
  const FloatType* cs;
  const ps_t probs;
  const A_t A;
  const IdType* subindptr;
  const IdType* indptr;
  const IdType* indices;
  const IdType* nids;
135
  bool is_pinned;
sangwzh's avatar
sangwzh committed
136
  __host__ __device__ auto operator()(IdType idx) {
137
138
139
    const auto in_row = idx_coo[idx];
    const auto ps = probs[idx];
    IdType rofs = idx - subindptr[in_row];
140
141
    const auto in_idx = indptr[in_row] + rofs;
    const auto u = indices[is_pinned ? idx : in_idx];
142
143
    const auto t = nids ? nids[u] : u;  // t in the paper
    // rolled random number r_t is a function of the random_seed and t
144
    const float rnd = seed.uniform(t);
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    return rnd <= cs[in_row] * A[in_idx] * ps;
  }
};

template <typename IdType, typename FloatType>
struct TransformOpMean {
  const IdType* ds;
  const FloatType* ws;
  __host__ __device__ auto operator()(IdType idx, FloatType ps) {
    return ps * ds[idx] / ws[idx];
  }
};

struct TransformOpMinWith1 {
  template <typename FloatType>
  __host__ __device__ auto operator()(FloatType x) {
    return min((FloatType)1, x);
  }
};

template <typename IdType>
struct IndptrFunc {
  const IdType* indptr;
168
169
170
171
  const IdType* in_deg;
  __host__ __device__ auto operator()(IdType row) {
    return indptr[row] + (in_deg ? in_deg[row] : 0);
  }
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
};

template <typename FloatType>
struct SquareFunc {
  __host__ __device__ auto operator()(FloatType x) {
    return thrust::make_tuple(x, x * x);
  }
};

struct TupleSum {
  template <typename T>
  __host__ __device__ T operator()(const T& a, const T& b) const {
    return thrust::make_tuple(
        thrust::get<0>(a) + thrust::get<0>(b),
        thrust::get<1>(a) + thrust::get<1>(b));
  }
};

template <typename IdType, typename FloatType>
struct DegreeFunc {
  const IdType num_picks;
  const IdType* rows;
  const IdType* indptr;
  IdType* in_deg;
196
  IdType* inrow_indptr;
197
198
199
  FloatType* cs;
  __host__ __device__ auto operator()(IdType tIdx) {
    const auto out_row = rows[tIdx];
200
201
    const auto indptr_val = indptr[out_row];
    const auto d = indptr[out_row + 1] - indptr_val;
202
    in_deg[tIdx] = d;
203
204
    inrow_indptr[tIdx] = indptr_val;
    cs[tIdx] = num_picks / (FloatType)d;
205
206
207
208
209
  }
};

template <typename IdType, typename FloatType>
__global__ void _CSRRowWiseOneHopExtractorKernel(
210
211
212
213
214
    const continuous_seed seed, const IdType hop_size,
    const IdType* const indptr, const IdType* const subindptr,
    const IdType* const indices, const IdType* const idx_coo,
    const IdType* const nids, const FloatType* const A, FloatType* const rands,
    IdType* const hop, FloatType* const A_l) {
215
216
217
218
219
220
  IdType tx = static_cast<IdType>(blockIdx.x) * blockDim.x + threadIdx.x;
  const int stride_x = gridDim.x * blockDim.x;

  while (tx < hop_size) {
    IdType rpos = idx_coo[tx];
    IdType rofs = tx - subindptr[rpos];
221
222
223
224
    const auto in_idx = indptr[rpos] + rofs;
    const auto not_pinned = indices != hop;
    const auto u = indices[not_pinned ? in_idx : tx];
    if (not_pinned) hop[tx] = u;
225
    const auto t = nids ? nids[u] : u;
226
    if (A) A_l[tx] = A[in_idx];
227
228
    // rolled random number r_t is a function of the random_seed and t
    rands[tx] = (FloatType)seed.uniform(t);
229
230
231
232
    tx += stride_x;
  }
}

233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
constexpr int CACHE_LINE_SIZE = 128;

template <typename IdType>
struct AlignmentFunc {
  static_assert(CACHE_LINE_SIZE % sizeof(IdType) == 0);
  const IdType* in_deg;
  const int64_t* perm;
  IdType num_rows;
  __host__ __device__ auto operator()(IdType row) {
    constexpr int num_elements = CACHE_LINE_SIZE / sizeof(IdType);
    return in_deg[perm ? perm[row % num_rows] : row] + num_elements - 1;
  }
};

template <typename IdType>
__global__ void _CSRRowWiseOneHopExtractorAlignedKernel(
    const IdType hop_size, const IdType num_rows, const IdType* const indptr,
    const IdType* const subindptr, const IdType* const subindptr_aligned,
    const IdType* const indices, IdType* const hop, const int64_t* const perm) {
  IdType tx = static_cast<IdType>(blockIdx.x) * blockDim.x + threadIdx.x;
  const int stride_x = gridDim.x * blockDim.x;

  while (tx < hop_size) {
    const IdType rpos_ =
        dgl::cuda::_UpperBound(subindptr_aligned, num_rows, tx) - 1;
    const IdType rpos = perm ? perm[rpos_] : rpos_;
    const auto out_row = subindptr[rpos];
    const auto d = subindptr[rpos + 1] - out_row;
    const int offset =
        ((uint64_t)(indices + indptr[rpos] - subindptr_aligned[rpos_]) %
         CACHE_LINE_SIZE) /
        sizeof(IdType);
    const IdType rofs = tx - subindptr_aligned[rpos_] - offset;
    if (rofs >= 0 && rofs < d) {
      const auto in_idx = indptr[rpos] + rofs;
      assert((uint64_t)(indices + in_idx - tx) % CACHE_LINE_SIZE == 0);
      const auto u = indices[in_idx];
      hop[out_row + rofs] = u;
    }
    tx += stride_x;
  }
}

276
277
template <typename IdType, typename FloatType, int BLOCK_CTAS, int TILE_SIZE>
__global__ void _CSRRowWiseLayerSampleDegreeKernel(
278
279
    const IdType num_picks, const IdType num_rows, FloatType* const cs,
    const FloatType* const ds, const FloatType* const d2s,
280
281
    const IdType* const indptr, const FloatType* const probs,
    const FloatType* const A, const IdType* const subindptr) {
sangwzh's avatar
sangwzh committed
282
  typedef hipcub::BlockReduce<FloatType, BLOCK_SIZE> BlockReduce;
283
284
285
286
287
288
289
290
291
292
293
294
295
296
  __shared__ typename BlockReduce::TempStorage temp_storage;
  __shared__ FloatType var_1_bcast[BLOCK_CTAS];

  // we assign one warp per row
  assert(blockDim.x == CTA_SIZE);
  assert(blockDim.y == BLOCK_CTAS);

  IdType out_row = blockIdx.x * TILE_SIZE + threadIdx.y;
  const auto last_row =
      min(static_cast<IdType>(blockIdx.x + 1) * TILE_SIZE, num_rows);

  constexpr FloatType ONE = 1;

  while (out_row < last_row) {
297
    const auto in_row_start = indptr[out_row];
298
299
    const auto out_row_start = subindptr[out_row];

300
    const IdType degree = subindptr[out_row + 1] - out_row_start;
301
302
303
304
305

    if (degree > 0) {
      // stands for k in in arXiv:2210.13339, i.e. fanout
      const auto k = min(num_picks, degree);
      // slightly better than NS
306
      const FloatType d_ = ds ? ds[out_row] : degree;
307
308
      // stands for right handside of Equation (22) in arXiv:2210.13339
      FloatType var_target =
309
          d_ * d_ / k + (ds ? d2s[out_row] - d_ * d_ / degree : 0);
310
311
312
313
314
315
316
317
318
319
320

      auto c = cs[out_row];
      const int num_valid = min(degree, (IdType)CTA_SIZE);
      // stands for left handside of Equation (22) in arXiv:2210.13339
      FloatType var_1;
      do {
        var_1 = 0;
        if (A) {
          for (int idx = threadIdx.x; idx < degree; idx += CTA_SIZE) {
            const auto w = A[in_row_start + idx];
            const auto ps = probs ? probs[out_row_start + idx] : w;
321
            var_1 += w > 0 ? w * w / min(ONE, c * ps) : 0;
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
          }
        } else {
          for (int idx = threadIdx.x; idx < degree; idx += CTA_SIZE) {
            const auto ps = probs[out_row_start + idx];
            var_1 += 1 / min(ONE, c * ps);
          }
        }
        var_1 = BlockReduce(temp_storage).Sum(var_1, num_valid);
        if (threadIdx.x == 0) var_1_bcast[threadIdx.y] = var_1;
        __syncthreads();
        var_1 = var_1_bcast[threadIdx.y];

        c *= var_1 / var_target;
      } while (min(var_1, var_target) / max(var_1, var_target) < 1 - eps);

      if (threadIdx.x == 0) cs[out_row] = c;
    }

    out_row += BLOCK_CTAS;
  }
}

}  // namespace

346
347
348
349
350
351
352
353
template <typename IdType>
int log_size(const IdType size) {
  if (size <= 0) return 0;
  for (int i = 0; i < static_cast<int>(sizeof(IdType)) * 8; i++)
    if (((size - 1) >> i) == 0) return i;
  return sizeof(IdType) * 8;
}

354
355
template <typename IdType, typename FloatType, typename exec_policy_t>
void compute_importance_sampling_probabilities(
sangwzh's avatar
sangwzh committed
356
    CSRMatrix mat, const IdType hop_size, hipStream_t stream,
357
    const continuous_seed seed, const IdType num_rows, const IdType* indptr,
358
359
    const IdType* subindptr, const IdType* indices, IdArray idx_coo_arr,
    const IdType* nids,
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
    FloatArray cs_arr,  // holds the computed cs values, has size num_rows
    const bool weighted, const FloatType* A, const FloatType* ds,
    const FloatType* d2s, const IdType num_picks, DGLContext ctx,
    const runtime::CUDAWorkspaceAllocator& allocator,
    const exec_policy_t& exec_policy, const int importance_sampling,
    IdType* hop_1,  // holds the contiguous one-hop neighborhood, has size |E|
    FloatType* rands,  // holds the rolled random numbers r_t for each edge, has
                       // size |E|
    FloatType* probs_found) {  // holds the computed pi_t values for each edge,
                               // has size |E|
  auto device = runtime::DeviceAPI::Get(ctx);
  auto idx_coo = idx_coo_arr.Ptr<IdType>();
  auto cs = cs_arr.Ptr<FloatType>();
  FloatArray A_l_arr = weighted
                           ? NewFloatArray(hop_size, ctx, sizeof(FloatType) * 8)
                           : NullArray();
  auto A_l = A_l_arr.Ptr<FloatType>();

378
  const int max_log_num_vertices = log_size(mat.num_cols);
379
380
381
382
383
384

  {  // extracts the onehop neighborhood cols to a contiguous range into hop_1
    const dim3 block(BLOCK_SIZE);
    const dim3 grid((hop_size + BLOCK_SIZE - 1) / BLOCK_SIZE);
    CUDA_KERNEL_CALL(
        (_CSRRowWiseOneHopExtractorKernel<IdType, FloatType>), grid, block, 0,
385
386
        stream, seed, hop_size, indptr, subindptr, indices, idx_coo, nids,
        weighted ? A : nullptr, rands, hop_1, A_l);
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
  }
  int64_t hop_uniq_size = 0;
  IdArray hop_new_arr = NewIdArray(hop_size, ctx, sizeof(IdType) * 8);
  auto hop_new = hop_new_arr.Ptr<IdType>();
  auto hop_unique = allocator.alloc_unique<IdType>(hop_size);
  // After this block, hop_unique holds the unique set of one-hop neighborhood
  // and hop_new holds the relabeled hop_1, idx_coo already holds relabeled
  // destination. hop_unique[hop_new] == hop_1 holds
  {
    auto hop_2 = allocator.alloc_unique<IdType>(hop_size);
    auto hop_3 = allocator.alloc_unique<IdType>(hop_size);

    device->CopyDataFromTo(
        hop_1, 0, hop_2.get(), 0, sizeof(IdType) * hop_size, ctx, ctx,
        mat.indptr->dtype);

sangwzh's avatar
sangwzh committed
403
    hipcub::DoubleBuffer<IdType> hop_b(hop_2.get(), hop_3.get());
404
405
406

    {
      std::size_t temp_storage_bytes = 0;
sangwzh's avatar
sangwzh committed
407
      CUDA_CALL(hipcub::DeviceRadixSort::SortKeys(
408
409
410
411
412
          nullptr, temp_storage_bytes, hop_b, hop_size, 0, max_log_num_vertices,
          stream));

      auto temp = allocator.alloc_unique<char>(temp_storage_bytes);

sangwzh's avatar
sangwzh committed
413
      CUDA_CALL(hipcub::DeviceRadixSort::SortKeys(
414
415
416
417
418
419
420
421
422
          temp.get(), temp_storage_bytes, hop_b, hop_size, 0,
          max_log_num_vertices, stream));
    }

    auto hop_counts = allocator.alloc_unique<IdType>(hop_size + 1);
    auto hop_unique_size = allocator.alloc_unique<int64_t>(1);

    {
      std::size_t temp_storage_bytes = 0;
sangwzh's avatar
sangwzh committed
423
      CUDA_CALL(hipcub::DeviceRunLengthEncode::Encode(
424
425
426
427
428
          nullptr, temp_storage_bytes, hop_b.Current(), hop_unique.get(),
          hop_counts.get(), hop_unique_size.get(), hop_size, stream));

      auto temp = allocator.alloc_unique<char>(temp_storage_bytes);

sangwzh's avatar
sangwzh committed
429
      CUDA_CALL(hipcub::DeviceRunLengthEncode::Encode(
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
          temp.get(), temp_storage_bytes, hop_b.Current(), hop_unique.get(),
          hop_counts.get(), hop_unique_size.get(), hop_size, stream));

      device->CopyDataFromTo(
          hop_unique_size.get(), 0, &hop_uniq_size, 0, sizeof(hop_uniq_size),
          ctx, DGLContext{kDGLCPU, 0}, mat.indptr->dtype);
    }

    thrust::lower_bound(
        exec_policy, hop_unique.get(), hop_unique.get() + hop_uniq_size, hop_1,
        hop_1 + hop_size, hop_new);
  }

  // @todo Consider creating a CSC because the SpMV will be done multiple times.
  COOMatrix rmat(
      num_rows, hop_uniq_size, idx_coo_arr, hop_new_arr, NullArray(), true,
      mat.sorted);

  BcastOff bcast_off;
  bcast_off.use_bcast = false;
  bcast_off.out_len = 1;
  bcast_off.lhs_len = 1;
  bcast_off.rhs_len = 1;

  FloatArray probs_arr =
      NewFloatArray(hop_uniq_size, ctx, sizeof(FloatType) * 8);
  auto probs_1 = probs_arr.Ptr<FloatType>();
  FloatArray probs_arr_2 =
      NewFloatArray(hop_uniq_size, ctx, sizeof(FloatType) * 8);
  auto probs = probs_arr_2.Ptr<FloatType>();
  auto arg_u = NewIdArray(hop_uniq_size, ctx, sizeof(IdType) * 8);
  auto arg_e = NewIdArray(hop_size, ctx, sizeof(IdType) * 8);

  double prev_ex_nodes = hop_uniq_size;

  for (int iters = 0; iters < importance_sampling || importance_sampling < 0;
       iters++) {
    if (weighted && iters == 0) {
      cuda::SpMMCoo<
          IdType, FloatType, cuda::binary::Mul<FloatType>,
          cuda::reduce::Max<IdType, FloatType, true>>(
          bcast_off, rmat, cs_arr, A_l_arr, probs_arr_2, arg_u, arg_e);
    } else {
      cuda::SpMMCoo<
          IdType, FloatType, cuda::binary::CopyLhs<FloatType>,
          cuda::reduce::Max<IdType, FloatType, true>>(
          bcast_off, rmat, cs_arr, NullArray(), iters ? probs_arr : probs_arr_2,
          arg_u, arg_e);
    }

    if (iters)
      thrust::transform(
          exec_policy, probs_1, probs_1 + hop_uniq_size, probs, probs,
          thrust::multiplies<FloatType>{});

    thrust::gather(
        exec_policy, hop_new, hop_new + hop_size, probs, probs_found);

    {
      constexpr int BLOCK_CTAS = BLOCK_SIZE / CTA_SIZE;
      // the number of rows each thread block will cover
      constexpr int TILE_SIZE = BLOCK_CTAS;
      const dim3 block(CTA_SIZE, BLOCK_CTAS);
      const dim3 grid((num_rows + TILE_SIZE - 1) / TILE_SIZE);
      CUDA_KERNEL_CALL(
          (_CSRRowWiseLayerSampleDegreeKernel<
              IdType, FloatType, BLOCK_CTAS, TILE_SIZE>),
497
          grid, block, 0, stream, (IdType)num_picks, num_rows, cs,
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
          weighted ? ds : nullptr, weighted ? d2s : nullptr, indptr,
          probs_found, A, subindptr);
    }

    {
      auto probs_min_1 =
          thrust::make_transform_iterator(probs, TransformOpMinWith1{});
      const double cur_ex_nodes = thrust::reduce(
          exec_policy, probs_min_1, probs_min_1 + hop_uniq_size, 0.0);
      if (cur_ex_nodes / prev_ex_nodes >= 1 - eps) break;
      prev_ex_nodes = cur_ex_nodes;
    }
  }
}

/////////////////////////////// CSR ///////////////////////////////

template <DGLDeviceType XPU, typename IdType, typename FloatType>
sangwzh's avatar
sangwzh committed
516
__host__ std::pair<COOMatrix, FloatArray> CSRLaborSampling(
517
518
    CSRMatrix mat, IdArray rows_arr, const int64_t num_picks,
    FloatArray prob_arr, const int importance_sampling, IdArray random_seed_arr,
519
    float seed2_contribution, IdArray NIDs) {
520
521
522
523
524
525
  const bool weighted = !IsNullArray(prob_arr);

  const auto& ctx = rows_arr->ctx;

  runtime::CUDAWorkspaceAllocator allocator(ctx);

sangwzh's avatar
sangwzh committed
526
527
  const auto stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
  const auto exec_policy = thrust::hip::par_nosync(allocator).on(stream);
528
529
530
531

  auto device = runtime::DeviceAPI::Get(ctx);

  const IdType num_rows = rows_arr->shape[0];
532
533
534
535
536
537
538
539
540
541
542
543
544
  // IdType* const rows = rows_arr.Ptr<IdType>();
  IdType* const rows = static_cast<IdType*>(GetDevicePointer(rows_arr));
  // IdType* const nids = IsNullArray(NIDs) ? nullptr : NIDs.Ptr<IdType>();
  IdType* const nids = IsNullArray(NIDs) ? nullptr : static_cast<IdType*>(GetDevicePointer(NIDs));
  // FloatType* const A = prob_arr.Ptr<FloatType>();
  FloatType* const A = static_cast<FloatType*>(GetDevicePointer(prob_arr));;

  // IdType* const indptr_ = mat.indptr.Ptr<IdType>();
  IdType* const indptr_ = static_cast<IdType*>(GetDevicePointer(mat.indptr));
  // IdType* const indices_ = mat.indices.Ptr<IdType>();
  IdType* const indices_ = static_cast<IdType*>(GetDevicePointer(mat.indices));
  // IdType* const data = CSRHasData(mat) ? mat.data.Ptr<IdType>() : nullptr;
  IdType* const data = CSRHasData(mat) ? static_cast<IdType*>(GetDevicePointer(mat.data)) : nullptr;
545

546
547
  // Read indptr only once in case it is pinned and access is slow.
  auto indptr = allocator.alloc_unique<IdType>(num_rows);
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
  // compute in-degrees
  auto in_deg = allocator.alloc_unique<IdType>(num_rows + 1);
  // cs stands for c_s in arXiv:2210.13339
  FloatArray cs_arr = NewFloatArray(num_rows, ctx, sizeof(FloatType) * 8);
  auto cs = cs_arr.Ptr<FloatType>();
  // ds stands for A_{*s} in arXiv:2210.13339
  FloatArray ds_arr = weighted
                          ? NewFloatArray(num_rows, ctx, sizeof(FloatType) * 8)
                          : NullArray();
  auto ds = ds_arr.Ptr<FloatType>();
  // d2s stands for (A^2)_{*s} in arXiv:2210.13339, ^2 is elementwise.
  FloatArray d2s_arr = weighted
                           ? NewFloatArray(num_rows, ctx, sizeof(FloatType) * 8)
                           : NullArray();
  auto d2s = d2s_arr.Ptr<FloatType>();

564
565
566
567
568
569
  thrust::counting_iterator<IdType> iota(0);
  thrust::for_each(
      exec_policy, iota, iota + num_rows,
      DegreeFunc<IdType, FloatType>{
          (IdType)num_picks, rows, indptr_, in_deg.get(), indptr.get(), cs});

570
  if (weighted) {
571
572
573
574
    auto b_offsets = thrust::make_transform_iterator(
        iota, IndptrFunc<IdType>{indptr.get(), nullptr});
    auto e_offsets = thrust::make_transform_iterator(
        iota, IndptrFunc<IdType>{indptr.get(), in_deg.get()});
575
576
577
578
579

    auto A_A2 = thrust::make_transform_iterator(A, SquareFunc<FloatType>{});
    auto ds_d2s = thrust::make_zip_iterator(ds, d2s);

    size_t prefix_temp_size = 0;
sangwzh's avatar
sangwzh committed
580
    CUDA_CALL(hipcub::DeviceSegmentedReduce::Reduce(
581
582
583
        nullptr, prefix_temp_size, A_A2, ds_d2s, num_rows, b_offsets, e_offsets,
        TupleSum{}, thrust::make_tuple((FloatType)0, (FloatType)0), stream));
    auto temp = allocator.alloc_unique<char>(prefix_temp_size);
sangwzh's avatar
sangwzh committed
584
    CUDA_CALL(hipcub::DeviceSegmentedReduce::Reduce(
585
586
587
588
589
590
591
592
593
594
595
596
        temp.get(), prefix_temp_size, A_A2, ds_d2s, num_rows, b_offsets,
        e_offsets, TupleSum{}, thrust::make_tuple((FloatType)0, (FloatType)0),
        stream));
  }

  // fill subindptr
  IdArray subindptr_arr = NewIdArray(num_rows + 1, ctx, sizeof(IdType) * 8);
  auto subindptr = subindptr_arr.Ptr<IdType>();

  IdType hop_size;
  {
    size_t prefix_temp_size = 0;
sangwzh's avatar
sangwzh committed
597
    CUDA_CALL(hipcub::DeviceScan::ExclusiveSum(
598
599
600
        nullptr, prefix_temp_size, in_deg.get(), subindptr, num_rows + 1,
        stream));
    auto temp = allocator.alloc_unique<char>(prefix_temp_size);
sangwzh's avatar
sangwzh committed
601
    CUDA_CALL(hipcub::DeviceScan::ExclusiveSum(
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
        temp.get(), prefix_temp_size, in_deg.get(), subindptr, num_rows + 1,
        stream));

    device->CopyDataFromTo(
        subindptr, num_rows * sizeof(hop_size), &hop_size, 0, sizeof(hop_size),
        ctx, DGLContext{kDGLCPU, 0}, mat.indptr->dtype);
  }
  IdArray hop_arr = NewIdArray(hop_size, ctx, sizeof(IdType) * 8);
  CSRMatrix smat(
      num_rows, mat.num_cols, subindptr_arr, hop_arr, NullArray(), mat.sorted);
  // @todo Consider fusing CSRToCOO into StencilOpFused kernel
  auto smatcoo = CSRToCOO(smat, false);

  auto idx_coo_arr = smatcoo.row;
  auto idx_coo = idx_coo_arr.Ptr<IdType>();

  auto hop_1 = hop_arr.Ptr<IdType>();
619
620
621
622
623
624
625
626
627
628
629
  const bool is_pinned = mat.indices.IsPinned();
  if (is_pinned) {
    const auto res = Sort(rows_arr, log_size(mat.num_rows));
    const int64_t* perm = static_cast<int64_t*>(res.second->data);

    IdType hop_size;  // Shadows the original one as this is temporary
    auto subindptr_aligned = allocator.alloc_unique<IdType>(num_rows + 1);
    {
      auto modified_in_deg = thrust::make_transform_iterator(
          iota, AlignmentFunc<IdType>{in_deg.get(), perm, num_rows});
      size_t prefix_temp_size = 0;
sangwzh's avatar
sangwzh committed
630
      CUDA_CALL(hipcub::DeviceScan::ExclusiveSum(
631
632
633
          nullptr, prefix_temp_size, modified_in_deg, subindptr_aligned.get(),
          num_rows + 1, stream));
      auto temp = allocator.alloc_unique<char>(prefix_temp_size);
sangwzh's avatar
sangwzh committed
634
      CUDA_CALL(hipcub::DeviceScan::ExclusiveSum(
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
          temp.get(), prefix_temp_size, modified_in_deg,
          subindptr_aligned.get(), num_rows + 1, stream));

      device->CopyDataFromTo(
          subindptr_aligned.get(), num_rows * sizeof(hop_size), &hop_size, 0,
          sizeof(hop_size), ctx, DGLContext{kDGLCPU, 0}, mat.indptr->dtype);
    }
    const dim3 block(BLOCK_SIZE);
    const dim3 grid((hop_size + BLOCK_SIZE - 1) / BLOCK_SIZE);
    CUDA_KERNEL_CALL(
        (_CSRRowWiseOneHopExtractorAlignedKernel<IdType>), grid, block, 0,
        stream, hop_size, num_rows, indptr.get(), subindptr,
        subindptr_aligned.get(), indices_, hop_1, perm);
  }
  const auto indices = is_pinned ? hop_1 : indices_;

651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
  auto rands =
      allocator.alloc_unique<FloatType>(importance_sampling ? hop_size : 1);
  auto probs_found =
      allocator.alloc_unique<FloatType>(importance_sampling ? hop_size : 1);

  if (weighted) {
    // Recompute c for weighted graphs.
    constexpr int BLOCK_CTAS = BLOCK_SIZE / CTA_SIZE;
    // the number of rows each thread block will cover
    constexpr int TILE_SIZE = BLOCK_CTAS;
    const dim3 block(CTA_SIZE, BLOCK_CTAS);
    const dim3 grid((num_rows + TILE_SIZE - 1) / TILE_SIZE);
    CUDA_KERNEL_CALL(
        (_CSRRowWiseLayerSampleDegreeKernel<
            IdType, FloatType, BLOCK_CTAS, TILE_SIZE>),
666
667
        grid, block, 0, stream, (IdType)num_picks, num_rows, cs, ds, d2s,
        indptr.get(), nullptr, A, subindptr);
668
669
  }

670
  const continuous_seed random_seed =
671
      IsNullArray(random_seed_arr)
672
673
          ? continuous_seed(RandomEngine::ThreadLocal()->RandInt(1000000000))
          : continuous_seed(random_seed_arr, seed2_contribution);
674
675
676
677

  if (importance_sampling)
    compute_importance_sampling_probabilities<
        IdType, FloatType, decltype(exec_policy)>(
678
        mat, hop_size, stream, random_seed, num_rows, indptr.get(), subindptr,
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
        indices, idx_coo_arr, nids, cs_arr, weighted, A, ds, d2s,
        (IdType)num_picks, ctx, allocator, exec_policy, importance_sampling,
        hop_1, rands.get(), probs_found.get());

  IdArray picked_row = NewIdArray(hop_size, ctx, sizeof(IdType) * 8);
  IdArray picked_col = NewIdArray(hop_size, ctx, sizeof(IdType) * 8);
  IdArray picked_idx = NewIdArray(hop_size, ctx, sizeof(IdType) * 8);
  FloatArray picked_imp =
      importance_sampling || weighted
          ? NewFloatArray(hop_size, ctx, sizeof(FloatType) * 8)
          : NullArray();

  IdType* const picked_row_data = picked_row.Ptr<IdType>();
  IdType* const picked_col_data = picked_col.Ptr<IdType>();
  IdType* const picked_idx_data = picked_idx.Ptr<IdType>();
  FloatType* const picked_imp_data = picked_imp.Ptr<FloatType>();

  auto picked_inrow = allocator.alloc_unique<IdType>(
      importance_sampling || weighted ? hop_size : 1);

  // Sample edges here
  IdType num_edges;
  {
    thrust::constant_iterator<FloatType> one(1);
    if (importance_sampling) {
      auto output = thrust::make_zip_iterator(
          picked_inrow.get(), picked_row_data, picked_col_data, picked_idx_data,
          picked_imp_data);
      if (weighted) {
        auto transformed_output = thrust::make_transform_output_iterator(
            output,
            TransformOpImp<
                IdType, FloatType, FloatType*, FloatType*, decltype(one)>{
712
713
                probs_found.get(), A, one, idx_coo, rows, cs, indptr.get(),
                subindptr, indices, data, is_pinned});
714
715
716
717
718
719
720
721
722
723
724
725
        auto stencil =
            thrust::make_zip_iterator(idx_coo, probs_found.get(), rands.get());
        num_edges =
            thrust::copy_if(
                exec_policy, iota, iota + hop_size, stencil, transformed_output,
                thrust::make_zip_function(StencilOp<FloatType>{cs})) -
            transformed_output;
      } else {
        auto transformed_output = thrust::make_transform_output_iterator(
            output,
            TransformOpImp<
                IdType, FloatType, FloatType*, decltype(one), decltype(one)>{
726
727
                probs_found.get(), one, one, idx_coo, rows, cs, indptr.get(),
                subindptr, indices, data, is_pinned});
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
        auto stencil =
            thrust::make_zip_iterator(idx_coo, probs_found.get(), rands.get());
        num_edges =
            thrust::copy_if(
                exec_policy, iota, iota + hop_size, stencil, transformed_output,
                thrust::make_zip_function(StencilOp<FloatType>{cs})) -
            transformed_output;
      }
    } else {
      if (weighted) {
        auto output = thrust::make_zip_iterator(
            picked_inrow.get(), picked_row_data, picked_col_data,
            picked_idx_data, picked_imp_data);
        auto transformed_output = thrust::make_transform_output_iterator(
            output,
            TransformOpImp<
                IdType, FloatType, decltype(one), FloatType*, FloatType*>{
745
746
                one, A, A, idx_coo, rows, cs, indptr.get(), subindptr, indices,
                data, is_pinned});
747
748
        const auto pred =
            StencilOpFused<IdType, FloatType, decltype(one), FloatType*>{
749
750
                random_seed, idx_coo,      cs,      one,  A,
                subindptr,   indptr.get(), indices, nids, is_pinned};
751
752
753
754
755
756
757
758
759
        num_edges = thrust::copy_if(
                        exec_policy, iota, iota + hop_size, iota,
                        transformed_output, pred) -
                    transformed_output;
      } else {
        auto output = thrust::make_zip_iterator(
            picked_row_data, picked_col_data, picked_idx_data);
        auto transformed_output = thrust::make_transform_output_iterator(
            output, TransformOp<IdType>{
760
761
                        idx_coo, rows, indptr.get(), subindptr, indices, data,
                        is_pinned});
762
763
        const auto pred =
            StencilOpFused<IdType, FloatType, decltype(one), decltype(one)>{
764
765
                random_seed, idx_coo,      cs,      one,  one,
                subindptr,   indptr.get(), indices, nids, is_pinned};
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
        num_edges = thrust::copy_if(
                        exec_policy, iota, iota + hop_size, iota,
                        transformed_output, pred) -
                    transformed_output;
      }
    }
  }

  // Normalize edge weights here
  if (importance_sampling || weighted) {
    thrust::constant_iterator<IdType> one(1);
    // contains degree information
    auto ds = allocator.alloc_unique<IdType>(num_rows);
    // contains sum of edge weights
    auto ws = allocator.alloc_unique<FloatType>(num_rows);
    // contains degree information only for vertices with nonzero degree
    auto ds_2 = allocator.alloc_unique<IdType>(num_rows);
    // contains sum of edge weights only for vertices with nonzero degree
    auto ws_2 = allocator.alloc_unique<FloatType>(num_rows);
    auto output_ = thrust::make_zip_iterator(ds.get(), ws.get());
    // contains row ids only for vertices with nonzero degree
    auto keys = allocator.alloc_unique<IdType>(num_rows);
    auto input = thrust::make_zip_iterator(one, picked_imp_data);
    auto new_end = thrust::reduce_by_key(
        exec_policy, picked_inrow.get(), picked_inrow.get() + num_edges, input,
        keys.get(), output_, thrust::equal_to<IdType>{}, TupleSum{});
    {
      thrust::constant_iterator<IdType> zero_int(0);
      thrust::constant_iterator<FloatType> zero_float(0);
      auto input = thrust::make_zip_iterator(zero_int, zero_float);
      auto output = thrust::make_zip_iterator(ds_2.get(), ws_2.get());
      thrust::copy(exec_policy, input, input + num_rows, output);
      {
        const auto num_rows_2 = new_end.first - keys.get();
        thrust::scatter(
            exec_policy, output_, output_ + num_rows_2, keys.get(), output);
      }
    }
    {
      auto input =
          thrust::make_zip_iterator(picked_inrow.get(), picked_imp_data);
      auto transformed_input = thrust::make_transform_iterator(
          input, thrust::make_zip_function(TransformOpMean<IdType, FloatType>{
                     ds_2.get(), ws_2.get()}));
      thrust::copy(
          exec_policy, transformed_input, transformed_input + num_edges,
          picked_imp_data);
    }
  }

  picked_row = picked_row.CreateView({num_edges}, picked_row->dtype);
  picked_col = picked_col.CreateView({num_edges}, picked_col->dtype);
  picked_idx = picked_idx.CreateView({num_edges}, picked_idx->dtype);
  if (importance_sampling || weighted)
    picked_imp = picked_imp.CreateView({num_edges}, picked_imp->dtype);

  return std::make_pair(
      COOMatrix(mat.num_rows, mat.num_cols, picked_row, picked_col, picked_idx),
      picked_imp);
}

template std::pair<COOMatrix, FloatArray>
CSRLaborSampling<kDGLCUDA, int32_t, float>(
829
    CSRMatrix, IdArray, int64_t, FloatArray, int, IdArray, float, IdArray);
830
831
template std::pair<COOMatrix, FloatArray>
CSRLaborSampling<kDGLCUDA, int64_t, float>(
832
    CSRMatrix, IdArray, int64_t, FloatArray, int, IdArray, float, IdArray);
833
834
template std::pair<COOMatrix, FloatArray>
CSRLaborSampling<kDGLCUDA, int32_t, double>(
835
    CSRMatrix, IdArray, int64_t, FloatArray, int, IdArray, float, IdArray);
836
837
template std::pair<COOMatrix, FloatArray>
CSRLaborSampling<kDGLCUDA, int64_t, double>(
838
    CSRMatrix, IdArray, int64_t, FloatArray, int, IdArray, float, IdArray);
839
840
841
842

}  // namespace impl
}  // namespace aten
}  // namespace dgl