randomwalk_gpu.cu 19.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
#include "hip/hip_runtime.h"
/*!
 *  Copyright (c) 2021-2022 by Contributors
 * \file graph/sampling/randomwalk_gpu.cu
 * \brief CUDA random walk sampleing
 */

#include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include <dgl/runtime/device_api.h>
#include <dgl/random.h>
#include <hiprand_kernel.h>
#include <vector>
#include <utility>
#include <tuple>

#include "../../../array/cuda/dgl_cub.cuh"
#include "../../../runtime/cuda/cuda_common.h"
#include "frequency_hashmap.cuh"

namespace dgl {

using namespace dgl::runtime;
using namespace dgl::aten;

namespace sampling {

namespace impl {

namespace {

template<typename IdType>
struct GraphKernelData {
  const IdType *in_ptr;
  const IdType *in_cols;
  const IdType *data;
};

template<typename IdType, typename FloatType, int BLOCK_SIZE, int TILE_SIZE>
__global__ void _RandomWalkKernel(
    const uint64_t rand_seed, const IdType *seed_data, const int64_t num_seeds,
    const IdType* metapath_data, const uint64_t max_num_steps,
    const GraphKernelData<IdType>* graphs,
    const FloatType* restart_prob_data,
    const int64_t restart_prob_size,
    const int64_t max_nodes,
    IdType *out_traces_data,
    IdType *out_eids_data) {
  assert(BLOCK_SIZE == blockDim.x);
  int64_t idx = blockIdx.x * TILE_SIZE + threadIdx.x;
  int64_t last_idx = min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_seeds);
  int64_t trace_length = (max_num_steps + 1);
  hiprandState rng;
  // reference:
  //     https://docs.nvidia.com/cuda/hiprand/device-api-overview.html#performance-notes
  hiprand_init(rand_seed + idx, 0, 0, &rng);

  while (idx < last_idx) {
    IdType curr = seed_data[idx];
    assert(curr < max_nodes);
    IdType *traces_data_ptr = &out_traces_data[idx * trace_length];
    IdType *eids_data_ptr = &out_eids_data[idx * max_num_steps];
    *(traces_data_ptr++) = curr;
    int64_t step_idx;
    for (step_idx = 0; step_idx < max_num_steps; ++step_idx) {
      IdType metapath_id = metapath_data[step_idx];
      const GraphKernelData<IdType> &graph = graphs[metapath_id];
      const int64_t in_row_start = graph.in_ptr[curr];
      const int64_t deg = graph.in_ptr[curr + 1] - graph.in_ptr[curr];
      if (deg == 0) {  // the degree is zero
        break;
      }
      const int64_t num = hiprand(&rng) % deg;
      IdType pick = graph.in_cols[in_row_start + num];
      IdType eid = (graph.data? graph.data[in_row_start + num] : in_row_start + num);
      *traces_data_ptr = pick;
      *eids_data_ptr = eid;
      if ((restart_prob_size > 1) && (hiprand_uniform(&rng) < restart_prob_data[step_idx])) {
        break;
      } else if ((restart_prob_size == 1) && (hiprand_uniform(&rng) < restart_prob_data[0])) {
        break;
      }
      ++traces_data_ptr; ++eids_data_ptr;
      curr = pick;
    }
    for (; step_idx < max_num_steps; ++step_idx) {
      *(traces_data_ptr++) = -1;
      *(eids_data_ptr++) = -1;
    }
    idx += BLOCK_SIZE;
  }
}

template <typename IdType, typename FloatType, int BLOCK_SIZE, int TILE_SIZE>
__global__ void _RandomWalkBiasedKernel(
    const uint64_t rand_seed,
    const IdType *seed_data,
    const int64_t num_seeds,
    const IdType *metapath_data,
    const uint64_t max_num_steps,
    const GraphKernelData<IdType> *graphs,
    const FloatType **probs,
    const FloatType **prob_sums,
    const FloatType *restart_prob_data,
    const int64_t restart_prob_size,
    const int64_t max_nodes,
    IdType *out_traces_data,
    IdType *out_eids_data) {
  assert(BLOCK_SIZE == blockDim.x);
  int64_t idx = blockIdx.x * TILE_SIZE + threadIdx.x;
  int64_t last_idx = min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_seeds);
  int64_t trace_length = (max_num_steps + 1);
  hiprandState rng;
  // reference:
  //     https://docs.nvidia.com/cuda/hiprand/device-api-overview.html#performance-notes
  hiprand_init(rand_seed + idx, 0, 0, &rng);

  while (idx < last_idx) {
    IdType curr = seed_data[idx];
    assert(curr < max_nodes);
    IdType *traces_data_ptr = &out_traces_data[idx * trace_length];
    IdType *eids_data_ptr = &out_eids_data[idx * max_num_steps];
    *(traces_data_ptr++) = curr;
    int64_t step_idx;
    for (step_idx = 0; step_idx < max_num_steps; ++step_idx) {
      IdType metapath_id = metapath_data[step_idx];
      const GraphKernelData<IdType> &graph = graphs[metapath_id];
      const int64_t in_row_start = graph.in_ptr[curr];
      const int64_t deg = graph.in_ptr[curr + 1] - graph.in_ptr[curr];
      if (deg == 0) {  // the degree is zero
        break;
      }

      // randomly select by weight
      const FloatType *prob_sum = prob_sums[metapath_id];
      const FloatType *prob = probs[metapath_id];
      int64_t num;
      if (prob == nullptr) {
        num = hiprand(&rng) % deg;
      } else {
        auto rnd_sum_w = prob_sum[curr] * hiprand_uniform(&rng);
        FloatType sum_w{0.};
        for (num = 0; num < deg; ++num) {
          sum_w += prob[in_row_start + num];
          if (sum_w >= rnd_sum_w) break;
        }
      }

      IdType pick = graph.in_cols[in_row_start + num];
      IdType eid = (graph.data? graph.data[in_row_start + num] : in_row_start + num);
      *traces_data_ptr = pick;
      *eids_data_ptr = eid;
      if ((restart_prob_size > 1) && (hiprand_uniform(&rng) < restart_prob_data[step_idx])) {
        break;
      } else if ((restart_prob_size == 1) && (hiprand_uniform(&rng) < restart_prob_data[0])) {
        break;
      }
      ++traces_data_ptr; ++eids_data_ptr;
      curr = pick;
    }
    for (; step_idx < max_num_steps; ++step_idx) {
      *(traces_data_ptr++) = -1;
      *(eids_data_ptr++) = -1;
    }
    idx += BLOCK_SIZE;
  }
}

}  // namespace

// random walk for uniform choice
template<DLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> RandomWalkUniform(
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    FloatArray restart_prob) {
  const int64_t max_num_steps = metapath->shape[0];
  const IdType *metapath_data = static_cast<IdType *>(metapath->data);
  const int64_t begin_ntype = hg->meta_graph()->FindEdge(metapath_data[0]).first;
  const int64_t max_nodes = hg->NumVertices(begin_ntype);
  int64_t num_etypes = hg->NumEdgeTypes();
  auto ctx = seeds->ctx;

  const IdType *seed_data = static_cast<const IdType*>(seeds->data);
  CHECK(seeds->ndim == 1) << "seeds shape is not one dimension.";
  const int64_t num_seeds = seeds->shape[0];
  int64_t trace_length = max_num_steps + 1;
  IdArray traces = IdArray::Empty({num_seeds, trace_length}, seeds->dtype, ctx);
  IdArray eids = IdArray::Empty({num_seeds, max_num_steps}, seeds->dtype, ctx);
  IdType *traces_data = traces.Ptr<IdType>();
  IdType *eids_data = eids.Ptr<IdType>();

  std::vector<GraphKernelData<IdType>> h_graphs(num_etypes);
  for (int64_t etype = 0; etype < num_etypes; ++etype) {
    const CSRMatrix &csr = hg->GetCSRMatrix(etype);
    h_graphs[etype].in_ptr  = static_cast<const IdType*>(csr.indptr->data);
    h_graphs[etype].in_cols = static_cast<const IdType*>(csr.indices->data);
    h_graphs[etype].data = (CSRHasData(csr) ? static_cast<const IdType*>(csr.data->data) : nullptr);
  }
  // use cuda stream from local thread
  hipStream_t stream = runtime::getCurrentCUDAStream();
  auto device = DeviceAPI::Get(ctx);
  auto d_graphs = static_cast<GraphKernelData<IdType>*>(
      device->AllocWorkspace(ctx, (num_etypes) * sizeof(GraphKernelData<IdType>)));
  // copy graph metadata pointers to GPU
  device->CopyDataFromTo(h_graphs.data(), 0, d_graphs, 0,
      (num_etypes) * sizeof(GraphKernelData<IdType>),
      DGLContext{kDLCPU, 0},
      ctx,
      hg->GetCSRMatrix(0).indptr->dtype);
  // copy metapath to GPU
  auto d_metapath = metapath.CopyTo(ctx);
  const IdType *d_metapath_data = static_cast<IdType *>(d_metapath->data);

  constexpr int BLOCK_SIZE = 256;
  constexpr int TILE_SIZE = BLOCK_SIZE * 4;
  dim3 block(256);
  dim3 grid((num_seeds + TILE_SIZE - 1) / TILE_SIZE);
  const uint64_t random_seed = RandomEngine::ThreadLocal()->RandInt(1000000000);
  ATEN_FLOAT_TYPE_SWITCH(restart_prob->dtype, FloatType, "random walk GPU kernel", {
lisj's avatar
lisj committed
222
    CHECK(restart_prob->ctx.device_type == kDLROCM) << "restart prob should be in GPU.";
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
    CHECK(restart_prob->ndim == 1) << "restart prob dimension should be 1.";
    const FloatType *restart_prob_data = restart_prob.Ptr<FloatType>();
    const int64_t restart_prob_size = restart_prob->shape[0];
    CUDA_KERNEL_CALL(
      (_RandomWalkKernel<IdType, FloatType, BLOCK_SIZE, TILE_SIZE>),
      grid, block, 0, stream,
      random_seed,
      seed_data,
      num_seeds,
      d_metapath_data,
      max_num_steps,
      d_graphs,
      restart_prob_data,
      restart_prob_size,
      max_nodes,
      traces_data,
      eids_data);
  });

  device->FreeWorkspace(ctx, d_graphs);
  return std::make_pair(traces, eids);
}

/** 
 * \brief Random walk for biased choice. We use inverse transform sampling to
 * choose the next step.
 */
template <DLDeviceType XPU, typename FloatType, typename IdType>
std::pair<IdArray, IdArray> RandomWalkBiased(
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob,
    FloatArray restart_prob) {
  const int64_t max_num_steps = metapath->shape[0];
  const IdType *metapath_data = static_cast<IdType *>(metapath->data);
  const int64_t begin_ntype = hg->meta_graph()->FindEdge(metapath_data[0]).first;
  const int64_t max_nodes = hg->NumVertices(begin_ntype);
  int64_t num_etypes = hg->NumEdgeTypes();
  auto ctx = seeds->ctx;

  const IdType *seed_data = static_cast<const IdType*>(seeds->data);
  CHECK(seeds->ndim == 1) << "seeds shape is not one dimension.";
  const int64_t num_seeds = seeds->shape[0];
  int64_t trace_length = max_num_steps + 1;
  IdArray traces = IdArray::Empty({num_seeds, trace_length}, seeds->dtype, ctx);
  IdArray eids = IdArray::Empty({num_seeds, max_num_steps}, seeds->dtype, ctx);
  IdType *traces_data = traces.Ptr<IdType>();
  IdType *eids_data = eids.Ptr<IdType>();

  hipStream_t stream = runtime::getCurrentCUDAStream();
  auto device = DeviceAPI::Get(ctx);
  // new probs and prob sums pointers
  assert(num_etypes == static_cast<int64_t>(prob.size()));
  std::unique_ptr<FloatType *[]> probs(new FloatType *[prob.size()]);
  std::unique_ptr<FloatType *[]> prob_sums(new FloatType *[prob.size()]);
  std::vector<FloatArray> prob_sums_arr;
  prob_sums_arr.reserve(prob.size());

  // graphs
  std::vector<GraphKernelData<IdType>> h_graphs(num_etypes);
  for (int64_t etype = 0; etype < num_etypes; ++etype) {
    const CSRMatrix &csr = hg->GetCSRMatrix(etype);
    h_graphs[etype].in_ptr  = static_cast<const IdType*>(csr.indptr->data);
    h_graphs[etype].in_cols = static_cast<const IdType*>(csr.indices->data);
    h_graphs[etype].data = (CSRHasData(csr) ? static_cast<const IdType*>(csr.data->data) : nullptr);

    int64_t num_segments = csr.indptr->shape[0] - 1;
    // will handle empty probs in the kernel
    if (IsNullArray(prob[etype])) {
      probs[etype] = nullptr;
      prob_sums[etype] = nullptr;
      continue;
    }
    probs[etype] = prob[etype].Ptr<FloatType>();
    prob_sums_arr.push_back(FloatArray::Empty({num_segments}, prob[etype]->dtype, ctx));
    prob_sums[etype] = prob_sums_arr[etype].Ptr<FloatType>();

    // calculate the sum of the neighbor weights
    const IdType *d_offsets = static_cast<const IdType*>(csr.indptr->data);
    size_t temp_storage_size = 0;
    CUDA_CALL(hipcub::DeviceSegmentedReduce::Sum(nullptr, temp_storage_size,
        probs[etype],
        prob_sums[etype],
        num_segments,
        d_offsets,
        d_offsets + 1, stream));
    void *temp_storage = device->AllocWorkspace(ctx, temp_storage_size);
    CUDA_CALL(hipcub::DeviceSegmentedReduce::Sum(temp_storage, temp_storage_size,
        probs[etype],
        prob_sums[etype],
        num_segments,
        d_offsets,
        d_offsets + 1, stream));
    device->FreeWorkspace(ctx, temp_storage);
  }

  // copy graph metadata pointers to GPU
  auto d_graphs = static_cast<GraphKernelData<IdType>*>(
      device->AllocWorkspace(ctx, (num_etypes) * sizeof(GraphKernelData<IdType>)));
  device->CopyDataFromTo(h_graphs.data(), 0, d_graphs, 0,
      (num_etypes) * sizeof(GraphKernelData<IdType>),
      DGLContext{kDLCPU, 0},
      ctx,
      hg->GetCSRMatrix(0).indptr->dtype);
  // copy probs pointers to GPU
  const FloatType **probs_dev = static_cast<const FloatType **>(
      device->AllocWorkspace(ctx, num_etypes * sizeof(FloatType *)));
  device->CopyDataFromTo(probs.get(), 0, probs_dev, 0,
      (num_etypes) * sizeof(FloatType *),
      DGLContext{kDLCPU, 0},
      ctx,
      prob[0]->dtype);
  // copy probs_sum pointers to GPU
  const FloatType **prob_sums_dev = static_cast<const FloatType **>(
      device->AllocWorkspace(ctx, num_etypes * sizeof(FloatType *)));
  device->CopyDataFromTo(prob_sums.get(), 0, prob_sums_dev, 0,
      (num_etypes) * sizeof(FloatType *),
      DGLContext{kDLCPU, 0},
      ctx,
      prob[0]->dtype);
  // copy metapath to GPU
  auto d_metapath = metapath.CopyTo(ctx);
  const IdType *d_metapath_data = static_cast<IdType *>(d_metapath->data);

  constexpr int BLOCK_SIZE = 256;
  constexpr int TILE_SIZE = BLOCK_SIZE * 4;
  dim3 block(256);
  dim3 grid((num_seeds + TILE_SIZE - 1) / TILE_SIZE);
  const uint64_t random_seed = RandomEngine::ThreadLocal()->RandInt(1000000000);
lisj's avatar
lisj committed
353
  CHECK(restart_prob->ctx.device_type == kDLROCM) << "restart prob should be in GPU.";
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
  CHECK(restart_prob->ndim == 1) << "restart prob dimension should be 1.";
  const FloatType *restart_prob_data = restart_prob.Ptr<FloatType>();
  const int64_t restart_prob_size = restart_prob->shape[0];
  CUDA_KERNEL_CALL(
    (_RandomWalkBiasedKernel<IdType, FloatType, BLOCK_SIZE, TILE_SIZE>),
    grid, block, 0, stream,
    random_seed,
    seed_data,
    num_seeds,
    d_metapath_data,
    max_num_steps,
    d_graphs,
    probs_dev,
    prob_sums_dev,
    restart_prob_data,
    restart_prob_size,
    max_nodes,
    traces_data,
    eids_data);

  device->FreeWorkspace(ctx, d_graphs);
  device->FreeWorkspace(ctx, probs_dev);
  device->FreeWorkspace(ctx, prob_sums_dev);
  return std::make_pair(traces, eids);
}

template<DLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> RandomWalk(
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob) {

  bool isUniform = true;
  for (const auto &etype_prob : prob) {
    if (!IsNullArray(etype_prob)) {
      isUniform = false;
      break;
    }
  }

  auto restart_prob = NDArray::Empty(
      {0}, DLDataType{kDLFloat, 32, 1}, DGLContext{XPU, 0});
  if (!isUniform) {
    std::pair<IdArray, IdArray> ret;
    ATEN_FLOAT_TYPE_SWITCH(prob[0]->dtype, FloatType, "probability", {
      ret = RandomWalkBiased<XPU, FloatType, IdType>(hg, seeds, metapath, prob, restart_prob);
    });
    return ret;
  } else {
    return RandomWalkUniform<XPU, IdType>(hg, seeds, metapath, restart_prob);
  }
}

template<DLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> RandomWalkWithRestart(
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob,
    double restart_prob) {

  bool isUniform = true;
  for (const auto &etype_prob : prob) {
    if (!IsNullArray(etype_prob)) {
      isUniform = false;
      break;
    }
  }

  auto device_ctx = seeds->ctx;
  auto restart_prob_array = NDArray::Empty(
      {1}, DLDataType{kDLFloat, 64, 1}, device_ctx);
  auto device = dgl::runtime::DeviceAPI::Get(device_ctx);

  // use cuda stream from local thread
  hipStream_t stream = runtime::getCurrentCUDAStream();
  device->CopyDataFromTo(
      &restart_prob, 0, restart_prob_array.Ptr<double>(), 0,
      sizeof(double),
      DGLContext{kDLCPU, 0}, device_ctx,
      restart_prob_array->dtype);
  device->StreamSync(device_ctx, stream);

  if (!isUniform) {
    std::pair<IdArray, IdArray> ret;
    ATEN_FLOAT_TYPE_SWITCH(prob[0]->dtype, FloatType, "probability", {
      ret = RandomWalkBiased<XPU, FloatType, IdType>(
          hg, seeds, metapath, prob, restart_prob_array);
    });
    return ret;
  } else {
    return RandomWalkUniform<XPU, IdType>(hg, seeds, metapath, restart_prob_array);
  }
}

template<DLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart(
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob,
    FloatArray restart_prob) {

  bool isUniform = true;
  for (const auto &etype_prob : prob) {
    if (!IsNullArray(etype_prob)) {
      isUniform = false;
      break;
    }
  }

  if (!isUniform) {
    std::pair<IdArray, IdArray> ret;
    ATEN_FLOAT_TYPE_SWITCH(prob[0]->dtype, FloatType, "probability", {
      ret = RandomWalkBiased<XPU, FloatType, IdType>(hg, seeds, metapath, prob, restart_prob);
    });
    return ret;
  } else {
    return RandomWalkUniform<XPU, IdType>(hg, seeds, metapath, restart_prob);
  }
}

template<DLDeviceType XPU, typename IdxType>
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(
    const IdArray src,
    const IdArray dst,
    const int64_t num_samples_per_node,
    const int64_t k) {
lisj's avatar
lisj committed
483
  CHECK(src->ctx.device_type == kDLROCM) <<
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
    "IdArray needs be on GPU!";
  const IdxType* src_data = src.Ptr<IdxType>();
  const IdxType* dst_data = dst.Ptr<IdxType>();
  const int64_t num_dst_nodes = (dst->shape[0] / num_samples_per_node);
  auto ctx = src->ctx;
  // use cuda stream from local thread
  hipStream_t stream = runtime::getCurrentCUDAStream();
  auto frequency_hashmap = FrequencyHashmap<IdxType>(num_dst_nodes,
      num_samples_per_node, ctx, stream);
  auto ret = frequency_hashmap.Topk(src_data, dst_data, src->dtype,
      src->shape[0], num_samples_per_node, k);
  return ret;
}

template
lisj's avatar
lisj committed
499
std::pair<IdArray, IdArray> RandomWalk<kDLROCM, int32_t>(
500
501
502
503
504
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob);
template
lisj's avatar
lisj committed
505
std::pair<IdArray, IdArray> RandomWalk<kDLROCM, int64_t>(
506
507
508
509
510
511
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob);

template
lisj's avatar
lisj committed
512
std::pair<IdArray, IdArray> RandomWalkWithRestart<kDLROCM, int32_t>(
513
514
515
516
517
518
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob,
    double restart_prob);
template
lisj's avatar
lisj committed
519
std::pair<IdArray, IdArray> RandomWalkWithRestart<kDLROCM, int64_t>(
520
521
522
523
524
525
526
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob,
    double restart_prob);

template
lisj's avatar
lisj committed
527
std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart<kDLROCM, int32_t>(
528
529
530
531
532
533
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob,
    FloatArray restart_prob);
template
lisj's avatar
lisj committed
534
std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart<kDLROCM, int64_t>(
535
536
537
538
539
540
541
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob,
    FloatArray restart_prob);

template
lisj's avatar
lisj committed
542
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors<kDLROCM, int32_t>(
543
544
545
546
547
    const IdArray src,
    const IdArray dst,
    const int64_t num_samples_per_node,
    const int64_t k);
template
lisj's avatar
lisj committed
548
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors<kDLROCM, int64_t>(
549
550
551
552
553
554
555
556
557
558
559
    const IdArray src,
    const IdArray dst,
    const int64_t num_samples_per_node,
    const int64_t k);


};  // namespace impl

};  // namespace sampling

};  // namespace dgl