Unverified Commit 4cf5f682 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Polish metapath_randomwalk.h (#5471)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent 8a830272
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
* Copyright (c) 2018 by Contributors * Copyright (c) 2018 by Contributors
* @file graph/sampler/generic_randomwalk_cpu.h * @file graph/sampler/generic_randomwalk_cpu.h
* @brief DGL sampler - templated implementation definition of random walks on * @brief DGL sampler - templated implementation definition of random walks on
* CPU * CPU.
*/ */
#ifndef DGL_GRAPH_SAMPLING_RANDOMWALKS_METAPATH_RANDOMWALK_H_ #ifndef DGL_GRAPH_SAMPLING_RANDOMWALKS_METAPATH_RANDOMWALK_H_
...@@ -30,10 +30,6 @@ namespace impl { ...@@ -30,10 +30,6 @@ namespace impl {
namespace { namespace {
// bool WhetherToTerminate(
// IdxType *node_ids_generated_so_far,
// dgl_id_t last_node_id_generated,
// int64_t number_of_nodes_generated_so_far)
template <typename IdxType> template <typename IdxType>
using TerminatePredicate = std::function<bool(IdxType *, dgl_id_t, int64_t)>; using TerminatePredicate = std::function<bool(IdxType *, dgl_id_t, int64_t)>;
...@@ -78,9 +74,9 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStep( ...@@ -78,9 +74,9 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStep(
const int64_t size = offsets[curr + 1] - offsets[curr]; const int64_t size = offsets[curr + 1] - offsets[curr];
if (size == 0) return std::make_tuple(-1, -1, true); if (size == 0) return std::make_tuple(-1, -1, true);
// Use a reference to the original array instead of copying // Use a reference to the original array instead of copying. This avoids
// This avoids updating the ref counts atomically from different threads // updating the ref counts atomically from different threads and avoids cache
// and avoids cache ping-ponging in the tight loop // ping-ponging in the tight loop.
const FloatArray &prob_etype = prob[etype]; const FloatArray &prob_etype = prob[etype];
IdxType idx = 0; IdxType idx = 0;
if (IsNullArray(prob_etype)) { if (IsNullArray(prob_etype)) {
...@@ -115,9 +111,8 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStep( ...@@ -115,9 +111,8 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStep(
* @param edges_by_type Vector of results from \c GetAdj() by edge type. * @param edges_by_type Vector of results from \c GetAdj() by edge type.
* @param metapath_data Edge types of given metapath. * @param metapath_data Edge types of given metapath.
* @param prob Transition probability per edge type, for this special case this * @param prob Transition probability per edge type, for this special case this
* will be a NullArray * will be a NullArray.
* @param terminate Predicate for terminating the current * @param terminate Predicate for terminating the current random walk path.
* random walk path.
* *
* @return A pair of ID of next successor (-1 if not exist), as well as whether * @return A pair of ID of next successor (-1 if not exist), as well as whether
* to terminate. \note This function is called only if all the probability * to terminate. \note This function is called only if all the probability
...@@ -162,8 +157,8 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStepUniform( ...@@ -162,8 +157,8 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStepUniform(
* first edge type in the metapath. * first edge type in the metapath.
* @param metapath A 1D array of edge types representing the metapath. * @param metapath A 1D array of edge types representing the metapath.
* @param prob A vector of 1D float arrays, indicating the transition * @param prob A vector of 1D float arrays, indicating the transition
* probability of each edge by edge type. An empty float array assumes uniform * probability of each edge by edge type. An empty float array assumes
* transition. * uniform transition.
* @param terminate Predicate for terminating a random walk path. * @param terminate Predicate for terminating a random walk path.
* @return A 2D array of shape (len(seeds), len(metapath) + 1) with node IDs, * @return A 2D array of shape (len(seeds), len(metapath) + 1) with node IDs,
* and A 2D array of shape (len(seeds), len(metapath)) with edge IDs. * and A 2D array of shape (len(seeds), len(metapath)) with edge IDs.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment