Unverified Commit 4cf5f682 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Polish metapath_randomwalk.h (#5471)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent 8a830272
......@@ -2,7 +2,7 @@
* Copyright (c) 2018 by Contributors
* @file graph/sampler/generic_randomwalk_cpu.h
* @brief DGL sampler - templated implementation definition of random walks on
* CPU
* CPU.
*/
#ifndef DGL_GRAPH_SAMPLING_RANDOMWALKS_METAPATH_RANDOMWALK_H_
......@@ -30,21 +30,17 @@ namespace impl {
namespace {
// bool WhetherToTerminate(
// IdxType *node_ids_generated_so_far,
// dgl_id_t last_node_id_generated,
// int64_t number_of_nodes_generated_so_far)
template <typename IdxType>
using TerminatePredicate = std::function<bool(IdxType *, dgl_id_t, int64_t)>;
/**
* @brief Select one successor of metapath-based random walk, given the path
* generated so far.
* generated so far.
*
* @param data The path generated so far, of type \c IdxType.
* @param curr The last node ID generated.
* @param len The number of nodes generated so far. Note that the seed node is
* always included as \c data[0], and the successors start from \c data[1].
* always included as \c data[0], and the successors start from \c data[1].
*
* @param edges_by_type Vector of results from \c GetAdj() by edge type.
* @param metapath_data Edge types of given metapath.
......@@ -52,7 +48,7 @@ using TerminatePredicate = std::function<bool(IdxType *, dgl_id_t, int64_t)>;
* @param terminate Predicate for terminating the current random walk path.
*
* @return A tuple of ID of next successor (-1 if not exist), the last traversed
* edge ID, as well as whether to terminate.
* edge ID, as well as whether to terminate.
*/
template <DGLDeviceType XPU, typename IdxType>
std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStep(
......@@ -78,9 +74,9 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStep(
const int64_t size = offsets[curr + 1] - offsets[curr];
if (size == 0) return std::make_tuple(-1, -1, true);
// Use a reference to the original array instead of copying
// This avoids updating the ref counts atomically from different threads
// and avoids cache ping-ponging in the tight loop
// Use a reference to the original array instead of copying. This avoids
// updating the ref counts atomically from different threads and avoids cache
// ping-ponging in the tight loop.
const FloatArray &prob_etype = prob[etype];
IdxType idx = 0;
if (IsNullArray(prob_etype)) {
......@@ -105,23 +101,22 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStep(
/**
* @brief Select one successor of metapath-based random walk, given the path
* generated so far specifically for the uniform probability distribution.
* generated so far specifically for the uniform probability distribution.
*
* @param data The path generated so far, of type \c IdxType.
* @param curr The last node ID generated.
* @param len The number of nodes generated so far. Note that the seed node is
* always included as \c data[0], and the successors start from \c data[1].
* always included as \c data[0], and the successors start from \c data[1].
*
* @param edges_by_type Vector of results from \c GetAdj() by edge type.
* @param metapath_data Edge types of given metapath.
* @param prob Transition probability per edge type, for this special case this
* will be a NullArray
* @param terminate Predicate for terminating the current
* random walk path.
* will be a NullArray.
* @param terminate Predicate for terminating the current random walk path.
*
* @return A pair of ID of next successor (-1 if not exist), as well as whether
* to terminate. \note This function is called only if all the probability
* arrays are null.
* to terminate. \note This function is called only if all the probability
* arrays are null.
*/
template <DGLDeviceType XPU, typename IdxType>
std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStepUniform(
......@@ -159,14 +154,14 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStepUniform(
* @brief Metapath-based random walk.
* @param hg The heterograph.
* @param seeds A 1D array of seed nodes, with the type the source type of the
* first edge type in the metapath.
* first edge type in the metapath.
* @param metapath A 1D array of edge types representing the metapath.
* @param prob A vector of 1D float arrays, indicating the transition
* probability of each edge by edge type. An empty float array assumes uniform
* transition.
* probability of each edge by edge type. An empty float array assumes
* uniform transition.
* @param terminate Predicate for terminating a random walk path.
* @return A 2D array of shape (len(seeds), len(metapath) + 1) with node IDs,
* and A 2D array of shape (len(seeds), len(metapath)) with edge IDs.
* and A 2D array of shape (len(seeds), len(metapath)) with edge IDs.
*/
template <DGLDeviceType XPU, typename IdxType>
std::pair<IdArray, IdArray> MetapathBasedRandomWalk(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment