randomwalks_impl.h 4.79 KB
Newer Older
1
2
3
4
5
6
/*!
 *  Copyright (c) 2018 by Contributors
 * \file graph/sampling/randomwalks_impl.h
 * \brief DGL sampler - templated implementation definition of random walks
 */

7
8
#ifndef DGL_GRAPH_SAMPLING_RANDOMWALKS_RANDOMWALKS_IMPL_H_
#define DGL_GRAPH_SAMPLING_RANDOMWALKS_RANDOMWALKS_IMPL_H_
9
10
11
12
13

#include <dgl/base_heterograph.h>
#include <dgl/array.h>
#include <vector>
#include <utility>
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
14
#include <tuple>
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#include <functional>

namespace dgl {

using namespace dgl::runtime;
using namespace dgl::aten;

namespace sampling {

namespace impl {

/*!
 * \brief Random walk step function
 */
template<typename IdxType>
using StepFunc = std::function<
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
31
32
  //        ID        Edge ID   terminate?
  std::tuple<dgl_id_t, dgl_id_t, bool>(
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
      IdxType *,    // node IDs generated so far
      dgl_id_t,     // last node ID
      int64_t)>;    // # of steps

/*!
 * \brief Get the node types traversed by the metapath.
 * \return A 1D array of shape (len(metapath) + 1,) with node type IDs.
 */
template<DLDeviceType XPU, typename IdxType>
TypeArray GetNodeTypesFromMetapath(
    const HeteroGraphPtr hg,
    const TypeArray metapath);

/*!
 * \brief Metapath-based random walk.
 * \param hg The heterograph.
 * \param seeds A 1D array of seed nodes, with the type the source type of the first
 *        edge type in the metapath.
 * \param metapath A 1D array of edge types representing the metapath.
 * \param prob A vector of 1D float arrays, indicating the transition probability of
 *        each edge by edge type.  An empty float array assumes uniform transition.
 * \return A 2D array of shape (len(seeds), len(metapath) + 1) with node IDs.  The
 *         paths that terminated early are padded with -1.
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
56
57
 *         A 2D array of shape (len(seeds), len(metapath)) with edge IDs.  The
 *         paths that terminated early are padded with -1.
58
59
60
61
 * \note This function should be called together with GetNodeTypesFromMetapath to
 *       determine the node type of each node in the random walk traces.
 */
template<DLDeviceType XPU, typename IdxType>
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
62
std::pair<IdArray, IdArray> RandomWalk(
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob);

/*!
 * \brief Metapath-based random walk with restart probability.
 * \param hg The heterograph.
 * \param seeds A 1D array of seed nodes, with the type the source type of the first
 *        edge type in the metapath.
 * \param metapath A 1D array of edge types representing the metapath.
 * \param prob A vector of 1D float arrays, indicating the transition probability of
 *        each edge by edge type.  An empty float array assumes uniform transition.
 * \param restart_prob Restart probability
 * \return A 2D array of shape (len(seeds), len(metapath) + 1) with node IDs.  The
 *         paths that terminated early are padded with -1.
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
79
80
 *         A 2D array of shape (len(seeds), len(metapath)) with edge IDs.  The
 *         paths that terminated early are padded with -1.
81
82
83
84
 * \note This function should be called together with GetNodeTypesFromMetapath to
 *       determine the node type of each node in the random walk traces.
 */
template<DLDeviceType XPU, typename IdxType>
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
85
std::pair<IdArray, IdArray> RandomWalkWithRestart(
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob,
    double restart_prob);

/*!
 * \brief Metapath-based random walk with stepwise restart probability.  Useful
 *        for PinSAGE-like models.
 * \param hg The heterograph.
 * \param seeds A 1D array of seed nodes, with the type the source type of the first
 *        edge type in the metapath.
 * \param metapath A 1D array of edge types representing the metapath.
 * \param prob A vector of 1D float arrays, indicating the transition probability of
 *        each edge by edge type.  An empty float array assumes uniform transition.
 * \param restart_prob Restart probability array which has the same number of elements
 *        as \c metapath, indicating the probability to terminate after transition.
 * \return A 2D array of shape (len(seeds), len(metapath) + 1) with node IDs.  The
 *         paths that terminated early are padded with -1.
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
105
106
 *         A 2D array of shape (len(seeds), len(metapath)) with edge IDs.  The
 *         paths that terminated early are padded with -1.
107
108
109
110
 * \note This function should be called together with GetNodeTypesFromMetapath to
 *       determine the node type of each node in the random walk traces.
 */
template<DLDeviceType XPU, typename IdxType>
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
111
std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart(
112
113
114
115
116
117
118
119
120
121
122
123
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob,
    FloatArray restart_prob);

};  // namespace impl

};  // namespace sampling

};  // namespace dgl

124
#endif  // DGL_GRAPH_SAMPLING_RANDOMWALKS_RANDOMWALKS_IMPL_H_