Unverified Commit 4cf5f682 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Polish metapath_randomwalk.h (#5471)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent 8a830272
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
* Copyright (c) 2018 by Contributors * Copyright (c) 2018 by Contributors
* @file graph/sampler/generic_randomwalk_cpu.h * @file graph/sampler/generic_randomwalk_cpu.h
* @brief DGL sampler - templated implementation definition of random walks on * @brief DGL sampler - templated implementation definition of random walks on
* CPU * CPU.
*/ */
#ifndef DGL_GRAPH_SAMPLING_RANDOMWALKS_METAPATH_RANDOMWALK_H_ #ifndef DGL_GRAPH_SAMPLING_RANDOMWALKS_METAPATH_RANDOMWALK_H_
...@@ -30,21 +30,17 @@ namespace impl { ...@@ -30,21 +30,17 @@ namespace impl {
namespace { namespace {
// bool WhetherToTerminate(
// IdxType *node_ids_generated_so_far,
// dgl_id_t last_node_id_generated,
// int64_t number_of_nodes_generated_so_far)
template <typename IdxType> template <typename IdxType>
using TerminatePredicate = std::function<bool(IdxType *, dgl_id_t, int64_t)>; using TerminatePredicate = std::function<bool(IdxType *, dgl_id_t, int64_t)>;
/** /**
* @brief Select one successor of metapath-based random walk, given the path * @brief Select one successor of metapath-based random walk, given the path
* generated so far. * generated so far.
* *
* @param data The path generated so far, of type \c IdxType. * @param data The path generated so far, of type \c IdxType.
* @param curr The last node ID generated. * @param curr The last node ID generated.
* @param len The number of nodes generated so far. Note that the seed node is * @param len The number of nodes generated so far. Note that the seed node is
* always included as \c data[0], and the successors start from \c data[1]. * always included as \c data[0], and the successors start from \c data[1].
* *
* @param edges_by_type Vector of results from \c GetAdj() by edge type. * @param edges_by_type Vector of results from \c GetAdj() by edge type.
* @param metapath_data Edge types of given metapath. * @param metapath_data Edge types of given metapath.
...@@ -52,7 +48,7 @@ using TerminatePredicate = std::function<bool(IdxType *, dgl_id_t, int64_t)>; ...@@ -52,7 +48,7 @@ using TerminatePredicate = std::function<bool(IdxType *, dgl_id_t, int64_t)>;
* @param terminate Predicate for terminating the current random walk path. * @param terminate Predicate for terminating the current random walk path.
* *
* @return A tuple of ID of next successor (-1 if not exist), the last traversed * @return A tuple of ID of next successor (-1 if not exist), the last traversed
* edge ID, as well as whether to terminate. * edge ID, as well as whether to terminate.
*/ */
template <DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStep( std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStep(
...@@ -78,9 +74,9 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStep( ...@@ -78,9 +74,9 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStep(
const int64_t size = offsets[curr + 1] - offsets[curr]; const int64_t size = offsets[curr + 1] - offsets[curr];
if (size == 0) return std::make_tuple(-1, -1, true); if (size == 0) return std::make_tuple(-1, -1, true);
// Use a reference to the original array instead of copying // Use a reference to the original array instead of copying. This avoids
// This avoids updating the ref counts atomically from different threads // updating the ref counts atomically from different threads and avoids cache
// and avoids cache ping-ponging in the tight loop // ping-ponging in the tight loop.
const FloatArray &prob_etype = prob[etype]; const FloatArray &prob_etype = prob[etype];
IdxType idx = 0; IdxType idx = 0;
if (IsNullArray(prob_etype)) { if (IsNullArray(prob_etype)) {
...@@ -105,23 +101,22 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStep( ...@@ -105,23 +101,22 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStep(
/** /**
* @brief Select one successor of metapath-based random walk, given the path * @brief Select one successor of metapath-based random walk, given the path
* generated so far specifically for the uniform probability distribution. * generated so far specifically for the uniform probability distribution.
* *
* @param data The path generated so far, of type \c IdxType. * @param data The path generated so far, of type \c IdxType.
* @param curr The last node ID generated. * @param curr The last node ID generated.
* @param len The number of nodes generated so far. Note that the seed node is * @param len The number of nodes generated so far. Note that the seed node is
* always included as \c data[0], and the successors start from \c data[1]. * always included as \c data[0], and the successors start from \c data[1].
* *
* @param edges_by_type Vector of results from \c GetAdj() by edge type. * @param edges_by_type Vector of results from \c GetAdj() by edge type.
* @param metapath_data Edge types of given metapath. * @param metapath_data Edge types of given metapath.
* @param prob Transition probability per edge type, for this special case this * @param prob Transition probability per edge type, for this special case this
* will be a NullArray * will be a NullArray.
* @param terminate Predicate for terminating the current * @param terminate Predicate for terminating the current random walk path.
* random walk path.
* *
* @return A pair of ID of next successor (-1 if not exist), as well as whether * @return A pair of ID of next successor (-1 if not exist), as well as whether
* to terminate. \note This function is called only if all the probability * to terminate. \note This function is called only if all the probability
* arrays are null. * arrays are null.
*/ */
template <DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStepUniform( std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStepUniform(
...@@ -159,14 +154,14 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStepUniform( ...@@ -159,14 +154,14 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStepUniform(
* @brief Metapath-based random walk. * @brief Metapath-based random walk.
* @param hg The heterograph. * @param hg The heterograph.
* @param seeds A 1D array of seed nodes, with the type the source type of the * @param seeds A 1D array of seed nodes, with the type the source type of the
* first edge type in the metapath. * first edge type in the metapath.
* @param metapath A 1D array of edge types representing the metapath. * @param metapath A 1D array of edge types representing the metapath.
* @param prob A vector of 1D float arrays, indicating the transition * @param prob A vector of 1D float arrays, indicating the transition
* probability of each edge by edge type. An empty float array assumes uniform * probability of each edge by edge type. An empty float array assumes
* transition. * uniform transition.
* @param terminate Predicate for terminating a random walk path. * @param terminate Predicate for terminating a random walk path.
* @return A 2D array of shape (len(seeds), len(metapath) + 1) with node IDs, * @return A 2D array of shape (len(seeds), len(metapath) + 1) with node IDs,
* and A 2D array of shape (len(seeds), len(metapath)) with edge IDs. * and A 2D array of shape (len(seeds), len(metapath)) with edge IDs.
*/ */
template <DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
std::pair<IdArray, IdArray> MetapathBasedRandomWalk( std::pair<IdArray, IdArray> MetapathBasedRandomWalk(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment