randomwalk_gpu.hip 19.4 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
2
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
3
/**
4
 *  Copyright (c) 2021-2022 by Contributors
5
6
 * @file graph/sampling/randomwalk_gpu.cu
 * @brief CUDA random walk sampleing
7
8
 */

sangwzh's avatar
sangwzh committed
9
#include <hiprand/hiprand_kernel.h>
10
11
12
#include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include <dgl/random.h>
13
14
#include <dgl/runtime/device_api.h>

sangwzh's avatar
sangwzh committed
15
#include <hipcub/hipcub.hpp>
16
#include <tuple>
17
18
#include <utility>
#include <vector>
19

20
#include "../../../runtime/cuda/cuda_common.h"
21
22
23
24
25
26
27
28
29
30
31
32
33
#include "frequency_hashmap.cuh"

namespace dgl {

using namespace dgl::runtime;
using namespace dgl::aten;

namespace sampling {

namespace impl {

namespace {

34
template <typename IdType>
35
36
37
38
39
40
struct GraphKernelData {
  const IdType *in_ptr;
  const IdType *in_cols;
  const IdType *data;
};

41
template <typename IdType, typename FloatType, int BLOCK_SIZE, int TILE_SIZE>
42
43
__global__ void _RandomWalkKernel(
    const uint64_t rand_seed, const IdType *seed_data, const int64_t num_seeds,
44
45
46
47
    const IdType *metapath_data, const uint64_t max_num_steps,
    const GraphKernelData<IdType> *graphs, const FloatType *restart_prob_data,
    const int64_t restart_prob_size, const int64_t max_nodes,
    IdType *out_traces_data, IdType *out_eids_data) {
48
49
  assert(BLOCK_SIZE == blockDim.x);
  int64_t idx = blockIdx.x * TILE_SIZE + threadIdx.x;
50
51
  int64_t last_idx =
      min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_seeds);
52
  int64_t trace_length = (max_num_steps + 1);
sangwzh's avatar
sangwzh committed
53
  hiprandState_t rng;
54
  // reference:
sangwzh's avatar
sangwzh committed
55
56
  //     https://docs.nvidia.com/cuda/hiprand/device-api-overview.html#performance-notes
  hiprand_init(rand_seed + idx, 0, 0, &rng);
57
58
59

  while (idx < last_idx) {
    IdType curr = seed_data[idx];
60
    assert(curr < max_nodes);
61
62
63
64
65
66
67
68
69
70
71
72
    IdType *traces_data_ptr = &out_traces_data[idx * trace_length];
    IdType *eids_data_ptr = &out_eids_data[idx * max_num_steps];
    *(traces_data_ptr++) = curr;
    int64_t step_idx;
    for (step_idx = 0; step_idx < max_num_steps; ++step_idx) {
      IdType metapath_id = metapath_data[step_idx];
      const GraphKernelData<IdType> &graph = graphs[metapath_id];
      const int64_t in_row_start = graph.in_ptr[curr];
      const int64_t deg = graph.in_ptr[curr + 1] - graph.in_ptr[curr];
      if (deg == 0) {  // the degree is zero
        break;
      }
sangwzh's avatar
sangwzh committed
73
      const int64_t num = hiprand(&rng) % deg;
74
      IdType pick = graph.in_cols[in_row_start + num];
75
76
      IdType eid =
          (graph.data ? graph.data[in_row_start + num] : in_row_start + num);
77
78
      *traces_data_ptr = pick;
      *eids_data_ptr = eid;
79
      if ((restart_prob_size > 1) &&
sangwzh's avatar
sangwzh committed
80
          (hiprand_uniform(&rng) < restart_prob_data[step_idx])) {
81
        break;
82
83
      } else if (
          (restart_prob_size == 1) &&
sangwzh's avatar
sangwzh committed
84
          (hiprand_uniform(&rng) < restart_prob_data[0])) {
85
86
        break;
      }
87
88
      ++traces_data_ptr;
      ++eids_data_ptr;
89
90
91
92
93
94
95
96
97
98
      curr = pick;
    }
    for (; step_idx < max_num_steps; ++step_idx) {
      *(traces_data_ptr++) = -1;
      *(eids_data_ptr++) = -1;
    }
    idx += BLOCK_SIZE;
  }
}

99
100
template <typename IdType, typename FloatType, int BLOCK_SIZE, int TILE_SIZE>
__global__ void _RandomWalkBiasedKernel(
101
102
103
104
105
106
    const uint64_t rand_seed, const IdType *seed_data, const int64_t num_seeds,
    const IdType *metapath_data, const uint64_t max_num_steps,
    const GraphKernelData<IdType> *graphs, const FloatType **probs,
    const FloatType **prob_sums, const FloatType *restart_prob_data,
    const int64_t restart_prob_size, const int64_t max_nodes,
    IdType *out_traces_data, IdType *out_eids_data) {
107
108
  assert(BLOCK_SIZE == blockDim.x);
  int64_t idx = blockIdx.x * TILE_SIZE + threadIdx.x;
109
110
  int64_t last_idx =
      min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_seeds);
111
  int64_t trace_length = (max_num_steps + 1);
sangwzh's avatar
sangwzh committed
112
  hiprandState_t rng;
113
  // reference:
sangwzh's avatar
sangwzh committed
114
115
  //     https://docs.nvidia.com/cuda/hiprand/device-api-overview.html#performance-notes
  hiprand_init(rand_seed + idx, 0, 0, &rng);
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137

  while (idx < last_idx) {
    IdType curr = seed_data[idx];
    assert(curr < max_nodes);
    IdType *traces_data_ptr = &out_traces_data[idx * trace_length];
    IdType *eids_data_ptr = &out_eids_data[idx * max_num_steps];
    *(traces_data_ptr++) = curr;
    int64_t step_idx;
    for (step_idx = 0; step_idx < max_num_steps; ++step_idx) {
      IdType metapath_id = metapath_data[step_idx];
      const GraphKernelData<IdType> &graph = graphs[metapath_id];
      const int64_t in_row_start = graph.in_ptr[curr];
      const int64_t deg = graph.in_ptr[curr + 1] - graph.in_ptr[curr];
      if (deg == 0) {  // the degree is zero
        break;
      }

      // randomly select by weight
      const FloatType *prob_sum = prob_sums[metapath_id];
      const FloatType *prob = probs[metapath_id];
      int64_t num;
      if (prob == nullptr) {
sangwzh's avatar
sangwzh committed
138
        num = hiprand(&rng) % deg;
139
      } else {
sangwzh's avatar
sangwzh committed
140
        auto rnd_sum_w = prob_sum[curr] * hiprand_uniform(&rng);
141
142
143
144
145
146
147
148
        FloatType sum_w{0.};
        for (num = 0; num < deg; ++num) {
          sum_w += prob[in_row_start + num];
          if (sum_w >= rnd_sum_w) break;
        }
      }

      IdType pick = graph.in_cols[in_row_start + num];
149
150
      IdType eid =
          (graph.data ? graph.data[in_row_start + num] : in_row_start + num);
151
152
      *traces_data_ptr = pick;
      *eids_data_ptr = eid;
153
      if ((restart_prob_size > 1) &&
sangwzh's avatar
sangwzh committed
154
          (hiprand_uniform(&rng) < restart_prob_data[step_idx])) {
155
        break;
156
157
      } else if (
          (restart_prob_size == 1) &&
sangwzh's avatar
sangwzh committed
158
          (hiprand_uniform(&rng) < restart_prob_data[0])) {
159
160
        break;
      }
161
162
      ++traces_data_ptr;
      ++eids_data_ptr;
163
164
165
166
167
168
169
170
171
172
      curr = pick;
    }
    for (; step_idx < max_num_steps; ++step_idx) {
      *(traces_data_ptr++) = -1;
      *(eids_data_ptr++) = -1;
    }
    idx += BLOCK_SIZE;
  }
}

173
174
175
}  // namespace

// random walk for uniform choice
176
template <DGLDeviceType XPU, typename IdType>
177
std::pair<IdArray, IdArray> RandomWalkUniform(
178
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
179
180
181
    FloatArray restart_prob) {
  const int64_t max_num_steps = metapath->shape[0];
  const IdType *metapath_data = static_cast<IdType *>(metapath->data);
182
183
  const int64_t begin_ntype =
      hg->meta_graph()->FindEdge(metapath_data[0]).first;
184
  const int64_t max_nodes = hg->NumVertices(begin_ntype);
185
  int64_t num_etypes = hg->NumEdgeTypes();
186
  auto ctx = seeds->ctx;
187

188
  const IdType *seed_data = static_cast<const IdType *>(seeds->data);
189
190
191
  CHECK(seeds->ndim == 1) << "seeds shape is not one dimension.";
  const int64_t num_seeds = seeds->shape[0];
  int64_t trace_length = max_num_steps + 1;
192
193
  IdArray traces = IdArray::Empty({num_seeds, trace_length}, seeds->dtype, ctx);
  IdArray eids = IdArray::Empty({num_seeds, max_num_steps}, seeds->dtype, ctx);
194
195
196
  IdType *traces_data = traces.Ptr<IdType>();
  IdType *eids_data = eids.Ptr<IdType>();

197
  std::vector<GraphKernelData<IdType>> h_graphs(num_etypes);
198
199
  for (int64_t etype = 0; etype < num_etypes; ++etype) {
    const CSRMatrix &csr = hg->GetCSRMatrix(etype);
200
201
202
203
204
    h_graphs[etype].in_ptr = static_cast<const IdType *>(csr.indptr->data);
    h_graphs[etype].in_cols = static_cast<const IdType *>(csr.indices->data);
    h_graphs[etype].data =
        (CSRHasData(csr) ? static_cast<const IdType *>(csr.data->data)
                         : nullptr);
205
  }
206
  // use cuda stream from local thread
sangwzh's avatar
sangwzh committed
207
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
208
  auto device = DeviceAPI::Get(ctx);
209
210
  auto d_graphs = static_cast<GraphKernelData<IdType> *>(device->AllocWorkspace(
      ctx, (num_etypes) * sizeof(GraphKernelData<IdType>)));
211
  // copy graph metadata pointers to GPU
212
213
214
215
  device->CopyDataFromTo(
      h_graphs.data(), 0, d_graphs, 0,
      (num_etypes) * sizeof(GraphKernelData<IdType>), DGLContext{kDGLCPU, 0},
      ctx, hg->GetCSRMatrix(0).indptr->dtype);
216
217
218
  // copy metapath to GPU
  auto d_metapath = metapath.CopyTo(ctx);
  const IdType *d_metapath_data = static_cast<IdType *>(d_metapath->data);
219
220
221
222
223
224

  constexpr int BLOCK_SIZE = 256;
  constexpr int TILE_SIZE = BLOCK_SIZE * 4;
  dim3 block(256);
  dim3 grid((num_seeds + TILE_SIZE - 1) / TILE_SIZE);
  const uint64_t random_seed = RandomEngine::ThreadLocal()->RandInt(1000000000);
225
226
  ATEN_FLOAT_TYPE_SWITCH(
      restart_prob->dtype, FloatType, "random walk GPU kernel", {
sangwzh's avatar
sangwzh committed
227
        CHECK(restart_prob->ctx.device_type == kDGLCUDA||restart_prob->ctx.device_type == kDGLROCM)
228
229
230
231
232
233
234
235
236
237
            << "restart prob should be in GPU.";
        CHECK(restart_prob->ndim == 1) << "restart prob dimension should be 1.";
        const FloatType *restart_prob_data = restart_prob.Ptr<FloatType>();
        const int64_t restart_prob_size = restart_prob->shape[0];
        CUDA_KERNEL_CALL(
            (_RandomWalkKernel<IdType, FloatType, BLOCK_SIZE, TILE_SIZE>), grid,
            block, 0, stream, random_seed, seed_data, num_seeds,
            d_metapath_data, max_num_steps, d_graphs, restart_prob_data,
            restart_prob_size, max_nodes, traces_data, eids_data);
      });
238
239
240
241
242

  device->FreeWorkspace(ctx, d_graphs);
  return std::make_pair(traces, eids);
}

243
/**
244
 * @brief Random walk for biased choice. We use inverse transform sampling to
245
246
 * choose the next step.
 */
247
template <DGLDeviceType XPU, typename FloatType, typename IdType>
248
std::pair<IdArray, IdArray> RandomWalkBiased(
249
250
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
    const std::vector<FloatArray> &prob, FloatArray restart_prob) {
251
252
  const int64_t max_num_steps = metapath->shape[0];
  const IdType *metapath_data = static_cast<IdType *>(metapath->data);
253
254
  const int64_t begin_ntype =
      hg->meta_graph()->FindEdge(metapath_data[0]).first;
255
256
257
258
  const int64_t max_nodes = hg->NumVertices(begin_ntype);
  int64_t num_etypes = hg->NumEdgeTypes();
  auto ctx = seeds->ctx;

259
  const IdType *seed_data = static_cast<const IdType *>(seeds->data);
260
261
262
263
264
265
266
267
  CHECK(seeds->ndim == 1) << "seeds shape is not one dimension.";
  const int64_t num_seeds = seeds->shape[0];
  int64_t trace_length = max_num_steps + 1;
  IdArray traces = IdArray::Empty({num_seeds, trace_length}, seeds->dtype, ctx);
  IdArray eids = IdArray::Empty({num_seeds, max_num_steps}, seeds->dtype, ctx);
  IdType *traces_data = traces.Ptr<IdType>();
  IdType *eids_data = eids.Ptr<IdType>();

sangwzh's avatar
sangwzh committed
268
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
269
270
271
272
273
274
275
276
277
278
279
280
  auto device = DeviceAPI::Get(ctx);
  // new probs and prob sums pointers
  assert(num_etypes == static_cast<int64_t>(prob.size()));
  std::unique_ptr<FloatType *[]> probs(new FloatType *[prob.size()]);
  std::unique_ptr<FloatType *[]> prob_sums(new FloatType *[prob.size()]);
  std::vector<FloatArray> prob_sums_arr;
  prob_sums_arr.reserve(prob.size());

  // graphs
  std::vector<GraphKernelData<IdType>> h_graphs(num_etypes);
  for (int64_t etype = 0; etype < num_etypes; ++etype) {
    const CSRMatrix &csr = hg->GetCSRMatrix(etype);
281
282
283
284
285
    h_graphs[etype].in_ptr = static_cast<const IdType *>(csr.indptr->data);
    h_graphs[etype].in_cols = static_cast<const IdType *>(csr.indices->data);
    h_graphs[etype].data =
        (CSRHasData(csr) ? static_cast<const IdType *>(csr.data->data)
                         : nullptr);
286
287
288
289
290
291
292
293
294

    int64_t num_segments = csr.indptr->shape[0] - 1;
    // will handle empty probs in the kernel
    if (IsNullArray(prob[etype])) {
      probs[etype] = nullptr;
      prob_sums[etype] = nullptr;
      continue;
    }
    probs[etype] = prob[etype].Ptr<FloatType>();
295
296
    prob_sums_arr.push_back(
        FloatArray::Empty({num_segments}, prob[etype]->dtype, ctx));
297
298
299
    prob_sums[etype] = prob_sums_arr[etype].Ptr<FloatType>();

    // calculate the sum of the neighbor weights
300
    const IdType *d_offsets = static_cast<const IdType *>(csr.indptr->data);
301
    size_t temp_storage_size = 0;
sangwzh's avatar
sangwzh committed
302
    CUDA_CALL(hipcub::DeviceSegmentedReduce::Sum(
303
304
        nullptr, temp_storage_size, probs[etype], prob_sums[etype],
        num_segments, d_offsets, d_offsets + 1, stream));
305
    void *temp_storage = device->AllocWorkspace(ctx, temp_storage_size);
sangwzh's avatar
sangwzh committed
306
    CUDA_CALL(hipcub::DeviceSegmentedReduce::Sum(
307
308
        temp_storage, temp_storage_size, probs[etype], prob_sums[etype],
        num_segments, d_offsets, d_offsets + 1, stream));
309
310
311
312
    device->FreeWorkspace(ctx, temp_storage);
  }

  // copy graph metadata pointers to GPU
313
314
315
316
317
318
  auto d_graphs = static_cast<GraphKernelData<IdType> *>(device->AllocWorkspace(
      ctx, (num_etypes) * sizeof(GraphKernelData<IdType>)));
  device->CopyDataFromTo(
      h_graphs.data(), 0, d_graphs, 0,
      (num_etypes) * sizeof(GraphKernelData<IdType>), DGLContext{kDGLCPU, 0},
      ctx, hg->GetCSRMatrix(0).indptr->dtype);
319
320
321
  // copy probs pointers to GPU
  const FloatType **probs_dev = static_cast<const FloatType **>(
      device->AllocWorkspace(ctx, num_etypes * sizeof(FloatType *)));
322
323
324
  device->CopyDataFromTo(
      probs.get(), 0, probs_dev, 0, (num_etypes) * sizeof(FloatType *),
      DGLContext{kDGLCPU, 0}, ctx, prob[0]->dtype);
325
326
327
  // copy probs_sum pointers to GPU
  const FloatType **prob_sums_dev = static_cast<const FloatType **>(
      device->AllocWorkspace(ctx, num_etypes * sizeof(FloatType *)));
328
329
330
  device->CopyDataFromTo(
      prob_sums.get(), 0, prob_sums_dev, 0, (num_etypes) * sizeof(FloatType *),
      DGLContext{kDGLCPU, 0}, ctx, prob[0]->dtype);
331
332
333
334
335
336
337
338
339
  // copy metapath to GPU
  auto d_metapath = metapath.CopyTo(ctx);
  const IdType *d_metapath_data = static_cast<IdType *>(d_metapath->data);

  constexpr int BLOCK_SIZE = 256;
  constexpr int TILE_SIZE = BLOCK_SIZE * 4;
  dim3 block(256);
  dim3 grid((num_seeds + TILE_SIZE - 1) / TILE_SIZE);
  const uint64_t random_seed = RandomEngine::ThreadLocal()->RandInt(1000000000);
sangwzh's avatar
sangwzh committed
340
  CHECK(restart_prob->ctx.device_type == kDGLCUDA ||restart_prob->ctx.device_type == kDGLROCM)
341
      << "restart prob should be in GPU.";
342
343
344
345
  CHECK(restart_prob->ndim == 1) << "restart prob dimension should be 1.";
  const FloatType *restart_prob_data = restart_prob.Ptr<FloatType>();
  const int64_t restart_prob_size = restart_prob->shape[0];
  CUDA_KERNEL_CALL(
346
347
348
349
      (_RandomWalkBiasedKernel<IdType, FloatType, BLOCK_SIZE, TILE_SIZE>), grid,
      block, 0, stream, random_seed, seed_data, num_seeds, d_metapath_data,
      max_num_steps, d_graphs, probs_dev, prob_sums_dev, restart_prob_data,
      restart_prob_size, max_nodes, traces_data, eids_data);
350
351
352
353
354
355
356

  device->FreeWorkspace(ctx, d_graphs);
  device->FreeWorkspace(ctx, probs_dev);
  device->FreeWorkspace(ctx, prob_sums_dev);
  return std::make_pair(traces, eids);
}

357
template <DGLDeviceType XPU, typename IdType>
358
std::pair<IdArray, IdArray> RandomWalk(
359
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
360
    const std::vector<FloatArray> &prob) {
361
  bool isUniform = true;
362
363
  for (const auto &etype_prob : prob) {
    if (!IsNullArray(etype_prob)) {
364
365
      isUniform = false;
      break;
366
367
368
    }
  }

369
370
  auto restart_prob =
      NDArray::Empty({0}, DGLDataType{kDGLFloat, 32, 1}, DGLContext{XPU, 0});
371
372
373
  if (!isUniform) {
    std::pair<IdArray, IdArray> ret;
    ATEN_FLOAT_TYPE_SWITCH(prob[0]->dtype, FloatType, "probability", {
374
375
      ret = RandomWalkBiased<XPU, FloatType, IdType>(
          hg, seeds, metapath, prob, restart_prob);
376
377
378
379
380
    });
    return ret;
  } else {
    return RandomWalkUniform<XPU, IdType>(hg, seeds, metapath, restart_prob);
  }
381
382
}

383
template <DGLDeviceType XPU, typename IdType>
384
std::pair<IdArray, IdArray> RandomWalkWithRestart(
385
386
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
    const std::vector<FloatArray> &prob, double restart_prob) {
387
  bool isUniform = true;
388
389
  for (const auto &etype_prob : prob) {
    if (!IsNullArray(etype_prob)) {
390
391
      isUniform = false;
      break;
392
393
    }
  }
394

395
  auto device_ctx = seeds->ctx;
396
397
  auto restart_prob_array =
      NDArray::Empty({1}, DGLDataType{kDGLFloat, 64, 1}, device_ctx);
398
399
  auto device = dgl::runtime::DeviceAPI::Get(device_ctx);

400
  // use cuda stream from local thread
sangwzh's avatar
sangwzh committed
401
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
402
  device->CopyDataFromTo(
403
404
      &restart_prob, 0, restart_prob_array.Ptr<double>(), 0, sizeof(double),
      DGLContext{kDGLCPU, 0}, device_ctx, restart_prob_array->dtype);
405
406
  device->StreamSync(device_ctx, stream);

407
408
409
410
411
412
413
414
  if (!isUniform) {
    std::pair<IdArray, IdArray> ret;
    ATEN_FLOAT_TYPE_SWITCH(prob[0]->dtype, FloatType, "probability", {
      ret = RandomWalkBiased<XPU, FloatType, IdType>(
          hg, seeds, metapath, prob, restart_prob_array);
    });
    return ret;
  } else {
415
416
    return RandomWalkUniform<XPU, IdType>(
        hg, seeds, metapath, restart_prob_array);
417
  }
418
419
}

420
template <DGLDeviceType XPU, typename IdType>
421
std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart(
422
423
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
    const std::vector<FloatArray> &prob, FloatArray restart_prob) {
424
  bool isUniform = true;
425
426
  for (const auto &etype_prob : prob) {
    if (!IsNullArray(etype_prob)) {
427
428
      isUniform = false;
      break;
429
430
431
    }
  }

432
433
434
  if (!isUniform) {
    std::pair<IdArray, IdArray> ret;
    ATEN_FLOAT_TYPE_SWITCH(prob[0]->dtype, FloatType, "probability", {
435
436
      ret = RandomWalkBiased<XPU, FloatType, IdType>(
          hg, seeds, metapath, prob, restart_prob);
437
438
439
440
441
    });
    return ret;
  } else {
    return RandomWalkUniform<XPU, IdType>(hg, seeds, metapath, restart_prob);
  }
442
443
}

444
template <DGLDeviceType XPU, typename IdxType>
445
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(
446
    const IdArray src, const IdArray dst, const int64_t num_samples_per_node,
447
    const int64_t k) {
sangwzh's avatar
sangwzh committed
448
  CHECK(src->ctx.device_type == kDGLCUDA || src->ctx.device_type == kDGLROCM) << "IdArray needs be on GPU!";
449
450
  const IdxType *src_data = src.Ptr<IdxType>();
  const IdxType *dst_data = dst.Ptr<IdxType>();
451
452
  const int64_t num_dst_nodes = (dst->shape[0] / num_samples_per_node);
  auto ctx = src->ctx;
453
  // use cuda stream from local thread
sangwzh's avatar
sangwzh committed
454
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
455
456
457
458
  auto frequency_hashmap = FrequencyHashmap<IdxType>(
      num_dst_nodes, num_samples_per_node, ctx, stream);
  auto ret = frequency_hashmap.Topk(
      src_data, dst_data, src->dtype, src->shape[0], num_samples_per_node, k);
459
460
461
  return ret;
}

462
463
template std::pair<IdArray, IdArray> RandomWalk<kDGLCUDA, int32_t>(
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
464
    const std::vector<FloatArray> &prob);
465
466
template std::pair<IdArray, IdArray> RandomWalk<kDGLCUDA, int64_t>(
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
467
468
    const std::vector<FloatArray> &prob);

469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
template std::pair<IdArray, IdArray> RandomWalkWithRestart<kDGLCUDA, int32_t>(
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
    const std::vector<FloatArray> &prob, double restart_prob);
template std::pair<IdArray, IdArray> RandomWalkWithRestart<kDGLCUDA, int64_t>(
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
    const std::vector<FloatArray> &prob, double restart_prob);

template std::pair<IdArray, IdArray>
RandomWalkWithStepwiseRestart<kDGLCUDA, int32_t>(
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
    const std::vector<FloatArray> &prob, FloatArray restart_prob);
template std::pair<IdArray, IdArray>
RandomWalkWithStepwiseRestart<kDGLCUDA, int64_t>(
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
    const std::vector<FloatArray> &prob, FloatArray restart_prob);

template std::tuple<IdArray, IdArray, IdArray>
SelectPinSageNeighbors<kDGLCUDA, int32_t>(
    const IdArray src, const IdArray dst, const int64_t num_samples_per_node,
488
    const int64_t k);
489
490
491
template std::tuple<IdArray, IdArray, IdArray>
SelectPinSageNeighbors<kDGLCUDA, int64_t>(
    const IdArray src, const IdArray dst, const int64_t num_samples_per_node,
492
493
494
495
496
497
498
    const int64_t k);

};  // namespace impl

};  // namespace sampling

};  // namespace dgl