randomwalk_with_restart_cpu.cc 2.76 KB
Newer Older
1
/**
2
 *  Copyright (c) 2018 by Contributors
3
4
 * @file graph/sampling/randomwalk_with_restart_cpu.cc
 * @brief DGL sampler - CPU implementation of metapath-based random walk with
5
 * restart with OpenMP
6
7
8
9
10
 */

#include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include <dgl/random.h>
11

12
13
#include <utility>
#include <vector>
14

15
#include "metapath_randomwalk.h"
16
17
#include "randomwalks_cpu.h"
#include "randomwalks_impl.h"
18
19
20
21
22
23
24
25
26
27

namespace dgl {

using namespace dgl::runtime;
using namespace dgl::aten;

namespace sampling {

namespace impl {

28
template <DGLDeviceType XPU, typename IdxType>
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
29
std::pair<IdArray, IdArray> RandomWalkWithRestart(
30
31
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
    const std::vector<FloatArray> &prob, double restart_prob) {
32
  TerminatePredicate<IdxType> terminate =
33
34
35
36
37
      [restart_prob](IdxType *data, dgl_id_t curr, int64_t len) {
        return RandomEngine::ThreadLocal()->Uniform<double>() < restart_prob;
      };
  return MetapathBasedRandomWalk<XPU, IdxType>(
      hg, seeds, metapath, prob, terminate);
38
39
}

40
41
42
43
44
45
46
47
template std::pair<IdArray, IdArray> RandomWalkWithRestart<kDGLCPU, int32_t>(
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
    const std::vector<FloatArray> &prob, double restart_prob);
template std::pair<IdArray, IdArray> RandomWalkWithRestart<kDGLCPU, int64_t>(
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
    const std::vector<FloatArray> &prob, double restart_prob);

template <DGLDeviceType XPU, typename IdxType>
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
48
std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart(
49
50
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
    const std::vector<FloatArray> &prob, FloatArray restart_prob) {
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
51
  std::pair<IdArray, IdArray> result;
52
53
54
55

  ATEN_FLOAT_TYPE_SWITCH(restart_prob->dtype, DType, "restart probability", {
    DType *restart_prob_data = static_cast<DType *>(restart_prob->data);
    TerminatePredicate<IdxType> terminate =
56
57
58
59
60
61
        [restart_prob_data](IdxType *data, dgl_id_t curr, int64_t len) {
          return RandomEngine::ThreadLocal()->Uniform<DType>() <
                 restart_prob_data[len];
        };
    result = MetapathBasedRandomWalk<XPU, IdxType>(
        hg, seeds, metapath, prob, terminate);
62
63
64
65
66
  });

  return result;
}

67
68
69
70
71
72
73
74
template std::pair<IdArray, IdArray>
RandomWalkWithStepwiseRestart<kDGLCPU, int32_t>(
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
    const std::vector<FloatArray> &prob, FloatArray restart_prob);
template std::pair<IdArray, IdArray>
RandomWalkWithStepwiseRestart<kDGLCPU, int64_t>(
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
    const std::vector<FloatArray> &prob, FloatArray restart_prob);
75
76
77
78
79
80

};  // namespace impl

};  // namespace sampling

};  // namespace dgl