/*! * Copyright (c) 2018 by Contributors * \file graph/sampling/randomwalks_impl.h * \brief DGL sampler - templated implementation definition of random walks */ #ifndef DGL_GRAPH_SAMPLING_RANDOMWALKS_RANDOMWALKS_IMPL_H_ #define DGL_GRAPH_SAMPLING_RANDOMWALKS_RANDOMWALKS_IMPL_H_ #include #include #include #include #include namespace dgl { using namespace dgl::runtime; using namespace dgl::aten; namespace sampling { namespace impl { /*! * \brief Random walk step function */ template using StepFunc = std::function< // ID terminate? std::pair( IdxType *, // node IDs generated so far dgl_id_t, // last node ID int64_t)>; // # of steps /*! * \brief Get the node types traversed by the metapath. * \return A 1D array of shape (len(metapath) + 1,) with node type IDs. */ template TypeArray GetNodeTypesFromMetapath( const HeteroGraphPtr hg, const TypeArray metapath); /*! * \brief Metapath-based random walk. * \param hg The heterograph. * \param seeds A 1D array of seed nodes, with the type the source type of the first * edge type in the metapath. * \param metapath A 1D array of edge types representing the metapath. * \param prob A vector of 1D float arrays, indicating the transition probability of * each edge by edge type. An empty float array assumes uniform transition. * \return A 2D array of shape (len(seeds), len(metapath) + 1) with node IDs. The * paths that terminated early are padded with -1. * \note This function should be called together with GetNodeTypesFromMetapath to * determine the node type of each node in the random walk traces. */ template IdArray RandomWalk( const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath, const std::vector &prob); /*! * \brief Metapath-based random walk with restart probability. * \param hg The heterograph. * \param seeds A 1D array of seed nodes, with the type the source type of the first * edge type in the metapath. * \param metapath A 1D array of edge types representing the metapath. * \param prob A vector of 1D float arrays, indicating the transition probability of * each edge by edge type. An empty float array assumes uniform transition. * \param restart_prob Restart probability * \return A 2D array of shape (len(seeds), len(metapath) + 1) with node IDs. The * paths that terminated early are padded with -1. * \note This function should be called together with GetNodeTypesFromMetapath to * determine the node type of each node in the random walk traces. */ template IdArray RandomWalkWithRestart( const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath, const std::vector &prob, double restart_prob); /*! * \brief Metapath-based random walk with stepwise restart probability. Useful * for PinSAGE-like models. * \param hg The heterograph. * \param seeds A 1D array of seed nodes, with the type the source type of the first * edge type in the metapath. * \param metapath A 1D array of edge types representing the metapath. * \param prob A vector of 1D float arrays, indicating the transition probability of * each edge by edge type. An empty float array assumes uniform transition. * \param restart_prob Restart probability array which has the same number of elements * as \c metapath, indicating the probability to terminate after transition. * \return A 2D array of shape (len(seeds), len(metapath) + 1) with node IDs. The * paths that terminated early are padded with -1. * \note This function should be called together with GetNodeTypesFromMetapath to * determine the node type of each node in the random walk traces. */ template IdArray RandomWalkWithStepwiseRestart( const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath, const std::vector &prob, FloatArray restart_prob); }; // namespace impl }; // namespace sampling }; // namespace dgl #endif // DGL_GRAPH_SAMPLING_RANDOMWALKS_RANDOMWALKS_IMPL_H_