Unverified Commit df089424 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Minor code style fix. (#4825)



* blabla

* more

* blabla

* blabla

* ablabla

* blabla
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 8ac27dad
...@@ -188,9 +188,11 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -188,9 +188,11 @@ class BaseHeteroGraph : public runtime::Object {
/** /**
* @brief Get all edge ids between the two given endpoints * @brief Get all edge ids between the two given endpoints
* @note The given src and dst vertices should belong to the source vertex * @note The given src and dst vertices should belong to the source vertex
* type and the dest vertex type of the given edge type, respectively. \param * type and the dest vertex type of the given edge type, respectively.
* etype The edge type \param src The source vertex. \param dst The * @param etype The edge type
* destination vertex. \return the edge id array. * @param src The source vertex.
* @param dst The destination vertex.
* @return the edge id array.
*/ */
virtual IdArray EdgeId( virtual IdArray EdgeId(
dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const = 0; dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const = 0;
...@@ -284,17 +286,18 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -284,17 +286,18 @@ class BaseHeteroGraph : public runtime::Object {
* @brief Get all the edges in the graph. * @brief Get all the edges in the graph.
* @note If order is "srcdst", the returned edges list is sorted by their src * @note If order is "srcdst", the returned edges list is sorted by their src
* and dst ids. If order is "eid", they are in their edge id order. Otherwise, * and dst ids. If order is "eid", they are in their edge id order. Otherwise,
* in the arbitrary order. \param etype The edge type \param order The order * in the arbitrary order.
* of the returned edge list. \return the id arrays of the two endpoints of * @param etype The edge type
* the edges. * @param order The order of the returned edge list.
* @return the id arrays of the two endpoints of the edges.
*/ */
virtual EdgeArray Edges( virtual EdgeArray Edges(
dgl_type_t etype, const std::string& order = "") const = 0; dgl_type_t etype, const std::string& order = "") const = 0;
/** /**
* @brief Get the in degree of the given vertex. * @brief Get the in degree of the given vertex.
* @note The given vertex should belong to the dest vertex type * @note The given vertex should belong to the dest vertex type of the given
* of the given edge type. * edge type.
* @param etype The edge type * @param etype The edge type
* @param vid The vertex id. * @param vid The vertex id.
* @return the in degree * @return the in degree
...@@ -303,8 +306,8 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -303,8 +306,8 @@ class BaseHeteroGraph : public runtime::Object {
/** /**
* @brief Get the in degrees of the given vertices. * @brief Get the in degrees of the given vertices.
* @note The given vertex should belong to the dest vertex type * @note The given vertex should belong to the dest vertex type of the given
* of the given edge type. * edge type.
* @param etype The edge type * @param etype The edge type
* @param vid The vertex id array. * @param vid The vertex id array.
* @return the in degree array * @return the in degree array
...@@ -313,8 +316,8 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -313,8 +316,8 @@ class BaseHeteroGraph : public runtime::Object {
/** /**
* @brief Get the out degree of the given vertex. * @brief Get the out degree of the given vertex.
* @note The given vertex should belong to the source vertex type * @note The given vertex should belong to the source vertex type of the given
* of the given edge type. * edge type.
* @param etype The edge type * @param etype The edge type
* @param vid The vertex id. * @param vid The vertex id.
* @return the out degree * @return the out degree
...@@ -323,8 +326,8 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -323,8 +326,8 @@ class BaseHeteroGraph : public runtime::Object {
/** /**
* @brief Get the out degrees of the given vertices. * @brief Get the out degrees of the given vertices.
* @note The given vertex should belong to the source vertex type * @note The given vertex should belong to the source vertex type of the given
* of the given edge type. * edge type.
* @param etype The edge type * @param etype The edge type
* @param vid The vertex id array. * @param vid The vertex id array.
* @return the out degree array * @return the out degree array
...@@ -333,8 +336,8 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -333,8 +336,8 @@ class BaseHeteroGraph : public runtime::Object {
/** /**
* @brief Return the successor vector * @brief Return the successor vector
* @note The given vertex should belong to the source vertex type * @note The given vertex should belong to the source vertex type of the given
* of the given edge type. * edge type.
* @param vid The vertex id. * @param vid The vertex id.
* @return the successor vector iterator pair. * @return the successor vector iterator pair.
*/ */
...@@ -342,8 +345,8 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -342,8 +345,8 @@ class BaseHeteroGraph : public runtime::Object {
/** /**
* @brief Return the out edge id vector * @brief Return the out edge id vector
* @note The given vertex should belong to the source vertex type * @note The given vertex should belong to the source vertex type of the given
* of the given edge type. * edge type.
* @param vid The vertex id. * @param vid The vertex id.
* @return the out edge id vector iterator pair. * @return the out edge id vector iterator pair.
*/ */
...@@ -351,8 +354,8 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -351,8 +354,8 @@ class BaseHeteroGraph : public runtime::Object {
/** /**
* @brief Return the predecessor vector * @brief Return the predecessor vector
* @note The given vertex should belong to the dest vertex type * @note The given vertex should belong to the dest vertex type of the given
* of the given edge type. * edge type.
* @param vid The vertex id. * @param vid The vertex id.
* @return the predecessor vector iterator pair. * @return the predecessor vector iterator pair.
*/ */
...@@ -360,8 +363,8 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -360,8 +363,8 @@ class BaseHeteroGraph : public runtime::Object {
/** /**
* @brief Return the in edge id vector * @brief Return the in edge id vector
* @note The given vertex should belong to the dest vertex type * @note The given vertex should belong to the dest vertex type of the given
* of the given edge type. * edge type.
* @param vid The vertex id. * @param vid The vertex id.
* @return the in edge id vector iterator pair. * @return the in edge id vector iterator pair.
*/ */
...@@ -391,7 +394,6 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -391,7 +394,6 @@ class BaseHeteroGraph : public runtime::Object {
/** /**
* @brief Determine which format to use with a preference. * @brief Determine which format to use with a preference.
* *
* Otherwise, it will return whatever DGL thinks is the most appropriate given * Otherwise, it will return whatever DGL thinks is the most appropriate given
* the arguments. * the arguments.
* *
...@@ -477,7 +479,8 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -477,7 +479,8 @@ class BaseHeteroGraph : public runtime::Object {
* *
* @param eids The edges in the subgraph. * @param eids The edges in the subgraph.
* @param preserve_nodes If true, the vertices will not be relabeled, so some * @param preserve_nodes If true, the vertices will not be relabeled, so some
* vertices may have no incident edges. \return the subgraph. * vertices may have no incident edges.
* @return the subgraph.
*/ */
virtual HeteroSubgraph EdgeSubgraph( virtual HeteroSubgraph EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes = false) const = 0; const std::vector<IdArray>& eids, bool preserve_nodes = false) const = 0;
...@@ -713,7 +716,8 @@ HeteroGraphPtr CreateFromCSC( ...@@ -713,7 +716,8 @@ HeteroGraphPtr CreateFromCSC(
* @param graph Graph * @param graph Graph
* @param nodes Node IDs of each type * @param nodes Node IDs of each type
* @param relabel_nodes Whether to remove isolated nodes and relabel the rest * @param relabel_nodes Whether to remove isolated nodes and relabel the rest
* ones \return Subgraph containing only the in edges. The returned graph has * ones
* @return Subgraph containing only the in edges. The returned graph has
* the same schema as the original one. * the same schema as the original one.
*/ */
HeteroSubgraph InEdgeGraph( HeteroSubgraph InEdgeGraph(
...@@ -725,7 +729,8 @@ HeteroSubgraph InEdgeGraph( ...@@ -725,7 +729,8 @@ HeteroSubgraph InEdgeGraph(
* @param graph Graph * @param graph Graph
* @param nodes Node IDs of each type * @param nodes Node IDs of each type
* @param relabel_nodes Whether to remove isolated nodes and relabel the rest * @param relabel_nodes Whether to remove isolated nodes and relabel the rest
* ones \return Subgraph containing only the out edges. The returned graph has * ones
* @return Subgraph containing only the out edges. The returned graph has
* the same schema as the original one. * the same schema as the original one.
*/ */
HeteroSubgraph OutEdgeGraph( HeteroSubgraph OutEdgeGraph(
......
...@@ -246,7 +246,8 @@ DGL_DLL void DGLAPISetLastError(const char* msg); ...@@ -246,7 +246,8 @@ DGL_DLL void DGLAPISetLastError(const char* msg);
* DGLGetLastError can be called to retrieve the error * DGLGetLastError can be called to retrieve the error
* *
* this function is threadsafe and can be called by different thread * this function is threadsafe and can be called by different thread
* \return error info *
* @return error info
*/ */
DGL_DLL const char* DGLGetLastError(void); DGL_DLL const char* DGLGetLastError(void);
/** /**
......
...@@ -235,7 +235,9 @@ class List : public ObjectRef { ...@@ -235,7 +235,9 @@ class List : public ObjectRef {
} }
/** /**
* @brief Constructs a container with n elements. Each element is a copy of * @brief Constructs a container with n elements. Each element is a copy of
* val \param n The size of the container \param val The init value * val
* @param n The size of the container
* @param val The init value
*/ */
explicit List(size_t n, const T& val) { explicit List(size_t n, const T& val) {
auto tmp_obj = std::make_shared<ListObject>(); auto tmp_obj = std::make_shared<ListObject>();
......
...@@ -32,13 +32,13 @@ namespace impl { ...@@ -32,13 +32,13 @@ namespace impl {
// *ATTENTION*: This function will be invoked concurrently. Please make sure // *ATTENTION*: This function will be invoked concurrently. Please make sure
// it is thread-safe. // it is thread-safe.
// //
// \param rowid The row to pick from. // @param rowid The row to pick from.
// \param off Starting offset of this row. // @param off Starting offset of this row.
// \param len NNZ of the row. // @param len NNZ of the row.
// \param num_picks Number of picks on the row. // @param num_picks Number of picks on the row.
// \param col Pointer of the column indices. // @param col Pointer of the column indices.
// \param data Pointer of the data indices. // @param data Pointer of the data indices.
// \param out_idx Picked indices in [off, off + len). // @param out_idx Picked indices in [off, off + len).
template <typename IdxType> template <typename IdxType>
using PickFn = std::function<void( using PickFn = std::function<void(
IdxType rowid, IdxType off, IdxType len, IdxType num_picks, IdxType rowid, IdxType off, IdxType len, IdxType num_picks,
...@@ -57,11 +57,11 @@ using PickFn = std::function<void( ...@@ -57,11 +57,11 @@ using PickFn = std::function<void(
// *ATTENTION*: This function will be invoked concurrently. Please make sure // *ATTENTION*: This function will be invoked concurrently. Please make sure
// it is thread-safe. // it is thread-safe.
// //
// \param rowid The row to pick from. // @param rowid The row to pick from.
// \param off Starting offset of this row. // @param off Starting offset of this row.
// \param len NNZ of the row. // @param len NNZ of the row.
// \param col Pointer of the column indices. // @param col Pointer of the column indices.
// \param data Pointer of the data indices. // @param data Pointer of the data indices.
template <typename IdxType> template <typename IdxType>
using NumPicksFn = std::function<IdxType( using NumPicksFn = std::function<IdxType(
IdxType rowid, IdxType off, IdxType len, const IdxType* col, IdxType rowid, IdxType off, IdxType len, const IdxType* col,
...@@ -80,14 +80,14 @@ using NumPicksFn = std::function<IdxType( ...@@ -80,14 +80,14 @@ using NumPicksFn = std::function<IdxType(
// *ATTENTION*: This function will be invoked concurrently. Please make sure // *ATTENTION*: This function will be invoked concurrently. Please make sure
// it is thread-safe. // it is thread-safe.
// //
// \param off Starting offset of this row. // @param off Starting offset of this row.
// \param et_offset Starting offset of this range. // @param et_offset Starting offset of this range.
// \param cur_et The edge type. // @param cur_et The edge type.
// \param et_len Length of the range. // @param et_len Length of the range.
// \param et_idx A map from local idx to column id. // @param et_idx A map from local idx to column id.
// \param et_eid Edge-type-specific id array. // @param et_eid Edge-type-specific id array.
// \param eid Pointer of the homogenized edge id array. // @param eid Pointer of the homogenized edge id array.
// \param out_idx Picked indices in [et_offset, et_offset + et_len). // @param out_idx Picked indices in [et_offset, et_offset + et_len).
template <typename IdxType> template <typename IdxType>
using EtypeRangePickFn = std::function<void( using EtypeRangePickFn = std::function<void(
IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len, IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,
......
...@@ -226,7 +226,8 @@ void UpdateGradMinMax_hetero( ...@@ -226,7 +226,8 @@ void UpdateGradMinMax_hetero(
/** /**
* @brief CUDA implementation of backward phase of Segment Reduce with Min/Max * @brief CUDA implementation of backward phase of Segment Reduce with Min/Max
* reducer. * reducer.
* @note math equation: out[arg[i, k], k] = feat[i, k] \param feat The input * @note math equation: out[arg[i, k], k] = feat[i, k]
* @param feat The input
* tensor. * tensor.
* @param arg The ArgMin/Max information, used for indexing. * @param arg The ArgMin/Max information, used for indexing.
* @param out The output tensor. * @param out The output tensor.
......
...@@ -39,7 +39,7 @@ inline int FindNumThreads(int dim, int max_nthrs = CUDA_MAX_NUM_THREADS) { ...@@ -39,7 +39,7 @@ inline int FindNumThreads(int dim, int max_nthrs = CUDA_MAX_NUM_THREADS) {
} }
/** /**
* !\brief Find number of blocks is smaller than nblks and max_nblks * @brief Find number of blocks is smaller than nblks and max_nblks
* on the given axis ('x', 'y' or 'z'). * on the given axis ('x', 'y' or 'z').
*/ */
template <char axis> template <char axis>
......
/** /**
Copyright (c) 2021 Intel Corporation * Copyright (c) 2021 Intel Corporation
\file distgnn/partition/main_Libra.py *
\brief Libra - Vertex-cut based graph partitioner for distirbuted training * @file distgnn/partition/main_Libra.py
\author Vasimuddin Md <vasimuddin.md@intel.com>, * @brief Libra - Vertex-cut based graph partitioner for distirbuted training
Guixiang Ma <guixiang.ma@intel.com> * @author Vasimuddin Md <vasimuddin.md@intel.com>,
Sanchit Misra <sanchit.misra@intel.com>, * Guixiang Ma <guixiang.ma@intel.com>
Ramanarayan Mohanty <ramanarayan.mohanty@intel.com>, * Sanchit Misra <sanchit.misra@intel.com>,
Sasikanth Avancha <sasikanth.avancha@intel.com> * Ramanarayan Mohanty <ramanarayan.mohanty@intel.com>,
Nesreen K. Ahmed <nesreen.k.ahmed@intel.com> * Sasikanth Avancha <sasikanth.avancha@intel.com>
*/ * Nesreen K. Ahmed <nesreen.k.ahmed@intel.com>
*/
#include <dgl/base_heterograph.h> #include <dgl/base_heterograph.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
...@@ -347,7 +348,8 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLLibraVertexCut") ...@@ -347,7 +348,8 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLLibraVertexCut")
* @param[in, out] offset start of the range of local node IDs for this * @param[in, out] offset start of the range of local node IDs for this
* partition * partition
* @param[in] nc number of partitions/communities * @param[in] nc number of partitions/communities
* @param[in] c current partition number \param[in] fsize size of pre-allocated * @param[in] c current partition number
* @param[in] fsize size of pre-allocated
* memory tensor * memory tensor
* @param[in] prefix input Libra partition file location * @param[in] prefix input Libra partition file location
*/ */
...@@ -516,7 +518,8 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLLibra2dglSetLR") ...@@ -516,7 +518,8 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLLibra2dglSetLR")
* @param[out] inner_nodes marks whether a node is split or not. * @param[out] inner_nodes marks whether a node is split or not.
* @param[in] ldt_key per partition dict for tracking global to local node IDs * @param[in] ldt_key per partition dict for tracking global to local node IDs
* @param[out] gdt_key global dict for storing number of local nodes (or split * @param[out] gdt_key global dict for storing number of local nodes (or split
* nodes) for a given global node ID \param[out] gdt_value global * nodes) for a given global node ID
* @param[out] gdt_value global
* dict, stores local node IDs (due to split) across partitions for * dict, stores local node IDs (due to split) across partitions for
* a given global node ID. * a given global node ID.
* @param[in] node_map keeps track of range of local node IDs (consecutive) * @param[in] node_map keeps track of range of local node IDs (consecutive)
......
...@@ -255,7 +255,8 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -255,7 +255,8 @@ class HeteroGraph : public BaseHeteroGraph {
*/ */
void RecordStream(DGLStreamHandle stream) override; void RecordStream(DGLStreamHandle stream) override;
/** @brief Copy the data to shared memory. /**
* @brief Copy the data to shared memory.
* *
* Also save names of node types and edge types of the HeteroGraph object to * Also save names of node types and edge types of the HeteroGraph object to
* shared memory * shared memory
...@@ -266,8 +267,10 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -266,8 +267,10 @@ class HeteroGraph : public BaseHeteroGraph {
const std::vector<std::string>& etypes, const std::vector<std::string>& etypes,
const std::set<std::string>& fmts); const std::set<std::string>& fmts);
/** @brief Create a heterograph from /**
* \return the HeteroGraphPtr, names of node types, names of edge types * @brief Create a heterograph from
*
* @return the HeteroGraphPtr, names of node types, names of edge types
*/ */
static std::tuple< static std::tuple<
HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>> HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>>
...@@ -296,11 +299,14 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -296,11 +299,14 @@ class HeteroGraph : public BaseHeteroGraph {
/** @brief The shared memory object for meta info*/ /** @brief The shared memory object for meta info*/
std::shared_ptr<runtime::SharedMemory> shared_mem_; std::shared_ptr<runtime::SharedMemory> shared_mem_;
/** @brief The name of the shared memory. Return empty string if it is not in /**
* shared memory. */ * @brief The name of the shared memory. Return empty string if it is not in
* shared memory.
*/
std::string SharedMemName() const; std::string SharedMemName() const;
/** @brief template class for Flatten operation /**
* @brief template class for Flatten operation
* *
* @tparam IdType Graph's index data type, can be int32_t or int64_t * @tparam IdType Graph's index data type, can be int32_t or int64_t
* @param etypes vector of etypes to be falttened * @param etypes vector of etypes to be falttened
......
...@@ -115,7 +115,8 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStep( ...@@ -115,7 +115,8 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStep(
* @param edges_by_type Vector of results from \c GetAdj() by edge type. * @param edges_by_type Vector of results from \c GetAdj() by edge type.
* @param metapath_data Edge types of given metapath. * @param metapath_data Edge types of given metapath.
* @param prob Transition probability per edge type, for this special case this * @param prob Transition probability per edge type, for this special case this
* will be a NullArray \param terminate Predicate for terminating the current * will be a NullArray
* @param terminate Predicate for terminating the current
* random walk path. * random walk path.
* *
* @return A pair of ID of next successor (-1 if not exist), as well as whether * @return A pair of ID of next successor (-1 if not exist), as well as whether
...@@ -158,13 +159,14 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStepUniform( ...@@ -158,13 +159,14 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStepUniform(
* @brief Metapath-based random walk. * @brief Metapath-based random walk.
* @param hg The heterograph. * @param hg The heterograph.
* @param seeds A 1D array of seed nodes, with the type the source type of the * @param seeds A 1D array of seed nodes, with the type the source type of the
* first edge type in the metapath. \param metapath A 1D array of edge types * first edge type in the metapath.
* representing the metapath. \param prob A vector of 1D float arrays, * @param metapath A 1D array of edge types representing the metapath.
* indicating the transition probability of each edge by edge type. An empty * @param prob A vector of 1D float arrays, indicating the transition
* float array assumes uniform transition. \param terminate Predicate for * probability of each edge by edge type. An empty float array assumes uniform
* terminating a random walk path. \return A 2D array of shape (len(seeds), * transition.
* len(metapath) + 1) with node IDs, and A 2D array of shape (len(seeds), * @param terminate Predicate for terminating a random walk path.
* len(metapath)) with edge IDs. * @return A 2D array of shape (len(seeds), len(metapath) + 1) with node IDs,
* and A 2D array of shape (len(seeds), len(metapath)) with edge IDs.
*/ */
template <DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
std::pair<IdArray, IdArray> MetapathBasedRandomWalk( std::pair<IdArray, IdArray> MetapathBasedRandomWalk(
......
...@@ -31,11 +31,14 @@ namespace impl { ...@@ -31,11 +31,14 @@ namespace impl {
* @param seeds A 1D array of seed nodes, with the type the source type of the * @param seeds A 1D array of seed nodes, with the type the source type of the
* first edge type in the metapath. * first edge type in the metapath.
* @param p Float, indicating likelihood of immediately revisiting a node in the * @param p Float, indicating likelihood of immediately revisiting a node in the
* walk. \param q Float, control parameter to interpolate between breadth-first * walk.
* strategy and depth-first strategy. \param walk_length Int, length of walk. * @param q Float, control parameter to interpolate between breadth-first
* strategy and depth-first strategy.
* @param walk_length Int, length of walk.
* @param prob A vector of 1D float arrays, indicating the transition * @param prob A vector of 1D float arrays, indicating the transition
* probability of each edge by edge type. An empty float array assumes * probability of each edge by edge type. An empty float array assumes
* uniform transition. \return A 2D array of shape (len(seeds), len(walk_length) * uniform transition.
* @return A 2D array of shape (len(seeds), len(walk_length)
* + 1) with node IDs. The paths that terminated early are padded with -1. * + 1) with node IDs. The paths that terminated early are padded with -1.
*/ */
template <DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
......
...@@ -56,11 +56,12 @@ bool has_edge_between(const CSRMatrix &csr, dgl_id_t u, dgl_id_t v) { ...@@ -56,11 +56,12 @@ bool has_edge_between(const CSRMatrix &csr, dgl_id_t u, dgl_id_t v) {
* @param q Float, control parameter to interpolate between breadth-first * @param q Float, control parameter to interpolate between breadth-first
* strategy and depth-first strategy. * strategy and depth-first strategy.
* @param len The number of nodes generated so far. Note that the seed node is * @param len The number of nodes generated so far. Note that the seed node is
* always included as \c data[0], and the successors start from \c * always included as \c data[0], and the successors start from \c data[1].
* data[1]. \param csr The CSR matrix \param prob Transition probability \param * @param csr The CSR matrix
* terminate Predicate for terminating the current random walk path. \return A * @param prob Transition probability
* tuple of ID of next successor (-1 if not exist), the edge ID traversed, as * @param terminate Predicate for terminating the current random walk path.
* well as whether to terminate. * @return A tuple of ID of next successor (-1 if not exist), the edge ID
* traversed, as well as whether to terminate.
*/ */
template <DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
......
...@@ -31,12 +31,13 @@ namespace { ...@@ -31,12 +31,13 @@ namespace {
/** /**
* @brief Generic Random Walk. * @brief Generic Random Walk.
* @param seeds A 1D array of seed nodes, with the type the source type of the * @param seeds A 1D array of seed nodes, with the type the source type of the
* first edge type in the metapath. \param max_num_steps The maximum number of * first edge type in the metapath.
* steps of a random walk path. \param step The random walk step function with * @param max_num_steps The maximum number of steps of a random walk path.
* type \c StepFunc. \param max_nodes Throws an error if one of the values in \c * @param step The random walk step function with type \c StepFunc.
* seeds exceeds this argument. \return A 2D array of shape (len(seeds), * @param max_nodes Throws an error if one of the values in \c seeds exceeds
* max_num_steps + 1) with node IDs. \note The graph itself should be bounded in * this argument.
* the closure of \c step. * @return A 2D array of shape (len(seeds), max_num_steps + 1) with node IDs.
* @note The graph itself should be bounded in the closure of \c step.
*/ */
template <DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
std::pair<IdArray, IdArray> GenericRandomWalk( std::pair<IdArray, IdArray> GenericRandomWalk(
......
...@@ -47,10 +47,13 @@ TypeArray GetNodeTypesFromMetapath( ...@@ -47,10 +47,13 @@ TypeArray GetNodeTypesFromMetapath(
* @brief Metapath-based random walk. * @brief Metapath-based random walk.
* @param hg The heterograph. * @param hg The heterograph.
* @param seeds A 1D array of seed nodes, with the type the source type of the * @param seeds A 1D array of seed nodes, with the type the source type of the
* first edge type in the metapath. \param metapath A 1D array of edge types * first edge type in the metapath.
* representing the metapath. \param prob A vector of 1D float arrays, * @param metapath A 1D array of edge types
* representing the metapath.
* @param prob A vector of 1D float arrays,
* indicating the transition probability of each edge by edge type. An empty * indicating the transition probability of each edge by edge type. An empty
* float array assumes uniform transition. \return A 2D array of shape * float array assumes uniform transition.
* @return A 2D array of shape
* (len(seeds), len(metapath) + 1) with node IDs. The paths that terminated * (len(seeds), len(metapath) + 1) with node IDs. The paths that terminated
* early are padded with -1. A 2D array of shape (len(seeds), len(metapath)) * early are padded with -1. A 2D array of shape (len(seeds), len(metapath))
* with edge IDs. The paths that terminated early are padded with -1. \note * with edge IDs. The paths that terminated early are padded with -1. \note
...@@ -66,11 +69,15 @@ std::pair<IdArray, IdArray> RandomWalk( ...@@ -66,11 +69,15 @@ std::pair<IdArray, IdArray> RandomWalk(
* @brief Metapath-based random walk with restart probability. * @brief Metapath-based random walk with restart probability.
* @param hg The heterograph. * @param hg The heterograph.
* @param seeds A 1D array of seed nodes, with the type the source type of the * @param seeds A 1D array of seed nodes, with the type the source type of the
* first edge type in the metapath. \param metapath A 1D array of edge types * first edge type in the metapath.
* representing the metapath. \param prob A vector of 1D float arrays, * @param metapath A 1D array of edge types
* representing the metapath.
* @param prob A vector of 1D float arrays,
* indicating the transition probability of each edge by edge type. An empty * indicating the transition probability of each edge by edge type. An empty
* float array assumes uniform transition. \param restart_prob Restart * float array assumes uniform transition.
* probability \return A 2D array of shape (len(seeds), len(metapath) + 1) with * @param restart_prob Restart
* probability
* @return A 2D array of shape (len(seeds), len(metapath) + 1) with
* node IDs. The paths that terminated early are padded with -1. A 2D array of * node IDs. The paths that terminated early are padded with -1. A 2D array of
* shape (len(seeds), len(metapath)) with edge IDs. The paths that terminated * shape (len(seeds), len(metapath)) with edge IDs. The paths that terminated
* early are padded with -1. \note This function should be called together with * early are padded with -1. \note This function should be called together with
...@@ -87,12 +94,16 @@ std::pair<IdArray, IdArray> RandomWalkWithRestart( ...@@ -87,12 +94,16 @@ std::pair<IdArray, IdArray> RandomWalkWithRestart(
* for PinSAGE-like models. * for PinSAGE-like models.
* @param hg The heterograph. * @param hg The heterograph.
* @param seeds A 1D array of seed nodes, with the type the source type of the * @param seeds A 1D array of seed nodes, with the type the source type of the
* first edge type in the metapath. \param metapath A 1D array of edge types * first edge type in the metapath.
* representing the metapath. \param prob A vector of 1D float arrays, * @param metapath A 1D array of edge types
* representing the metapath.
* @param prob A vector of 1D float arrays,
* indicating the transition probability of each edge by edge type. An empty * indicating the transition probability of each edge by edge type. An empty
* float array assumes uniform transition. \param restart_prob Restart * float array assumes uniform transition.
* @param restart_prob Restart
* probability array which has the same number of elements as \c metapath, * probability array which has the same number of elements as \c metapath,
* indicating the probability to terminate after transition. \return A 2D array * indicating the probability to terminate after transition.
* @return A 2D array
* of shape (len(seeds), len(metapath) + 1) with node IDs. The paths that * of shape (len(seeds), len(metapath) + 1) with node IDs. The paths that
* terminated early are padded with -1. A 2D array of shape (len(seeds), * terminated early are padded with -1. A 2D array of shape (len(seeds),
* len(metapath)) with edge IDs. The paths that terminated early are padded * len(metapath)) with edge IDs. The paths that terminated early are padded
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment