randomwalk_gpu.hip 21.8 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
2
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
3
/**
4
 *  Copyright (c) 2021-2022 by Contributors
5
6
 * @file graph/sampling/randomwalk_gpu.cu
 * @brief CUDA random walk sampleing
7
8
 */

sangwzh's avatar
sangwzh committed
9
#include <hiprand/hiprand_kernel.h>
10
11
12
#include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include <dgl/random.h>
13
14
#include <dgl/runtime/device_api.h>

sangwzh's avatar
sangwzh committed
15
#include <hipcub/hipcub.hpp>
16
#include <tuple>
17
18
#include <utility>
#include <vector>
19

20
#include "../../../runtime/cuda/cuda_common.h"
21
22
23
24
25
26
27
28
29
30
31
32
33
#include "frequency_hashmap.cuh"

namespace dgl {

using namespace dgl::runtime;
using namespace dgl::aten;

namespace sampling {

namespace impl {

namespace {

34
template <typename IdType>
35
36
37
38
39
struct GraphKernelData {
  const IdType *in_ptr;
  const IdType *in_cols;
  const IdType *data;
};
sangwzh's avatar
sangwzh committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
template<typename IdType>
inline IdType* __GetDevicePointer(runtime::NDArray array) {
  IdType* ptr = array.Ptr<IdType>();
  if (array.IsPinned()) {
    CUDA_CALL(hipHostGetDevicePointer(&ptr, ptr, 0));
  }
  return ptr;
}

inline void* __GetDevicePointer(runtime::NDArray array) {
  void* ptr = array->data;
  if (array.IsPinned()) {
    CUDA_CALL(hipHostGetDevicePointer(&ptr, ptr, 0));
  }
  return ptr;
}
56

57
template <typename IdType, typename FloatType, int BLOCK_SIZE, int TILE_SIZE>
58
59
__global__ void _RandomWalkKernel(
    const uint64_t rand_seed, const IdType *seed_data, const int64_t num_seeds,
60
61
62
63
    const IdType *metapath_data, const uint64_t max_num_steps,
    const GraphKernelData<IdType> *graphs, const FloatType *restart_prob_data,
    const int64_t restart_prob_size, const int64_t max_nodes,
    IdType *out_traces_data, IdType *out_eids_data) {
64
65
  assert(BLOCK_SIZE == blockDim.x);
  int64_t idx = blockIdx.x * TILE_SIZE + threadIdx.x;
66
67
  int64_t last_idx =
      min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_seeds);
68
  int64_t trace_length = (max_num_steps + 1);
sangwzh's avatar
sangwzh committed
69
  hiprandState_t rng;
70
  // reference:
sangwzh's avatar
sangwzh committed
71
72
  //     https://docs.nvidia.com/cuda/hiprand/device-api-overview.html#performance-notes
  hiprand_init(rand_seed + idx, 0, 0, &rng);
73
74
75

  while (idx < last_idx) {
    IdType curr = seed_data[idx];
76
    assert(curr < max_nodes);
77
78
79
80
81
82
83
84
85
86
87
88
    IdType *traces_data_ptr = &out_traces_data[idx * trace_length];
    IdType *eids_data_ptr = &out_eids_data[idx * max_num_steps];
    *(traces_data_ptr++) = curr;
    int64_t step_idx;
    for (step_idx = 0; step_idx < max_num_steps; ++step_idx) {
      IdType metapath_id = metapath_data[step_idx];
      const GraphKernelData<IdType> &graph = graphs[metapath_id];
      const int64_t in_row_start = graph.in_ptr[curr];
      const int64_t deg = graph.in_ptr[curr + 1] - graph.in_ptr[curr];
      if (deg == 0) {  // the degree is zero
        break;
      }
sangwzh's avatar
sangwzh committed
89
      const int64_t num = hiprand(&rng) % deg;
90
      IdType pick = graph.in_cols[in_row_start + num];
91
92
      IdType eid =
          (graph.data ? graph.data[in_row_start + num] : in_row_start + num);
93
94
      *traces_data_ptr = pick;
      *eids_data_ptr = eid;
95
      if ((restart_prob_size > 1) &&
sangwzh's avatar
sangwzh committed
96
          (hiprand_uniform(&rng) < restart_prob_data[step_idx])) {
97
        break;
98
99
      } else if (
          (restart_prob_size == 1) &&
sangwzh's avatar
sangwzh committed
100
          (hiprand_uniform(&rng) < restart_prob_data[0])) {
101
102
        break;
      }
103
104
      ++traces_data_ptr;
      ++eids_data_ptr;
105
106
107
108
109
110
111
112
113
114
      curr = pick;
    }
    for (; step_idx < max_num_steps; ++step_idx) {
      *(traces_data_ptr++) = -1;
      *(eids_data_ptr++) = -1;
    }
    idx += BLOCK_SIZE;
  }
}

115
116
template <typename IdType, typename FloatType, int BLOCK_SIZE, int TILE_SIZE>
__global__ void _RandomWalkBiasedKernel(
117
118
119
120
121
122
    const uint64_t rand_seed, const IdType *seed_data, const int64_t num_seeds,
    const IdType *metapath_data, const uint64_t max_num_steps,
    const GraphKernelData<IdType> *graphs, const FloatType **probs,
    const FloatType **prob_sums, const FloatType *restart_prob_data,
    const int64_t restart_prob_size, const int64_t max_nodes,
    IdType *out_traces_data, IdType *out_eids_data) {
123
124
  assert(BLOCK_SIZE == blockDim.x);
  int64_t idx = blockIdx.x * TILE_SIZE + threadIdx.x;
125
126
  int64_t last_idx =
      min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_seeds);
127
  int64_t trace_length = (max_num_steps + 1);
sangwzh's avatar
sangwzh committed
128
  hiprandState_t rng;
129
  // reference:
sangwzh's avatar
sangwzh committed
130
131
  //     https://docs.nvidia.com/cuda/hiprand/device-api-overview.html#performance-notes
  hiprand_init(rand_seed + idx, 0, 0, &rng);
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153

  while (idx < last_idx) {
    IdType curr = seed_data[idx];
    assert(curr < max_nodes);
    IdType *traces_data_ptr = &out_traces_data[idx * trace_length];
    IdType *eids_data_ptr = &out_eids_data[idx * max_num_steps];
    *(traces_data_ptr++) = curr;
    int64_t step_idx;
    for (step_idx = 0; step_idx < max_num_steps; ++step_idx) {
      IdType metapath_id = metapath_data[step_idx];
      const GraphKernelData<IdType> &graph = graphs[metapath_id];
      const int64_t in_row_start = graph.in_ptr[curr];
      const int64_t deg = graph.in_ptr[curr + 1] - graph.in_ptr[curr];
      if (deg == 0) {  // the degree is zero
        break;
      }

      // randomly select by weight
      const FloatType *prob_sum = prob_sums[metapath_id];
      const FloatType *prob = probs[metapath_id];
      int64_t num;
      if (prob == nullptr) {
sangwzh's avatar
sangwzh committed
154
        num = hiprand(&rng) % deg;
155
      } else {
sangwzh's avatar
sangwzh committed
156
        auto rnd_sum_w = prob_sum[curr] * hiprand_uniform(&rng);
157
158
159
160
161
162
163
164
        FloatType sum_w{0.};
        for (num = 0; num < deg; ++num) {
          sum_w += prob[in_row_start + num];
          if (sum_w >= rnd_sum_w) break;
        }
      }

      IdType pick = graph.in_cols[in_row_start + num];
165
166
      IdType eid =
          (graph.data ? graph.data[in_row_start + num] : in_row_start + num);
167
168
      *traces_data_ptr = pick;
      *eids_data_ptr = eid;
169
      if ((restart_prob_size > 1) &&
sangwzh's avatar
sangwzh committed
170
          (hiprand_uniform(&rng) < restart_prob_data[step_idx])) {
171
        break;
172
173
      } else if (
          (restart_prob_size == 1) &&
sangwzh's avatar
sangwzh committed
174
          (hiprand_uniform(&rng) < restart_prob_data[0])) {
175
176
        break;
      }
177
178
      ++traces_data_ptr;
      ++eids_data_ptr;
179
180
181
182
183
184
185
186
187
188
      curr = pick;
    }
    for (; step_idx < max_num_steps; ++step_idx) {
      *(traces_data_ptr++) = -1;
      *(eids_data_ptr++) = -1;
    }
    idx += BLOCK_SIZE;
  }
}

189
190
191
}  // namespace

// random walk for uniform choice
192
template <DGLDeviceType XPU, typename IdType>
193
std::pair<IdArray, IdArray> RandomWalkUniform(
194
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
195
196
    FloatArray restart_prob) {
  const int64_t max_num_steps = metapath->shape[0];
sangwzh's avatar
sangwzh committed
197
198
  // const IdType *metapath_data = static_cast<IdType *>(metapath->data);
  const IdType *metapath_data = static_cast<const IdType *>(__GetDevicePointer(metapath));
199
200
  const int64_t begin_ntype =
      hg->meta_graph()->FindEdge(metapath_data[0]).first;
201
  const int64_t max_nodes = hg->NumVertices(begin_ntype);
202
  int64_t num_etypes = hg->NumEdgeTypes();
203
  auto ctx = seeds->ctx;
204

sangwzh's avatar
sangwzh committed
205
206
207
  // const IdType *seed_data = static_cast<const IdType *>(seeds->data);
  const IdType *seed_data = static_cast<const IdType *>(__GetDevicePointer(seeds));
  // const IdType *seed_data = static_cast<const IdType *>(__GetDevicePointer(seeds));
208
209
210
  CHECK(seeds->ndim == 1) << "seeds shape is not one dimension.";
  const int64_t num_seeds = seeds->shape[0];
  int64_t trace_length = max_num_steps + 1;
211
212
  IdArray traces = IdArray::Empty({num_seeds, trace_length}, seeds->dtype, ctx);
  IdArray eids = IdArray::Empty({num_seeds, max_num_steps}, seeds->dtype, ctx);
213
214
215
  IdType *traces_data = traces.Ptr<IdType>();
  IdType *eids_data = eids.Ptr<IdType>();

216
  std::vector<GraphKernelData<IdType>> h_graphs(num_etypes);
217
218
  for (int64_t etype = 0; etype < num_etypes; ++etype) {
    const CSRMatrix &csr = hg->GetCSRMatrix(etype);
sangwzh's avatar
sangwzh committed
219
220
221
222
223
224
225
    // h_graphs[etype].in_ptr = static_cast<const IdType *>(csr.indptr->data);
    // h_graphs[etype].in_cols = static_cast<const IdType *>(csr.indices->data);
    // h_graphs[etype].data =
    //     (CSRHasData(csr) ? static_cast<const IdType *>(csr.data->data)
    //                      : nullptr);
    h_graphs[etype].in_ptr = static_cast<const IdType *>(__GetDevicePointer(csr.indptr));
    h_graphs[etype].in_cols = static_cast<const IdType *>(__GetDevicePointer(csr.indices));
226
    h_graphs[etype].data =
sangwzh's avatar
sangwzh committed
227
        (CSRHasData(csr) ? static_cast<const IdType *>(__GetDevicePointer(csr.data))
228
                         : nullptr);
229
  }
230
  // use cuda stream from local thread
sangwzh's avatar
sangwzh committed
231
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
232
  auto device = DeviceAPI::Get(ctx);
233
234
  auto d_graphs = static_cast<GraphKernelData<IdType> *>(device->AllocWorkspace(
      ctx, (num_etypes) * sizeof(GraphKernelData<IdType>)));
235
  // copy graph metadata pointers to GPU
236
237
238
239
  device->CopyDataFromTo(
      h_graphs.data(), 0, d_graphs, 0,
      (num_etypes) * sizeof(GraphKernelData<IdType>), DGLContext{kDGLCPU, 0},
      ctx, hg->GetCSRMatrix(0).indptr->dtype);
240
241
242
  // copy metapath to GPU
  auto d_metapath = metapath.CopyTo(ctx);
  const IdType *d_metapath_data = static_cast<IdType *>(d_metapath->data);
243
244
245
246
247
248

  constexpr int BLOCK_SIZE = 256;
  constexpr int TILE_SIZE = BLOCK_SIZE * 4;
  dim3 block(256);
  dim3 grid((num_seeds + TILE_SIZE - 1) / TILE_SIZE);
  const uint64_t random_seed = RandomEngine::ThreadLocal()->RandInt(1000000000);
249
250
  ATEN_FLOAT_TYPE_SWITCH(
      restart_prob->dtype, FloatType, "random walk GPU kernel", {
sangwzh's avatar
sangwzh committed
251
        CHECK(restart_prob->ctx.device_type == kDGLCUDA||restart_prob->ctx.device_type == kDGLROCM)
252
253
            << "restart prob should be in GPU.";
        CHECK(restart_prob->ndim == 1) << "restart prob dimension should be 1.";
sangwzh's avatar
sangwzh committed
254
255
        // const FloatType *restart_prob_data = restart_prob.Ptr<FloatType>();
        const FloatType *restart_prob_data = static_cast<const FloatType *>(__GetDevicePointer(restart_prob));
256
257
258
259
260
261
262
        const int64_t restart_prob_size = restart_prob->shape[0];
        CUDA_KERNEL_CALL(
            (_RandomWalkKernel<IdType, FloatType, BLOCK_SIZE, TILE_SIZE>), grid,
            block, 0, stream, random_seed, seed_data, num_seeds,
            d_metapath_data, max_num_steps, d_graphs, restart_prob_data,
            restart_prob_size, max_nodes, traces_data, eids_data);
      });
263
264
265
266
267

  device->FreeWorkspace(ctx, d_graphs);
  return std::make_pair(traces, eids);
}

268
/**
269
 * @brief Random walk for biased choice. We use inverse transform sampling to
270
271
 * choose the next step.
 */
272
template <DGLDeviceType XPU, typename FloatType, typename IdType>
273
std::pair<IdArray, IdArray> RandomWalkBiased(
274
275
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
    const std::vector<FloatArray> &prob, FloatArray restart_prob) {
276
  const int64_t max_num_steps = metapath->shape[0];
sangwzh's avatar
sangwzh committed
277
278
  // const IdType *metapath_data = static_cast<IdType *>(metapath->data);
  const IdType *metapath_data = static_cast<IdType *>(__GetDevicePointer(metapath));
279
280
  const int64_t begin_ntype =
      hg->meta_graph()->FindEdge(metapath_data[0]).first;
281
282
283
284
  const int64_t max_nodes = hg->NumVertices(begin_ntype);
  int64_t num_etypes = hg->NumEdgeTypes();
  auto ctx = seeds->ctx;

sangwzh's avatar
sangwzh committed
285
286
  // const IdType *seed_data = static_cast<const IdType *>(seeds->data);
  const IdType *seed_data = static_cast<const IdType *>(__GetDevicePointer(seeds));
287
288
289
290
291
292
  CHECK(seeds->ndim == 1) << "seeds shape is not one dimension.";
  const int64_t num_seeds = seeds->shape[0];
  int64_t trace_length = max_num_steps + 1;
  IdArray traces = IdArray::Empty({num_seeds, trace_length}, seeds->dtype, ctx);
  IdArray eids = IdArray::Empty({num_seeds, max_num_steps}, seeds->dtype, ctx);
  IdType *traces_data = traces.Ptr<IdType>();
sangwzh's avatar
sangwzh committed
293
294
295
  // IdType *traces_data = static_cast<IdType *>(__GetDevicePointer(traces));
  // IdType *eids_data = eids.Ptr<IdType>();
  IdType *eids_data = static_cast<IdType *>(__GetDevicePointer(eids));
296

sangwzh's avatar
sangwzh committed
297
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
298
299
300
301
302
303
304
305
306
307
308
309
  auto device = DeviceAPI::Get(ctx);
  // new probs and prob sums pointers
  assert(num_etypes == static_cast<int64_t>(prob.size()));
  std::unique_ptr<FloatType *[]> probs(new FloatType *[prob.size()]);
  std::unique_ptr<FloatType *[]> prob_sums(new FloatType *[prob.size()]);
  std::vector<FloatArray> prob_sums_arr;
  prob_sums_arr.reserve(prob.size());

  // graphs
  std::vector<GraphKernelData<IdType>> h_graphs(num_etypes);
  for (int64_t etype = 0; etype < num_etypes; ++etype) {
    const CSRMatrix &csr = hg->GetCSRMatrix(etype);
sangwzh's avatar
sangwzh committed
310
311
312
313
314
315
316
    // h_graphs[etype].in_ptr = static_cast<const IdType *>(csr.indptr->data);
    // h_graphs[etype].in_cols = static_cast<const IdType *>(csr.indices->data);
    // h_graphs[etype].data =
    //     (CSRHasData(csr) ? static_cast<const IdType *>(csr.data->data)
    //                      : nullptr);
    h_graphs[etype].in_ptr = static_cast<const IdType *>(__GetDevicePointer(csr.indptr));
    h_graphs[etype].in_cols = static_cast<const IdType *>(__GetDevicePointer(csr.indices));
317
    h_graphs[etype].data =
sangwzh's avatar
sangwzh committed
318
        (CSRHasData(csr) ? static_cast<const IdType *>(__GetDevicePointer(csr.data))
319
                         : nullptr);
320
321
322
323
324
325
326
327

    int64_t num_segments = csr.indptr->shape[0] - 1;
    // will handle empty probs in the kernel
    if (IsNullArray(prob[etype])) {
      probs[etype] = nullptr;
      prob_sums[etype] = nullptr;
      continue;
    }
sangwzh's avatar
sangwzh committed
328
329
    // probs[etype] = prob[etype].Ptr<FloatType>();
    probs[etype] = static_cast<FloatType *>(__GetDevicePointer(prob[etype]));
330
331
    prob_sums_arr.push_back(
        FloatArray::Empty({num_segments}, prob[etype]->dtype, ctx));
sangwzh's avatar
sangwzh committed
332
333
    // prob_sums[etype] = prob_sums_arr[etype].Ptr<FloatType>();
    prob_sums[etype] = static_cast<FloatType *>(__GetDevicePointer(prob_sums_arr[etype]));
334
335

    // calculate the sum of the neighbor weights
sangwzh's avatar
sangwzh committed
336
337
    // const IdType *d_offsets = static_cast<const IdType *>(csr.indptr->data);
    const IdType *d_offsets = static_cast<const IdType *>(__GetDevicePointer(csr.indptr));
338
    size_t temp_storage_size = 0;
sangwzh's avatar
sangwzh committed
339
    CUDA_CALL(hipcub::DeviceSegmentedReduce::Sum(
340
341
        nullptr, temp_storage_size, probs[etype], prob_sums[etype],
        num_segments, d_offsets, d_offsets + 1, stream));
342
    void *temp_storage = device->AllocWorkspace(ctx, temp_storage_size);
sangwzh's avatar
sangwzh committed
343
    CUDA_CALL(hipcub::DeviceSegmentedReduce::Sum(
344
345
        temp_storage, temp_storage_size, probs[etype], prob_sums[etype],
        num_segments, d_offsets, d_offsets + 1, stream));
346
347
348
349
    device->FreeWorkspace(ctx, temp_storage);
  }

  // copy graph metadata pointers to GPU
350
351
352
353
354
355
  auto d_graphs = static_cast<GraphKernelData<IdType> *>(device->AllocWorkspace(
      ctx, (num_etypes) * sizeof(GraphKernelData<IdType>)));
  device->CopyDataFromTo(
      h_graphs.data(), 0, d_graphs, 0,
      (num_etypes) * sizeof(GraphKernelData<IdType>), DGLContext{kDGLCPU, 0},
      ctx, hg->GetCSRMatrix(0).indptr->dtype);
356
357
358
  // copy probs pointers to GPU
  const FloatType **probs_dev = static_cast<const FloatType **>(
      device->AllocWorkspace(ctx, num_etypes * sizeof(FloatType *)));
359
360
361
  device->CopyDataFromTo(
      probs.get(), 0, probs_dev, 0, (num_etypes) * sizeof(FloatType *),
      DGLContext{kDGLCPU, 0}, ctx, prob[0]->dtype);
362
363
364
  // copy probs_sum pointers to GPU
  const FloatType **prob_sums_dev = static_cast<const FloatType **>(
      device->AllocWorkspace(ctx, num_etypes * sizeof(FloatType *)));
365
366
367
  device->CopyDataFromTo(
      prob_sums.get(), 0, prob_sums_dev, 0, (num_etypes) * sizeof(FloatType *),
      DGLContext{kDGLCPU, 0}, ctx, prob[0]->dtype);
368
369
  // copy metapath to GPU
  auto d_metapath = metapath.CopyTo(ctx);
sangwzh's avatar
sangwzh committed
370
371
  // const IdType *d_metapath_data = static_cast<IdType *>(d_metapath->data);
  const IdType *d_metapath_data = static_cast<IdType *>(__GetDevicePointer(d_metapath));
372
373
374
375
376
377

  constexpr int BLOCK_SIZE = 256;
  constexpr int TILE_SIZE = BLOCK_SIZE * 4;
  dim3 block(256);
  dim3 grid((num_seeds + TILE_SIZE - 1) / TILE_SIZE);
  const uint64_t random_seed = RandomEngine::ThreadLocal()->RandInt(1000000000);
sangwzh's avatar
sangwzh committed
378
  CHECK(restart_prob->ctx.device_type == kDGLCUDA ||restart_prob->ctx.device_type == kDGLROCM)
379
      << "restart prob should be in GPU.";
380
  CHECK(restart_prob->ndim == 1) << "restart prob dimension should be 1.";
sangwzh's avatar
sangwzh committed
381
382
383
  // const FloatType *restart_prob_data = restart_prob.Ptr<FloatType>();
  const FloatType *restart_prob_data = static_cast<const FloatType *>(__GetDevicePointer(restart_prob));
  const int64_t restart_prob_size = restart_prob->shape[0];  
384
  CUDA_KERNEL_CALL(
385
386
387
388
      (_RandomWalkBiasedKernel<IdType, FloatType, BLOCK_SIZE, TILE_SIZE>), grid,
      block, 0, stream, random_seed, seed_data, num_seeds, d_metapath_data,
      max_num_steps, d_graphs, probs_dev, prob_sums_dev, restart_prob_data,
      restart_prob_size, max_nodes, traces_data, eids_data);
389
390
391
392
393
394
395

  device->FreeWorkspace(ctx, d_graphs);
  device->FreeWorkspace(ctx, probs_dev);
  device->FreeWorkspace(ctx, prob_sums_dev);
  return std::make_pair(traces, eids);
}

396
template <DGLDeviceType XPU, typename IdType>
397
std::pair<IdArray, IdArray> RandomWalk(
398
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
399
    const std::vector<FloatArray> &prob) {
400
  bool isUniform = true;
401
402
  for (const auto &etype_prob : prob) {
    if (!IsNullArray(etype_prob)) {
403
404
      isUniform = false;
      break;
405
406
407
    }
  }

408
409
  auto restart_prob =
      NDArray::Empty({0}, DGLDataType{kDGLFloat, 32, 1}, DGLContext{XPU, 0});
410
411
412
  if (!isUniform) {
    std::pair<IdArray, IdArray> ret;
    ATEN_FLOAT_TYPE_SWITCH(prob[0]->dtype, FloatType, "probability", {
413
414
      ret = RandomWalkBiased<XPU, FloatType, IdType>(
          hg, seeds, metapath, prob, restart_prob);
415
416
417
418
419
    });
    return ret;
  } else {
    return RandomWalkUniform<XPU, IdType>(hg, seeds, metapath, restart_prob);
  }
420
421
}

422
template <DGLDeviceType XPU, typename IdType>
423
std::pair<IdArray, IdArray> RandomWalkWithRestart(
424
425
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
    const std::vector<FloatArray> &prob, double restart_prob) {
426
  bool isUniform = true;
427
428
  for (const auto &etype_prob : prob) {
    if (!IsNullArray(etype_prob)) {
429
430
      isUniform = false;
      break;
431
432
    }
  }
433

434
  auto device_ctx = seeds->ctx;
435
436
  auto restart_prob_array =
      NDArray::Empty({1}, DGLDataType{kDGLFloat, 64, 1}, device_ctx);
437
438
  auto device = dgl::runtime::DeviceAPI::Get(device_ctx);

439
  // use cuda stream from local thread
sangwzh's avatar
sangwzh committed
440
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
441
  device->CopyDataFromTo(
442
443
      &restart_prob, 0, restart_prob_array.Ptr<double>(), 0, sizeof(double),
      DGLContext{kDGLCPU, 0}, device_ctx, restart_prob_array->dtype);
444
445
  device->StreamSync(device_ctx, stream);

446
447
448
449
450
451
452
453
  if (!isUniform) {
    std::pair<IdArray, IdArray> ret;
    ATEN_FLOAT_TYPE_SWITCH(prob[0]->dtype, FloatType, "probability", {
      ret = RandomWalkBiased<XPU, FloatType, IdType>(
          hg, seeds, metapath, prob, restart_prob_array);
    });
    return ret;
  } else {
454
455
    return RandomWalkUniform<XPU, IdType>(
        hg, seeds, metapath, restart_prob_array);
456
  }
457
458
}

459
template <DGLDeviceType XPU, typename IdType>
460
std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart(
461
462
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
    const std::vector<FloatArray> &prob, FloatArray restart_prob) {
463
  bool isUniform = true;
464
465
  for (const auto &etype_prob : prob) {
    if (!IsNullArray(etype_prob)) {
466
467
      isUniform = false;
      break;
468
469
470
    }
  }

471
472
473
  if (!isUniform) {
    std::pair<IdArray, IdArray> ret;
    ATEN_FLOAT_TYPE_SWITCH(prob[0]->dtype, FloatType, "probability", {
474
475
      ret = RandomWalkBiased<XPU, FloatType, IdType>(
          hg, seeds, metapath, prob, restart_prob);
476
477
478
479
480
    });
    return ret;
  } else {
    return RandomWalkUniform<XPU, IdType>(hg, seeds, metapath, restart_prob);
  }
481
482
}

483
template <DGLDeviceType XPU, typename IdxType>
484
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(
485
    const IdArray src, const IdArray dst, const int64_t num_samples_per_node,
486
    const int64_t k) {
sangwzh's avatar
sangwzh committed
487
  CHECK(src->ctx.device_type == kDGLCUDA || src->ctx.device_type == kDGLROCM) << "IdArray needs be on GPU!";
sangwzh's avatar
sangwzh committed
488
489
490
491
  // const IdxType *src_data = src.Ptr<IdxType>();
  const IdxType *src_data = static_cast<IdxType*>(__GetDevicePointer(src));
  // const IdxType *dst_data = dst.Ptr<IdxType>();
  const IdxType *dst_data = static_cast<IdxType*>(__GetDevicePointer(dst));
492
493
  const int64_t num_dst_nodes = (dst->shape[0] / num_samples_per_node);
  auto ctx = src->ctx;
494
  // use cuda stream from local thread
sangwzh's avatar
sangwzh committed
495
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
496
497
498
499
  auto frequency_hashmap = FrequencyHashmap<IdxType>(
      num_dst_nodes, num_samples_per_node, ctx, stream);
  auto ret = frequency_hashmap.Topk(
      src_data, dst_data, src->dtype, src->shape[0], num_samples_per_node, k);
500
501
502
  return ret;
}

503
504
template std::pair<IdArray, IdArray> RandomWalk<kDGLCUDA, int32_t>(
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
505
    const std::vector<FloatArray> &prob);
506
507
template std::pair<IdArray, IdArray> RandomWalk<kDGLCUDA, int64_t>(
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
508
509
    const std::vector<FloatArray> &prob);

510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
template std::pair<IdArray, IdArray> RandomWalkWithRestart<kDGLCUDA, int32_t>(
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
    const std::vector<FloatArray> &prob, double restart_prob);
template std::pair<IdArray, IdArray> RandomWalkWithRestart<kDGLCUDA, int64_t>(
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
    const std::vector<FloatArray> &prob, double restart_prob);

template std::pair<IdArray, IdArray>
RandomWalkWithStepwiseRestart<kDGLCUDA, int32_t>(
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
    const std::vector<FloatArray> &prob, FloatArray restart_prob);
template std::pair<IdArray, IdArray>
RandomWalkWithStepwiseRestart<kDGLCUDA, int64_t>(
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
    const std::vector<FloatArray> &prob, FloatArray restart_prob);

template std::tuple<IdArray, IdArray, IdArray>
SelectPinSageNeighbors<kDGLCUDA, int32_t>(
    const IdArray src, const IdArray dst, const int64_t num_samples_per_node,
529
    const int64_t k);
530
531
532
template std::tuple<IdArray, IdArray, IdArray>
SelectPinSageNeighbors<kDGLCUDA, int64_t>(
    const IdArray src, const IdArray dst, const int64_t num_samples_per_node,
533
534
535
536
537
538
539
    const int64_t k);

};  // namespace impl

};  // namespace sampling

};  // namespace dgl