randomwalk_gpu.cu 19.1 KB
Newer Older
1
/**
2
 *  Copyright (c) 2021-2022 by Contributors
3
4
 * @file graph/sampling/randomwalk_gpu.cu
 * @brief CUDA random walk sampleing
5
6
 */

7
#include <curand_kernel.h>
8
9
10
#include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include <dgl/random.h>
11
12
#include <dgl/runtime/device_api.h>

13
#include <cub/cub.cuh>
14
#include <tuple>
15
16
#include <utility>
#include <vector>
17

18
#include "../../../runtime/cuda/cuda_common.h"
19
20
21
22
23
24
25
26
27
28
29
30
31
#include "frequency_hashmap.cuh"

namespace dgl {

using namespace dgl::runtime;
using namespace dgl::aten;

namespace sampling {

namespace impl {

namespace {

32
template <typename IdType>
33
34
35
36
37
38
struct GraphKernelData {
  const IdType *in_ptr;
  const IdType *in_cols;
  const IdType *data;
};

39
template <typename IdType, typename FloatType, int BLOCK_SIZE, int TILE_SIZE>
40
41
__global__ void _RandomWalkKernel(
    const uint64_t rand_seed, const IdType *seed_data, const int64_t num_seeds,
42
43
44
45
    const IdType *metapath_data, const uint64_t max_num_steps,
    const GraphKernelData<IdType> *graphs, const FloatType *restart_prob_data,
    const int64_t restart_prob_size, const int64_t max_nodes,
    IdType *out_traces_data, IdType *out_eids_data) {
46
47
  assert(BLOCK_SIZE == blockDim.x);
  int64_t idx = blockIdx.x * TILE_SIZE + threadIdx.x;
48
49
  int64_t last_idx =
      min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_seeds);
50
51
52
53
54
55
56
57
  int64_t trace_length = (max_num_steps + 1);
  curandState rng;
  // reference:
  //     https://docs.nvidia.com/cuda/curand/device-api-overview.html#performance-notes
  curand_init(rand_seed + idx, 0, 0, &rng);

  while (idx < last_idx) {
    IdType curr = seed_data[idx];
58
    assert(curr < max_nodes);
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    IdType *traces_data_ptr = &out_traces_data[idx * trace_length];
    IdType *eids_data_ptr = &out_eids_data[idx * max_num_steps];
    *(traces_data_ptr++) = curr;
    int64_t step_idx;
    for (step_idx = 0; step_idx < max_num_steps; ++step_idx) {
      IdType metapath_id = metapath_data[step_idx];
      const GraphKernelData<IdType> &graph = graphs[metapath_id];
      const int64_t in_row_start = graph.in_ptr[curr];
      const int64_t deg = graph.in_ptr[curr + 1] - graph.in_ptr[curr];
      if (deg == 0) {  // the degree is zero
        break;
      }
      const int64_t num = curand(&rng) % deg;
      IdType pick = graph.in_cols[in_row_start + num];
73
74
      IdType eid =
          (graph.data ? graph.data[in_row_start + num] : in_row_start + num);
75
76
      *traces_data_ptr = pick;
      *eids_data_ptr = eid;
77
78
      if ((restart_prob_size > 1) &&
          (curand_uniform(&rng) < restart_prob_data[step_idx])) {
79
        break;
80
81
82
      } else if (
          (restart_prob_size == 1) &&
          (curand_uniform(&rng) < restart_prob_data[0])) {
83
84
        break;
      }
85
86
      ++traces_data_ptr;
      ++eids_data_ptr;
87
88
89
90
91
92
93
94
95
96
      curr = pick;
    }
    for (; step_idx < max_num_steps; ++step_idx) {
      *(traces_data_ptr++) = -1;
      *(eids_data_ptr++) = -1;
    }
    idx += BLOCK_SIZE;
  }
}

97
98
template <typename IdType, typename FloatType, int BLOCK_SIZE, int TILE_SIZE>
__global__ void _RandomWalkBiasedKernel(
99
100
101
102
103
104
    const uint64_t rand_seed, const IdType *seed_data, const int64_t num_seeds,
    const IdType *metapath_data, const uint64_t max_num_steps,
    const GraphKernelData<IdType> *graphs, const FloatType **probs,
    const FloatType **prob_sums, const FloatType *restart_prob_data,
    const int64_t restart_prob_size, const int64_t max_nodes,
    IdType *out_traces_data, IdType *out_eids_data) {
105
106
  assert(BLOCK_SIZE == blockDim.x);
  int64_t idx = blockIdx.x * TILE_SIZE + threadIdx.x;
107
108
  int64_t last_idx =
      min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_seeds);
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
  int64_t trace_length = (max_num_steps + 1);
  curandState rng;
  // reference:
  //     https://docs.nvidia.com/cuda/curand/device-api-overview.html#performance-notes
  curand_init(rand_seed + idx, 0, 0, &rng);

  while (idx < last_idx) {
    IdType curr = seed_data[idx];
    assert(curr < max_nodes);
    IdType *traces_data_ptr = &out_traces_data[idx * trace_length];
    IdType *eids_data_ptr = &out_eids_data[idx * max_num_steps];
    *(traces_data_ptr++) = curr;
    int64_t step_idx;
    for (step_idx = 0; step_idx < max_num_steps; ++step_idx) {
      IdType metapath_id = metapath_data[step_idx];
      const GraphKernelData<IdType> &graph = graphs[metapath_id];
      const int64_t in_row_start = graph.in_ptr[curr];
      const int64_t deg = graph.in_ptr[curr + 1] - graph.in_ptr[curr];
      if (deg == 0) {  // the degree is zero
        break;
      }

      // randomly select by weight
      const FloatType *prob_sum = prob_sums[metapath_id];
      const FloatType *prob = probs[metapath_id];
      int64_t num;
      if (prob == nullptr) {
        num = curand(&rng) % deg;
      } else {
        auto rnd_sum_w = prob_sum[curr] * curand_uniform(&rng);
        FloatType sum_w{0.};
        for (num = 0; num < deg; ++num) {
          sum_w += prob[in_row_start + num];
          if (sum_w >= rnd_sum_w) break;
        }
      }

      IdType pick = graph.in_cols[in_row_start + num];
147
148
      IdType eid =
          (graph.data ? graph.data[in_row_start + num] : in_row_start + num);
149
150
      *traces_data_ptr = pick;
      *eids_data_ptr = eid;
151
152
      if ((restart_prob_size > 1) &&
          (curand_uniform(&rng) < restart_prob_data[step_idx])) {
153
        break;
154
155
156
      } else if (
          (restart_prob_size == 1) &&
          (curand_uniform(&rng) < restart_prob_data[0])) {
157
158
        break;
      }
159
160
      ++traces_data_ptr;
      ++eids_data_ptr;
161
162
163
164
165
166
167
168
169
170
      curr = pick;
    }
    for (; step_idx < max_num_steps; ++step_idx) {
      *(traces_data_ptr++) = -1;
      *(eids_data_ptr++) = -1;
    }
    idx += BLOCK_SIZE;
  }
}

171
172
173
}  // namespace

// random walk for uniform choice
174
template <DGLDeviceType XPU, typename IdType>
175
std::pair<IdArray, IdArray> RandomWalkUniform(
176
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
177
178
179
    FloatArray restart_prob) {
  const int64_t max_num_steps = metapath->shape[0];
  const IdType *metapath_data = static_cast<IdType *>(metapath->data);
180
181
  const int64_t begin_ntype =
      hg->meta_graph()->FindEdge(metapath_data[0]).first;
182
  const int64_t max_nodes = hg->NumVertices(begin_ntype);
183
  int64_t num_etypes = hg->NumEdgeTypes();
184
  auto ctx = seeds->ctx;
185

186
  const IdType *seed_data = static_cast<const IdType *>(seeds->data);
187
188
189
  CHECK(seeds->ndim == 1) << "seeds shape is not one dimension.";
  const int64_t num_seeds = seeds->shape[0];
  int64_t trace_length = max_num_steps + 1;
190
191
  IdArray traces = IdArray::Empty({num_seeds, trace_length}, seeds->dtype, ctx);
  IdArray eids = IdArray::Empty({num_seeds, max_num_steps}, seeds->dtype, ctx);
192
193
194
  IdType *traces_data = traces.Ptr<IdType>();
  IdType *eids_data = eids.Ptr<IdType>();

195
  std::vector<GraphKernelData<IdType>> h_graphs(num_etypes);
196
197
  for (int64_t etype = 0; etype < num_etypes; ++etype) {
    const CSRMatrix &csr = hg->GetCSRMatrix(etype);
198
199
200
201
202
    h_graphs[etype].in_ptr = static_cast<const IdType *>(csr.indptr->data);
    h_graphs[etype].in_cols = static_cast<const IdType *>(csr.indices->data);
    h_graphs[etype].data =
        (CSRHasData(csr) ? static_cast<const IdType *>(csr.data->data)
                         : nullptr);
203
  }
204
  // use cuda stream from local thread
205
  cudaStream_t stream = runtime::getCurrentCUDAStream();
206
  auto device = DeviceAPI::Get(ctx);
207
208
  auto d_graphs = static_cast<GraphKernelData<IdType> *>(device->AllocWorkspace(
      ctx, (num_etypes) * sizeof(GraphKernelData<IdType>)));
209
  // copy graph metadata pointers to GPU
210
211
212
213
  device->CopyDataFromTo(
      h_graphs.data(), 0, d_graphs, 0,
      (num_etypes) * sizeof(GraphKernelData<IdType>), DGLContext{kDGLCPU, 0},
      ctx, hg->GetCSRMatrix(0).indptr->dtype);
214
215
216
  // copy metapath to GPU
  auto d_metapath = metapath.CopyTo(ctx);
  const IdType *d_metapath_data = static_cast<IdType *>(d_metapath->data);
217
218
219
220
221
222

  constexpr int BLOCK_SIZE = 256;
  constexpr int TILE_SIZE = BLOCK_SIZE * 4;
  dim3 block(256);
  dim3 grid((num_seeds + TILE_SIZE - 1) / TILE_SIZE);
  const uint64_t random_seed = RandomEngine::ThreadLocal()->RandInt(1000000000);
223
224
225
226
227
228
229
230
231
232
233
234
235
  ATEN_FLOAT_TYPE_SWITCH(
      restart_prob->dtype, FloatType, "random walk GPU kernel", {
        CHECK(restart_prob->ctx.device_type == kDGLCUDA)
            << "restart prob should be in GPU.";
        CHECK(restart_prob->ndim == 1) << "restart prob dimension should be 1.";
        const FloatType *restart_prob_data = restart_prob.Ptr<FloatType>();
        const int64_t restart_prob_size = restart_prob->shape[0];
        CUDA_KERNEL_CALL(
            (_RandomWalkKernel<IdType, FloatType, BLOCK_SIZE, TILE_SIZE>), grid,
            block, 0, stream, random_seed, seed_data, num_seeds,
            d_metapath_data, max_num_steps, d_graphs, restart_prob_data,
            restart_prob_size, max_nodes, traces_data, eids_data);
      });
236
237
238
239
240

  device->FreeWorkspace(ctx, d_graphs);
  return std::make_pair(traces, eids);
}

241
/**
242
 * @brief Random walk for biased choice. We use inverse transform sampling to
243
244
 * choose the next step.
 */
245
template <DGLDeviceType XPU, typename FloatType, typename IdType>
246
std::pair<IdArray, IdArray> RandomWalkBiased(
247
248
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
    const std::vector<FloatArray> &prob, FloatArray restart_prob) {
249
250
  const int64_t max_num_steps = metapath->shape[0];
  const IdType *metapath_data = static_cast<IdType *>(metapath->data);
251
252
  const int64_t begin_ntype =
      hg->meta_graph()->FindEdge(metapath_data[0]).first;
253
254
255
256
  const int64_t max_nodes = hg->NumVertices(begin_ntype);
  int64_t num_etypes = hg->NumEdgeTypes();
  auto ctx = seeds->ctx;

257
  const IdType *seed_data = static_cast<const IdType *>(seeds->data);
258
259
260
261
262
263
264
265
  CHECK(seeds->ndim == 1) << "seeds shape is not one dimension.";
  const int64_t num_seeds = seeds->shape[0];
  int64_t trace_length = max_num_steps + 1;
  IdArray traces = IdArray::Empty({num_seeds, trace_length}, seeds->dtype, ctx);
  IdArray eids = IdArray::Empty({num_seeds, max_num_steps}, seeds->dtype, ctx);
  IdType *traces_data = traces.Ptr<IdType>();
  IdType *eids_data = eids.Ptr<IdType>();

266
  cudaStream_t stream = runtime::getCurrentCUDAStream();
267
268
269
270
271
272
273
274
275
276
277
278
  auto device = DeviceAPI::Get(ctx);
  // new probs and prob sums pointers
  assert(num_etypes == static_cast<int64_t>(prob.size()));
  std::unique_ptr<FloatType *[]> probs(new FloatType *[prob.size()]);
  std::unique_ptr<FloatType *[]> prob_sums(new FloatType *[prob.size()]);
  std::vector<FloatArray> prob_sums_arr;
  prob_sums_arr.reserve(prob.size());

  // graphs
  std::vector<GraphKernelData<IdType>> h_graphs(num_etypes);
  for (int64_t etype = 0; etype < num_etypes; ++etype) {
    const CSRMatrix &csr = hg->GetCSRMatrix(etype);
279
280
281
282
283
    h_graphs[etype].in_ptr = static_cast<const IdType *>(csr.indptr->data);
    h_graphs[etype].in_cols = static_cast<const IdType *>(csr.indices->data);
    h_graphs[etype].data =
        (CSRHasData(csr) ? static_cast<const IdType *>(csr.data->data)
                         : nullptr);
284
285
286
287
288
289
290
291
292

    int64_t num_segments = csr.indptr->shape[0] - 1;
    // will handle empty probs in the kernel
    if (IsNullArray(prob[etype])) {
      probs[etype] = nullptr;
      prob_sums[etype] = nullptr;
      continue;
    }
    probs[etype] = prob[etype].Ptr<FloatType>();
293
294
    prob_sums_arr.push_back(
        FloatArray::Empty({num_segments}, prob[etype]->dtype, ctx));
295
296
297
    prob_sums[etype] = prob_sums_arr[etype].Ptr<FloatType>();

    // calculate the sum of the neighbor weights
298
    const IdType *d_offsets = static_cast<const IdType *>(csr.indptr->data);
299
    size_t temp_storage_size = 0;
300
301
302
    CUDA_CALL(cub::DeviceSegmentedReduce::Sum(
        nullptr, temp_storage_size, probs[etype], prob_sums[etype],
        num_segments, d_offsets, d_offsets + 1, stream));
303
    void *temp_storage = device->AllocWorkspace(ctx, temp_storage_size);
304
305
306
    CUDA_CALL(cub::DeviceSegmentedReduce::Sum(
        temp_storage, temp_storage_size, probs[etype], prob_sums[etype],
        num_segments, d_offsets, d_offsets + 1, stream));
307
308
309
310
    device->FreeWorkspace(ctx, temp_storage);
  }

  // copy graph metadata pointers to GPU
311
312
313
314
315
316
  auto d_graphs = static_cast<GraphKernelData<IdType> *>(device->AllocWorkspace(
      ctx, (num_etypes) * sizeof(GraphKernelData<IdType>)));
  device->CopyDataFromTo(
      h_graphs.data(), 0, d_graphs, 0,
      (num_etypes) * sizeof(GraphKernelData<IdType>), DGLContext{kDGLCPU, 0},
      ctx, hg->GetCSRMatrix(0).indptr->dtype);
317
318
319
  // copy probs pointers to GPU
  const FloatType **probs_dev = static_cast<const FloatType **>(
      device->AllocWorkspace(ctx, num_etypes * sizeof(FloatType *)));
320
321
322
  device->CopyDataFromTo(
      probs.get(), 0, probs_dev, 0, (num_etypes) * sizeof(FloatType *),
      DGLContext{kDGLCPU, 0}, ctx, prob[0]->dtype);
323
324
325
  // copy probs_sum pointers to GPU
  const FloatType **prob_sums_dev = static_cast<const FloatType **>(
      device->AllocWorkspace(ctx, num_etypes * sizeof(FloatType *)));
326
327
328
  device->CopyDataFromTo(
      prob_sums.get(), 0, prob_sums_dev, 0, (num_etypes) * sizeof(FloatType *),
      DGLContext{kDGLCPU, 0}, ctx, prob[0]->dtype);
329
330
331
332
333
334
335
336
337
  // copy metapath to GPU
  auto d_metapath = metapath.CopyTo(ctx);
  const IdType *d_metapath_data = static_cast<IdType *>(d_metapath->data);

  constexpr int BLOCK_SIZE = 256;
  constexpr int TILE_SIZE = BLOCK_SIZE * 4;
  dim3 block(256);
  dim3 grid((num_seeds + TILE_SIZE - 1) / TILE_SIZE);
  const uint64_t random_seed = RandomEngine::ThreadLocal()->RandInt(1000000000);
338
339
  CHECK(restart_prob->ctx.device_type == kDGLCUDA)
      << "restart prob should be in GPU.";
340
341
342
343
  CHECK(restart_prob->ndim == 1) << "restart prob dimension should be 1.";
  const FloatType *restart_prob_data = restart_prob.Ptr<FloatType>();
  const int64_t restart_prob_size = restart_prob->shape[0];
  CUDA_KERNEL_CALL(
344
345
346
347
      (_RandomWalkBiasedKernel<IdType, FloatType, BLOCK_SIZE, TILE_SIZE>), grid,
      block, 0, stream, random_seed, seed_data, num_seeds, d_metapath_data,
      max_num_steps, d_graphs, probs_dev, prob_sums_dev, restart_prob_data,
      restart_prob_size, max_nodes, traces_data, eids_data);
348
349
350
351
352
353
354

  device->FreeWorkspace(ctx, d_graphs);
  device->FreeWorkspace(ctx, probs_dev);
  device->FreeWorkspace(ctx, prob_sums_dev);
  return std::make_pair(traces, eids);
}

355
template <DGLDeviceType XPU, typename IdType>
356
std::pair<IdArray, IdArray> RandomWalk(
357
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
358
    const std::vector<FloatArray> &prob) {
359
  bool isUniform = true;
360
361
  for (const auto &etype_prob : prob) {
    if (!IsNullArray(etype_prob)) {
362
363
      isUniform = false;
      break;
364
365
366
    }
  }

367
368
  auto restart_prob =
      NDArray::Empty({0}, DGLDataType{kDGLFloat, 32, 1}, DGLContext{XPU, 0});
369
370
371
  if (!isUniform) {
    std::pair<IdArray, IdArray> ret;
    ATEN_FLOAT_TYPE_SWITCH(prob[0]->dtype, FloatType, "probability", {
372
373
      ret = RandomWalkBiased<XPU, FloatType, IdType>(
          hg, seeds, metapath, prob, restart_prob);
374
375
376
377
378
    });
    return ret;
  } else {
    return RandomWalkUniform<XPU, IdType>(hg, seeds, metapath, restart_prob);
  }
379
380
}

381
template <DGLDeviceType XPU, typename IdType>
382
std::pair<IdArray, IdArray> RandomWalkWithRestart(
383
384
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
    const std::vector<FloatArray> &prob, double restart_prob) {
385
  bool isUniform = true;
386
387
  for (const auto &etype_prob : prob) {
    if (!IsNullArray(etype_prob)) {
388
389
      isUniform = false;
      break;
390
391
    }
  }
392

393
  auto device_ctx = seeds->ctx;
394
395
  auto restart_prob_array =
      NDArray::Empty({1}, DGLDataType{kDGLFloat, 64, 1}, device_ctx);
396
397
  auto device = dgl::runtime::DeviceAPI::Get(device_ctx);

398
  // use cuda stream from local thread
399
  cudaStream_t stream = runtime::getCurrentCUDAStream();
400
  device->CopyDataFromTo(
401
402
      &restart_prob, 0, restart_prob_array.Ptr<double>(), 0, sizeof(double),
      DGLContext{kDGLCPU, 0}, device_ctx, restart_prob_array->dtype);
403
404
  device->StreamSync(device_ctx, stream);

405
406
407
408
409
410
411
412
  if (!isUniform) {
    std::pair<IdArray, IdArray> ret;
    ATEN_FLOAT_TYPE_SWITCH(prob[0]->dtype, FloatType, "probability", {
      ret = RandomWalkBiased<XPU, FloatType, IdType>(
          hg, seeds, metapath, prob, restart_prob_array);
    });
    return ret;
  } else {
413
414
    return RandomWalkUniform<XPU, IdType>(
        hg, seeds, metapath, restart_prob_array);
415
  }
416
417
}

418
template <DGLDeviceType XPU, typename IdType>
419
std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart(
420
421
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
    const std::vector<FloatArray> &prob, FloatArray restart_prob) {
422
  bool isUniform = true;
423
424
  for (const auto &etype_prob : prob) {
    if (!IsNullArray(etype_prob)) {
425
426
      isUniform = false;
      break;
427
428
429
    }
  }

430
431
432
  if (!isUniform) {
    std::pair<IdArray, IdArray> ret;
    ATEN_FLOAT_TYPE_SWITCH(prob[0]->dtype, FloatType, "probability", {
433
434
      ret = RandomWalkBiased<XPU, FloatType, IdType>(
          hg, seeds, metapath, prob, restart_prob);
435
436
437
438
439
    });
    return ret;
  } else {
    return RandomWalkUniform<XPU, IdType>(hg, seeds, metapath, restart_prob);
  }
440
441
}

442
template <DGLDeviceType XPU, typename IdxType>
443
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(
444
    const IdArray src, const IdArray dst, const int64_t num_samples_per_node,
445
    const int64_t k) {
446
447
448
  CHECK(src->ctx.device_type == kDGLCUDA) << "IdArray needs be on GPU!";
  const IdxType *src_data = src.Ptr<IdxType>();
  const IdxType *dst_data = dst.Ptr<IdxType>();
449
450
  const int64_t num_dst_nodes = (dst->shape[0] / num_samples_per_node);
  auto ctx = src->ctx;
451
  // use cuda stream from local thread
452
  cudaStream_t stream = runtime::getCurrentCUDAStream();
453
454
455
456
  auto frequency_hashmap = FrequencyHashmap<IdxType>(
      num_dst_nodes, num_samples_per_node, ctx, stream);
  auto ret = frequency_hashmap.Topk(
      src_data, dst_data, src->dtype, src->shape[0], num_samples_per_node, k);
457
458
459
  return ret;
}

460
461
template std::pair<IdArray, IdArray> RandomWalk<kDGLCUDA, int32_t>(
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
462
    const std::vector<FloatArray> &prob);
463
464
template std::pair<IdArray, IdArray> RandomWalk<kDGLCUDA, int64_t>(
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
465
466
    const std::vector<FloatArray> &prob);

467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
template std::pair<IdArray, IdArray> RandomWalkWithRestart<kDGLCUDA, int32_t>(
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
    const std::vector<FloatArray> &prob, double restart_prob);
template std::pair<IdArray, IdArray> RandomWalkWithRestart<kDGLCUDA, int64_t>(
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
    const std::vector<FloatArray> &prob, double restart_prob);

template std::pair<IdArray, IdArray>
RandomWalkWithStepwiseRestart<kDGLCUDA, int32_t>(
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
    const std::vector<FloatArray> &prob, FloatArray restart_prob);
template std::pair<IdArray, IdArray>
RandomWalkWithStepwiseRestart<kDGLCUDA, int64_t>(
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
    const std::vector<FloatArray> &prob, FloatArray restart_prob);

template std::tuple<IdArray, IdArray, IdArray>
SelectPinSageNeighbors<kDGLCUDA, int32_t>(
    const IdArray src, const IdArray dst, const int64_t num_samples_per_node,
486
    const int64_t k);
487
488
489
template std::tuple<IdArray, IdArray, IdArray>
SelectPinSageNeighbors<kDGLCUDA, int64_t>(
    const IdArray src, const IdArray dst, const int64_t num_samples_per_node,
490
491
492
493
494
495
496
    const int64_t k);

};  // namespace impl

};  // namespace sampling

};  // namespace dgl