randomwalks_impl.h 4.87 KB
Newer Older
1
/**
2
 *  Copyright (c) 2018 by Contributors
3
4
 * @file graph/sampling/randomwalks_impl.h
 * @brief DGL sampler - templated implementation definition of random walks
5
6
 */

7
8
#ifndef DGL_GRAPH_SAMPLING_RANDOMWALKS_RANDOMWALKS_IMPL_H_
#define DGL_GRAPH_SAMPLING_RANDOMWALKS_RANDOMWALKS_IMPL_H_
9
10

#include <dgl/array.h>
11
12
#include <dgl/base_heterograph.h>

13
#include <functional>
14
15
16
#include <tuple>
#include <utility>
#include <vector>
17
18
19
20
21
22
23
24
25
26

namespace dgl {

using namespace dgl::runtime;
using namespace dgl::aten;

namespace sampling {

namespace impl {

27
/**
28
 * @brief Random walk step function
29
 */
30
template <typename IdxType>
31
using StepFunc = std::function<
32
33
34
35
36
    //        ID        Edge ID   terminate?
    std::tuple<dgl_id_t, dgl_id_t, bool>(
        IdxType *,  // node IDs generated so far
        dgl_id_t,   // last node ID
        int64_t)>;  // # of steps
37

38
/**
39
40
 * @brief Get the node types traversed by the metapath.
 * @return A 1D array of shape (len(metapath) + 1,) with node type IDs.
41
 */
42
template <DGLDeviceType XPU, typename IdxType>
43
TypeArray GetNodeTypesFromMetapath(
44
    const HeteroGraphPtr hg, const TypeArray metapath);
45

46
/**
47
48
49
 * @brief Metapath-based random walk.
 * @param hg The heterograph.
 * @param seeds A 1D array of seed nodes, with the type the source type of the
50
51
52
53
 * first edge type in the metapath.
 * @param metapath A 1D array of edge types
 * representing the metapath.
 * @param prob A vector of 1D float arrays,
54
 * indicating the transition probability of each edge by edge type.  An empty
55
56
 * float array assumes uniform transition.
 * @return A 2D array of shape
57
58
59
60
 * (len(seeds), len(metapath) + 1) with node IDs.  The paths that terminated
 * early are padded with -1. A 2D array of shape (len(seeds), len(metapath))
 * with edge IDs.  The paths that terminated early are padded with -1. \note
 * This function should be called together with GetNodeTypesFromMetapath to
61
62
 *       determine the node type of each node in the random walk traces.
 */
63
template <DGLDeviceType XPU, typename IdxType>
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
64
std::pair<IdArray, IdArray> RandomWalk(
65
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
66
67
    const std::vector<FloatArray> &prob);

68
/**
69
70
71
 * @brief Metapath-based random walk with restart probability.
 * @param hg The heterograph.
 * @param seeds A 1D array of seed nodes, with the type the source type of the
72
73
74
75
 * first edge type in the metapath.
 * @param metapath A 1D array of edge types
 * representing the metapath.
 * @param prob A vector of 1D float arrays,
76
 * indicating the transition probability of each edge by edge type.  An empty
77
78
79
80
 * float array assumes uniform transition.
 * @param restart_prob Restart
 * probability
 * @return A 2D array of shape (len(seeds), len(metapath) + 1) with
81
82
83
84
85
 * node IDs.  The paths that terminated early are padded with -1. A 2D array of
 * shape (len(seeds), len(metapath)) with edge IDs.  The paths that terminated
 * early are padded with -1. \note This function should be called together with
 * GetNodeTypesFromMetapath to determine the node type of each node in the
 * random walk traces.
86
 */
87
template <DGLDeviceType XPU, typename IdxType>
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
88
std::pair<IdArray, IdArray> RandomWalkWithRestart(
89
90
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
    const std::vector<FloatArray> &prob, double restart_prob);
91

92
/**
93
 * @brief Metapath-based random walk with stepwise restart probability.  Useful
94
 *        for PinSAGE-like models.
95
96
 * @param hg The heterograph.
 * @param seeds A 1D array of seed nodes, with the type the source type of the
97
98
99
100
 * first edge type in the metapath.
 * @param metapath A 1D array of edge types
 * representing the metapath.
 * @param prob A vector of 1D float arrays,
101
 * indicating the transition probability of each edge by edge type.  An empty
102
103
 * float array assumes uniform transition.
 * @param restart_prob Restart
104
 * probability array which has the same number of elements as \c metapath,
105
106
 * indicating the probability to terminate after transition.
 * @return A 2D array
107
108
109
110
111
112
 * of shape (len(seeds), len(metapath) + 1) with node IDs.  The paths that
 * terminated early are padded with -1. A 2D array of shape (len(seeds),
 * len(metapath)) with edge IDs.  The paths that terminated early are padded
 * with -1. \note This function should be called together with
 * GetNodeTypesFromMetapath to determine the node type of each node in the
 * random walk traces.
113
 */
114
template <DGLDeviceType XPU, typename IdxType>
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
115
std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart(
116
117
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
    const std::vector<FloatArray> &prob, FloatArray restart_prob);
118

119
template <DGLDeviceType XPU, typename IdxType>
120
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(
121
    const IdArray src, const IdArray dst, const int64_t num_samples_per_node,
122
123
    const int64_t k);

124
125
126
127
128
129
};  // namespace impl

};  // namespace sampling

};  // namespace dgl

130
#endif  // DGL_GRAPH_SAMPLING_RANDOMWALKS_RANDOMWALKS_IMPL_H_