Unverified Commit 619d735d authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Replace \xxx with @XXX in structured comment. (#4822)



* param

* brief

* note

* return

* tparam

* brief2

* file

* return2

* return

* blabla

* all
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 96297fb8
/*! /*!
* Copyright (c) 2020 by Contributors * Copyright (c) 2020 by Contributors
* \file dgl/aten/array_ops.h * @file dgl/aten/array_ops.h
* \brief Common array operations required by DGL. * @brief Common array operations required by DGL.
* *
* Note that this is not meant for a full support of array library such as ATen. * Note that this is not meant for a full support of array library such as ATen.
* Only a limited set of operators required by DGL are implemented. * Only a limited set of operators required by DGL are implemented.
...@@ -23,36 +23,36 @@ namespace aten { ...@@ -23,36 +23,36 @@ namespace aten {
// ID array // ID array
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
/*! \return A special array to represent null. */ /*! @return A special array to represent null. */
inline NDArray NullArray(const DGLDataType& dtype = DGLDataType{kDGLInt, 64, 1}, inline NDArray NullArray(const DGLDataType& dtype = DGLDataType{kDGLInt, 64, 1},
const DGLContext& ctx = DGLContext{kDGLCPU, 0}) { const DGLContext& ctx = DGLContext{kDGLCPU, 0}) {
return NDArray::Empty({0}, dtype, ctx); return NDArray::Empty({0}, dtype, ctx);
} }
/*! /*!
* \return Whether the input array is a null array. * @return Whether the input array is a null array.
*/ */
inline bool IsNullArray(NDArray array) { inline bool IsNullArray(NDArray array) {
return array->shape[0] == 0; return array->shape[0] == 0;
} }
/*! /*!
* \brief Create a new id array with given length * @brief Create a new id array with given length
* \param length The array length * @param length The array length
* \param ctx The array context * @param ctx The array context
* \param nbits The number of integer bits * @param nbits The number of integer bits
* \return id array * @return id array
*/ */
IdArray NewIdArray(int64_t length, IdArray NewIdArray(int64_t length,
DGLContext ctx = DGLContext{kDGLCPU, 0}, DGLContext ctx = DGLContext{kDGLCPU, 0},
uint8_t nbits = 64); uint8_t nbits = 64);
/*! /*!
* \brief Create a new id array using the given vector data * @brief Create a new id array using the given vector data
* \param vec The vector data * @param vec The vector data
* \param nbits The integer bits of the returned array * @param nbits The integer bits of the returned array
* \param ctx The array context * @param ctx The array context
* \return the id array * @return the id array
*/ */
template <typename T> template <typename T>
IdArray VecToIdArray(const std::vector<T>& vec, IdArray VecToIdArray(const std::vector<T>& vec,
...@@ -60,42 +60,42 @@ IdArray VecToIdArray(const std::vector<T>& vec, ...@@ -60,42 +60,42 @@ IdArray VecToIdArray(const std::vector<T>& vec,
DGLContext ctx = DGLContext{kDGLCPU, 0}); DGLContext ctx = DGLContext{kDGLCPU, 0});
/*! /*!
* \brief Return an array representing a 1D range. * @brief Return an array representing a 1D range.
* \param low Lower bound (inclusive). * @param low Lower bound (inclusive).
* \param high Higher bound (exclusive). * @param high Higher bound (exclusive).
* \param nbits result array's bits (32 or 64) * @param nbits result array's bits (32 or 64)
* \param ctx Device context * @param ctx Device context
* \return range array * @return range array
*/ */
IdArray Range(int64_t low, int64_t high, uint8_t nbits, DGLContext ctx); IdArray Range(int64_t low, int64_t high, uint8_t nbits, DGLContext ctx);
/*! /*!
* \brief Return an array full of the given value * @brief Return an array full of the given value
* \param val The value to fill. * @param val The value to fill.
* \param length Number of elements. * @param length Number of elements.
* \param nbits result array's bits (32 or 64) * @param nbits result array's bits (32 or 64)
* \param ctx Device context * @param ctx Device context
* \return the result array * @return the result array
*/ */
IdArray Full(int64_t val, int64_t length, uint8_t nbits, DGLContext ctx); IdArray Full(int64_t val, int64_t length, uint8_t nbits, DGLContext ctx);
/*! /*!
* \brief Return an array full of the given value with the given type. * @brief Return an array full of the given value with the given type.
* \param val The value to fill. * @param val The value to fill.
* \param length Number of elements. * @param length Number of elements.
* \param ctx Device context * @param ctx Device context
* \return the result array * @return the result array
*/ */
template <typename DType> template <typename DType>
NDArray Full(DType val, int64_t length, DGLContext ctx); NDArray Full(DType val, int64_t length, DGLContext ctx);
/*! \brief Create a deep copy of the given array */ /*! @brief Create a deep copy of the given array */
IdArray Clone(IdArray arr); IdArray Clone(IdArray arr);
/*! \brief Convert the idarray to the given bit width */ /*! @brief Convert the idarray to the given bit width */
IdArray AsNumBits(IdArray arr, uint8_t bits); IdArray AsNumBits(IdArray arr, uint8_t bits);
/*! \brief Arithmetic functions */ /*! @brief Arithmetic functions */
IdArray Add(IdArray lhs, IdArray rhs); IdArray Add(IdArray lhs, IdArray rhs);
IdArray Sub(IdArray lhs, IdArray rhs); IdArray Sub(IdArray lhs, IdArray rhs);
IdArray Mul(IdArray lhs, IdArray rhs); IdArray Mul(IdArray lhs, IdArray rhs);
...@@ -138,31 +138,31 @@ IdArray LE(int64_t lhs, IdArray rhs); ...@@ -138,31 +138,31 @@ IdArray LE(int64_t lhs, IdArray rhs);
IdArray EQ(int64_t lhs, IdArray rhs); IdArray EQ(int64_t lhs, IdArray rhs);
IdArray NE(int64_t lhs, IdArray rhs); IdArray NE(int64_t lhs, IdArray rhs);
/*! \brief Stack two arrays (of len L) into a 2*L length array */ /*! @brief Stack two arrays (of len L) into a 2*L length array */
IdArray HStack(IdArray arr1, IdArray arr2); IdArray HStack(IdArray arr1, IdArray arr2);
/*! \brief Return the indices of the elements that are non-zero. */ /*! @brief Return the indices of the elements that are non-zero. */
IdArray NonZero(BoolArray bool_arr); IdArray NonZero(BoolArray bool_arr);
/*! /*!
* \brief Return the data under the index. In numpy notation, A[I] * @brief Return the data under the index. In numpy notation, A[I]
* \tparam ValueType The type of return value. * @tparam ValueType The type of return value.
*/ */
template<typename ValueType> template<typename ValueType>
ValueType IndexSelect(NDArray array, int64_t index); ValueType IndexSelect(NDArray array, int64_t index);
/*! /*!
* \brief Return the data under the index. In numpy notation, A[I] * @brief Return the data under the index. In numpy notation, A[I]
*/ */
NDArray IndexSelect(NDArray array, IdArray index); NDArray IndexSelect(NDArray array, IdArray index);
/*! /*!
* \brief Return the data from `start` (inclusive) to `end` (exclusive). * @brief Return the data from `start` (inclusive) to `end` (exclusive).
*/ */
NDArray IndexSelect(NDArray array, int64_t start, int64_t end); NDArray IndexSelect(NDArray array, int64_t start, int64_t end);
/*! /*!
* \brief Permute the elements of an array according to given indices. * @brief Permute the elements of an array according to given indices.
* *
* Only support 1D arrays. * Only support 1D arrays.
* *
...@@ -176,7 +176,7 @@ NDArray IndexSelect(NDArray array, int64_t start, int64_t end); ...@@ -176,7 +176,7 @@ NDArray IndexSelect(NDArray array, int64_t start, int64_t end);
NDArray Scatter(NDArray array, IdArray indices); NDArray Scatter(NDArray array, IdArray indices);
/*! /*!
* \brief Scatter data into the output array. * @brief Scatter data into the output array.
* *
* Equivalent to: * Equivalent to:
* *
...@@ -187,15 +187,15 @@ NDArray Scatter(NDArray array, IdArray indices); ...@@ -187,15 +187,15 @@ NDArray Scatter(NDArray array, IdArray indices);
void Scatter_(IdArray index, NDArray value, NDArray out); void Scatter_(IdArray index, NDArray value, NDArray out);
/*! /*!
* \brief Repeat each element a number of times. Equivalent to np.repeat(array, repeats) * @brief Repeat each element a number of times. Equivalent to np.repeat(array, repeats)
* \param array A 1D vector * @param array A 1D vector
* \param repeats A 1D integer vector for number of times to repeat for each element in * @param repeats A 1D integer vector for number of times to repeat for each element in
* \c array. Must have the same shape as \c array. * \c array. Must have the same shape as \c array.
*/ */
NDArray Repeat(NDArray array, IdArray repeats); NDArray Repeat(NDArray array, IdArray repeats);
/*! /*!
* \brief Relabel the given ids to consecutive ids. * @brief Relabel the given ids to consecutive ids.
* *
* Relabeling is done inplace. The mapping is created from the union * Relabeling is done inplace. The mapping is created from the union
* of the give arrays. * of the give arrays.
...@@ -206,21 +206,21 @@ NDArray Repeat(NDArray array, IdArray repeats); ...@@ -206,21 +206,21 @@ NDArray Repeat(NDArray array, IdArray repeats);
* mapping is [2, 3, 10, 4, 0, 5], meaning the new ID 0 maps to the old ID * mapping is [2, 3, 10, 4, 0, 5], meaning the new ID 0 maps to the old ID
* 2, 1 maps to 3, so on and so forth. * 2, 1 maps to 3, so on and so forth.
* *
* \param arrays The id arrays to relabel. * @param arrays The id arrays to relabel.
* \return mapping array M from new id to old id. * @return mapping array M from new id to old id.
*/ */
IdArray Relabel_(const std::vector<IdArray>& arrays); IdArray Relabel_(const std::vector<IdArray>& arrays);
/*! /*!
* \brief concatenate the given id arrays to one array * @brief concatenate the given id arrays to one array
* *
* Example: * Example:
* *
* Given two IdArrays [2, 3, 10, 0, 2] and [4, 10, 5] * Given two IdArrays [2, 3, 10, 0, 2] and [4, 10, 5]
* Return [2, 3, 10, 0, 2, 4, 10, 5] * Return [2, 3, 10, 0, 2, 4, 10, 5]
* *
* \param arrays The id arrays to concatenate. * @param arrays The id arrays to concatenate.
* \return concatenated array. * @return concatenated array.
*/ */
NDArray Concat(const std::vector<IdArray>& arrays); NDArray Concat(const std::vector<IdArray>& arrays);
...@@ -230,7 +230,7 @@ inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) { ...@@ -230,7 +230,7 @@ inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) {
} }
/*! /*!
* \brief Packs a tensor containing padded sequences of variable length. * @brief Packs a tensor containing padded sequences of variable length.
* *
* Similar to \c pack_padded_sequence in PyTorch, except that * Similar to \c pack_padded_sequence in PyTorch, except that
* *
...@@ -240,11 +240,11 @@ inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) { ...@@ -240,11 +240,11 @@ inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) {
* 3. Along with the tensor containing the packed sequence, it returns both the * 3. Along with the tensor containing the packed sequence, it returns both the
* length, as well as the offsets to the packed tensor, of each sequence. * length, as well as the offsets to the packed tensor, of each sequence.
* *
* \param array The tensor containing sequences padded to the same length * @param array The tensor containing sequences padded to the same length
* \param pad_value The padding value * @param pad_value The padding value
* \return A triplet of packed tensor, the length tensor, and the offset tensor * @return A triplet of packed tensor, the length tensor, and the offset tensor
* *
* \note Example: consider the following array with padding value -1: * @note Example: consider the following array with padding value -1:
* *
* <code> * <code>
* [[1, 2, -1, -1], * [[1, 2, -1, -1],
...@@ -262,7 +262,7 @@ template<typename ValueType> ...@@ -262,7 +262,7 @@ template<typename ValueType>
std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, ValueType pad_value); std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, ValueType pad_value);
/*! /*!
* \brief Batch-slice a 1D or 2D array, and then pack the list of sliced arrays * @brief Batch-slice a 1D or 2D array, and then pack the list of sliced arrays
* by concatenation. * by concatenation.
* *
* If a 2D array is given, then the function is equivalent to: * If a 2D array is given, then the function is equivalent to:
...@@ -285,14 +285,14 @@ std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, ValueType pad_value); ...@@ -285,14 +285,14 @@ std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, ValueType pad_value);
* return packed, offsets * return packed, offsets
* </code> * </code>
* *
* \param array A 1D or 2D tensor for slicing * @param array A 1D or 2D tensor for slicing
* \param lengths A 1D tensor indicating the number of elements to slice * @param lengths A 1D tensor indicating the number of elements to slice
* \return The tensor with packed slices along with the offsets. * @return The tensor with packed slices along with the offsets.
*/ */
std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths); std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths);
/*! /*!
* \brief Return the cumulative summation (or inclusive sum) of the input array. * @brief Return the cumulative summation (or inclusive sum) of the input array.
* *
* The first element out[0] is equal to the first element of the input array * The first element out[0] is equal to the first element of the input array
* array[0]. The rest elements are defined recursively, out[i] = out[i-1] + array[i]. * array[0]. The rest elements are defined recursively, out[i] = out[i-1] + array[i].
...@@ -302,40 +302,40 @@ std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths); ...@@ -302,40 +302,40 @@ std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths);
* length is the input array length plus one. This is useful for creating * length is the input array length plus one. This is useful for creating
* an indptr array over a count array. * an indptr array over a count array.
* *
* \param array The 1D input array. * @param array The 1D input array.
* \return Array after cumsum. * @return Array after cumsum.
*/ */
IdArray CumSum(IdArray array, bool prepend_zero = false); IdArray CumSum(IdArray array, bool prepend_zero = false);
/*! /*!
* \brief Return the nonzero index. * @brief Return the nonzero index.
* *
* Only support 1D array. The result index array is in int64. * Only support 1D array. The result index array is in int64.
* *
* \param array The input array. * @param array The input array.
* \return A 1D index array storing the positions of the non zero values. * @return A 1D index array storing the positions of the non zero values.
*/ */
IdArray NonZero(NDArray array); IdArray NonZero(NDArray array);
/*! /*!
* \brief Sort the ID vector in ascending order. * @brief Sort the ID vector in ascending order.
* *
* It performs both sort and arg_sort (returning the sorted index). The sorted index * It performs both sort and arg_sort (returning the sorted index). The sorted index
* is always in int64. * is always in int64.
* *
* \param array Input array. * @param array Input array.
* \param num_bits The number of bits used in key comparison. For example, if the data type * @param num_bits The number of bits used in key comparison. For example, if the data type
* of the input array is int32_t and `num_bits = 8`, it only uses bits in index * of the input array is int32_t and `num_bits = 8`, it only uses bits in index
* range [0, 8) for sorting. Setting it to a small value could * range [0, 8) for sorting. Setting it to a small value could
* speed up the sorting if the underlying sorting algorithm is radix sort (e.g., on GPU). * speed up the sorting if the underlying sorting algorithm is radix sort (e.g., on GPU).
* Setting it to zero (default value) means using all the bits for comparison. * Setting it to zero (default value) means using all the bits for comparison.
* On CPU, it currently has no effect. * On CPU, it currently has no effect.
* \return A pair of arrays: sorted values and sorted index to the original position. * @return A pair of arrays: sorted values and sorted index to the original position.
*/ */
std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits = 0); std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits = 0);
/*! /*!
* \brief Return a string that prints out some debug information. * @brief Return a string that prints out some debug information.
*/ */
std::string ToDebugString(NDArray array); std::string ToDebugString(NDArray array);
...@@ -356,7 +356,7 @@ IdArray VecToIdArray(const std::vector<T>& vec, ...@@ -356,7 +356,7 @@ IdArray VecToIdArray(const std::vector<T>& vec,
} }
/*! /*!
* \brief Get the context of the first array, and check if the non-null arrays' * @brief Get the context of the first array, and check if the non-null arrays'
* contexts are the same. * contexts are the same.
*/ */
inline DGLContext GetContextOf(const std::vector<IdArray>& arrays) { inline DGLContext GetContextOf(const std::vector<IdArray>& arrays) {
......
/*! /*!
* Copyright (c) 2020 by Contributors * Copyright (c) 2020 by Contributors
* \file dgl/aten/coo.h * @file dgl/aten/coo.h
* \brief Common COO operations required by DGL. * @brief Common COO operations required by DGL.
*/ */
#ifndef DGL_ATEN_COO_H_ #ifndef DGL_ATEN_COO_H_
#define DGL_ATEN_COO_H_ #define DGL_ATEN_COO_H_
...@@ -24,7 +24,7 @@ namespace aten { ...@@ -24,7 +24,7 @@ namespace aten {
struct CSRMatrix; struct CSRMatrix;
/*! /*!
* \brief Plain COO structure * @brief Plain COO structure
* *
* The data array stores integer ids for reading edge features. * The data array stores integer ids for reading edge features.
* Note that we do allow duplicate non-zero entries -- multiple non-zero entries * Note that we do allow duplicate non-zero entries -- multiple non-zero entries
...@@ -37,21 +37,21 @@ constexpr uint64_t kDGLSerialize_AtenCooMatrixMagic = 0xDD61ffd305dff127; ...@@ -37,21 +37,21 @@ constexpr uint64_t kDGLSerialize_AtenCooMatrixMagic = 0xDD61ffd305dff127;
// TODO(BarclayII): Graph queries on COO formats should support the case where // TODO(BarclayII): Graph queries on COO formats should support the case where
// data ordered by rows/columns instead of EID. // data ordered by rows/columns instead of EID.
struct COOMatrix { struct COOMatrix {
/*! \brief the dense shape of the matrix */ /*! @brief the dense shape of the matrix */
int64_t num_rows = 0, num_cols = 0; int64_t num_rows = 0, num_cols = 0;
/*! \brief COO index arrays */ /*! @brief COO index arrays */
IdArray row, col; IdArray row, col;
/*! \brief data index array. When is null, assume it is from 0 to NNZ - 1. */ /*! @brief data index array. When is null, assume it is from 0 to NNZ - 1. */
IdArray data; IdArray data;
/*! \brief whether the row indices are sorted */ /*! @brief whether the row indices are sorted */
bool row_sorted = false; bool row_sorted = false;
/*! \brief whether the column indices per row are sorted */ /*! @brief whether the column indices per row are sorted */
bool col_sorted = false; bool col_sorted = false;
/*! \brief whether the matrix is in pinned memory */ /*! @brief whether the matrix is in pinned memory */
bool is_pinned = false; bool is_pinned = false;
/*! \brief default constructor */ /*! @brief default constructor */
COOMatrix() = default; COOMatrix() = default;
/*! \brief constructor */ /*! @brief constructor */
COOMatrix(int64_t nrows, int64_t ncols, IdArray rarr, IdArray carr, COOMatrix(int64_t nrows, int64_t ncols, IdArray rarr, IdArray carr,
IdArray darr = NullArray(), bool rsorted = false, IdArray darr = NullArray(), bool rsorted = false,
bool csorted = false) bool csorted = false)
...@@ -65,7 +65,7 @@ struct COOMatrix { ...@@ -65,7 +65,7 @@ struct COOMatrix {
CheckValidity(); CheckValidity();
} }
/*! \brief constructor from SparseMatrix object */ /*! @brief constructor from SparseMatrix object */
explicit COOMatrix(const SparseMatrix& spmat) explicit COOMatrix(const SparseMatrix& spmat)
: num_rows(spmat.num_rows), : num_rows(spmat.num_rows),
num_cols(spmat.num_cols), num_cols(spmat.num_cols),
...@@ -121,7 +121,7 @@ struct COOMatrix { ...@@ -121,7 +121,7 @@ struct COOMatrix {
CHECK_NO_OVERFLOW(row->dtype, num_cols); CHECK_NO_OVERFLOW(row->dtype, num_cols);
} }
/*! \brief Return a copy of this matrix on the give device context. */ /*! @brief Return a copy of this matrix on the give device context. */
inline COOMatrix CopyTo(const DGLContext &ctx) const { inline COOMatrix CopyTo(const DGLContext &ctx) const {
if (ctx == row->ctx) if (ctx == row->ctx)
return *this; return *this;
...@@ -131,8 +131,8 @@ struct COOMatrix { ...@@ -131,8 +131,8 @@ struct COOMatrix {
} }
/*! /*!
* \brief Pin the row, col and data (if not Null) of the matrix. * @brief Pin the row, col and data (if not Null) of the matrix.
* \note This is an in-place method. Behavior depends on the current context, * @note This is an in-place method. Behavior depends on the current context,
* kDGLCPU: will be pinned; * kDGLCPU: will be pinned;
* IsPinned: directly return; * IsPinned: directly return;
* kDGLCUDA: invalid, will throw an error. * kDGLCUDA: invalid, will throw an error.
...@@ -150,8 +150,8 @@ struct COOMatrix { ...@@ -150,8 +150,8 @@ struct COOMatrix {
} }
/*! /*!
* \brief Unpin the row, col and data (if not Null) of the matrix. * @brief Unpin the row, col and data (if not Null) of the matrix.
* \note This is an in-place method. Behavior depends on the current context, * @note This is an in-place method. Behavior depends on the current context,
* IsPinned: will be unpinned; * IsPinned: will be unpinned;
* others: directly return. * others: directly return.
* The context check is deferred to unpinning the NDArray. * The context check is deferred to unpinning the NDArray.
...@@ -168,8 +168,8 @@ struct COOMatrix { ...@@ -168,8 +168,8 @@ struct COOMatrix {
} }
/*! /*!
* \brief Record stream for the row, col and data (if not Null) of the matrix. * @brief Record stream for the row, col and data (if not Null) of the matrix.
* \param stream The stream that is using the graph * @param stream The stream that is using the graph
*/ */
inline void RecordStream(DGLStreamHandle stream) const { inline void RecordStream(DGLStreamHandle stream) const {
row.RecordStream(stream); row.RecordStream(stream);
...@@ -182,29 +182,29 @@ struct COOMatrix { ...@@ -182,29 +182,29 @@ struct COOMatrix {
///////////////////////// COO routines ////////////////////////// ///////////////////////// COO routines //////////////////////////
/*! \brief Return true if the value (row, col) is non-zero */ /*! @brief Return true if the value (row, col) is non-zero */
bool COOIsNonZero(COOMatrix , int64_t row, int64_t col); bool COOIsNonZero(COOMatrix , int64_t row, int64_t col);
/*! /*!
* \brief Batched implementation of COOIsNonZero. * @brief Batched implementation of COOIsNonZero.
* \note This operator allows broadcasting (i.e, either row or col can be of length 1). * @note This operator allows broadcasting (i.e, either row or col can be of length 1).
*/ */
runtime::NDArray COOIsNonZero(COOMatrix , runtime::NDArray row, runtime::NDArray col); runtime::NDArray COOIsNonZero(COOMatrix , runtime::NDArray row, runtime::NDArray col);
/*! \brief Return the nnz of the given row */ /*! @brief Return the nnz of the given row */
int64_t COOGetRowNNZ(COOMatrix , int64_t row); int64_t COOGetRowNNZ(COOMatrix , int64_t row);
runtime::NDArray COOGetRowNNZ(COOMatrix , runtime::NDArray row); runtime::NDArray COOGetRowNNZ(COOMatrix , runtime::NDArray row);
/*! \brief Return the data array of the given row */ /*! @brief Return the data array of the given row */
std::pair<runtime::NDArray, runtime::NDArray> std::pair<runtime::NDArray, runtime::NDArray>
COOGetRowDataAndIndices(COOMatrix , int64_t row); COOGetRowDataAndIndices(COOMatrix , int64_t row);
/*! \brief Whether the COO matrix contains data */ /*! @brief Whether the COO matrix contains data */
inline bool COOHasData(COOMatrix csr) { inline bool COOHasData(COOMatrix csr) {
return !IsNullArray(csr.data); return !IsNullArray(csr.data);
} }
/*! /*!
* \brief Check whether the COO is sorted. * @brief Check whether the COO is sorted.
* *
* It returns two flags: one for whether the row is sorted; * It returns two flags: one for whether the row is sorted;
* the other for whether the columns of each row is sorted * the other for whether the columns of each row is sorted
...@@ -215,22 +215,22 @@ inline bool COOHasData(COOMatrix csr) { ...@@ -215,22 +215,22 @@ inline bool COOHasData(COOMatrix csr) {
std::pair<bool, bool> COOIsSorted(COOMatrix coo); std::pair<bool, bool> COOIsSorted(COOMatrix coo);
/*! /*!
* \brief Get the data and the row,col indices for each returned entries. * @brief Get the data and the row,col indices for each returned entries.
* *
* The operator supports matrix with duplicate entries and all the matched entries * The operator supports matrix with duplicate entries and all the matched entries
* will be returned. The operator assumes there is NO duplicate (row, col) pair * will be returned. The operator assumes there is NO duplicate (row, col) pair
* in the given input. Otherwise, the returned result is undefined. * in the given input. Otherwise, the returned result is undefined.
* *
* \note This operator allows broadcasting (i.e, either row or col can be of length 1). * @note This operator allows broadcasting (i.e, either row or col can be of length 1).
* \param mat Sparse matrix * @param mat Sparse matrix
* \param rows Row index * @param rows Row index
* \param cols Column index * @param cols Column index
* \return Three arrays {rows, cols, data} * @return Three arrays {rows, cols, data}
*/ */
std::vector<runtime::NDArray> COOGetDataAndIndices( std::vector<runtime::NDArray> COOGetDataAndIndices(
COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols); COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols);
/*! \brief Get data. The return type is an ndarray due to possible duplicate entries. */ /*! @brief Get data. The return type is an ndarray due to possible duplicate entries. */
inline runtime::NDArray COOGetAllData(COOMatrix mat, int64_t row, int64_t col) { inline runtime::NDArray COOGetAllData(COOMatrix mat, int64_t row, int64_t col) {
IdArray rows = VecToIdArray<int64_t>({row}, mat.row->dtype.bits, mat.row->ctx); IdArray rows = VecToIdArray<int64_t>({row}, mat.row->dtype.bits, mat.row->ctx);
IdArray cols = VecToIdArray<int64_t>({col}, mat.row->dtype.bits, mat.row->ctx); IdArray cols = VecToIdArray<int64_t>({col}, mat.row->dtype.bits, mat.row->ctx);
...@@ -239,26 +239,26 @@ inline runtime::NDArray COOGetAllData(COOMatrix mat, int64_t row, int64_t col) { ...@@ -239,26 +239,26 @@ inline runtime::NDArray COOGetAllData(COOMatrix mat, int64_t row, int64_t col) {
} }
/*! /*!
* \brief Get the data for each (row, col) pair. * @brief Get the data for each (row, col) pair.
* *
* The operator supports matrix with duplicate entries but only one matched entry * The operator supports matrix with duplicate entries but only one matched entry
* will be returned for each (row, col) pair. Support duplicate input (row, col) * will be returned for each (row, col) pair. Support duplicate input (row, col)
* pairs. * pairs.
* *
* \note This operator allows broadcasting (i.e, either row or col can be of length 1). * @note This operator allows broadcasting (i.e, either row or col can be of length 1).
* *
* \param mat Sparse matrix. * @param mat Sparse matrix.
* \param rows Row index. * @param rows Row index.
* \param cols Column index. * @param cols Column index.
* \return Data array. The i^th element is the data of (rows[i], cols[i]) * @return Data array. The i^th element is the data of (rows[i], cols[i])
*/ */
runtime::NDArray COOGetData(COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols); runtime::NDArray COOGetData(COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols);
/*! \brief Return a transposed COO matrix */ /*! @brief Return a transposed COO matrix */
COOMatrix COOTranspose(COOMatrix coo); COOMatrix COOTranspose(COOMatrix coo);
/*! /*!
* \brief Convert COO matrix to CSR matrix. * @brief Convert COO matrix to CSR matrix.
* *
* If the input COO matrix does not have data array, the data array of * If the input COO matrix does not have data array, the data array of
* the result CSR matrix stores a shuffle index for how the entries * the result CSR matrix stores a shuffle index for how the entries
...@@ -276,44 +276,44 @@ COOMatrix COOTranspose(COOMatrix coo); ...@@ -276,44 +276,44 @@ COOMatrix COOTranspose(COOMatrix coo);
* also column sorted. * also column sorted.
* - Otherwise, the conversion is more costly but still is O(nnz). * - Otherwise, the conversion is more costly but still is O(nnz).
* *
* \param coo Input COO matrix. * @param coo Input COO matrix.
* \return CSR matrix. * @return CSR matrix.
*/ */
CSRMatrix COOToCSR(COOMatrix coo); CSRMatrix COOToCSR(COOMatrix coo);
/*! /*!
* \brief Slice rows of the given matrix and return. * @brief Slice rows of the given matrix and return.
* \param coo COO matrix * @param coo COO matrix
* \param start Start row id (inclusive) * @param start Start row id (inclusive)
* \param end End row id (exclusive) * @param end End row id (exclusive)
*/ */
COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end); COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end);
COOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows); COOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows);
/*! /*!
* \brief Get the submatrix specified by the row and col ids. * @brief Get the submatrix specified by the row and col ids.
* *
* In numpy notation, given matrix M, row index array I, col index array J * In numpy notation, given matrix M, row index array I, col index array J
* This function returns the submatrix M[I, J]. * This function returns the submatrix M[I, J].
* *
* \param coo The input coo matrix * @param coo The input coo matrix
* \param rows The row index to select * @param rows The row index to select
* \param cols The col index to select * @param cols The col index to select
* \return submatrix * @return submatrix
*/ */
COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols); COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
/*! \return True if the matrix has duplicate entries */ /*! @return True if the matrix has duplicate entries */
bool COOHasDuplicate(COOMatrix coo); bool COOHasDuplicate(COOMatrix coo);
/*! /*!
* \brief Deduplicate the entries of a sorted COO matrix, replacing the data with the * @brief Deduplicate the entries of a sorted COO matrix, replacing the data with the
* number of occurrences of the row-col coordinates. * number of occurrences of the row-col coordinates.
*/ */
std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo); std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo);
/*! /*!
* \brief Sort the indices of a COO matrix in-place. * @brief Sort the indices of a COO matrix in-place.
* *
* The function sorts row indices in ascending order. If sort_column is true, * The function sorts row indices in ascending order. If sort_column is true,
* col indices are sorted in ascending order too. The data array of the returned COOMatrix * col indices are sorted in ascending order too. The data array of the returned COOMatrix
...@@ -322,13 +322,13 @@ std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo); ...@@ -322,13 +322,13 @@ std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo);
* Complexity: O(N*log(N)) time and O(1) space, where N is the number of nonzeros. * Complexity: O(N*log(N)) time and O(1) space, where N is the number of nonzeros.
* TODO(minjie): The time complexity could be improved to O(N) by using a O(N) space. * TODO(minjie): The time complexity could be improved to O(N) by using a O(N) space.
* *
* \param mat The coo matrix to sort. * @param mat The coo matrix to sort.
* \param sort_column True if column index should be sorted too. * @param sort_column True if column index should be sorted too.
*/ */
void COOSort_(COOMatrix* mat, bool sort_column = false); void COOSort_(COOMatrix* mat, bool sort_column = false);
/*! /*!
* \brief Sort the indices of a COO matrix. * @brief Sort the indices of a COO matrix.
* *
* The function sorts row indices in ascending order. If sort_column is true, * The function sorts row indices in ascending order. If sort_column is true,
* col indices are sorted in ascending order too. The data array of the returned COOMatrix * col indices are sorted in ascending order too. The data array of the returned COOMatrix
...@@ -337,9 +337,9 @@ void COOSort_(COOMatrix* mat, bool sort_column = false); ...@@ -337,9 +337,9 @@ void COOSort_(COOMatrix* mat, bool sort_column = false);
* Complexity: O(N*log(N)) time and O(1) space, where N is the number of nonzeros. * Complexity: O(N*log(N)) time and O(1) space, where N is the number of nonzeros.
* TODO(minjie): The time complexity could be improved to O(N) by using a O(N) space. * TODO(minjie): The time complexity could be improved to O(N) by using a O(N) space.
* *
* \param mat The input coo matrix * @param mat The input coo matrix
* \param sort_column True if column index should be sorted too. * @param sort_column True if column index should be sorted too.
* \return COO matrix with index sorted. * @return COO matrix with index sorted.
*/ */
inline COOMatrix COOSort(COOMatrix mat, bool sort_column = false) { inline COOMatrix COOSort(COOMatrix mat, bool sort_column = false) {
if ((mat.row_sorted && !sort_column) || mat.col_sorted) if ((mat.row_sorted && !sort_column) || mat.col_sorted)
...@@ -353,22 +353,22 @@ inline COOMatrix COOSort(COOMatrix mat, bool sort_column = false) { ...@@ -353,22 +353,22 @@ inline COOMatrix COOSort(COOMatrix mat, bool sort_column = false) {
} }
/*! /*!
* \brief Remove entries from COO matrix by entry indices (data indices) * @brief Remove entries from COO matrix by entry indices (data indices)
* \return A new COO matrix as well as a mapping from the new COO entries to the old COO * @return A new COO matrix as well as a mapping from the new COO entries to the old COO
* entries. * entries.
*/ */
COOMatrix COORemove(COOMatrix coo, IdArray entries); COOMatrix COORemove(COOMatrix coo, IdArray entries);
/*! /*!
* \brief Reorder the rows and colmns according to the new row and column order. * @brief Reorder the rows and colmns according to the new row and column order.
* \param csr The input coo matrix. * @param csr The input coo matrix.
* \param new_row_ids the new row Ids (the index is the old row Id) * @param new_row_ids the new row Ids (the index is the old row Id)
* \param new_col_ids the new column Ids (the index is the old col Id). * @param new_col_ids the new column Ids (the index is the old col Id).
*/ */
COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids); COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
/*! /*!
* \brief Randomly select a fixed number of non-zero entries along each given row independently. * @brief Randomly select a fixed number of non-zero entries along each given row independently.
* *
* The function performs random choices along each row independently. * The function performs random choices along each row independently.
* The picked indices are returned in the form of a COO matrix. * The picked indices are returned in the form of a COO matrix.
...@@ -393,14 +393,14 @@ COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArr ...@@ -393,14 +393,14 @@ COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArr
* // sampled.cols = [1, 2, 3] * // sampled.cols = [1, 2, 3]
* // sampled.data = [3, 0, 4] * // sampled.data = [3, 0, 4]
* *
* \param mat Input coo matrix. * @param mat Input coo matrix.
* \param rows Rows to sample from. * @param rows Rows to sample from.
* \param num_samples Number of samples * @param num_samples Number of samples
* \param prob_or_mask Unnormalized probability array or mask array. * @param prob_or_mask Unnormalized probability array or mask array.
* Should be of the same length as the data array. * Should be of the same length as the data array.
* If an empty array is provided, assume uniform. * If an empty array is provided, assume uniform.
* \param replace True if sample with replacement * @param replace True if sample with replacement
* \return A COOMatrix storing the picked row and col indices. Its data field stores the * @return A COOMatrix storing the picked row and col indices. Its data field stores the
* the index of the picked elements in the value array. * the index of the picked elements in the value array.
*/ */
COOMatrix COORowWiseSampling( COOMatrix COORowWiseSampling(
...@@ -411,7 +411,7 @@ COOMatrix COORowWiseSampling( ...@@ -411,7 +411,7 @@ COOMatrix COORowWiseSampling(
bool replace = true); bool replace = true);
/*! /*!
* \brief Randomly select a fixed number of non-zero entries for each edge type * @brief Randomly select a fixed number of non-zero entries for each edge type
* along each given row independently. * along each given row independently.
* *
* The function performs random choices along each row independently. * The function performs random choices along each row independently.
...@@ -442,17 +442,17 @@ COOMatrix COORowWiseSampling( ...@@ -442,17 +442,17 @@ COOMatrix COORowWiseSampling(
* // sampled.cols = [0, 3, 2, 3] * // sampled.cols = [0, 3, 2, 3]
* // sampled.data = [2, 0, 1, 4] * // sampled.data = [2, 0, 1, 4]
* *
* \param mat Input coo matrix. * @param mat Input coo matrix.
* \param rows Rows to sample from. * @param rows Rows to sample from.
* \param eid2etype_offset The offset to each edge type. * @param eid2etype_offset The offset to each edge type.
* \param num_samples Number of samples * @param num_samples Number of samples
* \param prob_or_mask Unnormalized probability array or mask array. * @param prob_or_mask Unnormalized probability array or mask array.
* Should be of the same length as the data array. * Should be of the same length as the data array.
* If an empty array is provided, assume uniform. * If an empty array is provided, assume uniform.
* \param replace True if sample with replacement * @param replace True if sample with replacement
* \return A COOMatrix storing the picked row and col indices. Its data field stores the * @return A COOMatrix storing the picked row and col indices. Its data field stores the
* the index of the picked elements in the value array. * the index of the picked elements in the value array.
* \note The edges of the entire graph must be ordered by their edge types. * @note The edges of the entire graph must be ordered by their edge types.
*/ */
COOMatrix COORowWisePerEtypeSampling( COOMatrix COORowWisePerEtypeSampling(
COOMatrix mat, COOMatrix mat,
...@@ -463,7 +463,7 @@ COOMatrix COORowWisePerEtypeSampling( ...@@ -463,7 +463,7 @@ COOMatrix COORowWisePerEtypeSampling(
bool replace = true); bool replace = true);
/*! /*!
* \brief Select K non-zero entries with the largest weights along each given row. * @brief Select K non-zero entries with the largest weights along each given row.
* *
* The function performs top-k selection along each row independently. * The function performs top-k selection along each row independently.
* The picked indices are returned in the form of a COO matrix. * The picked indices are returned in the form of a COO matrix.
...@@ -489,14 +489,14 @@ COOMatrix COORowWisePerEtypeSampling( ...@@ -489,14 +489,14 @@ COOMatrix COORowWisePerEtypeSampling(
* // sampled.cols = [1, 1, 2] * // sampled.cols = [1, 1, 2]
* // sampled.data = [3, 0, 1] * // sampled.data = [3, 0, 1]
* *
* \param mat Input COO matrix. * @param mat Input COO matrix.
* \param rows Rows to sample from. * @param rows Rows to sample from.
* \param k The K value. * @param k The K value.
* \param weight Weight associated with each entry. Should be of the same length as the * @param weight Weight associated with each entry. Should be of the same length as the
* data array. If an empty array is provided, assume uniform. * data array. If an empty array is provided, assume uniform.
* \param ascending If true, elements are sorted by ascending order, equivalent to find * @param ascending If true, elements are sorted by ascending order, equivalent to find
* the K smallest values. Otherwise, find K largest values. * the K smallest values. Otherwise, find K largest values.
* \return A COOMatrix storing the picked row and col indices. Its data field stores the * @return A COOMatrix storing the picked row and col indices. Its data field stores the
* the index of the picked elements in the value array. * the index of the picked elements in the value array.
*/ */
COOMatrix COORowWiseTopk( COOMatrix COORowWiseTopk(
...@@ -507,7 +507,7 @@ COOMatrix COORowWiseTopk( ...@@ -507,7 +507,7 @@ COOMatrix COORowWiseTopk(
bool ascending = false); bool ascending = false);
/*! /*!
* \brief Union two COOMatrix into one COOMatrix. * @brief Union two COOMatrix into one COOMatrix.
* *
* Two Matrix must have the same shape. * Two Matrix must have the same shape.
* *
...@@ -539,7 +539,7 @@ COOMatrix UnionCoo( ...@@ -539,7 +539,7 @@ COOMatrix UnionCoo(
const std::vector<COOMatrix>& coos); const std::vector<COOMatrix>& coos);
/*! /*!
* \brief DisjointUnion a list COOMatrix into one COOMatrix. * @brief DisjointUnion a list COOMatrix into one COOMatrix.
* *
* Examples: * Examples:
* *
...@@ -565,16 +565,16 @@ COOMatrix UnionCoo( ...@@ -565,16 +565,16 @@ COOMatrix UnionCoo(
* COOMatrix_C.num_rows : 5 * COOMatrix_C.num_rows : 5
* COOMatrix_C.num_cols : 5 * COOMatrix_C.num_cols : 5
* *
* \param coos The input list of coo matrix. * @param coos The input list of coo matrix.
* \param src_offset A list of integers recording src vertix id offset of each Matrix in coos * @param src_offset A list of integers recording src vertix id offset of each Matrix in coos
* \param src_offset A list of integers recording dst vertix id offset of each Matrix in coos * @param src_offset A list of integers recording dst vertix id offset of each Matrix in coos
* \return The combined COOMatrix. * @return The combined COOMatrix.
*/ */
COOMatrix DisjointUnionCoo( COOMatrix DisjointUnionCoo(
const std::vector<COOMatrix>& coos); const std::vector<COOMatrix>& coos);
/*! /*!
* \brief COOMatrix toSimple. * @brief COOMatrix toSimple.
* *
* A = [[0, 0, 0], * A = [[0, 0, 0],
* [3, 0, 2], * [3, 0, 2],
...@@ -590,7 +590,7 @@ COOMatrix DisjointUnionCoo( ...@@ -590,7 +590,7 @@ COOMatrix DisjointUnionCoo(
* cnt = [3, 2, 1, 1, 4] * cnt = [3, 2, 1, 1, 4]
* edge_map = [0, 0, 0, 1, 1, 2, 3, 4, 4, 4, 4] * edge_map = [0, 0, 0, 1, 1, 2, 3, 4, 4, 4, 4]
* *
* \return The simplified COOMatrix * @return The simplified COOMatrix
* The count recording the number of duplicated edges from the original graph. * The count recording the number of duplicated edges from the original graph.
* The edge mapping from the edge IDs of original graph to those of the * The edge mapping from the edge IDs of original graph to those of the
* returned graph. * returned graph.
...@@ -598,7 +598,7 @@ COOMatrix DisjointUnionCoo( ...@@ -598,7 +598,7 @@ COOMatrix DisjointUnionCoo(
std::tuple<COOMatrix, IdArray, IdArray> COOToSimple(const COOMatrix& coo); std::tuple<COOMatrix, IdArray, IdArray> COOToSimple(const COOMatrix& coo);
/*! /*!
* \brief Split a COOMatrix into multiple disjoin components. * @brief Split a COOMatrix into multiple disjoin components.
* *
* Examples: * Examples:
* *
...@@ -634,12 +634,12 @@ std::tuple<COOMatrix, IdArray, IdArray> COOToSimple(const COOMatrix& coo); ...@@ -634,12 +634,12 @@ std::tuple<COOMatrix, IdArray, IdArray> COOToSimple(const COOMatrix& coo);
* COOMatrix_B.num_rows : 3 * COOMatrix_B.num_rows : 3
* COOMatrix_B.num_cols : 2 * COOMatrix_B.num_cols : 2
* *
* \param coo COOMatrix to split. * @param coo COOMatrix to split.
* \param batch_size Number of disjoin components (Sub COOMatrix) * @param batch_size Number of disjoin components (Sub COOMatrix)
* \param edge_cumsum Number of edges of each components * @param edge_cumsum Number of edges of each components
* \param src_vertex_cumsum Number of src vertices of each component. * @param src_vertex_cumsum Number of src vertices of each component.
* \param dst_vertex_cumsum Number of dst vertices of each component. * @param dst_vertex_cumsum Number of dst vertices of each component.
* \return A list of COOMatrixes representing each disjoint components. * @return A list of COOMatrixes representing each disjoint components.
*/ */
std::vector<COOMatrix> DisjointPartitionCooBySizes( std::vector<COOMatrix> DisjointPartitionCooBySizes(
const COOMatrix &coo, const COOMatrix &coo,
...@@ -649,7 +649,7 @@ std::vector<COOMatrix> DisjointPartitionCooBySizes( ...@@ -649,7 +649,7 @@ std::vector<COOMatrix> DisjointPartitionCooBySizes(
const std::vector<uint64_t> &dst_vertex_cumsum); const std::vector<uint64_t> &dst_vertex_cumsum);
/*! /*!
* \brief Slice a contiguous chunk from a COOMatrix * @brief Slice a contiguous chunk from a COOMatrix
* *
* Examples: * Examples:
* *
...@@ -677,11 +677,11 @@ std::vector<COOMatrix> DisjointPartitionCooBySizes( ...@@ -677,11 +677,11 @@ std::vector<COOMatrix> DisjointPartitionCooBySizes(
* COOMatrix_ret.num_rows : 3 * COOMatrix_ret.num_rows : 3
* COOMatrix_ret.num_cols : 2 * COOMatrix_ret.num_cols : 2
* *
* \param coo COOMatrix to slice. * @param coo COOMatrix to slice.
* \param edge_range ID range of the edges in the chunk * @param edge_range ID range of the edges in the chunk
* \param src_vertex_range ID range of the src vertices in the chunk. * @param src_vertex_range ID range of the src vertices in the chunk.
* \param dst_vertex_range ID range of the dst vertices in the chunk. * @param dst_vertex_range ID range of the dst vertices in the chunk.
* \return COOMatrix representing the chunk. * @return COOMatrix representing the chunk.
*/ */
COOMatrix COOSliceContiguousChunk( COOMatrix COOSliceContiguousChunk(
const COOMatrix &coo, const COOMatrix &coo,
...@@ -690,7 +690,7 @@ COOMatrix COOSliceContiguousChunk( ...@@ -690,7 +690,7 @@ COOMatrix COOSliceContiguousChunk(
const std::vector<uint64_t> &dst_vertex_range); const std::vector<uint64_t> &dst_vertex_range);
/*! /*!
* \brief Create a LineGraph of input coo * @brief Create a LineGraph of input coo
* *
* A = [[0, 0, 1], * A = [[0, 0, 1],
* [1, 0, 1], * [1, 0, 1],
...@@ -715,9 +715,9 @@ COOMatrix COOSliceContiguousChunk( ...@@ -715,9 +715,9 @@ COOMatrix COOSliceContiguousChunk(
* [1, 0, 0, 0, 0], * [1, 0, 0, 0, 0],
* [0, 1, 1, 0, 0]] * [0, 1, 1, 0, 0]]
* *
* \param coo COOMatrix to create the LineGraph * @param coo COOMatrix to create the LineGraph
* \param backtracking whether the pair of (v, u) (u, v) edges are treated as linked * @param backtracking whether the pair of (v, u) (u, v) edges are treated as linked
* \return LineGraph in COO format * @return LineGraph in COO format
*/ */
COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking); COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking);
......
/*! /*!
* Copyright (c) 2020 by Contributors * Copyright (c) 2020 by Contributors
* \file dgl/aten/csr.h * @file dgl/aten/csr.h
* \brief Common CSR operations required by DGL. * @brief Common CSR operations required by DGL.
*/ */
#ifndef DGL_ATEN_CSR_H_ #ifndef DGL_ATEN_CSR_H_
#define DGL_ATEN_CSR_H_ #define DGL_ATEN_CSR_H_
...@@ -23,7 +23,7 @@ namespace aten { ...@@ -23,7 +23,7 @@ namespace aten {
struct COOMatrix; struct COOMatrix;
/*! /*!
* \brief Plain CSR matrix * @brief Plain CSR matrix
* *
* The column indices are 0-based and are not necessarily sorted. The data array stores * The column indices are 0-based and are not necessarily sorted. The data array stores
* integer ids for reading edge features. * integer ids for reading edge features.
...@@ -36,19 +36,19 @@ struct COOMatrix; ...@@ -36,19 +36,19 @@ struct COOMatrix;
constexpr uint64_t kDGLSerialize_AtenCsrMatrixMagic = 0xDD6cd31205dff127; constexpr uint64_t kDGLSerialize_AtenCsrMatrixMagic = 0xDD6cd31205dff127;
struct CSRMatrix { struct CSRMatrix {
/*! \brief the dense shape of the matrix */ /*! @brief the dense shape of the matrix */
int64_t num_rows = 0, num_cols = 0; int64_t num_rows = 0, num_cols = 0;
/*! \brief CSR index arrays */ /*! @brief CSR index arrays */
IdArray indptr, indices; IdArray indptr, indices;
/*! \brief data index array. When is null, assume it is from 0 to NNZ - 1. */ /*! @brief data index array. When is null, assume it is from 0 to NNZ - 1. */
IdArray data; IdArray data;
/*! \brief whether the column indices per row are sorted */ /*! @brief whether the column indices per row are sorted */
bool sorted = false; bool sorted = false;
/*! \brief whether the matrix is in pinned memory */ /*! @brief whether the matrix is in pinned memory */
bool is_pinned = false; bool is_pinned = false;
/*! \brief default constructor */ /*! @brief default constructor */
CSRMatrix() = default; CSRMatrix() = default;
/*! \brief constructor */ /*! @brief constructor */
CSRMatrix(int64_t nrows, int64_t ncols, IdArray parr, IdArray iarr, CSRMatrix(int64_t nrows, int64_t ncols, IdArray parr, IdArray iarr,
IdArray darr = NullArray(), bool sorted_flag = false) IdArray darr = NullArray(), bool sorted_flag = false)
: num_rows(nrows), : num_rows(nrows),
...@@ -60,7 +60,7 @@ struct CSRMatrix { ...@@ -60,7 +60,7 @@ struct CSRMatrix {
CheckValidity(); CheckValidity();
} }
/*! \brief constructor from SparseMatrix object */ /*! @brief constructor from SparseMatrix object */
explicit CSRMatrix(const SparseMatrix& spmat) explicit CSRMatrix(const SparseMatrix& spmat)
: num_rows(spmat.num_rows), : num_rows(spmat.num_rows),
num_cols(spmat.num_cols), num_cols(spmat.num_cols),
...@@ -114,7 +114,7 @@ struct CSRMatrix { ...@@ -114,7 +114,7 @@ struct CSRMatrix {
CHECK_EQ(indptr->shape[0], num_rows + 1); CHECK_EQ(indptr->shape[0], num_rows + 1);
} }
/*! \brief Return a copy of this matrix on the give device context. */ /*! @brief Return a copy of this matrix on the give device context. */
inline CSRMatrix CopyTo(const DGLContext &ctx) const { inline CSRMatrix CopyTo(const DGLContext &ctx) const {
if (ctx == indptr->ctx) if (ctx == indptr->ctx)
return *this; return *this;
...@@ -123,8 +123,8 @@ struct CSRMatrix { ...@@ -123,8 +123,8 @@ struct CSRMatrix {
} }
/*! /*!
* \brief Pin the indptr, indices and data (if not Null) of the matrix. * @brief Pin the indptr, indices and data (if not Null) of the matrix.
* \note This is an in-place method. Behavior depends on the current context, * @note This is an in-place method. Behavior depends on the current context,
* kDGLCPU: will be pinned; * kDGLCPU: will be pinned;
* IsPinned: directly return; * IsPinned: directly return;
* kDGLCUDA: invalid, will throw an error. * kDGLCUDA: invalid, will throw an error.
...@@ -142,8 +142,8 @@ struct CSRMatrix { ...@@ -142,8 +142,8 @@ struct CSRMatrix {
} }
/*! /*!
* \brief Unpin the indptr, indices and data (if not Null) of the matrix. * @brief Unpin the indptr, indices and data (if not Null) of the matrix.
* \note This is an in-place method. Behavior depends on the current context, * @note This is an in-place method. Behavior depends on the current context,
* IsPinned: will be unpinned; * IsPinned: will be unpinned;
* others: directly return. * others: directly return.
* The context check is deferred to unpinning the NDArray. * The context check is deferred to unpinning the NDArray.
...@@ -160,8 +160,8 @@ struct CSRMatrix { ...@@ -160,8 +160,8 @@ struct CSRMatrix {
} }
/*! /*!
* \brief Record stream for the indptr, indices and data (if not Null) of the matrix. * @brief Record stream for the indptr, indices and data (if not Null) of the matrix.
* \param stream The stream that is using the graph * @param stream The stream that is using the graph
*/ */
inline void RecordStream(DGLStreamHandle stream) const { inline void RecordStream(DGLStreamHandle stream) const {
indptr.RecordStream(stream); indptr.RecordStream(stream);
...@@ -174,34 +174,34 @@ struct CSRMatrix { ...@@ -174,34 +174,34 @@ struct CSRMatrix {
///////////////////////// CSR routines ////////////////////////// ///////////////////////// CSR routines //////////////////////////
/*! \brief Return true if the value (row, col) is non-zero */ /*! @brief Return true if the value (row, col) is non-zero */
bool CSRIsNonZero(CSRMatrix , int64_t row, int64_t col); bool CSRIsNonZero(CSRMatrix , int64_t row, int64_t col);
/*! /*!
* \brief Batched implementation of CSRIsNonZero. * @brief Batched implementation of CSRIsNonZero.
* \note This operator allows broadcasting (i.e, either row or col can be of length 1). * @note This operator allows broadcasting (i.e, either row or col can be of length 1).
*/ */
runtime::NDArray CSRIsNonZero(CSRMatrix, runtime::NDArray row, runtime::NDArray col); runtime::NDArray CSRIsNonZero(CSRMatrix, runtime::NDArray row, runtime::NDArray col);
/*! \brief Return the nnz of the given row */ /*! @brief Return the nnz of the given row */
int64_t CSRGetRowNNZ(CSRMatrix , int64_t row); int64_t CSRGetRowNNZ(CSRMatrix , int64_t row);
runtime::NDArray CSRGetRowNNZ(CSRMatrix , runtime::NDArray row); runtime::NDArray CSRGetRowNNZ(CSRMatrix , runtime::NDArray row);
/*! \brief Return the column index array of the given row */ /*! @brief Return the column index array of the given row */
runtime::NDArray CSRGetRowColumnIndices(CSRMatrix , int64_t row); runtime::NDArray CSRGetRowColumnIndices(CSRMatrix , int64_t row);
/*! \brief Return the data array of the given row */ /*! @brief Return the data array of the given row */
runtime::NDArray CSRGetRowData(CSRMatrix , int64_t row); runtime::NDArray CSRGetRowData(CSRMatrix , int64_t row);
/*! \brief Whether the CSR matrix contains data */ /*! @brief Whether the CSR matrix contains data */
inline bool CSRHasData(CSRMatrix csr) { inline bool CSRHasData(CSRMatrix csr) {
return !IsNullArray(csr.data); return !IsNullArray(csr.data);
} }
/*! \brief Whether the column indices of each row is sorted. */ /*! @brief Whether the column indices of each row is sorted. */
bool CSRIsSorted(CSRMatrix csr); bool CSRIsSorted(CSRMatrix csr);
/*! /*!
* \brief Get the data and the row,col indices for each returned entries. * @brief Get the data and the row,col indices for each returned entries.
* *
* The operator supports matrix with duplicate entries and all the matched entries * The operator supports matrix with duplicate entries and all the matched entries
* will be returned. The operator assumes there is NO duplicate (row, col) pair * will be returned. The operator assumes there is NO duplicate (row, col) pair
...@@ -210,16 +210,16 @@ bool CSRIsSorted(CSRMatrix csr); ...@@ -210,16 +210,16 @@ bool CSRIsSorted(CSRMatrix csr);
* If some (row, col) pairs do not contain a valid non-zero elements, * If some (row, col) pairs do not contain a valid non-zero elements,
* they will not be included in the return arrays. * they will not be included in the return arrays.
* *
* \note This operator allows broadcasting (i.e, either row or col can be of length 1). * @note This operator allows broadcasting (i.e, either row or col can be of length 1).
* \param mat Sparse matrix * @param mat Sparse matrix
* \param rows Row index * @param rows Row index
* \param cols Column index * @param cols Column index
* \return Three arrays {rows, cols, data} * @return Three arrays {rows, cols, data}
*/ */
std::vector<runtime::NDArray> CSRGetDataAndIndices( std::vector<runtime::NDArray> CSRGetDataAndIndices(
CSRMatrix , runtime::NDArray rows, runtime::NDArray cols); CSRMatrix , runtime::NDArray rows, runtime::NDArray cols);
/* \brief Get data. The return type is an ndarray due to possible duplicate entries. */ /* @brief Get data. The return type is an ndarray due to possible duplicate entries. */
inline runtime::NDArray CSRGetAllData(CSRMatrix mat, int64_t row, int64_t col) { inline runtime::NDArray CSRGetAllData(CSRMatrix mat, int64_t row, int64_t col) {
const auto& nbits = mat.indptr->dtype.bits; const auto& nbits = mat.indptr->dtype.bits;
const auto& ctx = mat.indptr->ctx; const auto& ctx = mat.indptr->ctx;
...@@ -230,7 +230,7 @@ inline runtime::NDArray CSRGetAllData(CSRMatrix mat, int64_t row, int64_t col) { ...@@ -230,7 +230,7 @@ inline runtime::NDArray CSRGetAllData(CSRMatrix mat, int64_t row, int64_t col) {
} }
/*! /*!
* \brief Get the data for each (row, col) pair. * @brief Get the data for each (row, col) pair.
* *
* The operator supports matrix with duplicate entries but only one matched entry * The operator supports matrix with duplicate entries but only one matched entry
* will be returned for each (row, col) pair. Support duplicate input (row, col) * will be returned for each (row, col) pair. Support duplicate input (row, col)
...@@ -239,17 +239,17 @@ inline runtime::NDArray CSRGetAllData(CSRMatrix mat, int64_t row, int64_t col) { ...@@ -239,17 +239,17 @@ inline runtime::NDArray CSRGetAllData(CSRMatrix mat, int64_t row, int64_t col) {
* If some (row, col) pairs do not contain a valid non-zero elements, * If some (row, col) pairs do not contain a valid non-zero elements,
* their data values are filled with -1. * their data values are filled with -1.
* *
* \note This operator allows broadcasting (i.e, either row or col can be of length 1). * @note This operator allows broadcasting (i.e, either row or col can be of length 1).
* *
* \param mat Sparse matrix. * @param mat Sparse matrix.
* \param rows Row index. * @param rows Row index.
* \param cols Column index. * @param cols Column index.
* \return Data array. The i^th element is the data of (rows[i], cols[i]) * @return Data array. The i^th element is the data of (rows[i], cols[i])
*/ */
runtime::NDArray CSRGetData(CSRMatrix, runtime::NDArray rows, runtime::NDArray cols); runtime::NDArray CSRGetData(CSRMatrix, runtime::NDArray rows, runtime::NDArray cols);
/*! /*!
* \brief Get the data for each (row, col) pair, then index into the weights array. * @brief Get the data for each (row, col) pair, then index into the weights array.
* *
* The operator supports matrix with duplicate entries but only one matched entry * The operator supports matrix with duplicate entries but only one matched entry
* will be returned for each (row, col) pair. Support duplicate input (row, col) * will be returned for each (row, col) pair. Support duplicate input (row, col)
...@@ -258,26 +258,26 @@ runtime::NDArray CSRGetData(CSRMatrix, runtime::NDArray rows, runtime::NDArray c ...@@ -258,26 +258,26 @@ runtime::NDArray CSRGetData(CSRMatrix, runtime::NDArray rows, runtime::NDArray c
* If some (row, col) pairs do not contain a valid non-zero elements to index into the * If some (row, col) pairs do not contain a valid non-zero elements to index into the
* weights array, DGL returns the value \a filler for that pair instead. * weights array, DGL returns the value \a filler for that pair instead.
* *
* \note This operator allows broadcasting (i.e, either row or col can be of length 1). * @note This operator allows broadcasting (i.e, either row or col can be of length 1).
* *
* \tparam DType the data type of the weights array. * @tparam DType the data type of the weights array.
* \param mat Sparse matrix. * @param mat Sparse matrix.
* \param rows Row index. * @param rows Row index.
* \param cols Column index. * @param cols Column index.
* \param weights The weights array. * @param weights The weights array.
* \param filler The value to return for row-column pairs not existent in the matrix. * @param filler The value to return for row-column pairs not existent in the matrix.
* \return Data array. The i^th element is the data of (rows[i], cols[i]) * @return Data array. The i^th element is the data of (rows[i], cols[i])
*/ */
template <typename DType> template <typename DType>
runtime::NDArray CSRGetData( runtime::NDArray CSRGetData(
CSRMatrix, runtime::NDArray rows, runtime::NDArray cols, runtime::NDArray weights, CSRMatrix, runtime::NDArray rows, runtime::NDArray cols, runtime::NDArray weights,
DType filler); DType filler);
/*! \brief Return a transposed CSR matrix */ /*! @brief Return a transposed CSR matrix */
CSRMatrix CSRTranspose(CSRMatrix csr); CSRMatrix CSRTranspose(CSRMatrix csr);
/*! /*!
* \brief Convert CSR matrix to COO matrix. * @brief Convert CSR matrix to COO matrix.
* *
* Complexity: O(nnz) * Complexity: O(nnz)
* *
...@@ -287,17 +287,17 @@ CSRMatrix CSRTranspose(CSRMatrix csr); ...@@ -287,17 +287,17 @@ CSRMatrix CSRTranspose(CSRMatrix csr);
* - If the input CSR is further sorted, the result COO is also * - If the input CSR is further sorted, the result COO is also
* column sorted. * column sorted.
* *
* \param csr Input csr matrix * @param csr Input csr matrix
* \param data_as_order If true, the data array in the input csr matrix contains the order * @param data_as_order If true, the data array in the input csr matrix contains the order
* by which the resulting COO tuples are stored. In this case, the * by which the resulting COO tuples are stored. In this case, the
* data array of the resulting COO matrix will be empty because it * data array of the resulting COO matrix will be empty because it
* is essentially a consecutive range. * is essentially a consecutive range.
* \return a coo matrix * @return a coo matrix
*/ */
COOMatrix CSRToCOO(CSRMatrix csr, bool data_as_order); COOMatrix CSRToCOO(CSRMatrix csr, bool data_as_order);
/*! /*!
* \brief Slice rows of the given matrix and return. * @brief Slice rows of the given matrix and return.
* *
* The sliced row IDs are relabeled to starting from zero. * The sliced row IDs are relabeled to starting from zero.
* *
...@@ -314,16 +314,16 @@ COOMatrix CSRToCOO(CSRMatrix csr, bool data_as_order); ...@@ -314,16 +314,16 @@ COOMatrix CSRToCOO(CSRMatrix csr, bool data_as_order);
* indptr = [0, 1, 1] * indptr = [0, 1, 1]
* indices = [2] * indices = [2]
* *
* \param csr CSR matrix * @param csr CSR matrix
* \param start Start row id (inclusive) * @param start Start row id (inclusive)
* \param end End row id (exclusive) * @param end End row id (exclusive)
* \return sliced rows stored in a CSR matrix * @return sliced rows stored in a CSR matrix
*/ */
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end); CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end);
CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows); CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows);
/*! /*!
* \brief Get the submatrix specified by the row and col ids. * @brief Get the submatrix specified by the row and col ids.
* *
* In numpy notation, given matrix M, row index array I, col index array J * In numpy notation, given matrix M, row index array I, col index array J
* This function returns the submatrix M[I, J]. It assumes that there is no * This function returns the submatrix M[I, J]. It assumes that there is no
...@@ -334,18 +334,18 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows); ...@@ -334,18 +334,18 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows);
* rows and cols (i.e., row #0 in the new matrix corresponds to rows[0] in * rows and cols (i.e., row #0 in the new matrix corresponds to rows[0] in
* the original matrix). * the original matrix).
* *
* \param csr The input csr matrix * @param csr The input csr matrix
* \param rows The row index to select * @param rows The row index to select
* \param cols The col index to select * @param cols The col index to select
* \return submatrix * @return submatrix
*/ */
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols); CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
/*! \return True if the matrix has duplicate entries */ /*! @return True if the matrix has duplicate entries */
bool CSRHasDuplicate(CSRMatrix csr); bool CSRHasDuplicate(CSRMatrix csr);
/*! /*!
* \brief Sort the column index at each row in ascending order in-place. * @brief Sort the column index at each row in ascending order in-place.
* *
* Only the indices and data arrays (if available) will be mutated. The indptr array * Only the indices and data arrays (if available) will be mutated. The indptr array
* stays the same. * stays the same.
...@@ -364,7 +364,7 @@ bool CSRHasDuplicate(CSRMatrix csr); ...@@ -364,7 +364,7 @@ bool CSRHasDuplicate(CSRMatrix csr);
void CSRSort_(CSRMatrix* csr); void CSRSort_(CSRMatrix* csr);
/*! /*!
* \brief Sort the column index at each row in ascending order. * @brief Sort the column index at each row in ascending order.
* *
* Return a new CSR matrix with sorted column indices and data arrays. * Return a new CSR matrix with sorted column indices and data arrays.
*/ */
...@@ -380,22 +380,22 @@ inline CSRMatrix CSRSort(CSRMatrix csr) { ...@@ -380,22 +380,22 @@ inline CSRMatrix CSRSort(CSRMatrix csr) {
} }
/*! /*!
* \brief Reorder the rows and colmns according to the new row and column order. * @brief Reorder the rows and colmns according to the new row and column order.
* \param csr The input csr matrix. * @param csr The input csr matrix.
* \param new_row_ids the new row Ids (the index is the old row Id) * @param new_row_ids the new row Ids (the index is the old row Id)
* \param new_col_ids the new column Ids (the index is the old col Id). * @param new_col_ids the new column Ids (the index is the old col Id).
*/ */
CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids); CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
/*! /*!
* \brief Remove entries from CSR matrix by entry indices (data indices) * @brief Remove entries from CSR matrix by entry indices (data indices)
* \return A new CSR matrix as well as a mapping from the new CSR entries to the old CSR * @return A new CSR matrix as well as a mapping from the new CSR entries to the old CSR
* entries. * entries.
*/ */
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries); CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries);
/*! /*!
* \brief Randomly select a fixed number of non-zero entries along each given row independently. * @brief Randomly select a fixed number of non-zero entries along each given row independently.
* *
* The function performs random choices along each row independently. * The function performs random choices along each row independently.
* The picked indices are returned in the form of a COO matrix. * The picked indices are returned in the form of a COO matrix.
...@@ -420,15 +420,15 @@ CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries); ...@@ -420,15 +420,15 @@ CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries);
* // sampled.cols = [1, 2, 3] * // sampled.cols = [1, 2, 3]
* // sampled.data = [3, 0, 4] * // sampled.data = [3, 0, 4]
* *
* \param mat Input CSR matrix. * @param mat Input CSR matrix.
* \param rows Rows to sample from. * @param rows Rows to sample from.
* \param num_samples Number of samples * @param num_samples Number of samples
* \param prob_or_mask Unnormalized probability array or mask array. * @param prob_or_mask Unnormalized probability array or mask array.
* Should be of the same length as the data array. * Should be of the same length as the data array.
* If an empty array is provided, assume uniform. * If an empty array is provided, assume uniform.
* \param replace True if sample with replacement * @param replace True if sample with replacement
* \return A COOMatrix storing the picked row, col and data indices. * @return A COOMatrix storing the picked row, col and data indices.
* \note The edges of the entire graph must be ordered by their edge types. * @note The edges of the entire graph must be ordered by their edge types.
*/ */
COOMatrix CSRRowWiseSampling( COOMatrix CSRRowWiseSampling(
CSRMatrix mat, CSRMatrix mat,
...@@ -438,7 +438,7 @@ COOMatrix CSRRowWiseSampling( ...@@ -438,7 +438,7 @@ COOMatrix CSRRowWiseSampling(
bool replace = true); bool replace = true);
/*! /*!
* \brief Randomly select a fixed number of non-zero entries for each edge type * @brief Randomly select a fixed number of non-zero entries for each edge type
* along each given row independently. * along each given row independently.
* *
* The function performs random choices along each row independently. * The function performs random choices along each row independently.
...@@ -469,17 +469,17 @@ COOMatrix CSRRowWiseSampling( ...@@ -469,17 +469,17 @@ COOMatrix CSRRowWiseSampling(
* // sampled.cols = [0, 3, 2, 3] * // sampled.cols = [0, 3, 2, 3]
* // sampled.data = [2, 0, 1, 4] * // sampled.data = [2, 0, 1, 4]
* *
* \param mat Input CSR matrix. * @param mat Input CSR matrix.
* \param rows Rows to sample from. * @param rows Rows to sample from.
* \param eid2etype_offset The offset to each edge type. * @param eid2etype_offset The offset to each edge type.
* \param num_samples Number of samples to choose per edge type. * @param num_samples Number of samples to choose per edge type.
* \param prob_or_mask Unnormalized probability array or mask array. * @param prob_or_mask Unnormalized probability array or mask array.
* Should be of the same length as the data array. * Should be of the same length as the data array.
* If an empty array is provided, assume uniform. * If an empty array is provided, assume uniform.
* \param replace True if sample with replacement * @param replace True if sample with replacement
* \param rowwise_etype_sorted whether the CSR column indices per row are ordered by edge type. * @param rowwise_etype_sorted whether the CSR column indices per row are ordered by edge type.
* \return A COOMatrix storing the picked row, col and data indices. * @return A COOMatrix storing the picked row, col and data indices.
* \note The edges must be ordered by their edge types. * @note The edges must be ordered by their edge types.
*/ */
COOMatrix CSRRowWisePerEtypeSampling( COOMatrix CSRRowWisePerEtypeSampling(
CSRMatrix mat, CSRMatrix mat,
...@@ -491,7 +491,7 @@ COOMatrix CSRRowWisePerEtypeSampling( ...@@ -491,7 +491,7 @@ COOMatrix CSRRowWisePerEtypeSampling(
bool rowwise_etype_sorted = false); bool rowwise_etype_sorted = false);
/*! /*!
* \brief Select K non-zero entries with the largest weights along each given row. * @brief Select K non-zero entries with the largest weights along each given row.
* *
* The function performs top-k selection along each row independently. * The function performs top-k selection along each row independently.
* The picked indices are returned in the form of a COO matrix. * The picked indices are returned in the form of a COO matrix.
...@@ -517,14 +517,14 @@ COOMatrix CSRRowWisePerEtypeSampling( ...@@ -517,14 +517,14 @@ COOMatrix CSRRowWisePerEtypeSampling(
* // sampled.cols = [1, 1, 2] * // sampled.cols = [1, 1, 2]
* // sampled.data = [3, 0, 1] * // sampled.data = [3, 0, 1]
* *
* \param mat Input CSR matrix. * @param mat Input CSR matrix.
* \param rows Rows to sample from. * @param rows Rows to sample from.
* \param k The K value. * @param k The K value.
* \param weight Weight associated with each entry. Should be of the same length as the * @param weight Weight associated with each entry. Should be of the same length as the
* data array. If an empty array is provided, assume uniform. * data array. If an empty array is provided, assume uniform.
* \param ascending If true, elements are sorted by ascending order, equivalent to find * @param ascending If true, elements are sorted by ascending order, equivalent to find
* the K smallest values. Otherwise, find K largest values. * the K smallest values. Otherwise, find K largest values.
* \return A COOMatrix storing the picked row and col indices. Its data field stores the * @return A COOMatrix storing the picked row and col indices. Its data field stores the
* the index of the picked elements in the value array. * the index of the picked elements in the value array.
*/ */
COOMatrix CSRRowWiseTopk( COOMatrix CSRRowWiseTopk(
...@@ -537,7 +537,7 @@ COOMatrix CSRRowWiseTopk( ...@@ -537,7 +537,7 @@ COOMatrix CSRRowWiseTopk(
/*! /*!
* \brief Randomly select a fixed number of non-zero entries along each given row independently, * @brief Randomly select a fixed number of non-zero entries along each given row independently,
* where the probability of columns to be picked can be biased according to its tag. * where the probability of columns to be picked can be biased according to its tag.
* *
* Each column is assigned an integer tag which determines its probability to be sampled. * Each column is assigned an integer tag which determines its probability to be sampled.
...@@ -580,13 +580,13 @@ COOMatrix CSRRowWiseTopk( ...@@ -580,13 +580,13 @@ COOMatrix CSRRowWiseTopk(
* // probability of tag 1 is 0. * // probability of tag 1 is 0.
* *
* *
* \param mat Input CSR matrix. * @param mat Input CSR matrix.
* \param rows Rows to sample from. * @param rows Rows to sample from.
* \param num_samples Number of samples. * @param num_samples Number of samples.
* \param tag_offset The boundaries of tags. Should be of the shape [num_row, num_tags+1] * @param tag_offset The boundaries of tags. Should be of the shape [num_row, num_tags+1]
* \param bias Unnormalized probability array. Should be of length num_tags * @param bias Unnormalized probability array. Should be of length num_tags
* \param replace True if sample with replacement * @param replace True if sample with replacement
* \return A COOMatrix storing the picked row and col indices. Its data field stores the * @return A COOMatrix storing the picked row and col indices. Its data field stores the
* the index of the picked elements in the value array. * the index of the picked elements in the value array.
* *
*/ */
...@@ -600,19 +600,19 @@ COOMatrix CSRRowWiseSamplingBiased( ...@@ -600,19 +600,19 @@ COOMatrix CSRRowWiseSamplingBiased(
); );
/*! /*!
* \brief Uniformly sample row-column pairs whose entries do not exist in the given * @brief Uniformly sample row-column pairs whose entries do not exist in the given
* sparse matrix using rejection sampling. * sparse matrix using rejection sampling.
* *
* \note The number of samples returned may not necessarily be the number of samples * @note The number of samples returned may not necessarily be the number of samples
* given. * given.
* *
* \param csr The CSR matrix. * @param csr The CSR matrix.
* \param num_samples The number of samples. * @param num_samples The number of samples.
* \param num_trials The number of trials. * @param num_trials The number of trials.
* \param exclude_self_loops Do not include the examples where the row equals the column. * @param exclude_self_loops Do not include the examples where the row equals the column.
* \param replace Whether to sample with replacement. * @param replace Whether to sample with replacement.
* \param redundancy How much redundant negative examples to take in case of duplicate examples. * @param redundancy How much redundant negative examples to take in case of duplicate examples.
* \return A pair of row and column tensors. * @return A pair of row and column tensors.
*/ */
std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling( std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
const CSRMatrix& csr, const CSRMatrix& csr,
...@@ -623,7 +623,7 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling( ...@@ -623,7 +623,7 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
double redundancy); double redundancy);
/*! /*!
* \brief Sort the column index according to the tag of each column. * @brief Sort the column index according to the tag of each column.
* *
* Example: * Example:
* indptr = [0, 5, 8] * indptr = [0, 5, 8]
...@@ -643,10 +643,10 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling( ...@@ -643,10 +643,10 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
* Return: * Return:
* [[0, 2, 4, 5], [0, 1, 3, 3]] (marked with ^) * [[0, 2, 4, 5], [0, 1, 3, 3]] (marked with ^)
* *
* \param csr The csr matrix to be sorted * @param csr The csr matrix to be sorted
* \param tag_array Tag of each column. IdArray with length num_cols * @param tag_array Tag of each column. IdArray with length num_cols
* \param num_tags Number of tags. It should be equal to max(tag_array)+1. * @param num_tags Number of tags. It should be equal to max(tag_array)+1.
* \return 1. A sorted copy of the given CSR matrix * @return 1. A sorted copy of the given CSR matrix
* 2. The split positions of different tags. NDArray of shape (num_rows, num_tags + 1) * 2. The split positions of different tags. NDArray of shape (num_rows, num_tags + 1)
*/ */
std::pair<CSRMatrix, NDArray> CSRSortByTag( std::pair<CSRMatrix, NDArray> CSRSortByTag(
...@@ -655,7 +655,7 @@ std::pair<CSRMatrix, NDArray> CSRSortByTag( ...@@ -655,7 +655,7 @@ std::pair<CSRMatrix, NDArray> CSRSortByTag(
int64_t num_tags); int64_t num_tags);
/* /*
* \brief Union two CSRMatrix into one CSRMatrix. * @brief Union two CSRMatrix into one CSRMatrix.
* *
* Two Matrix must have the same shape. * Two Matrix must have the same shape.
* *
...@@ -687,7 +687,7 @@ CSRMatrix UnionCsr( ...@@ -687,7 +687,7 @@ CSRMatrix UnionCsr(
const std::vector<CSRMatrix>& csrs); const std::vector<CSRMatrix>& csrs);
/*! /*!
* \brief Union a list CSRMatrix into one CSRMatrix. * @brief Union a list CSRMatrix into one CSRMatrix.
* *
* Examples: * Examples:
* *
...@@ -713,16 +713,16 @@ CSRMatrix UnionCsr( ...@@ -713,16 +713,16 @@ CSRMatrix UnionCsr(
* CSRMatrix_C.num_rows : 5 * CSRMatrix_C.num_rows : 5
* CSRMatrix_C.num_cols : 5 * CSRMatrix_C.num_cols : 5
* *
* \param csrs The input list of csr matrix. * @param csrs The input list of csr matrix.
* \param src_offset A list of integers recording src vertix id offset of each Matrix in csrs * @param src_offset A list of integers recording src vertix id offset of each Matrix in csrs
* \param src_offset A list of integers recording dst vertix id offset of each Matrix in csrs * @param src_offset A list of integers recording dst vertix id offset of each Matrix in csrs
* \return The combined CSRMatrix. * @return The combined CSRMatrix.
*/ */
CSRMatrix DisjointUnionCsr( CSRMatrix DisjointUnionCsr(
const std::vector<CSRMatrix>& csrs); const std::vector<CSRMatrix>& csrs);
/*! /*!
* \brief CSRMatrix toSimple. * @brief CSRMatrix toSimple.
* *
* A = [[0, 0, 0], * A = [[0, 0, 0],
* [3, 0, 2], * [3, 0, 2],
...@@ -738,7 +738,7 @@ CSRMatrix DisjointUnionCsr( ...@@ -738,7 +738,7 @@ CSRMatrix DisjointUnionCsr(
* cnt = [3, 2, 1, 1, 4] * cnt = [3, 2, 1, 1, 4]
* edge_map = [0, 0, 0, 1, 1, 2, 3, 4, 4, 4, 4] * edge_map = [0, 0, 0, 1, 1, 2, 3, 4, 4, 4, 4]
* *
* \return The simplified CSRMatrix * @return The simplified CSRMatrix
* The count recording the number of duplicated edges from the original graph. * The count recording the number of duplicated edges from the original graph.
* The edge mapping from the edge IDs of original graph to those of the * The edge mapping from the edge IDs of original graph to those of the
* returned graph. * returned graph.
...@@ -746,7 +746,7 @@ CSRMatrix DisjointUnionCsr( ...@@ -746,7 +746,7 @@ CSRMatrix DisjointUnionCsr(
std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(const CSRMatrix& csr); std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(const CSRMatrix& csr);
/*! /*!
* \brief Split a CSRMatrix into multiple disjoint components. * @brief Split a CSRMatrix into multiple disjoint components.
* *
* Examples: * Examples:
* *
...@@ -782,12 +782,12 @@ std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(const CSRMatrix& csr); ...@@ -782,12 +782,12 @@ std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(const CSRMatrix& csr);
* CSRMatrix_B.num_rows : 3 * CSRMatrix_B.num_rows : 3
* CSRMatrix_B.num_cols : 2 * CSRMatrix_B.num_cols : 2
* *
* \param csr CSRMatrix to split. * @param csr CSRMatrix to split.
* \param batch_size Number of disjoin components (Sub CSRMatrix) * @param batch_size Number of disjoin components (Sub CSRMatrix)
* \param edge_cumsum Number of edges of each components * @param edge_cumsum Number of edges of each components
* \param src_vertex_cumsum Number of src vertices of each component. * @param src_vertex_cumsum Number of src vertices of each component.
* \param dst_vertex_cumsum Number of dst vertices of each component. * @param dst_vertex_cumsum Number of dst vertices of each component.
* \return A list of CSRMatrixes representing each disjoint components. * @return A list of CSRMatrixes representing each disjoint components.
*/ */
std::vector<CSRMatrix> DisjointPartitionCsrBySizes( std::vector<CSRMatrix> DisjointPartitionCsrBySizes(
const CSRMatrix &csrs, const CSRMatrix &csrs,
...@@ -797,7 +797,7 @@ std::vector<CSRMatrix> DisjointPartitionCsrBySizes( ...@@ -797,7 +797,7 @@ std::vector<CSRMatrix> DisjointPartitionCsrBySizes(
const std::vector<uint64_t> &dst_vertex_cumsum); const std::vector<uint64_t> &dst_vertex_cumsum);
/*! /*!
* \brief Slice a contiguous chunk from a CSRMatrix * @brief Slice a contiguous chunk from a CSRMatrix
* *
* Examples: * Examples:
* *
...@@ -825,11 +825,11 @@ std::vector<CSRMatrix> DisjointPartitionCsrBySizes( ...@@ -825,11 +825,11 @@ std::vector<CSRMatrix> DisjointPartitionCsrBySizes(
* CSRMatrix_ret.num_rows : 3 * CSRMatrix_ret.num_rows : 3
* CSRMatrix_ret.num_cols : 2 * CSRMatrix_ret.num_cols : 2
* *
* \param csr CSRMatrix to slice. * @param csr CSRMatrix to slice.
* \param edge_range ID range of the edges in the chunk * @param edge_range ID range of the edges in the chunk
* \param src_vertex_range ID range of the src vertices in the chunk. * @param src_vertex_range ID range of the src vertices in the chunk.
* \param dst_vertex_range ID range of the dst vertices in the chunk. * @param dst_vertex_range ID range of the dst vertices in the chunk.
* \return CSRMatrix representing the chunk. * @return CSRMatrix representing the chunk.
*/ */
CSRMatrix CSRSliceContiguousChunk( CSRMatrix CSRSliceContiguousChunk(
const CSRMatrix &csr, const CSRMatrix &csr,
......
/*! /*!
* Copyright (c) 2020 by Contributors * Copyright (c) 2020 by Contributors
* \file dgl/aten/macro.h * @file dgl/aten/macro.h
* \brief Common macros for aten package. * @brief Common macros for aten package.
*/ */
#ifndef DGL_ATEN_MACRO_H_ #ifndef DGL_ATEN_MACRO_H_
......
...@@ -58,19 +58,19 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -58,19 +58,19 @@ class BaseHeteroGraph : public runtime::Object {
////////////////////// query/operations on meta graph /////////////////////// ////////////////////// query/operations on meta graph ///////////////////////
/*! \return the number of vertex types */ /*! @return the number of vertex types */
virtual uint64_t NumVertexTypes() const { return meta_graph_->NumVertices(); } virtual uint64_t NumVertexTypes() const { return meta_graph_->NumVertices(); }
/*! \return the number of edge types */ /*! @return the number of edge types */
virtual uint64_t NumEdgeTypes() const { return meta_graph_->NumEdges(); } virtual uint64_t NumEdgeTypes() const { return meta_graph_->NumEdges(); }
/*! \return given the edge type, find the source type */ /*! @return given the edge type, find the source type */
virtual std::pair<dgl_type_t, dgl_type_t> GetEndpointTypes( virtual std::pair<dgl_type_t, dgl_type_t> GetEndpointTypes(
dgl_type_t etype) const { dgl_type_t etype) const {
return meta_graph_->FindEdge(etype); return meta_graph_->FindEdge(etype);
} }
/*! \return the meta graph */ /*! @return the meta graph */
virtual GraphPtr meta_graph() const { return meta_graph_; } virtual GraphPtr meta_graph() const { return meta_graph_; }
/*! /*!
...@@ -134,34 +134,34 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -134,34 +134,34 @@ class BaseHeteroGraph : public runtime::Object {
*/ */
virtual bool IsMultigraph() const = 0; virtual bool IsMultigraph() const = 0;
/*! \return whether the graph is read-only */ /*! @return whether the graph is read-only */
virtual bool IsReadonly() const = 0; virtual bool IsReadonly() const = 0;
/*! \return the number of vertices in the graph.*/ /*! @return the number of vertices in the graph.*/
virtual uint64_t NumVertices(dgl_type_t vtype) const = 0; virtual uint64_t NumVertices(dgl_type_t vtype) const = 0;
/*! \return the number of vertices for each type in the graph as a vector */ /*! @return the number of vertices for each type in the graph as a vector */
inline virtual std::vector<int64_t> NumVerticesPerType() const { inline virtual std::vector<int64_t> NumVerticesPerType() const {
LOG(FATAL) << "[BUG] NumVerticesPerType() not supported on this object."; LOG(FATAL) << "[BUG] NumVerticesPerType() not supported on this object.";
return {}; return {};
} }
/*! \return the number of edges in the graph.*/ /*! @return the number of edges in the graph.*/
virtual uint64_t NumEdges(dgl_type_t etype) const = 0; virtual uint64_t NumEdges(dgl_type_t etype) const = 0;
/*! \return true if the given vertex is in the graph.*/ /*! @return true if the given vertex is in the graph.*/
virtual bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const = 0; virtual bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const = 0;
/*! \return a 0-1 array indicating whether the given vertices are in the /*! @return a 0-1 array indicating whether the given vertices are in the
* graph. * graph.
*/ */
virtual BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const = 0; virtual BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const = 0;
/*! \return true if the given edge is in the graph.*/ /*! @return true if the given edge is in the graph.*/
virtual bool HasEdgeBetween( virtual bool HasEdgeBetween(
dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const = 0; dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const = 0;
/*! \return a 0-1 array indicating whether the given edges are in the graph.*/ /*! @return a 0-1 array indicating whether the given edges are in the graph.*/
virtual BoolArray HasEdgesBetween( virtual BoolArray HasEdgesBetween(
dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const = 0; dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const = 0;
......
...@@ -103,21 +103,21 @@ class Graph : public GraphInterface { ...@@ -103,21 +103,21 @@ class Graph : public GraphInterface {
*/ */
bool IsReadonly() const override { return false; } bool IsReadonly() const override { return false; }
/*! \return the number of vertices in the graph.*/ /*! @return the number of vertices in the graph.*/
uint64_t NumVertices() const override { return adjlist_.size(); } uint64_t NumVertices() const override { return adjlist_.size(); }
/*! \return the number of edges in the graph.*/ /*! @return the number of edges in the graph.*/
uint64_t NumEdges() const override { return num_edges_; } uint64_t NumEdges() const override { return num_edges_; }
/*! \return a 0-1 array indicating whether the given vertices are in the /*! @return a 0-1 array indicating whether the given vertices are in the
* graph. * graph.
*/ */
BoolArray HasVertices(IdArray vids) const override; BoolArray HasVertices(IdArray vids) const override;
/*! \return true if the given edge is in the graph.*/ /*! @return true if the given edge is in the graph.*/
bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override; bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override;
/*! \return a 0-1 array indicating whether the given edges are in the graph.*/ /*! @return a 0-1 array indicating whether the given edges are in the graph.*/
BoolArray HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const override; BoolArray HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const override;
/*! /*!
......
...@@ -172,24 +172,24 @@ class GraphInterface : public runtime::Object { ...@@ -172,24 +172,24 @@ class GraphInterface : public runtime::Object {
*/ */
virtual bool IsReadonly() const = 0; virtual bool IsReadonly() const = 0;
/*! \return the number of vertices in the graph.*/ /*! @return the number of vertices in the graph.*/
virtual uint64_t NumVertices() const = 0; virtual uint64_t NumVertices() const = 0;
/*! \return the number of edges in the graph.*/ /*! @return the number of edges in the graph.*/
virtual uint64_t NumEdges() const = 0; virtual uint64_t NumEdges() const = 0;
/*! \return true if the given vertex is in the graph.*/ /*! @return true if the given vertex is in the graph.*/
virtual bool HasVertex(dgl_id_t vid) const { return vid < NumVertices(); } virtual bool HasVertex(dgl_id_t vid) const { return vid < NumVertices(); }
/*! \return a 0-1 array indicating whether the given vertices are in the /*! @return a 0-1 array indicating whether the given vertices are in the
* graph. * graph.
*/ */
virtual BoolArray HasVertices(IdArray vids) const = 0; virtual BoolArray HasVertices(IdArray vids) const = 0;
/*! \return true if the given edge is in the graph.*/ /*! @return true if the given edge is in the graph.*/
virtual bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const = 0; virtual bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const = 0;
/*! \return a 0-1 array indicating whether the given edges are in the graph.*/ /*! @return a 0-1 array indicating whether the given edges are in the graph.*/
virtual BoolArray HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const = 0; virtual BoolArray HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const = 0;
/*! /*!
......
/*! /*!
* Copyright (c) 2018 by Contributors * Copyright (c) 2018 by Contributors
* \file graph/graph_serializer.cc * @file graph/graph_serializer.cc
* \brief DGL serializer APIs * @brief DGL serializer APIs
*/ */
#ifndef DGL_GRAPH_SERIALIZER_H_ #ifndef DGL_GRAPH_SERIALIZER_H_
......
...@@ -225,10 +225,10 @@ class CSR : public GraphInterface { ...@@ -225,10 +225,10 @@ class CSR : public GraphInterface {
IdArray edge_ids() const { return adj_.data; } IdArray edge_ids() const { return adj_.data; }
/*! \return Load CSR from stream */ /*! @return Load CSR from stream */
bool Load(dmlc::Stream *fs); bool Load(dmlc::Stream *fs);
/*! \return Save CSR to stream */ /*! @return Save CSR to stream */
void Save(dmlc::Stream *fs) const; void Save(dmlc::Stream *fs) const;
void SortCSR() override { void SortCSR() override {
...@@ -577,18 +577,18 @@ class ImmutableGraph : public GraphInterface { ...@@ -577,18 +577,18 @@ class ImmutableGraph : public GraphInterface {
return is_unibipartite_; return is_unibipartite_;
} }
/*! \return the number of vertices in the graph.*/ /*! @return the number of vertices in the graph.*/
uint64_t NumVertices() const override { return AnyGraph()->NumVertices(); } uint64_t NumVertices() const override { return AnyGraph()->NumVertices(); }
/*! \return the number of edges in the graph.*/ /*! @return the number of edges in the graph.*/
uint64_t NumEdges() const override { return AnyGraph()->NumEdges(); } uint64_t NumEdges() const override { return AnyGraph()->NumEdges(); }
/*! \return true if the given vertex is in the graph.*/ /*! @return true if the given vertex is in the graph.*/
bool HasVertex(dgl_id_t vid) const override { return vid < NumVertices(); } bool HasVertex(dgl_id_t vid) const override { return vid < NumVertices(); }
BoolArray HasVertices(IdArray vids) const override; BoolArray HasVertices(IdArray vids) const override;
/*! \return true if the given edge is in the graph.*/ /*! @return true if the given edge is in the graph.*/
bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override { bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override {
if (in_csr_) { if (in_csr_) {
return in_csr_->HasEdgeBetween(dst, src); return in_csr_->HasEdgeBetween(dst, src);
...@@ -918,10 +918,10 @@ class ImmutableGraph : public GraphInterface { ...@@ -918,10 +918,10 @@ class ImmutableGraph : public GraphInterface {
*/ */
ImmutableGraphPtr Reverse() const; ImmutableGraphPtr Reverse() const;
/*! \return Load ImmutableGraph from stream, using out csr */ /*! @return Load ImmutableGraph from stream, using out csr */
bool Load(dmlc::Stream *fs); bool Load(dmlc::Stream *fs);
/*! \return Save ImmutableGraph to stream, using out csr */ /*! @return Save ImmutableGraph to stream, using out csr */
void Save(dmlc::Stream *fs) const; void Save(dmlc::Stream *fs) const;
void SortCSR() override { void SortCSR() override {
......
...@@ -65,7 +65,7 @@ DGL_DLL void* DGLBackendAllocWorkspace( ...@@ -65,7 +65,7 @@ DGL_DLL void* DGLBackendAllocWorkspace(
* @param device_id The device id which the space will be allocated. * @param device_id The device id which the space will be allocated.
* @return 0 when no error is thrown, -1 when failure happens * @return 0 when no error is thrown, -1 when failure happens
* *
* \sa DGLBackendAllocWorkspace * @sa DGLBackendAllocWorkspace
*/ */
DGL_DLL int DGLBackendFreeWorkspace(int device_type, int device_id, void* ptr); DGL_DLL int DGLBackendFreeWorkspace(int device_type, int device_id, void* ptr);
......
...@@ -166,7 +166,7 @@ typedef struct { ...@@ -166,7 +166,7 @@ typedef struct {
* For a DGLArray, the size of memory required to store the contents of * For a DGLArray, the size of memory required to store the contents of
* data can be calculated as follows: * data can be calculated as follows:
* *
* \code{.c} * @code{.c}
* static inline size_t GetDataSize(const DGLArray* t) { * static inline size_t GetDataSize(const DGLArray* t) {
* size_t size = 1; * size_t size = 1;
* for (int32_t i = 0; i < t->ndim; ++i) { * for (int32_t i = 0; i < t->ndim; ++i) {
...@@ -175,7 +175,7 @@ typedef struct { ...@@ -175,7 +175,7 @@ typedef struct {
* size *= (t->dtype.bits * t->dtype.lanes + 7) / 8; * size *= (t->dtype.bits * t->dtype.lanes + 7) / 8;
* return size; * return size;
* } * }
* \endcode * @endcode
*/ */
void* data; void* data;
/*! @brief The device of the tensor */ /*! @brief The device of the tensor */
...@@ -372,7 +372,7 @@ DGL_DLL int DGLCbArgToReturn(DGLValue* value, int code); ...@@ -372,7 +372,7 @@ DGL_DLL int DGLCbArgToReturn(DGLValue* value, int code);
* @param resource_handle The handle additional resouce handle from fron-end. * @param resource_handle The handle additional resouce handle from fron-end.
* @return 0 if success, -1 if failure happens, set error via * @return 0 if success, -1 if failure happens, set error via
* DGLAPISetLastError. * DGLAPISetLastError.
* \sa DGLCFuncSetReturn * @sa DGLCFuncSetReturn
*/ */
typedef int (*DGLPackedCFunc)( typedef int (*DGLPackedCFunc)(
DGLValue* args, int* type_codes, int num_args, DGLRetValueHandle ret, DGLValue* args, int* type_codes, int num_args, DGLRetValueHandle ret,
......
...@@ -284,7 +284,7 @@ class List : public ObjectRef { ...@@ -284,7 +284,7 @@ class List : public ObjectRef {
inline const T operator[](size_t i) const { inline const T operator[](size_t i) const {
return T(static_cast<const ListObject*>(obj_.get())->data[i]); return T(static_cast<const ListObject*>(obj_.get())->data[i]);
} }
/*! \return The size of the list */ /*! @return The size of the list */
inline size_t size() const { inline size_t size() const {
if (obj_.get() == nullptr) return 0; if (obj_.get() == nullptr) return 0;
return static_cast<const ListObject*>(obj_.get())->data.size(); return static_cast<const ListObject*>(obj_.get())->data.size();
...@@ -321,7 +321,7 @@ class List : public ObjectRef { ...@@ -321,7 +321,7 @@ class List : public ObjectRef {
ListObject* n = this->CopyOnWrite(); ListObject* n = this->CopyOnWrite();
n->data[i] = value.obj_; n->data[i] = value.obj_;
} }
/*! \return whether list is empty */ /*! @return whether list is empty */
inline bool empty() const { return size() == 0; } inline bool empty() const { return size() == 0; }
/*! @brief Copy the content to a vector */ /*! @brief Copy the content to a vector */
inline std::vector<T> ToVector() const { inline std::vector<T> ToVector() const {
...@@ -341,20 +341,20 @@ class List : public ObjectRef { ...@@ -341,20 +341,20 @@ class List : public ObjectRef {
Ptr2ObjectRef, Ptr2ObjectRef,
std::vector<std::shared_ptr<Object> >::const_reverse_iterator>; std::vector<std::shared_ptr<Object> >::const_reverse_iterator>;
/*! \return begin iterator */ /*! @return begin iterator */
inline iterator begin() const { inline iterator begin() const {
return iterator(static_cast<const ListObject*>(obj_.get())->data.begin()); return iterator(static_cast<const ListObject*>(obj_.get())->data.begin());
} }
/*! \return end iterator */ /*! @return end iterator */
inline iterator end() const { inline iterator end() const {
return iterator(static_cast<const ListObject*>(obj_.get())->data.end()); return iterator(static_cast<const ListObject*>(obj_.get())->data.end());
} }
/*! \return rbegin iterator */ /*! @return rbegin iterator */
inline reverse_iterator rbegin() const { inline reverse_iterator rbegin() const {
return reverse_iterator( return reverse_iterator(
static_cast<const ListObject*>(obj_.get())->data.rbegin()); static_cast<const ListObject*>(obj_.get())->data.rbegin());
} }
/*! \return rend iterator */ /*! @return rend iterator */
inline reverse_iterator rend() const { inline reverse_iterator rend() const {
return reverse_iterator( return reverse_iterator(
static_cast<const ListObject*>(obj_.get())->data.rend()); static_cast<const ListObject*>(obj_.get())->data.rend());
...@@ -498,12 +498,12 @@ class Map : public ObjectRef { ...@@ -498,12 +498,12 @@ class Map : public ObjectRef {
inline const V at(const K& key) const { inline const V at(const K& key) const {
return V(static_cast<const MapObject*>(obj_.get())->data.at(key.obj_)); return V(static_cast<const MapObject*>(obj_.get())->data.at(key.obj_));
} }
/*! \return The size of the list */ /*! @return The size of the list */
inline size_t size() const { inline size_t size() const {
if (obj_.get() == nullptr) return 0; if (obj_.get() == nullptr) return 0;
return static_cast<const MapObject*>(obj_.get())->data.size(); return static_cast<const MapObject*>(obj_.get())->data.size();
} }
/*! \return The size of the list */ /*! @return The size of the list */
inline size_t count(const K& key) const { inline size_t count(const K& key) const {
if (obj_.get() == nullptr) return 0; if (obj_.get() == nullptr) return 0;
return static_cast<const MapObject*>(obj_.get())->data.count(key.obj_); return static_cast<const MapObject*>(obj_.get())->data.count(key.obj_);
...@@ -533,7 +533,7 @@ class Map : public ObjectRef { ...@@ -533,7 +533,7 @@ class Map : public ObjectRef {
n->data[key.obj_] = value.obj_; n->data[key.obj_] = value.obj_;
} }
/*! \return whether list is empty */ /*! @return whether list is empty */
inline bool empty() const { return size() == 0; } inline bool empty() const { return size() == 0; }
/*! @brief specify container obj */ /*! @brief specify container obj */
using ContainerType = MapObject; using ContainerType = MapObject;
...@@ -549,15 +549,15 @@ class Map : public ObjectRef { ...@@ -549,15 +549,15 @@ class Map : public ObjectRef {
using iterator = using iterator =
IterAdapter<Ptr2ObjectRef, MapObject::ContainerType::const_iterator>; IterAdapter<Ptr2ObjectRef, MapObject::ContainerType::const_iterator>;
/*! \return begin iterator */ /*! @return begin iterator */
inline iterator begin() const { inline iterator begin() const {
return iterator(static_cast<const MapObject*>(obj_.get())->data.begin()); return iterator(static_cast<const MapObject*>(obj_.get())->data.begin());
} }
/*! \return end iterator */ /*! @return end iterator */
inline iterator end() const { inline iterator end() const {
return iterator(static_cast<const MapObject*>(obj_.get())->data.end()); return iterator(static_cast<const MapObject*>(obj_.get())->data.end());
} }
/*! \return begin iterator */ /*! @return begin iterator */
inline iterator find(const K& key) const { inline iterator find(const K& key) const {
return iterator( return iterator(
static_cast<const MapObject*>(obj_.get())->data.find(key.obj_)); static_cast<const MapObject*>(obj_.get())->data.find(key.obj_));
...@@ -644,15 +644,15 @@ class Map<std::string, V, T1, T2> : public ObjectRef { ...@@ -644,15 +644,15 @@ class Map<std::string, V, T1, T2> : public ObjectRef {
using iterator = using iterator =
IterAdapter<Ptr2ObjectRef, StrMapObject::ContainerType::const_iterator>; IterAdapter<Ptr2ObjectRef, StrMapObject::ContainerType::const_iterator>;
/*! \return begin iterator */ /*! @return begin iterator */
inline iterator begin() const { inline iterator begin() const {
return iterator(static_cast<const StrMapObject*>(obj_.get())->data.begin()); return iterator(static_cast<const StrMapObject*>(obj_.get())->data.begin());
} }
/*! \return end iterator */ /*! @return end iterator */
inline iterator end() const { inline iterator end() const {
return iterator(static_cast<const StrMapObject*>(obj_.get())->data.end()); return iterator(static_cast<const StrMapObject*>(obj_.get())->data.end());
} }
/*! \return begin iterator */ /*! @return begin iterator */
inline iterator find(const std::string& key) const { inline iterator find(const std::string& key) const {
return iterator( return iterator(
static_cast<const StrMapObject*>(obj_.get())->data.find(key)); static_cast<const StrMapObject*>(obj_.get())->data.find(key));
......
...@@ -60,7 +60,7 @@ class DeviceAPI { ...@@ -60,7 +60,7 @@ class DeviceAPI {
* @param ctx The device context * @param ctx The device context
* @param kind The result kind * @param kind The result kind
* @param rv The return value. * @param rv The return value.
* \sa DeviceAttrKind * @sa DeviceAttrKind
*/ */
virtual void GetAttr( virtual void GetAttr(
DGLContext ctx, DeviceAttrKind kind, DGLRetValue* rv) = 0; DGLContext ctx, DeviceAttrKind kind, DGLRetValue* rv) = 0;
......
...@@ -43,9 +43,9 @@ class Module { ...@@ -43,9 +43,9 @@ class Module {
*/ */
inline PackedFunc GetFunction( inline PackedFunc GetFunction(
const std::string& name, bool query_imports = false); const std::string& name, bool query_imports = false);
/*! \return internal container */ /*! @return internal container */
inline ModuleNode* operator->(); inline ModuleNode* operator->();
/*! \return internal container */ /*! @return internal container */
inline const ModuleNode* operator->() const; inline const ModuleNode* operator->() const;
// The following functions requires link with runtime. // The following functions requires link with runtime.
/*! /*!
...@@ -78,7 +78,7 @@ class ModuleNode { ...@@ -78,7 +78,7 @@ class ModuleNode {
public: public:
/*! @brief virtual destructor */ /*! @brief virtual destructor */
virtual ~ModuleNode() {} virtual ~ModuleNode() {}
/*! \return The module type key */ /*! @return The module type key */
virtual const char* type_key() const = 0; virtual const char* type_key() const = 0;
/*! /*!
* @brief Get a PackedFunc from module. * @brief Get a PackedFunc from module.
...@@ -129,7 +129,7 @@ class ModuleNode { ...@@ -129,7 +129,7 @@ class ModuleNode {
* @return The corresponding function. * @return The corresponding function.
*/ */
DGL_DLL const PackedFunc* GetFuncFromEnv(const std::string& name); DGL_DLL const PackedFunc* GetFuncFromEnv(const std::string& name);
/*! \return The module it imports from */ /*! @return The module it imports from */
const std::vector<Module>& imports() const { return imports_; } const std::vector<Module>& imports() const { return imports_; }
protected: protected:
......
/*! /*!
* Copyright (c) 2017-2022 by Contributors * Copyright (c) 2017-2022 by Contributors
* \file dgl/runtime/ndarray.h * @file dgl/runtime/ndarray.h
* \brief Abstract device memory management API * @brief Abstract device memory management API
*/ */
#ifndef DGL_RUNTIME_NDARRAY_H_ #ifndef DGL_RUNTIME_NDARRAY_H_
#define DGL_RUNTIME_NDARRAY_H_ #define DGL_RUNTIME_NDARRAY_H_
...@@ -34,7 +34,7 @@ inline std::ostream& operator << (std::ostream& os, DGLDataType t); ...@@ -34,7 +34,7 @@ inline std::ostream& operator << (std::ostream& os, DGLDataType t);
namespace dgl { namespace dgl {
/*! /*!
* \brief Type traits that converts a C type to a DGLDataType. * @brief Type traits that converts a C type to a DGLDataType.
* *
* Usage: * Usage:
* DGLDataTypeTraits<int>::dtype == dtype * DGLDataTypeTraits<int>::dtype == dtype
...@@ -70,53 +70,53 @@ GEN_DGLDATATYPETRAITS_FOR(double, kDGLFloat, 64); ...@@ -70,53 +70,53 @@ GEN_DGLDATATYPETRAITS_FOR(double, kDGLFloat, 64);
namespace runtime { namespace runtime {
/*! /*!
* \brief DLPack converter. * @brief DLPack converter.
*/ */
struct DLPackConvert; struct DLPackConvert;
/*! /*!
* \brief Managed NDArray. * @brief Managed NDArray.
* The array is backed by reference counted blocks. * The array is backed by reference counted blocks.
*/ */
class NDArray { class NDArray {
public: public:
// internal container type // internal container type
struct Container; struct Container;
/*! \brief default constructor */ /*! @brief default constructor */
NDArray() {} NDArray() {}
/*! /*!
* \brief cosntruct a NDArray that refers to data * @brief cosntruct a NDArray that refers to data
* \param data The data this NDArray refers to * @param data The data this NDArray refers to
*/ */
explicit inline NDArray(Container* data); explicit inline NDArray(Container* data);
/*! /*!
* \brief copy constructor * @brief copy constructor
* \param other The value to be copied * @param other The value to be copied
*/ */
inline NDArray(const NDArray& other); // NOLINT(*) inline NDArray(const NDArray& other); // NOLINT(*)
/*! /*!
* \brief move constructor * @brief move constructor
* \param other The value to be moved * @param other The value to be moved
*/ */
NDArray(NDArray&& other) // NOLINT(*) NDArray(NDArray&& other) // NOLINT(*)
: data_(other.data_) { : data_(other.data_) {
other.data_ = nullptr; other.data_ = nullptr;
} }
/*! \brief destructor */ /*! @brief destructor */
~NDArray() { ~NDArray() {
this->reset(); this->reset();
} }
/*! /*!
* \brief Swap this array with another NDArray * @brief Swap this array with another NDArray
* \param other The other NDArray * @param other The other NDArray
*/ */
void swap(NDArray& other) { // NOLINT(*) void swap(NDArray& other) { // NOLINT(*)
std::swap(data_, other.data_); std::swap(data_, other.data_);
} }
/*! /*!
* \brief copy assignmemt * @brief copy assignmemt
* \param other The value to be assigned. * @param other The value to be assigned.
* \return reference to self. * @return reference to self.
*/ */
NDArray& operator=(const NDArray& other) { // NOLINT(*) NDArray& operator=(const NDArray& other) { // NOLINT(*)
// copy-and-swap idiom // copy-and-swap idiom
...@@ -124,35 +124,35 @@ class NDArray { ...@@ -124,35 +124,35 @@ class NDArray {
return *this; return *this;
} }
/*! /*!
* \brief move assignmemt * @brief move assignmemt
* \param other The value to be assigned. * @param other The value to be assigned.
* \return reference to self. * @return reference to self.
*/ */
NDArray& operator=(NDArray&& other) { // NOLINT(*) NDArray& operator=(NDArray&& other) { // NOLINT(*)
// copy-and-swap idiom // copy-and-swap idiom
NDArray(std::move(other)).swap(*this); // NOLINT(*) NDArray(std::move(other)).swap(*this); // NOLINT(*)
return *this; return *this;
} }
/*! \return If NDArray is defined */ /*! @return If NDArray is defined */
bool defined() const { bool defined() const {
return data_ != nullptr; return data_ != nullptr;
} }
/*! \return If both NDArray reference the same container */ /*! @return If both NDArray reference the same container */
bool same_as(const NDArray& other) const { bool same_as(const NDArray& other) const {
return data_ == other.data_; return data_ == other.data_;
} }
/*! \brief reset the content of NDArray to be nullptr */ /*! @brief reset the content of NDArray to be nullptr */
inline void reset(); inline void reset();
/*! /*!
* \return the reference counter * @return the reference counter
* \note this number is approximate in multi-threaded setting. * @note this number is approximate in multi-threaded setting.
*/ */
inline int use_count() const; inline int use_count() const;
/*! \return Pointer to content of DGLArray */ /*! @return Pointer to content of DGLArray */
inline const DGLArray* operator->() const; inline const DGLArray* operator->() const;
/*! \return True if the ndarray is contiguous. */ /*! @return True if the ndarray is contiguous. */
bool IsContiguous() const; bool IsContiguous() const;
/*! \return the data pointer with type. */ /*! @return the data pointer with type. */
template <typename T> template <typename T>
inline T* Ptr() const { inline T* Ptr() const {
if (!defined()) if (!defined())
...@@ -161,9 +161,9 @@ class NDArray { ...@@ -161,9 +161,9 @@ class NDArray {
return static_cast<T*>(operator->()->data); return static_cast<T*>(operator->()->data);
} }
/*! /*!
* \brief Copy data content from/into another array. * @brief Copy data content from/into another array.
* \param other The source array to be copied from. * @param other The source array to be copied from.
* \note The copy runs on the dgl internal stream if it involves a GPU context. * @note The copy runs on the dgl internal stream if it involves a GPU context.
*/ */
inline void CopyFrom(DGLArray* other); inline void CopyFrom(DGLArray* other);
inline void CopyFrom(const NDArray& other); inline void CopyFrom(const NDArray& other);
...@@ -171,79 +171,79 @@ class NDArray { ...@@ -171,79 +171,79 @@ class NDArray {
inline void CopyTo(const NDArray &other) const; inline void CopyTo(const NDArray &other) const;
/*! /*!
* \brief Copy the data to another context. * @brief Copy the data to another context.
* \param ctx The target context. * @param ctx The target context.
* \return The array under another context. * @return The array under another context.
*/ */
inline NDArray CopyTo(const DGLContext &ctx) const; inline NDArray CopyTo(const DGLContext &ctx) const;
/*! /*!
* \brief Return a new array with a copy of the content. * @brief Return a new array with a copy of the content.
*/ */
inline NDArray Clone() const; inline NDArray Clone() const;
/*! /*!
* \brief In-place method to pin the current array by calling PinContainer * @brief In-place method to pin the current array by calling PinContainer
* on the underlying NDArray:Container. * on the underlying NDArray:Container.
* \note This is an in-place method. Behavior depends on the current context, * @note This is an in-place method. Behavior depends on the current context,
* kDGLCPU: will be pinned; * kDGLCPU: will be pinned;
* IsPinned: directly return; * IsPinned: directly return;
* kDGLCUDA: invalid, will throw an error. * kDGLCUDA: invalid, will throw an error.
*/ */
inline void PinMemory_(); inline void PinMemory_();
/*! /*!
* \brief In-place method to unpin the current array by calling UnpinContainer * @brief In-place method to unpin the current array by calling UnpinContainer
* on the underlying NDArray:Container. * on the underlying NDArray:Container.
* \note This is an in-place method. Behavior depends on the current context, * @note This is an in-place method. Behavior depends on the current context,
* IsPinned: will be unpinned; * IsPinned: will be unpinned;
* others: directly return. * others: directly return.
*/ */
inline void UnpinMemory_(); inline void UnpinMemory_();
/*! /*!
* \brief Check if the array is pinned. * @brief Check if the array is pinned.
*/ */
inline bool IsPinned() const; inline bool IsPinned() const;
/*! /*!
* \brief Record streams that are using the underlying tensor. * @brief Record streams that are using the underlying tensor.
* \param stream The stream that is using the underlying tensor. * @param stream The stream that is using the underlying tensor.
*/ */
inline void RecordStream(DGLStreamHandle stream) const; inline void RecordStream(DGLStreamHandle stream) const;
/*! /*!
* \brief Load NDArray from stream * @brief Load NDArray from stream
* \param stream The input data stream * @param stream The input data stream
* \return Whether load is successful * @return Whether load is successful
*/ */
bool Load(dmlc::Stream* stream); bool Load(dmlc::Stream* stream);
/*! /*!
* \brief Save NDArray to stream * @brief Save NDArray to stream
* \param stream The output data stream * @param stream The output data stream
*/ */
void Save(dmlc::Stream* stream) const; void Save(dmlc::Stream* stream) const;
/*! /*!
* \brief Create a NDArray that shares the data memory with the current one. * @brief Create a NDArray that shares the data memory with the current one.
* \param shape The shape of the new array. * @param shape The shape of the new array.
* \param dtype The data type of the new array. * @param dtype The data type of the new array.
* \param offset The offset (in bytes) of the starting pointer. * @param offset The offset (in bytes) of the starting pointer.
* \note The memory size of new array must be smaller than the current one. * @note The memory size of new array must be smaller than the current one.
*/ */
DGL_DLL NDArray CreateView( DGL_DLL NDArray CreateView(
std::vector<int64_t> shape, DGLDataType dtype, int64_t offset = 0); std::vector<int64_t> shape, DGLDataType dtype, int64_t offset = 0);
/*! /*!
* \brief Create an empty NDArray. * @brief Create an empty NDArray.
* \param shape The shape of the new array. * @param shape The shape of the new array.
* \param dtype The data type of the new array. * @param dtype The data type of the new array.
* \param ctx The context of the Array. * @param ctx The context of the Array.
* \return The created Array * @return The created Array
*/ */
DGL_DLL static NDArray Empty(std::vector<int64_t> shape, DGL_DLL static NDArray Empty(std::vector<int64_t> shape,
DGLDataType dtype, DGLDataType dtype,
DGLContext ctx); DGLContext ctx);
/*! /*!
* \brief Create an empty NDArray with shared memory. * @brief Create an empty NDArray with shared memory.
* \param name The name of shared memory. * @param name The name of shared memory.
* \param shape The shape of the new array. * @param shape The shape of the new array.
* \param dtype The data type of the new array. * @param dtype The data type of the new array.
* \param ctx The context of the Array. * @param ctx The context of the Array.
* \param is_create whether to create shared memory. * @param is_create whether to create shared memory.
* \return The created Array * @return The created Array
*/ */
DGL_DLL static NDArray EmptyShared(const std::string &name, DGL_DLL static NDArray EmptyShared(const std::string &name,
std::vector<int64_t> shape, std::vector<int64_t> shape,
...@@ -251,33 +251,33 @@ class NDArray { ...@@ -251,33 +251,33 @@ class NDArray {
DGLContext ctx, DGLContext ctx,
bool is_create); bool is_create);
/*! /*!
* \brief Get the size of the array in the number of bytes. * @brief Get the size of the array in the number of bytes.
*/ */
size_t GetSize() const; size_t GetSize() const;
/*! /*!
* \brief Get the number of elements in this array. * @brief Get the number of elements in this array.
*/ */
int64_t NumElements() const; int64_t NumElements() const;
/*! /*!
* \brief Create a NDArray by copying from std::vector. * @brief Create a NDArray by copying from std::vector.
* \tparam T Type of vector data. Determines the dtype of returned array. * @tparam T Type of vector data. Determines the dtype of returned array.
*/ */
template<typename T> template<typename T>
DGL_DLL static NDArray FromVector( DGL_DLL static NDArray FromVector(
const std::vector<T>& vec, DGLContext ctx = DGLContext{kDGLCPU, 0}); const std::vector<T>& vec, DGLContext ctx = DGLContext{kDGLCPU, 0});
/*! /*!
* \brief Create a NDArray from a raw pointer. * @brief Create a NDArray from a raw pointer.
*/ */
DGL_DLL static NDArray CreateFromRaw(const std::vector<int64_t>& shape, DGL_DLL static NDArray CreateFromRaw(const std::vector<int64_t>& shape,
DGLDataType dtype, DGLContext ctx, void* raw, bool auto_free); DGLDataType dtype, DGLContext ctx, void* raw, bool auto_free);
/*! /*!
* \brief Create a std::vector from a 1D NDArray. * @brief Create a std::vector from a 1D NDArray.
* \tparam T Type of vector data. * @tparam T Type of vector data.
* \note Type casting is NOT performed. The caller has to make sure that the vector * @note Type casting is NOT performed. The caller has to make sure that the vector
* type matches the dtype of NDArray. * type matches the dtype of NDArray.
*/ */
template<typename T> template<typename T>
...@@ -286,10 +286,10 @@ class NDArray { ...@@ -286,10 +286,10 @@ class NDArray {
std::shared_ptr<SharedMemory> GetSharedMem() const; std::shared_ptr<SharedMemory> GetSharedMem() const;
/*! /*!
* \brief Function to copy data from one array to another. * @brief Function to copy data from one array to another.
* \param from The source array. * @param from The source array.
* \param to The target array. * @param to The target array.
* \param (optional) stream The stream used in copy. * @param (optional) stream The stream used in copy.
*/ */
DGL_DLL static void CopyFromTo( DGL_DLL static void CopyFromTo(
DGLArray* from, DGLArray* to); DGLArray* from, DGLArray* to);
...@@ -297,9 +297,9 @@ class NDArray { ...@@ -297,9 +297,9 @@ class NDArray {
DGLArray* from, DGLArray* to, DGLStreamHandle stream); DGLArray* from, DGLArray* to, DGLStreamHandle stream);
/*! /*!
* \brief Function to pin the DGLArray of a Container. * @brief Function to pin the DGLArray of a Container.
* \param ptr The container to be pinned. * @param ptr The container to be pinned.
* \note Data of the given array will be pinned inplace. * @note Data of the given array will be pinned inplace.
* Behavior depends on the current context, * Behavior depends on the current context,
* kDGLCPU: will be pinned; * kDGLCPU: will be pinned;
* IsPinned: directly return; * IsPinned: directly return;
...@@ -308,9 +308,9 @@ class NDArray { ...@@ -308,9 +308,9 @@ class NDArray {
DGL_DLL static void PinContainer(Container* ptr); DGL_DLL static void PinContainer(Container* ptr);
/*! /*!
* \brief Function to unpin the DGLArray of a Container. * @brief Function to unpin the DGLArray of a Container.
* \param ptr The container to be unpinned. * @param ptr The container to be unpinned.
* \note Data of the given array will be unpinned inplace. * @note Data of the given array will be unpinned inplace.
* Behavior depends on the current context, * Behavior depends on the current context,
* IsPinned: will be unpinned; * IsPinned: will be unpinned;
* others: directly return. * others: directly return.
...@@ -318,16 +318,16 @@ class NDArray { ...@@ -318,16 +318,16 @@ class NDArray {
DGL_DLL static void UnpinContainer(Container* ptr); DGL_DLL static void UnpinContainer(Container* ptr);
/*! /*!
* \brief Function check if the DGLArray of a Container is pinned. * @brief Function check if the DGLArray of a Container is pinned.
* \param ptr The container to be checked. * @param ptr The container to be checked.
* \return true if pinned. * @return true if pinned.
*/ */
DGL_DLL static bool IsContainerPinned(Container* ptr); DGL_DLL static bool IsContainerPinned(Container* ptr);
/*! /*!
* \brief Record streams that are using this tensor. * @brief Record streams that are using this tensor.
* \param ptr Pointer of the tensor to be recorded. * @param ptr Pointer of the tensor to be recorded.
* \param stream The stream that is using this tensor. * @param stream The stream that is using this tensor.
*/ */
DGL_DLL static void RecordStream(DGLArray* tensor, DGLStreamHandle stream); DGL_DLL static void RecordStream(DGLArray* tensor, DGLStreamHandle stream);
...@@ -344,7 +344,7 @@ class NDArray { ...@@ -344,7 +344,7 @@ class NDArray {
}; };
private: private:
/*! \brief Internal Data content */ /*! @brief Internal Data content */
Container* data_{nullptr}; Container* data_{nullptr};
// enable internal functions // enable internal functions
friend struct Internal; friend struct Internal;
...@@ -354,21 +354,21 @@ class NDArray { ...@@ -354,21 +354,21 @@ class NDArray {
}; };
/*! /*!
* \brief Save a DGLArray to stream * @brief Save a DGLArray to stream
* \param strm The outpu stream * @param strm The outpu stream
* \param tensor The tensor to be saved. * @param tensor The tensor to be saved.
*/ */
inline bool SaveDGLArray(dmlc::Stream* strm, const DGLArray* tensor); inline bool SaveDGLArray(dmlc::Stream* strm, const DGLArray* tensor);
/*! /*!
* \brief Reference counted Container object used to back NDArray. * @brief Reference counted Container object used to back NDArray.
* *
* This object is DGLArray compatible: * This object is DGLArray compatible:
* the pointer to the NDArrayContainer can be directly * the pointer to the NDArrayContainer can be directly
* interpreted as a DGLArray* * interpreted as a DGLArray*
* *
* \note: do not use this function directly, use NDArray. * @note: do not use this function directly, use NDArray.
*/ */
struct NDArray::Container { struct NDArray::Container {
public: public:
...@@ -377,28 +377,28 @@ struct NDArray::Container { ...@@ -377,28 +377,28 @@ struct NDArray::Container {
* is only called when the reference counter goes to 0 * is only called when the reference counter goes to 0
*/ */
/*! /*!
* \brief Tensor structure. * @brief Tensor structure.
* \note it is important that the first field is DGLArray * @note it is important that the first field is DGLArray
* So that this data structure is DGLArray compatible. * So that this data structure is DGLArray compatible.
* The head ptr of this struct can be viewed as DGLArray*. * The head ptr of this struct can be viewed as DGLArray*.
*/ */
DGLArray dl_tensor; DGLArray dl_tensor;
/*! /*!
* \brief addtional context, reserved for recycling * @brief addtional context, reserved for recycling
* \note We can attach additional content here * @note We can attach additional content here
* which the current container depend on * which the current container depend on
* (e.g. reference to original memory when creating views). * (e.g. reference to original memory when creating views).
*/ */
void* manager_ctx{nullptr}; void* manager_ctx{nullptr};
/*! /*!
* \brief Customized deleter * @brief Customized deleter
* *
* \note The customized deleter is helpful to enable * @note The customized deleter is helpful to enable
* different ways of memory allocator that are not * different ways of memory allocator that are not
* currently defined by the system. * currently defined by the system.
*/ */
void (*deleter)(Container* self) = nullptr; void (*deleter)(Container* self) = nullptr;
/*! \brief default constructor */ /*! @brief default constructor */
Container() { Container() {
dl_tensor.data = nullptr; dl_tensor.data = nullptr;
dl_tensor.ndim = 0; dl_tensor.ndim = 0;
...@@ -406,13 +406,13 @@ struct NDArray::Container { ...@@ -406,13 +406,13 @@ struct NDArray::Container {
dl_tensor.strides = nullptr; dl_tensor.strides = nullptr;
dl_tensor.byte_offset = 0; dl_tensor.byte_offset = 0;
} }
/*! \brief pointer to shared memory */ /*! @brief pointer to shared memory */
std::shared_ptr<SharedMemory> mem; std::shared_ptr<SharedMemory> mem;
/*! \brief developer function, increases reference counter */ /*! @brief developer function, increases reference counter */
void IncRef() { void IncRef() {
ref_counter_.fetch_add(1, std::memory_order_relaxed); ref_counter_.fetch_add(1, std::memory_order_relaxed);
} }
/*! \brief developer function, decrease reference counter */ /*! @brief developer function, decrease reference counter */
void DecRef() { void DecRef() {
if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) { if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) {
std::atomic_thread_fence(std::memory_order_acquire); std::atomic_thread_fence(std::memory_order_acquire);
...@@ -427,16 +427,16 @@ struct NDArray::Container { ...@@ -427,16 +427,16 @@ struct NDArray::Container {
friend class NDArray; friend class NDArray;
friend class RPCWrappedFunc; friend class RPCWrappedFunc;
/*! /*!
* \brief The shape container, * @brief The shape container,
* can be used for shape data. * can be used for shape data.
*/ */
std::vector<int64_t> shape_; std::vector<int64_t> shape_;
/*! /*!
* \brief The stride container, * @brief The stride container,
* can be used for stride data. * can be used for stride data.
*/ */
std::vector<int64_t> stride_; std::vector<int64_t> stride_;
/*! \brief The internal array object */ /*! @brief The internal array object */
std::atomic<int> ref_counter_{0}; std::atomic<int> ref_counter_{0};
bool pinned_by_dgl_{false}; bool pinned_by_dgl_{false};
...@@ -527,7 +527,7 @@ inline const DGLArray* NDArray::operator->() const { ...@@ -527,7 +527,7 @@ inline const DGLArray* NDArray::operator->() const {
return &(data_->dl_tensor); return &(data_->dl_tensor);
} }
/*! \brief Magic number for NDArray file */ /*! @brief Magic number for NDArray file */
constexpr uint64_t kDGLNDArrayMagic = 0xDD5E40F096B4A13F; constexpr uint64_t kDGLNDArrayMagic = 0xDD5E40F096B4A13F;
inline bool SaveDGLArray(dmlc::Stream* strm, inline bool SaveDGLArray(dmlc::Stream* strm,
...@@ -580,9 +580,9 @@ inline bool SaveDGLArray(dmlc::Stream* strm, ...@@ -580,9 +580,9 @@ inline bool SaveDGLArray(dmlc::Stream* strm,
} }
/*! /*!
* \brief Convert type code to its name * @brief Convert type code to its name
* \param type_code The type code . * @param type_code The type code .
* \return The name of type code. * @return The name of type code.
*/ */
inline const char* TypeCode2Str(int type_code) { inline const char* TypeCode2Str(int type_code) {
switch (type_code) { switch (type_code) {
...@@ -606,9 +606,9 @@ inline const char* TypeCode2Str(int type_code) { ...@@ -606,9 +606,9 @@ inline const char* TypeCode2Str(int type_code) {
} }
/*! /*!
* \brief Convert device type code to its name * @brief Convert device type code to its name
* \param device_type The device type code. * @param device_type The device type code.
* \return The name of the device. * @return The name of the device.
*/ */
inline const char* DeviceTypeCode2Str(DGLDeviceType device_type) { inline const char* DeviceTypeCode2Str(DGLDeviceType device_type) {
switch (device_type) { switch (device_type) {
...@@ -620,9 +620,9 @@ inline const char* DeviceTypeCode2Str(DGLDeviceType device_type) { ...@@ -620,9 +620,9 @@ inline const char* DeviceTypeCode2Str(DGLDeviceType device_type) {
} }
/*! /*!
* \brief convert a string to DGL type. * @brief convert a string to DGL type.
* \param s The string to be converted. * @param s The string to be converted.
* \return The corresponding dgl type. * @return The corresponding dgl type.
*/ */
inline DGLDataType String2DGLDataType(std::string s) { inline DGLDataType String2DGLDataType(std::string s) {
DGLDataType t; DGLDataType t;
...@@ -652,9 +652,9 @@ inline DGLDataType String2DGLDataType(std::string s) { ...@@ -652,9 +652,9 @@ inline DGLDataType String2DGLDataType(std::string s) {
} }
/*! /*!
* \brief convert a DGL type to string. * @brief convert a DGL type to string.
* \param t The type to be converted. * @param t The type to be converted.
* \return The corresponding dgl type in string. * @return The corresponding dgl type in string.
*/ */
inline std::string DGLDataType2String(DGLDataType t) { inline std::string DGLDataType2String(DGLDataType t) {
#ifndef _LIBCPP_SGX_NO_IOSTREAMS #ifndef _LIBCPP_SGX_NO_IOSTREAMS
...@@ -737,12 +737,12 @@ std::ostream& operator << (std::ostream& os, dgl::runtime::NDArray array); ...@@ -737,12 +737,12 @@ std::ostream& operator << (std::ostream& os, dgl::runtime::NDArray array);
///////////////// Operator overloading for DGLDataType ///////////////// ///////////////// Operator overloading for DGLDataType /////////////////
/*! \brief Check whether two data types are the same.*/ /*! @brief Check whether two data types are the same.*/
inline bool operator == (const DGLDataType& ty1, const DGLDataType& ty2) { inline bool operator == (const DGLDataType& ty1, const DGLDataType& ty2) {
return ty1.code == ty2.code && ty1.bits == ty2.bits && ty1.lanes == ty2.lanes; return ty1.code == ty2.code && ty1.bits == ty2.bits && ty1.lanes == ty2.lanes;
} }
/*! \brief Check whether two data types are different.*/ /*! @brief Check whether two data types are different.*/
inline bool operator != (const DGLDataType& ty1, const DGLDataType& ty2) { inline bool operator != (const DGLDataType& ty1, const DGLDataType& ty2) {
return !(ty1 == ty2); return !(ty1 == ty2);
} }
...@@ -761,12 +761,12 @@ inline std::ostream& operator << (std::ostream& os, DGLDataType t) { ...@@ -761,12 +761,12 @@ inline std::ostream& operator << (std::ostream& os, DGLDataType t) {
///////////////// Operator overloading for DGLContext ///////////////// ///////////////// Operator overloading for DGLContext /////////////////
/*! \brief Check whether two device contexts are the same.*/ /*! @brief Check whether two device contexts are the same.*/
inline bool operator == (const DGLContext& ctx1, const DGLContext& ctx2) { inline bool operator == (const DGLContext& ctx1, const DGLContext& ctx2) {
return ctx1.device_type == ctx2.device_type && ctx1.device_id == ctx2.device_id; return ctx1.device_type == ctx2.device_type && ctx1.device_id == ctx2.device_id;
} }
/*! \brief Check whether two device contexts are different.*/ /*! @brief Check whether two device contexts are different.*/
inline bool operator != (const DGLContext& ctx1, const DGLContext& ctx2) { inline bool operator != (const DGLContext& ctx1, const DGLContext& ctx2) {
return !(ctx1 == ctx2); return !(ctx1 == ctx2);
} }
......
...@@ -56,7 +56,7 @@ class Object { ...@@ -56,7 +56,7 @@ class Object {
public: public:
/*! @brief virtual destructor */ /*! @brief virtual destructor */
virtual ~Object() {} virtual ~Object() {}
/*! \return The unique type key of the object */ /*! @return The unique type key of the object */
virtual const char* type_key() const = 0; virtual const char* type_key() const = 0;
/*! /*!
* @brief Apply visitor to each field of the Object * @brief Apply visitor to each field of the Object
...@@ -65,7 +65,7 @@ class Object { ...@@ -65,7 +65,7 @@ class Object {
* @param visitor The visitor * @param visitor The visitor
*/ */
virtual void VisitAttrs(AttrVisitor* visitor) {} virtual void VisitAttrs(AttrVisitor* visitor) {}
/*! \return the type index of the object */ /*! @return the type index of the object */
virtual uint32_t type_index() const = 0; virtual uint32_t type_index() const = 0;
/*! /*!
* @brief Whether this object derives from object with type_index=tid. * @brief Whether this object derives from object with type_index=tid.
...@@ -116,7 +116,7 @@ class ObjectRef { ...@@ -116,7 +116,7 @@ class ObjectRef {
* *
* @param other Another object ref. * @param other Another object ref.
* @return the compare result. * @return the compare result.
* \sa same_as * @sa same_as
*/ */
inline bool operator==(const ObjectRef& other) const; inline bool operator==(const ObjectRef& other) const;
/*! /*!
...@@ -142,18 +142,18 @@ class ObjectRef { ...@@ -142,18 +142,18 @@ class ObjectRef {
* @brief Comparator * @brief Comparator
* @param other Another object ref. * @param other Another object ref.
* @return the compare result. * @return the compare result.
* \sa same_as * @sa same_as
*/ */
inline bool operator!=(const ObjectRef& other) const; inline bool operator!=(const ObjectRef& other) const;
/*! \return the hash function for ObjectRef */ /*! @return the hash function for ObjectRef */
inline size_t hash() const; inline size_t hash() const;
/*! \return whether the expression is null */ /*! @return whether the expression is null */
inline bool defined() const; inline bool defined() const;
/*! \return the internal type index of Object */ /*! @return the internal type index of Object */
inline uint32_t type_index() const; inline uint32_t type_index() const;
/*! \return the internal object pointer */ /*! @return the internal object pointer */
inline const Object* get() const; inline const Object* get() const;
/*! \return the internal object pointer */ /*! @return the internal object pointer */
inline const Object* operator->() const; inline const Object* operator->() const;
/*! /*!
* @brief Downcast this object to its actual type. * @brief Downcast this object to its actual type.
......
...@@ -56,7 +56,7 @@ class PackedFunc { ...@@ -56,7 +56,7 @@ class PackedFunc {
* @param args The arguments to the function. * @param args The arguments to the function.
* @param rv The return value. * @param rv The return value.
* *
* \code * @code
* // Example code on how to implemented FType * // Example code on how to implemented FType
* void MyPackedFunc(DGLArgs args, DGLRetValue* rv) { * void MyPackedFunc(DGLArgs args, DGLRetValue* rv) {
* // automatically convert arguments to desired type. * // automatically convert arguments to desired type.
...@@ -67,7 +67,7 @@ class PackedFunc { ...@@ -67,7 +67,7 @@ class PackedFunc {
* std::string my_return_value = "x"; * std::string my_return_value = "x";
* *rv = my_return_value; * *rv = my_return_value;
* } * }
* \endcode * @endcode
*/ */
using FType = std::function<void(DGLArgs args, DGLRetValue* rv)>; using FType = std::function<void(DGLArgs args, DGLRetValue* rv)>;
/*! @brief default constructor */ /*! @brief default constructor */
...@@ -82,14 +82,14 @@ class PackedFunc { ...@@ -82,14 +82,14 @@ class PackedFunc {
* @param args Arguments to be passed. * @param args Arguments to be passed.
* @tparam Args arguments to be passed. * @tparam Args arguments to be passed.
* *
* \code * @code
* // Example code on how to call packed function * // Example code on how to call packed function
* void CallPacked(PackedFunc f) { * void CallPacked(PackedFunc f) {
* // call like normal functions by pass in arguments * // call like normal functions by pass in arguments
* // return value is automatically converted back * // return value is automatically converted back
* int rvalue = f(1, 2.0); * int rvalue = f(1, 2.0);
* } * }
* \endcode * @endcode
*/ */
template <typename... Args> template <typename... Args>
inline DGLRetValue operator()(Args&&... args) const; inline DGLRetValue operator()(Args&&... args) const;
...@@ -99,11 +99,11 @@ class PackedFunc { ...@@ -99,11 +99,11 @@ class PackedFunc {
* @param rv The return value. * @param rv The return value.
*/ */
inline void CallPacked(DGLArgs args, DGLRetValue* rv) const; inline void CallPacked(DGLArgs args, DGLRetValue* rv) const;
/*! \return the internal body function */ /*! @return the internal body function */
inline FType body() const; inline FType body() const;
/*! \return Whether the packed function is nullptr */ /*! @return Whether the packed function is nullptr */
bool operator==(std::nullptr_t null) const { return body_ == nullptr; } bool operator==(std::nullptr_t null) const { return body_ == nullptr; }
/*! \return Whether the packed function is not nullptr */ /*! @return Whether the packed function is not nullptr */
bool operator!=(std::nullptr_t null) const { return body_ != nullptr; } bool operator!=(std::nullptr_t null) const { return body_ != nullptr; }
private: private:
...@@ -119,7 +119,7 @@ template <typename FType> ...@@ -119,7 +119,7 @@ template <typename FType>
class TypedPackedFunc; class TypedPackedFunc;
/*! /*!
* \anchor TypedPackedFuncAnchor * @anchor TypedPackedFuncAnchor
* @brief A PackedFunc wrapper to provide typed function signature. * @brief A PackedFunc wrapper to provide typed function signature.
* It is backed by a PackedFunc internally. * It is backed by a PackedFunc internally.
* *
...@@ -134,7 +134,7 @@ class TypedPackedFunc; ...@@ -134,7 +134,7 @@ class TypedPackedFunc;
* We can construct a TypedPackedFunc from a lambda function * We can construct a TypedPackedFunc from a lambda function
* with the same signature. * with the same signature.
* *
* \code * @code
* // user defined lambda function. * // user defined lambda function.
* auto addone = [](int x)->int { * auto addone = [](int x)->int {
* return x + 1; * return x + 1;
...@@ -146,7 +146,7 @@ class TypedPackedFunc; ...@@ -146,7 +146,7 @@ class TypedPackedFunc;
* int y = ftyped(1); * int y = ftyped(1);
* // Can be directly converted to PackedFunc * // Can be directly converted to PackedFunc
* PackedFunc packed = ftype; * PackedFunc packed = ftype;
* \endcode * @endcode
* @tparam R The return value of the function. * @tparam R The return value of the function.
* @tparam Args The argument signature of the function. * @tparam Args The argument signature of the function.
*/ */
...@@ -161,7 +161,7 @@ class TypedPackedFunc<R(Args...)> { ...@@ -161,7 +161,7 @@ class TypedPackedFunc<R(Args...)> {
* @brief construct by wrap a PackedFunc * @brief construct by wrap a PackedFunc
* *
* Example usage: * Example usage:
* \code * @code
* PackedFunc packed([](DGLArgs args, DGLRetValue *rv) { * PackedFunc packed([](DGLArgs args, DGLRetValue *rv) {
* int x = args[0]; * int x = args[0];
* *rv = x + 1; * *rv = x + 1;
...@@ -170,7 +170,7 @@ class TypedPackedFunc<R(Args...)> { ...@@ -170,7 +170,7 @@ class TypedPackedFunc<R(Args...)> {
* TypedPackedFunc<int(int)> ftyped(packed); * TypedPackedFunc<int(int)> ftyped(packed);
* // call the typed version. * // call the typed version.
* CHECK_EQ(ftyped(1), 2); * CHECK_EQ(ftyped(1), 2);
* \endcode * @endcode
* *
* @param packed The packed function * @param packed The packed function
*/ */
...@@ -179,13 +179,13 @@ class TypedPackedFunc<R(Args...)> { ...@@ -179,13 +179,13 @@ class TypedPackedFunc<R(Args...)> {
* @brief construct from a lambda function with the same signature. * @brief construct from a lambda function with the same signature.
* *
* Example usage: * Example usage:
* \code * @code
* auto typed_lambda = [](int x)->int { return x + 1; } * auto typed_lambda = [](int x)->int { return x + 1; }
* // construct from packed function * // construct from packed function
* TypedPackedFunc<int(int)> ftyped(typed_lambda); * TypedPackedFunc<int(int)> ftyped(typed_lambda);
* // call the typed version. * // call the typed version.
* CHECK_EQ(ftyped(1), 2); * CHECK_EQ(ftyped(1), 2);
* \endcode * @endcode
* *
* @param typed_lambda typed lambda function. * @param typed_lambda typed lambda function.
* @tparam FLambda the type of the lambda function. * @tparam FLambda the type of the lambda function.
...@@ -200,13 +200,13 @@ class TypedPackedFunc<R(Args...)> { ...@@ -200,13 +200,13 @@ class TypedPackedFunc<R(Args...)> {
* @brief copy assignment operator from typed lambda * @brief copy assignment operator from typed lambda
* *
* Example usage: * Example usage:
* \code * @code
* // construct from packed function * // construct from packed function
* TypedPackedFunc<int(int)> ftyped; * TypedPackedFunc<int(int)> ftyped;
* ftyped = [](int x) { return x + 1; } * ftyped = [](int x) { return x + 1; }
* // call the typed version. * // call the typed version.
* CHECK_EQ(ftyped(1), 2); * CHECK_EQ(ftyped(1), 2);
* \endcode * @endcode
* *
* @param typed_lambda typed lambda function. * @param typed_lambda typed lambda function.
* @tparam FLambda the type of the lambda function. * @tparam FLambda the type of the lambda function.
...@@ -274,7 +274,7 @@ class DGLArgs { ...@@ -274,7 +274,7 @@ class DGLArgs {
*/ */
DGLArgs(const DGLValue* values, const int* type_codes, int num_args) DGLArgs(const DGLValue* values, const int* type_codes, int num_args)
: values(values), type_codes(type_codes), num_args(num_args) {} : values(values), type_codes(type_codes), num_args(num_args) {}
/*! \return size of the arguments */ /*! @return size of the arguments */
inline int size() const; inline int size() const;
/*! /*!
* @brief Get i-th argument * @brief Get i-th argument
...@@ -668,7 +668,7 @@ class DGLRetValue : public DGLPODValue_ { ...@@ -668,7 +668,7 @@ class DGLRetValue : public DGLPODValue_ {
*ret_type_code = type_code_; *ret_type_code = type_code_;
type_code_ = kNull; type_code_ = kNull;
} }
/*! \return The value field, if the data is POD */ /*! @return The value field, if the data is POD */
const DGLValue& value() const { const DGLValue& value() const {
CHECK( CHECK(
type_code_ != kObjectHandle && type_code_ != kFuncHandle && type_code_ != kObjectHandle && type_code_ != kFuncHandle &&
......
/*! /*!
* Copyright (c) 2021 by Contributors * Copyright (c) 2021 by Contributors
* \file runtime/container.h * @file runtime/container.h
* \brief Defines the container object data structures. * @brief Defines the container object data structures.
*/ */
#ifndef DGL_RUNTIME_PARALLEL_FOR_H_ #ifndef DGL_RUNTIME_PARALLEL_FOR_H_
#define DGL_RUNTIME_PARALLEL_FOR_H_ #define DGL_RUNTIME_PARALLEL_FOR_H_
...@@ -57,7 +57,7 @@ inline size_t compute_num_threads(size_t begin, size_t end, size_t grain_size) { ...@@ -57,7 +57,7 @@ inline size_t compute_num_threads(size_t begin, size_t end, size_t grain_size) {
static DefaultGrainSizeT default_grain_size; static DefaultGrainSizeT default_grain_size;
/*! /*!
* \brief OpenMP-based parallel for loop. * @brief OpenMP-based parallel for loop.
* *
* It requires each thread's workload to have at least \a grain_size elements. * It requires each thread's workload to have at least \a grain_size elements.
* The loop body will be a function that takes in two arguments \a begin and \a end, which * The loop body will be a function that takes in two arguments \a begin and \a end, which
...@@ -102,7 +102,7 @@ void parallel_for( ...@@ -102,7 +102,7 @@ void parallel_for(
} }
/*! /*!
* \brief OpenMP-based parallel for loop with default grain size. * @brief OpenMP-based parallel for loop with default grain size.
* *
* parallel_for with grain size to default value, either 1 or controlled through * parallel_for with grain size to default value, either 1 or controlled through
* environment variable DGL_PARALLEL_FOR_GRAIN_SIZE. * environment variable DGL_PARALLEL_FOR_GRAIN_SIZE.
...@@ -118,7 +118,7 @@ void parallel_for( ...@@ -118,7 +118,7 @@ void parallel_for(
} }
/*! /*!
* \brief OpenMP-based two-stage parallel reduction. * @brief OpenMP-based two-stage parallel reduction.
* *
* The first-stage reduction function \a f works in parallel. Each thread's workload has * The first-stage reduction function \a f works in parallel. Each thread's workload has
* at least \a grain_size elements. The loop body will be a function that takes in * at least \a grain_size elements. The loop body will be a function that takes in
......
...@@ -14,13 +14,13 @@ ...@@ -14,13 +14,13 @@
* then into the same global registry in C++. * then into the same global registry in C++.
* The goal is to mix the front-end language and the DGL back-end. * The goal is to mix the front-end language and the DGL back-end.
* *
* \code * @code
* // register the function as MyAPIFuncName * // register the function as MyAPIFuncName
* DGL_REGISTER_GLOBAL(MyAPIFuncName) * DGL_REGISTER_GLOBAL(MyAPIFuncName)
* .set_body([](DGLArgs args, DGLRetValue* rv) { * .set_body([](DGLArgs args, DGLRetValue* rv) {
* // my code. * // my code.
* }); * });
* \endcode * @endcode
*/ */
#ifndef DGL_RUNTIME_REGISTRY_H_ #ifndef DGL_RUNTIME_REGISTRY_H_
#define DGL_RUNTIME_REGISTRY_H_ #define DGL_RUNTIME_REGISTRY_H_
...@@ -51,12 +51,12 @@ class Registry { ...@@ -51,12 +51,12 @@ class Registry {
/*! /*!
* @brief set the body of the function to be TypedPackedFunc. * @brief set the body of the function to be TypedPackedFunc.
* *
* \code * @code
* *
* DGL_REGISTER_API("addone") * DGL_REGISTER_API("addone")
* .set_body_typed<int(int)>([](int x) { return x + 1; }); * .set_body_typed<int(int)>([](int x) { return x + 1; });
* *
* \endcode * @endcode
* *
* @param f The body of the function. * @param f The body of the function.
* @tparam FType the signature of the function. * @tparam FType the signature of the function.
...@@ -122,11 +122,11 @@ class Registry { ...@@ -122,11 +122,11 @@ class Registry {
/*! /*!
* @brief Register a function globally. * @brief Register a function globally.
* \code * @code
* DGL_REGISTER_GLOBAL("MyPrint") * DGL_REGISTER_GLOBAL("MyPrint")
* .set_body([](DGLArgs args, DGLRetValue* rv) { * .set_body([](DGLArgs args, DGLRetValue* rv) {
* }); * });
* \endcode * @endcode
*/ */
#define DGL_REGISTER_GLOBAL(OpName) \ #define DGL_REGISTER_GLOBAL(OpName) \
DGL_STR_CONCAT(DGL_FUNC_REG_VAR_DEF, __COUNTER__) = \ DGL_STR_CONCAT(DGL_FUNC_REG_VAR_DEF, __COUNTER__) = \
......
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file dgl/runtime/serializer.h * @file dgl/runtime/serializer.h
* \brief Serializer extension to support DGL data types * @brief Serializer extension to support DGL data types
* Include this file to enable serialization of DGLDataType, DGLContext * Include this file to enable serialization of DGLDataType, DGLContext
*/ */
#ifndef DGL_RUNTIME_SMART_PTR_SERIALIZER_H_ #ifndef DGL_RUNTIME_SMART_PTR_SERIALIZER_H_
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment