"torchvision/csrc/io/decoder/sync_decoder.h" did not exist on "32e16805a17401f5ef5ec825c808d645f5c26509"
randomwalk_with_restart_cpu.cc 2.64 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
/*!
 *  Copyright (c) 2018 by Contributors
 * \file graph/sampling/randomwalk_with_restart_cpu.cc
 * \brief DGL sampler - CPU implementation of metapath-based random walk with restart with OpenMP
 */

#include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include <dgl/random.h>
#include <utility>
#include <vector>
#include "randomwalks_impl.h"
#include "randomwalks_cpu.h"
#include "metapath_randomwalk.h"

namespace dgl {

using namespace dgl::runtime;
using namespace dgl::aten;

namespace sampling {

namespace impl {

template<DLDeviceType XPU, typename IdxType>
IdArray RandomWalkWithRestart(
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob,
    double restart_prob) {
  TerminatePredicate<IdxType> terminate =
    [restart_prob] (IdxType *data, dgl_id_t curr, int64_t len) {
      return RandomEngine::ThreadLocal()->Uniform<double>() < restart_prob;
    };
  return MetapathBasedRandomWalk<XPU, IdxType>(hg, seeds, metapath, prob, terminate);
}

template
IdArray RandomWalkWithRestart<kDLCPU, int32_t>(
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob,
    double restart_prob);
template
IdArray RandomWalkWithRestart<kDLCPU, int64_t>(
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob,
    double restart_prob);

template<DLDeviceType XPU, typename IdxType>
IdArray RandomWalkWithStepwiseRestart(
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob,
    FloatArray restart_prob) {
  IdArray result;

  ATEN_FLOAT_TYPE_SWITCH(restart_prob->dtype, DType, "restart probability", {
    DType *restart_prob_data = static_cast<DType *>(restart_prob->data);
    TerminatePredicate<IdxType> terminate =
      [restart_prob_data] (IdxType *data, dgl_id_t curr, int64_t len) {
        return RandomEngine::ThreadLocal()->Uniform<DType>() < restart_prob_data[len];
      };
    result = MetapathBasedRandomWalk<XPU, IdxType>(hg, seeds, metapath, prob, terminate);
  });

  return result;
}

template
IdArray RandomWalkWithStepwiseRestart<kDLCPU, int32_t>(
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob,
    FloatArray restart_prob);
template
IdArray RandomWalkWithStepwiseRestart<kDLCPU, int64_t>(
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob,
    FloatArray restart_prob);

};  // namespace impl

};  // namespace sampling

};  // namespace dgl