node2vec_impl.h 1.59 KB
Newer Older
1
/**
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
2
 *  Copyright (c) 2021 by Contributors
3
4
 * @file graph/sampling/node2vec_impl.h
 * @brief DGL sampler - templated implementation definition of node2vec random
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
5
6
7
8
9
10
11
12
13
14
 * walks
 */

#ifndef DGL_GRAPH_SAMPLING_RANDOMWALKS_NODE2VEC_IMPL_H_
#define DGL_GRAPH_SAMPLING_RANDOMWALKS_NODE2VEC_IMPL_H_

#include <dgl/array.h>
#include <dgl/base_heterograph.h>

#include <functional>
15
#include <tuple>
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
16
17
18
19
20
21
22
23
24
25
26
27
#include <utility>
#include <vector>

namespace dgl {

using namespace dgl::runtime;
using namespace dgl::aten;

namespace sampling {

namespace impl {

28
/**
29
30
31
 * @brief Node2vec random walk.
 * @param hg The heterograph.
 * @param seeds A 1D array of seed nodes, with the type the source type of the
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
32
 * first edge type in the metapath.
33
 * @param p Float, indicating likelihood of immediately revisiting a node in the
34
35
36
37
 * walk.
 * @param q Float, control parameter to interpolate between breadth-first
 * strategy and depth-first strategy.
 * @param walk_length Int, length of walk.
38
 * @param prob A vector of 1D float arrays, indicating the transition
39
 *        probability of each edge by edge type.  An empty float array assumes
40
41
 * uniform transition.
 * @return A 2D array of shape (len(seeds), len(walk_length)
42
 * + 1) with node IDs.  The paths that terminated early are padded with -1.
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
43
 */
44
template <DGLDeviceType XPU, typename IdxType>
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
45
46
std::pair<IdArray, IdArray> Node2vec(
    const HeteroGraphPtr hg, const IdArray seeds, const double p,
47
    const double q, const int64_t walk_length, const FloatArray &prob);
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
48
49
50
51
52
53
54
55

};  // namespace impl

};  // namespace sampling

};  // namespace dgl

#endif  // DGL_GRAPH_SAMPLING_RANDOMWALKS_NODE2VEC_IMPL_H_