Unverified Commit bcd37684 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Replace /*! with /**. (#4823)



* replace

* blabla

* balbla

* blabla
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 619d735d
/*! /**
* Copyright (c) 2020 by Contributors * Copyright (c) 2020 by Contributors
* @file featgraph/include/featgraph.h * @file featgraph/include/featgraph.h
* @brief FeatGraph kernel headers. * @brief FeatGraph kernel headers.
......
/*! /**
* Copyright (c) 2020 by Contributors * Copyright (c) 2020 by Contributors
* @file featgraph/src/featgraph.cc * @file featgraph/src/featgraph.cc
* @brief FeatGraph kernels. * @brief FeatGraph kernels.
......
/* /**
* NOTE(zihao): this file was modified from TVM project: * NOTE(zihao): this file was modified from TVM project:
* - * -
* https://github.com/apache/tvm/blob/9713d675c64ae3075e10be5acadeef1328a44bb5/apps/howto_deploy/tvm_runtime_pack.cc * https://github.com/apache/tvm/blob/9713d675c64ae3075e10be5acadeef1328a44bb5/apps/howto_deploy/tvm_runtime_pack.cc
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
* under the License. * under the License.
*/ */
/*! /**
* @brief This is an all in one TVM runtime file. * @brief This is an all in one TVM runtime file.
* *
* You only have to use this file to compile libtvm_runtime to * You only have to use this file to compile libtvm_runtime to
......
/*! /**
* Copyright (c) 2020 by Contributors * Copyright (c) 2020 by Contributors
* @file dgl/array.h * @file dgl/array.h
* @brief Common array operations required by DGL. * @brief Common array operations required by DGL.
......
/*! /**
* Copyright (c) 2020 by Contributors * Copyright (c) 2020 by Contributors
* @file dgl/array_iterator.h * @file dgl/array_iterator.h
* @brief Various iterators. * @brief Various iterators.
......
/*! /**
* Copyright (c) 2020 by Contributors * Copyright (c) 2020 by Contributors
* @file dgl/aten/array_ops.h * @file dgl/aten/array_ops.h
* @brief Common array operations required by DGL. * @brief Common array operations required by DGL.
...@@ -23,20 +23,20 @@ namespace aten { ...@@ -23,20 +23,20 @@ namespace aten {
// ID array // ID array
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
/*! @return A special array to represent null. */ /** @return A special array to represent null. */
inline NDArray NullArray(const DGLDataType& dtype = DGLDataType{kDGLInt, 64, 1}, inline NDArray NullArray(const DGLDataType& dtype = DGLDataType{kDGLInt, 64, 1},
const DGLContext& ctx = DGLContext{kDGLCPU, 0}) { const DGLContext& ctx = DGLContext{kDGLCPU, 0}) {
return NDArray::Empty({0}, dtype, ctx); return NDArray::Empty({0}, dtype, ctx);
} }
/*! /**
* @return Whether the input array is a null array. * @return Whether the input array is a null array.
*/ */
inline bool IsNullArray(NDArray array) { inline bool IsNullArray(NDArray array) {
return array->shape[0] == 0; return array->shape[0] == 0;
} }
/*! /**
* @brief Create a new id array with given length * @brief Create a new id array with given length
* @param length The array length * @param length The array length
* @param ctx The array context * @param ctx The array context
...@@ -47,7 +47,7 @@ IdArray NewIdArray(int64_t length, ...@@ -47,7 +47,7 @@ IdArray NewIdArray(int64_t length,
DGLContext ctx = DGLContext{kDGLCPU, 0}, DGLContext ctx = DGLContext{kDGLCPU, 0},
uint8_t nbits = 64); uint8_t nbits = 64);
/*! /**
* @brief Create a new id array using the given vector data * @brief Create a new id array using the given vector data
* @param vec The vector data * @param vec The vector data
* @param nbits The integer bits of the returned array * @param nbits The integer bits of the returned array
...@@ -59,7 +59,7 @@ IdArray VecToIdArray(const std::vector<T>& vec, ...@@ -59,7 +59,7 @@ IdArray VecToIdArray(const std::vector<T>& vec,
uint8_t nbits = 64, uint8_t nbits = 64,
DGLContext ctx = DGLContext{kDGLCPU, 0}); DGLContext ctx = DGLContext{kDGLCPU, 0});
/*! /**
* @brief Return an array representing a 1D range. * @brief Return an array representing a 1D range.
* @param low Lower bound (inclusive). * @param low Lower bound (inclusive).
* @param high Higher bound (exclusive). * @param high Higher bound (exclusive).
...@@ -69,7 +69,7 @@ IdArray VecToIdArray(const std::vector<T>& vec, ...@@ -69,7 +69,7 @@ IdArray VecToIdArray(const std::vector<T>& vec,
*/ */
IdArray Range(int64_t low, int64_t high, uint8_t nbits, DGLContext ctx); IdArray Range(int64_t low, int64_t high, uint8_t nbits, DGLContext ctx);
/*! /**
* @brief Return an array full of the given value * @brief Return an array full of the given value
* @param val The value to fill. * @param val The value to fill.
* @param length Number of elements. * @param length Number of elements.
...@@ -79,7 +79,7 @@ IdArray Range(int64_t low, int64_t high, uint8_t nbits, DGLContext ctx); ...@@ -79,7 +79,7 @@ IdArray Range(int64_t low, int64_t high, uint8_t nbits, DGLContext ctx);
*/ */
IdArray Full(int64_t val, int64_t length, uint8_t nbits, DGLContext ctx); IdArray Full(int64_t val, int64_t length, uint8_t nbits, DGLContext ctx);
/*! /**
* @brief Return an array full of the given value with the given type. * @brief Return an array full of the given value with the given type.
* @param val The value to fill. * @param val The value to fill.
* @param length Number of elements. * @param length Number of elements.
...@@ -89,13 +89,13 @@ IdArray Full(int64_t val, int64_t length, uint8_t nbits, DGLContext ctx); ...@@ -89,13 +89,13 @@ IdArray Full(int64_t val, int64_t length, uint8_t nbits, DGLContext ctx);
template <typename DType> template <typename DType>
NDArray Full(DType val, int64_t length, DGLContext ctx); NDArray Full(DType val, int64_t length, DGLContext ctx);
/*! @brief Create a deep copy of the given array */ /** @brief Create a deep copy of the given array */
IdArray Clone(IdArray arr); IdArray Clone(IdArray arr);
/*! @brief Convert the idarray to the given bit width */ /** @brief Convert the idarray to the given bit width */
IdArray AsNumBits(IdArray arr, uint8_t bits); IdArray AsNumBits(IdArray arr, uint8_t bits);
/*! @brief Arithmetic functions */ /** @brief Arithmetic functions */
IdArray Add(IdArray lhs, IdArray rhs); IdArray Add(IdArray lhs, IdArray rhs);
IdArray Sub(IdArray lhs, IdArray rhs); IdArray Sub(IdArray lhs, IdArray rhs);
IdArray Mul(IdArray lhs, IdArray rhs); IdArray Mul(IdArray lhs, IdArray rhs);
...@@ -138,30 +138,30 @@ IdArray LE(int64_t lhs, IdArray rhs); ...@@ -138,30 +138,30 @@ IdArray LE(int64_t lhs, IdArray rhs);
IdArray EQ(int64_t lhs, IdArray rhs); IdArray EQ(int64_t lhs, IdArray rhs);
IdArray NE(int64_t lhs, IdArray rhs); IdArray NE(int64_t lhs, IdArray rhs);
/*! @brief Stack two arrays (of len L) into a 2*L length array */ /** @brief Stack two arrays (of len L) into a 2*L length array */
IdArray HStack(IdArray arr1, IdArray arr2); IdArray HStack(IdArray arr1, IdArray arr2);
/*! @brief Return the indices of the elements that are non-zero. */ /** @brief Return the indices of the elements that are non-zero. */
IdArray NonZero(BoolArray bool_arr); IdArray NonZero(BoolArray bool_arr);
/*! /**
* @brief Return the data under the index. In numpy notation, A[I] * @brief Return the data under the index. In numpy notation, A[I]
* @tparam ValueType The type of return value. * @tparam ValueType The type of return value.
*/ */
template<typename ValueType> template<typename ValueType>
ValueType IndexSelect(NDArray array, int64_t index); ValueType IndexSelect(NDArray array, int64_t index);
/*! /**
* @brief Return the data under the index. In numpy notation, A[I] * @brief Return the data under the index. In numpy notation, A[I]
*/ */
NDArray IndexSelect(NDArray array, IdArray index); NDArray IndexSelect(NDArray array, IdArray index);
/*! /**
* @brief Return the data from `start` (inclusive) to `end` (exclusive). * @brief Return the data from `start` (inclusive) to `end` (exclusive).
*/ */
NDArray IndexSelect(NDArray array, int64_t start, int64_t end); NDArray IndexSelect(NDArray array, int64_t start, int64_t end);
/*! /**
* @brief Permute the elements of an array according to given indices. * @brief Permute the elements of an array according to given indices.
* *
* Only support 1D arrays. * Only support 1D arrays.
...@@ -175,7 +175,7 @@ NDArray IndexSelect(NDArray array, int64_t start, int64_t end); ...@@ -175,7 +175,7 @@ NDArray IndexSelect(NDArray array, int64_t start, int64_t end);
*/ */
NDArray Scatter(NDArray array, IdArray indices); NDArray Scatter(NDArray array, IdArray indices);
/*! /**
* @brief Scatter data into the output array. * @brief Scatter data into the output array.
* *
* Equivalent to: * Equivalent to:
...@@ -186,7 +186,7 @@ NDArray Scatter(NDArray array, IdArray indices); ...@@ -186,7 +186,7 @@ NDArray Scatter(NDArray array, IdArray indices);
*/ */
void Scatter_(IdArray index, NDArray value, NDArray out); void Scatter_(IdArray index, NDArray value, NDArray out);
/*! /**
* @brief Repeat each element a number of times. Equivalent to np.repeat(array, repeats) * @brief Repeat each element a number of times. Equivalent to np.repeat(array, repeats)
* @param array A 1D vector * @param array A 1D vector
* @param repeats A 1D integer vector for number of times to repeat for each element in * @param repeats A 1D integer vector for number of times to repeat for each element in
...@@ -194,7 +194,7 @@ void Scatter_(IdArray index, NDArray value, NDArray out); ...@@ -194,7 +194,7 @@ void Scatter_(IdArray index, NDArray value, NDArray out);
*/ */
NDArray Repeat(NDArray array, IdArray repeats); NDArray Repeat(NDArray array, IdArray repeats);
/*! /**
* @brief Relabel the given ids to consecutive ids. * @brief Relabel the given ids to consecutive ids.
* *
* Relabeling is done inplace. The mapping is created from the union * Relabeling is done inplace. The mapping is created from the union
...@@ -211,7 +211,7 @@ NDArray Repeat(NDArray array, IdArray repeats); ...@@ -211,7 +211,7 @@ NDArray Repeat(NDArray array, IdArray repeats);
*/ */
IdArray Relabel_(const std::vector<IdArray>& arrays); IdArray Relabel_(const std::vector<IdArray>& arrays);
/*! /**
* @brief concatenate the given id arrays to one array * @brief concatenate the given id arrays to one array
* *
* Example: * Example:
...@@ -224,12 +224,12 @@ IdArray Relabel_(const std::vector<IdArray>& arrays); ...@@ -224,12 +224,12 @@ IdArray Relabel_(const std::vector<IdArray>& arrays);
*/ */
NDArray Concat(const std::vector<IdArray>& arrays); NDArray Concat(const std::vector<IdArray>& arrays);
/*!\brief Return whether the array is a valid 1D int array*/ /** @brief Return whether the array is a valid 1D int array*/
inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) { inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) {
return arr->ndim == 1 && arr->dtype.code == kDGLInt; return arr->ndim == 1 && arr->dtype.code == kDGLInt;
} }
/*! /**
* @brief Packs a tensor containing padded sequences of variable length. * @brief Packs a tensor containing padded sequences of variable length.
* *
* Similar to \c pack_padded_sequence in PyTorch, except that * Similar to \c pack_padded_sequence in PyTorch, except that
...@@ -261,7 +261,7 @@ inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) { ...@@ -261,7 +261,7 @@ inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) {
template<typename ValueType> template<typename ValueType>
std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, ValueType pad_value); std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, ValueType pad_value);
/*! /**
* @brief Batch-slice a 1D or 2D array, and then pack the list of sliced arrays * @brief Batch-slice a 1D or 2D array, and then pack the list of sliced arrays
* by concatenation. * by concatenation.
* *
...@@ -291,7 +291,7 @@ std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, ValueType pad_value); ...@@ -291,7 +291,7 @@ std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, ValueType pad_value);
*/ */
std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths); std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths);
/*! /**
* @brief Return the cumulative summation (or inclusive sum) of the input array. * @brief Return the cumulative summation (or inclusive sum) of the input array.
* *
* The first element out[0] is equal to the first element of the input array * The first element out[0] is equal to the first element of the input array
...@@ -307,7 +307,7 @@ std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths); ...@@ -307,7 +307,7 @@ std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths);
*/ */
IdArray CumSum(IdArray array, bool prepend_zero = false); IdArray CumSum(IdArray array, bool prepend_zero = false);
/*! /**
* @brief Return the nonzero index. * @brief Return the nonzero index.
* *
* Only support 1D array. The result index array is in int64. * Only support 1D array. The result index array is in int64.
...@@ -317,7 +317,7 @@ IdArray CumSum(IdArray array, bool prepend_zero = false); ...@@ -317,7 +317,7 @@ IdArray CumSum(IdArray array, bool prepend_zero = false);
*/ */
IdArray NonZero(NDArray array); IdArray NonZero(NDArray array);
/*! /**
* @brief Sort the ID vector in ascending order. * @brief Sort the ID vector in ascending order.
* *
* It performs both sort and arg_sort (returning the sorted index). The sorted index * It performs both sort and arg_sort (returning the sorted index). The sorted index
...@@ -334,7 +334,7 @@ IdArray NonZero(NDArray array); ...@@ -334,7 +334,7 @@ IdArray NonZero(NDArray array);
*/ */
std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits = 0); std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits = 0);
/*! /**
* @brief Return a string that prints out some debug information. * @brief Return a string that prints out some debug information.
*/ */
std::string ToDebugString(NDArray array); std::string ToDebugString(NDArray array);
...@@ -355,7 +355,7 @@ IdArray VecToIdArray(const std::vector<T>& vec, ...@@ -355,7 +355,7 @@ IdArray VecToIdArray(const std::vector<T>& vec,
return ret.CopyTo(ctx); return ret.CopyTo(ctx);
} }
/*! /**
* @brief Get the context of the first array, and check if the non-null arrays' * @brief Get the context of the first array, and check if the non-null arrays'
* contexts are the same. * contexts are the same.
*/ */
......
/*! /**
* Copyright (c) 2020 by Contributors * Copyright (c) 2020 by Contributors
* @file dgl/aten/coo.h * @file dgl/aten/coo.h
* @brief Common COO operations required by DGL. * @brief Common COO operations required by DGL.
...@@ -23,7 +23,7 @@ namespace aten { ...@@ -23,7 +23,7 @@ namespace aten {
struct CSRMatrix; struct CSRMatrix;
/*! /**
* @brief Plain COO structure * @brief Plain COO structure
* *
* The data array stores integer ids for reading edge features. * The data array stores integer ids for reading edge features.
...@@ -37,21 +37,21 @@ constexpr uint64_t kDGLSerialize_AtenCooMatrixMagic = 0xDD61ffd305dff127; ...@@ -37,21 +37,21 @@ constexpr uint64_t kDGLSerialize_AtenCooMatrixMagic = 0xDD61ffd305dff127;
// TODO(BarclayII): Graph queries on COO formats should support the case where // TODO(BarclayII): Graph queries on COO formats should support the case where
// data ordered by rows/columns instead of EID. // data ordered by rows/columns instead of EID.
struct COOMatrix { struct COOMatrix {
/*! @brief the dense shape of the matrix */ /** @brief the dense shape of the matrix */
int64_t num_rows = 0, num_cols = 0; int64_t num_rows = 0, num_cols = 0;
/*! @brief COO index arrays */ /** @brief COO index arrays */
IdArray row, col; IdArray row, col;
/*! @brief data index array. When is null, assume it is from 0 to NNZ - 1. */ /** @brief data index array. When is null, assume it is from 0 to NNZ - 1. */
IdArray data; IdArray data;
/*! @brief whether the row indices are sorted */ /** @brief whether the row indices are sorted */
bool row_sorted = false; bool row_sorted = false;
/*! @brief whether the column indices per row are sorted */ /** @brief whether the column indices per row are sorted */
bool col_sorted = false; bool col_sorted = false;
/*! @brief whether the matrix is in pinned memory */ /** @brief whether the matrix is in pinned memory */
bool is_pinned = false; bool is_pinned = false;
/*! @brief default constructor */ /** @brief default constructor */
COOMatrix() = default; COOMatrix() = default;
/*! @brief constructor */ /** @brief constructor */
COOMatrix(int64_t nrows, int64_t ncols, IdArray rarr, IdArray carr, COOMatrix(int64_t nrows, int64_t ncols, IdArray rarr, IdArray carr,
IdArray darr = NullArray(), bool rsorted = false, IdArray darr = NullArray(), bool rsorted = false,
bool csorted = false) bool csorted = false)
...@@ -65,7 +65,7 @@ struct COOMatrix { ...@@ -65,7 +65,7 @@ struct COOMatrix {
CheckValidity(); CheckValidity();
} }
/*! @brief constructor from SparseMatrix object */ /** @brief constructor from SparseMatrix object */
explicit COOMatrix(const SparseMatrix& spmat) explicit COOMatrix(const SparseMatrix& spmat)
: num_rows(spmat.num_rows), : num_rows(spmat.num_rows),
num_cols(spmat.num_cols), num_cols(spmat.num_cols),
...@@ -121,7 +121,7 @@ struct COOMatrix { ...@@ -121,7 +121,7 @@ struct COOMatrix {
CHECK_NO_OVERFLOW(row->dtype, num_cols); CHECK_NO_OVERFLOW(row->dtype, num_cols);
} }
/*! @brief Return a copy of this matrix on the give device context. */ /** @brief Return a copy of this matrix on the give device context. */
inline COOMatrix CopyTo(const DGLContext &ctx) const { inline COOMatrix CopyTo(const DGLContext &ctx) const {
if (ctx == row->ctx) if (ctx == row->ctx)
return *this; return *this;
...@@ -130,7 +130,7 @@ struct COOMatrix { ...@@ -130,7 +130,7 @@ struct COOMatrix {
row_sorted, col_sorted); row_sorted, col_sorted);
} }
/*! /**
* @brief Pin the row, col and data (if not Null) of the matrix. * @brief Pin the row, col and data (if not Null) of the matrix.
* @note This is an in-place method. Behavior depends on the current context, * @note This is an in-place method. Behavior depends on the current context,
* kDGLCPU: will be pinned; * kDGLCPU: will be pinned;
...@@ -149,7 +149,7 @@ struct COOMatrix { ...@@ -149,7 +149,7 @@ struct COOMatrix {
is_pinned = true; is_pinned = true;
} }
/*! /**
* @brief Unpin the row, col and data (if not Null) of the matrix. * @brief Unpin the row, col and data (if not Null) of the matrix.
* @note This is an in-place method. Behavior depends on the current context, * @note This is an in-place method. Behavior depends on the current context,
* IsPinned: will be unpinned; * IsPinned: will be unpinned;
...@@ -167,7 +167,7 @@ struct COOMatrix { ...@@ -167,7 +167,7 @@ struct COOMatrix {
is_pinned = false; is_pinned = false;
} }
/*! /**
* @brief Record stream for the row, col and data (if not Null) of the matrix. * @brief Record stream for the row, col and data (if not Null) of the matrix.
* @param stream The stream that is using the graph * @param stream The stream that is using the graph
*/ */
...@@ -182,28 +182,28 @@ struct COOMatrix { ...@@ -182,28 +182,28 @@ struct COOMatrix {
///////////////////////// COO routines ////////////////////////// ///////////////////////// COO routines //////////////////////////
/*! @brief Return true if the value (row, col) is non-zero */ /** @brief Return true if the value (row, col) is non-zero */
bool COOIsNonZero(COOMatrix , int64_t row, int64_t col); bool COOIsNonZero(COOMatrix , int64_t row, int64_t col);
/*! /**
* @brief Batched implementation of COOIsNonZero. * @brief Batched implementation of COOIsNonZero.
* @note This operator allows broadcasting (i.e, either row or col can be of length 1). * @note This operator allows broadcasting (i.e, either row or col can be of length 1).
*/ */
runtime::NDArray COOIsNonZero(COOMatrix , runtime::NDArray row, runtime::NDArray col); runtime::NDArray COOIsNonZero(COOMatrix , runtime::NDArray row, runtime::NDArray col);
/*! @brief Return the nnz of the given row */ /** @brief Return the nnz of the given row */
int64_t COOGetRowNNZ(COOMatrix , int64_t row); int64_t COOGetRowNNZ(COOMatrix , int64_t row);
runtime::NDArray COOGetRowNNZ(COOMatrix , runtime::NDArray row); runtime::NDArray COOGetRowNNZ(COOMatrix , runtime::NDArray row);
/*! @brief Return the data array of the given row */ /** @brief Return the data array of the given row */
std::pair<runtime::NDArray, runtime::NDArray> std::pair<runtime::NDArray, runtime::NDArray>
COOGetRowDataAndIndices(COOMatrix , int64_t row); COOGetRowDataAndIndices(COOMatrix , int64_t row);
/*! @brief Whether the COO matrix contains data */ /** @brief Whether the COO matrix contains data */
inline bool COOHasData(COOMatrix csr) { inline bool COOHasData(COOMatrix csr) {
return !IsNullArray(csr.data); return !IsNullArray(csr.data);
} }
/*! /**
* @brief Check whether the COO is sorted. * @brief Check whether the COO is sorted.
* *
* It returns two flags: one for whether the row is sorted; * It returns two flags: one for whether the row is sorted;
...@@ -214,7 +214,7 @@ inline bool COOHasData(COOMatrix csr) { ...@@ -214,7 +214,7 @@ inline bool COOHasData(COOMatrix csr) {
*/ */
std::pair<bool, bool> COOIsSorted(COOMatrix coo); std::pair<bool, bool> COOIsSorted(COOMatrix coo);
/*! /**
* @brief Get the data and the row,col indices for each returned entries. * @brief Get the data and the row,col indices for each returned entries.
* *
* The operator supports matrix with duplicate entries and all the matched entries * The operator supports matrix with duplicate entries and all the matched entries
...@@ -230,7 +230,7 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo); ...@@ -230,7 +230,7 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo);
std::vector<runtime::NDArray> COOGetDataAndIndices( std::vector<runtime::NDArray> COOGetDataAndIndices(
COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols); COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols);
/*! @brief Get data. The return type is an ndarray due to possible duplicate entries. */ /** @brief Get data. The return type is an ndarray due to possible duplicate entries. */
inline runtime::NDArray COOGetAllData(COOMatrix mat, int64_t row, int64_t col) { inline runtime::NDArray COOGetAllData(COOMatrix mat, int64_t row, int64_t col) {
IdArray rows = VecToIdArray<int64_t>({row}, mat.row->dtype.bits, mat.row->ctx); IdArray rows = VecToIdArray<int64_t>({row}, mat.row->dtype.bits, mat.row->ctx);
IdArray cols = VecToIdArray<int64_t>({col}, mat.row->dtype.bits, mat.row->ctx); IdArray cols = VecToIdArray<int64_t>({col}, mat.row->dtype.bits, mat.row->ctx);
...@@ -238,7 +238,7 @@ inline runtime::NDArray COOGetAllData(COOMatrix mat, int64_t row, int64_t col) { ...@@ -238,7 +238,7 @@ inline runtime::NDArray COOGetAllData(COOMatrix mat, int64_t row, int64_t col) {
return rst[2]; return rst[2];
} }
/*! /**
* @brief Get the data for each (row, col) pair. * @brief Get the data for each (row, col) pair.
* *
* The operator supports matrix with duplicate entries but only one matched entry * The operator supports matrix with duplicate entries but only one matched entry
...@@ -254,10 +254,10 @@ inline runtime::NDArray COOGetAllData(COOMatrix mat, int64_t row, int64_t col) { ...@@ -254,10 +254,10 @@ inline runtime::NDArray COOGetAllData(COOMatrix mat, int64_t row, int64_t col) {
*/ */
runtime::NDArray COOGetData(COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols); runtime::NDArray COOGetData(COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols);
/*! @brief Return a transposed COO matrix */ /** @brief Return a transposed COO matrix */
COOMatrix COOTranspose(COOMatrix coo); COOMatrix COOTranspose(COOMatrix coo);
/*! /**
* @brief Convert COO matrix to CSR matrix. * @brief Convert COO matrix to CSR matrix.
* *
* If the input COO matrix does not have data array, the data array of * If the input COO matrix does not have data array, the data array of
...@@ -281,7 +281,7 @@ COOMatrix COOTranspose(COOMatrix coo); ...@@ -281,7 +281,7 @@ COOMatrix COOTranspose(COOMatrix coo);
*/ */
CSRMatrix COOToCSR(COOMatrix coo); CSRMatrix COOToCSR(COOMatrix coo);
/*! /**
* @brief Slice rows of the given matrix and return. * @brief Slice rows of the given matrix and return.
* @param coo COO matrix * @param coo COO matrix
* @param start Start row id (inclusive) * @param start Start row id (inclusive)
...@@ -290,7 +290,7 @@ CSRMatrix COOToCSR(COOMatrix coo); ...@@ -290,7 +290,7 @@ CSRMatrix COOToCSR(COOMatrix coo);
COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end); COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end);
COOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows); COOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows);
/*! /**
* @brief Get the submatrix specified by the row and col ids. * @brief Get the submatrix specified by the row and col ids.
* *
* In numpy notation, given matrix M, row index array I, col index array J * In numpy notation, given matrix M, row index array I, col index array J
...@@ -303,16 +303,16 @@ COOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows); ...@@ -303,16 +303,16 @@ COOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows);
*/ */
COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols); COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
/*! @return True if the matrix has duplicate entries */ /** @return True if the matrix has duplicate entries */
bool COOHasDuplicate(COOMatrix coo); bool COOHasDuplicate(COOMatrix coo);
/*! /**
* @brief Deduplicate the entries of a sorted COO matrix, replacing the data with the * @brief Deduplicate the entries of a sorted COO matrix, replacing the data with the
* number of occurrences of the row-col coordinates. * number of occurrences of the row-col coordinates.
*/ */
std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo); std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo);
/*! /**
* @brief Sort the indices of a COO matrix in-place. * @brief Sort the indices of a COO matrix in-place.
* *
* The function sorts row indices in ascending order. If sort_column is true, * The function sorts row indices in ascending order. If sort_column is true,
...@@ -327,7 +327,7 @@ std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo); ...@@ -327,7 +327,7 @@ std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo);
*/ */
void COOSort_(COOMatrix* mat, bool sort_column = false); void COOSort_(COOMatrix* mat, bool sort_column = false);
/*! /**
* @brief Sort the indices of a COO matrix. * @brief Sort the indices of a COO matrix.
* *
* The function sorts row indices in ascending order. If sort_column is true, * The function sorts row indices in ascending order. If sort_column is true,
...@@ -352,14 +352,14 @@ inline COOMatrix COOSort(COOMatrix mat, bool sort_column = false) { ...@@ -352,14 +352,14 @@ inline COOMatrix COOSort(COOMatrix mat, bool sort_column = false) {
return ret; return ret;
} }
/*! /**
* @brief Remove entries from COO matrix by entry indices (data indices) * @brief Remove entries from COO matrix by entry indices (data indices)
* @return A new COO matrix as well as a mapping from the new COO entries to the old COO * @return A new COO matrix as well as a mapping from the new COO entries to the old COO
* entries. * entries.
*/ */
COOMatrix COORemove(COOMatrix coo, IdArray entries); COOMatrix COORemove(COOMatrix coo, IdArray entries);
/*! /**
* @brief Reorder the rows and colmns according to the new row and column order. * @brief Reorder the rows and colmns according to the new row and column order.
* @param csr The input coo matrix. * @param csr The input coo matrix.
* @param new_row_ids the new row Ids (the index is the old row Id) * @param new_row_ids the new row Ids (the index is the old row Id)
...@@ -367,7 +367,7 @@ COOMatrix COORemove(COOMatrix coo, IdArray entries); ...@@ -367,7 +367,7 @@ COOMatrix COORemove(COOMatrix coo, IdArray entries);
*/ */
COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids); COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
/*! /**
* @brief Randomly select a fixed number of non-zero entries along each given row independently. * @brief Randomly select a fixed number of non-zero entries along each given row independently.
* *
* The function performs random choices along each row independently. * The function performs random choices along each row independently.
...@@ -410,7 +410,7 @@ COOMatrix COORowWiseSampling( ...@@ -410,7 +410,7 @@ COOMatrix COORowWiseSampling(
NDArray prob_or_mask = NDArray(), NDArray prob_or_mask = NDArray(),
bool replace = true); bool replace = true);
/*! /**
* @brief Randomly select a fixed number of non-zero entries for each edge type * @brief Randomly select a fixed number of non-zero entries for each edge type
* along each given row independently. * along each given row independently.
* *
...@@ -462,7 +462,7 @@ COOMatrix COORowWisePerEtypeSampling( ...@@ -462,7 +462,7 @@ COOMatrix COORowWisePerEtypeSampling(
const std::vector<NDArray>& prob_or_mask, const std::vector<NDArray>& prob_or_mask,
bool replace = true); bool replace = true);
/*! /**
* @brief Select K non-zero entries with the largest weights along each given row. * @brief Select K non-zero entries with the largest weights along each given row.
* *
* The function performs top-k selection along each row independently. * The function performs top-k selection along each row independently.
...@@ -506,7 +506,7 @@ COOMatrix COORowWiseTopk( ...@@ -506,7 +506,7 @@ COOMatrix COORowWiseTopk(
NDArray weight, NDArray weight,
bool ascending = false); bool ascending = false);
/*! /**
* @brief Union two COOMatrix into one COOMatrix. * @brief Union two COOMatrix into one COOMatrix.
* *
* Two Matrix must have the same shape. * Two Matrix must have the same shape.
...@@ -538,7 +538,7 @@ COOMatrix COORowWiseTopk( ...@@ -538,7 +538,7 @@ COOMatrix COORowWiseTopk(
COOMatrix UnionCoo( COOMatrix UnionCoo(
const std::vector<COOMatrix>& coos); const std::vector<COOMatrix>& coos);
/*! /**
* @brief DisjointUnion a list COOMatrix into one COOMatrix. * @brief DisjointUnion a list COOMatrix into one COOMatrix.
* *
* Examples: * Examples:
...@@ -573,7 +573,7 @@ COOMatrix UnionCoo( ...@@ -573,7 +573,7 @@ COOMatrix UnionCoo(
COOMatrix DisjointUnionCoo( COOMatrix DisjointUnionCoo(
const std::vector<COOMatrix>& coos); const std::vector<COOMatrix>& coos);
/*! /**
* @brief COOMatrix toSimple. * @brief COOMatrix toSimple.
* *
* A = [[0, 0, 0], * A = [[0, 0, 0],
...@@ -597,7 +597,7 @@ COOMatrix DisjointUnionCoo( ...@@ -597,7 +597,7 @@ COOMatrix DisjointUnionCoo(
*/ */
std::tuple<COOMatrix, IdArray, IdArray> COOToSimple(const COOMatrix& coo); std::tuple<COOMatrix, IdArray, IdArray> COOToSimple(const COOMatrix& coo);
/*! /**
* @brief Split a COOMatrix into multiple disjoin components. * @brief Split a COOMatrix into multiple disjoin components.
* *
* Examples: * Examples:
...@@ -648,7 +648,7 @@ std::vector<COOMatrix> DisjointPartitionCooBySizes( ...@@ -648,7 +648,7 @@ std::vector<COOMatrix> DisjointPartitionCooBySizes(
const std::vector<uint64_t> &src_vertex_cumsum, const std::vector<uint64_t> &src_vertex_cumsum,
const std::vector<uint64_t> &dst_vertex_cumsum); const std::vector<uint64_t> &dst_vertex_cumsum);
/*! /**
* @brief Slice a contiguous chunk from a COOMatrix * @brief Slice a contiguous chunk from a COOMatrix
* *
* Examples: * Examples:
...@@ -689,7 +689,7 @@ COOMatrix COOSliceContiguousChunk( ...@@ -689,7 +689,7 @@ COOMatrix COOSliceContiguousChunk(
const std::vector<uint64_t> &src_vertex_range, const std::vector<uint64_t> &src_vertex_range,
const std::vector<uint64_t> &dst_vertex_range); const std::vector<uint64_t> &dst_vertex_range);
/*! /**
* @brief Create a LineGraph of input coo * @brief Create a LineGraph of input coo
* *
* A = [[0, 0, 1], * A = [[0, 0, 1],
......
/*! /**
* Copyright (c) 2020 by Contributors * Copyright (c) 2020 by Contributors
* @file dgl/aten/csr.h * @file dgl/aten/csr.h
* @brief Common CSR operations required by DGL. * @brief Common CSR operations required by DGL.
...@@ -8,25 +8,27 @@ ...@@ -8,25 +8,27 @@
#include <dmlc/io.h> #include <dmlc/io.h>
#include <dmlc/serializer.h> #include <dmlc/serializer.h>
#include <vector>
#include <tuple>
#include <string> #include <string>
#include <tuple>
#include <utility> #include <utility>
#include "./types.h" #include <vector>
#include "./array_ops.h" #include "./array_ops.h"
#include "./spmat.h"
#include "./macro.h" #include "./macro.h"
#include "./spmat.h"
#include "./types.h"
namespace dgl { namespace dgl {
namespace aten { namespace aten {
struct COOMatrix; struct COOMatrix;
/*! /**
* @brief Plain CSR matrix * @brief Plain CSR matrix
* *
* The column indices are 0-based and are not necessarily sorted. The data array stores * The column indices are 0-based and are not necessarily sorted. The data array
* integer ids for reading edge features. * stores integer ids for reading edge features.
* *
* Note that we do allow duplicate non-zero entries -- multiple non-zero entries * Note that we do allow duplicate non-zero entries -- multiple non-zero entries
* that have the same row, col indices. It corresponds to multigraph in * that have the same row, col indices. It corresponds to multigraph in
...@@ -36,21 +38,22 @@ struct COOMatrix; ...@@ -36,21 +38,22 @@ struct COOMatrix;
constexpr uint64_t kDGLSerialize_AtenCsrMatrixMagic = 0xDD6cd31205dff127; constexpr uint64_t kDGLSerialize_AtenCsrMatrixMagic = 0xDD6cd31205dff127;
struct CSRMatrix { struct CSRMatrix {
/*! @brief the dense shape of the matrix */ /** @brief the dense shape of the matrix */
int64_t num_rows = 0, num_cols = 0; int64_t num_rows = 0, num_cols = 0;
/*! @brief CSR index arrays */ /** @brief CSR index arrays */
IdArray indptr, indices; IdArray indptr, indices;
/*! @brief data index array. When is null, assume it is from 0 to NNZ - 1. */ /** @brief data index array. When is null, assume it is from 0 to NNZ - 1. */
IdArray data; IdArray data;
/*! @brief whether the column indices per row are sorted */ /** @brief whether the column indices per row are sorted */
bool sorted = false; bool sorted = false;
/*! @brief whether the matrix is in pinned memory */ /** @brief whether the matrix is in pinned memory */
bool is_pinned = false; bool is_pinned = false;
/*! @brief default constructor */ /** @brief default constructor */
CSRMatrix() = default; CSRMatrix() = default;
/*! @brief constructor */ /** @brief constructor */
CSRMatrix(int64_t nrows, int64_t ncols, IdArray parr, IdArray iarr, CSRMatrix(
IdArray darr = NullArray(), bool sorted_flag = false) int64_t nrows, int64_t ncols, IdArray parr, IdArray iarr,
IdArray darr = NullArray(), bool sorted_flag = false)
: num_rows(nrows), : num_rows(nrows),
num_cols(ncols), num_cols(ncols),
indptr(parr), indptr(parr),
...@@ -60,7 +63,7 @@ struct CSRMatrix { ...@@ -60,7 +63,7 @@ struct CSRMatrix {
CheckValidity(); CheckValidity();
} }
/*! @brief constructor from SparseMatrix object */ /** @brief constructor from SparseMatrix object */
explicit CSRMatrix(const SparseMatrix& spmat) explicit CSRMatrix(const SparseMatrix& spmat)
: num_rows(spmat.num_rows), : num_rows(spmat.num_rows),
num_cols(spmat.num_cols), num_cols(spmat.num_cols),
...@@ -73,8 +76,9 @@ struct CSRMatrix { ...@@ -73,8 +76,9 @@ struct CSRMatrix {
// Convert to a SparseMatrix object that can return to python. // Convert to a SparseMatrix object that can return to python.
SparseMatrix ToSparseMatrix() const { SparseMatrix ToSparseMatrix() const {
return SparseMatrix(static_cast<int32_t>(SparseFormat::kCSR), num_rows, return SparseMatrix(
num_cols, {indptr, indices, data}, {sorted}); static_cast<int32_t>(SparseFormat::kCSR), num_rows, num_cols,
{indptr, indices, data}, {sorted});
} }
bool Load(dmlc::Stream* fs) { bool Load(dmlc::Stream* fs) {
...@@ -114,25 +118,24 @@ struct CSRMatrix { ...@@ -114,25 +118,24 @@ struct CSRMatrix {
CHECK_EQ(indptr->shape[0], num_rows + 1); CHECK_EQ(indptr->shape[0], num_rows + 1);
} }
/*! @brief Return a copy of this matrix on the give device context. */ /** @brief Return a copy of this matrix on the give device context. */
inline CSRMatrix CopyTo(const DGLContext &ctx) const { inline CSRMatrix CopyTo(const DGLContext& ctx) const {
if (ctx == indptr->ctx) if (ctx == indptr->ctx) return *this;
return *this; return CSRMatrix(
return CSRMatrix(num_rows, num_cols, indptr.CopyTo(ctx), indices.CopyTo(ctx), num_rows, num_cols, indptr.CopyTo(ctx), indices.CopyTo(ctx),
aten::IsNullArray(data) ? data : data.CopyTo(ctx), sorted); aten::IsNullArray(data) ? data : data.CopyTo(ctx), sorted);
} }
/*! /**
* @brief Pin the indptr, indices and data (if not Null) of the matrix. * @brief Pin the indptr, indices and data (if not Null) of the matrix.
* @note This is an in-place method. Behavior depends on the current context, * @note This is an in-place method. Behavior depends on the current context,
* kDGLCPU: will be pinned; * kDGLCPU: will be pinned;
* IsPinned: directly return; * IsPinned: directly return;
* kDGLCUDA: invalid, will throw an error. * kDGLCUDA: invalid, will throw an error.
* The context check is deferred to pinning the NDArray. * The context check is deferred to pinning the NDArray.
*/ */
inline void PinMemory_() { inline void PinMemory_() {
if (is_pinned) if (is_pinned) return;
return;
indptr.PinMemory_(); indptr.PinMemory_();
indices.PinMemory_(); indices.PinMemory_();
if (!aten::IsNullArray(data)) { if (!aten::IsNullArray(data)) {
...@@ -141,16 +144,15 @@ struct CSRMatrix { ...@@ -141,16 +144,15 @@ struct CSRMatrix {
is_pinned = true; is_pinned = true;
} }
/*! /**
* @brief Unpin the indptr, indices and data (if not Null) of the matrix. * @brief Unpin the indptr, indices and data (if not Null) of the matrix.
* @note This is an in-place method. Behavior depends on the current context, * @note This is an in-place method. Behavior depends on the current context,
* IsPinned: will be unpinned; * IsPinned: will be unpinned;
* others: directly return. * others: directly return.
* The context check is deferred to unpinning the NDArray. * The context check is deferred to unpinning the NDArray.
*/ */
inline void UnpinMemory_() { inline void UnpinMemory_() {
if (!is_pinned) if (!is_pinned) return;
return;
indptr.UnpinMemory_(); indptr.UnpinMemory_();
indices.UnpinMemory_(); indices.UnpinMemory_();
if (!aten::IsNullArray(data)) { if (!aten::IsNullArray(data)) {
...@@ -159,8 +161,9 @@ struct CSRMatrix { ...@@ -159,8 +161,9 @@ struct CSRMatrix {
is_pinned = false; is_pinned = false;
} }
/*! /**
* @brief Record stream for the indptr, indices and data (if not Null) of the matrix. * @brief Record stream for the indptr, indices and data (if not Null) of the
* matrix.
* @param stream The stream that is using the graph * @param stream The stream that is using the graph
*/ */
inline void RecordStream(DGLStreamHandle stream) const { inline void RecordStream(DGLStreamHandle stream) const {
...@@ -174,52 +177,54 @@ struct CSRMatrix { ...@@ -174,52 +177,54 @@ struct CSRMatrix {
///////////////////////// CSR routines ////////////////////////// ///////////////////////// CSR routines //////////////////////////
/*! @brief Return true if the value (row, col) is non-zero */ /** @brief Return true if the value (row, col) is non-zero */
bool CSRIsNonZero(CSRMatrix , int64_t row, int64_t col); bool CSRIsNonZero(CSRMatrix, int64_t row, int64_t col);
/*! /**
* @brief Batched implementation of CSRIsNonZero. * @brief Batched implementation of CSRIsNonZero.
* @note This operator allows broadcasting (i.e, either row or col can be of length 1). * @note This operator allows broadcasting (i.e, either row or col can be of
* length 1).
*/ */
runtime::NDArray CSRIsNonZero(CSRMatrix, runtime::NDArray row, runtime::NDArray col); runtime::NDArray CSRIsNonZero(
CSRMatrix, runtime::NDArray row, runtime::NDArray col);
/*! @brief Return the nnz of the given row */ /** @brief Return the nnz of the given row */
int64_t CSRGetRowNNZ(CSRMatrix , int64_t row); int64_t CSRGetRowNNZ(CSRMatrix, int64_t row);
runtime::NDArray CSRGetRowNNZ(CSRMatrix , runtime::NDArray row); runtime::NDArray CSRGetRowNNZ(CSRMatrix, runtime::NDArray row);
/*! @brief Return the column index array of the given row */ /** @brief Return the column index array of the given row */
runtime::NDArray CSRGetRowColumnIndices(CSRMatrix , int64_t row); runtime::NDArray CSRGetRowColumnIndices(CSRMatrix, int64_t row);
/*! @brief Return the data array of the given row */ /** @brief Return the data array of the given row */
runtime::NDArray CSRGetRowData(CSRMatrix , int64_t row); runtime::NDArray CSRGetRowData(CSRMatrix, int64_t row);
/*! @brief Whether the CSR matrix contains data */ /** @brief Whether the CSR matrix contains data */
inline bool CSRHasData(CSRMatrix csr) { inline bool CSRHasData(CSRMatrix csr) { return !IsNullArray(csr.data); }
return !IsNullArray(csr.data);
}
/*! @brief Whether the column indices of each row is sorted. */ /** @brief Whether the column indices of each row is sorted. */
bool CSRIsSorted(CSRMatrix csr); bool CSRIsSorted(CSRMatrix csr);
/*! /**
* @brief Get the data and the row,col indices for each returned entries. * @brief Get the data and the row,col indices for each returned entries.
* *
* The operator supports matrix with duplicate entries and all the matched entries * The operator supports matrix with duplicate entries and all the matched
* will be returned. The operator assumes there is NO duplicate (row, col) pair * entries will be returned. The operator assumes there is NO duplicate (row,
* in the given input. Otherwise, the returned result is undefined. * col) pair in the given input. Otherwise, the returned result is undefined.
* *
* If some (row, col) pairs do not contain a valid non-zero elements, * If some (row, col) pairs do not contain a valid non-zero elements,
* they will not be included in the return arrays. * they will not be included in the return arrays.
* *
* @note This operator allows broadcasting (i.e, either row or col can be of length 1). * @note This operator allows broadcasting (i.e, either row or col can be of
* length 1).
* @param mat Sparse matrix * @param mat Sparse matrix
* @param rows Row index * @param rows Row index
* @param cols Column index * @param cols Column index
* @return Three arrays {rows, cols, data} * @return Three arrays {rows, cols, data}
*/ */
std::vector<runtime::NDArray> CSRGetDataAndIndices( std::vector<runtime::NDArray> CSRGetDataAndIndices(
CSRMatrix , runtime::NDArray rows, runtime::NDArray cols); CSRMatrix, runtime::NDArray rows, runtime::NDArray cols);
/* @brief Get data. The return type is an ndarray due to possible duplicate entries. */ /* @brief Get data. The return type is an ndarray due to possible duplicate
* entries. */
inline runtime::NDArray CSRGetAllData(CSRMatrix mat, int64_t row, int64_t col) { inline runtime::NDArray CSRGetAllData(CSRMatrix mat, int64_t row, int64_t col) {
const auto& nbits = mat.indptr->dtype.bits; const auto& nbits = mat.indptr->dtype.bits;
const auto& ctx = mat.indptr->ctx; const auto& ctx = mat.indptr->ctx;
...@@ -229,54 +234,60 @@ inline runtime::NDArray CSRGetAllData(CSRMatrix mat, int64_t row, int64_t col) { ...@@ -229,54 +234,60 @@ inline runtime::NDArray CSRGetAllData(CSRMatrix mat, int64_t row, int64_t col) {
return rst[2]; return rst[2];
} }
/*! /**
* @brief Get the data for each (row, col) pair. * @brief Get the data for each (row, col) pair.
* *
* The operator supports matrix with duplicate entries but only one matched entry * The operator supports matrix with duplicate entries but only one matched
* will be returned for each (row, col) pair. Support duplicate input (row, col) * entry will be returned for each (row, col) pair. Support duplicate input
* pairs. * (row, col) pairs.
* *
* If some (row, col) pairs do not contain a valid non-zero elements, * If some (row, col) pairs do not contain a valid non-zero elements,
* their data values are filled with -1. * their data values are filled with -1.
* *
* @note This operator allows broadcasting (i.e, either row or col can be of length 1). * @note This operator allows broadcasting (i.e, either row or col can be of
* length 1).
* *
* @param mat Sparse matrix. * @param mat Sparse matrix.
* @param rows Row index. * @param rows Row index.
* @param cols Column index. * @param cols Column index.
* @return Data array. The i^th element is the data of (rows[i], cols[i]) * @return Data array. The i^th element is the data of (rows[i], cols[i])
*/ */
runtime::NDArray CSRGetData(CSRMatrix, runtime::NDArray rows, runtime::NDArray cols); runtime::NDArray CSRGetData(
CSRMatrix, runtime::NDArray rows, runtime::NDArray cols);
/*! /**
* @brief Get the data for each (row, col) pair, then index into the weights array. * @brief Get the data for each (row, col) pair, then index into the weights
* array.
* *
* The operator supports matrix with duplicate entries but only one matched entry * The operator supports matrix with duplicate entries but only one matched
* will be returned for each (row, col) pair. Support duplicate input (row, col) * entry will be returned for each (row, col) pair. Support duplicate input
* pairs. * (row, col) pairs.
* *
* If some (row, col) pairs do not contain a valid non-zero elements to index into the * If some (row, col) pairs do not contain a valid non-zero elements to index
* weights array, DGL returns the value \a filler for that pair instead. * into the weights array, DGL returns the value \a filler for that pair
* instead.
* *
* @note This operator allows broadcasting (i.e, either row or col can be of length 1). * @note This operator allows broadcasting (i.e, either row or col can be of
* length 1).
* *
* @tparam DType the data type of the weights array. * @tparam DType the data type of the weights array.
* @param mat Sparse matrix. * @param mat Sparse matrix.
* @param rows Row index. * @param rows Row index.
* @param cols Column index. * @param cols Column index.
* @param weights The weights array. * @param weights The weights array.
* @param filler The value to return for row-column pairs not existent in the matrix. * @param filler The value to return for row-column pairs not existent in the
* matrix.
* @return Data array. The i^th element is the data of (rows[i], cols[i]) * @return Data array. The i^th element is the data of (rows[i], cols[i])
*/ */
template <typename DType> template <typename DType>
runtime::NDArray CSRGetData( runtime::NDArray CSRGetData(
CSRMatrix, runtime::NDArray rows, runtime::NDArray cols, runtime::NDArray weights, CSRMatrix, runtime::NDArray rows, runtime::NDArray cols,
DType filler); runtime::NDArray weights, DType filler);
/*! @brief Return a transposed CSR matrix */ /** @brief Return a transposed CSR matrix */
CSRMatrix CSRTranspose(CSRMatrix csr); CSRMatrix CSRTranspose(CSRMatrix csr);
/*! /**
* @brief Convert CSR matrix to COO matrix. * @brief Convert CSR matrix to COO matrix.
* *
* Complexity: O(nnz) * Complexity: O(nnz)
...@@ -288,15 +299,15 @@ CSRMatrix CSRTranspose(CSRMatrix csr); ...@@ -288,15 +299,15 @@ CSRMatrix CSRTranspose(CSRMatrix csr);
* column sorted. * column sorted.
* *
* @param csr Input csr matrix * @param csr Input csr matrix
* @param data_as_order If true, the data array in the input csr matrix contains the order * @param data_as_order If true, the data array in the input csr matrix contains
* by which the resulting COO tuples are stored. In this case, the * the order by which the resulting COO tuples are stored. In this case, the
* data array of the resulting COO matrix will be empty because it * data array of the resulting COO matrix will be empty
* is essentially a consecutive range. * because it is essentially a consecutive range.
* @return a coo matrix * @return a coo matrix
*/ */
COOMatrix CSRToCOO(CSRMatrix csr, bool data_as_order); COOMatrix CSRToCOO(CSRMatrix csr, bool data_as_order);
/*! /**
* @brief Slice rows of the given matrix and return. * @brief Slice rows of the given matrix and return.
* *
* The sliced row IDs are relabeled to starting from zero. * The sliced row IDs are relabeled to starting from zero.
...@@ -322,7 +333,7 @@ COOMatrix CSRToCOO(CSRMatrix csr, bool data_as_order); ...@@ -322,7 +333,7 @@ COOMatrix CSRToCOO(CSRMatrix csr, bool data_as_order);
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end); CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end);
CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows); CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows);
/*! /**
* @brief Get the submatrix specified by the row and col ids. * @brief Get the submatrix specified by the row and col ids.
* *
* In numpy notation, given matrix M, row index array I, col index array J * In numpy notation, given matrix M, row index array I, col index array J
...@@ -339,16 +350,17 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows); ...@@ -339,16 +350,17 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows);
* @param cols The col index to select * @param cols The col index to select
* @return submatrix * @return submatrix
*/ */
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols); CSRMatrix CSRSliceMatrix(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
/*! @return True if the matrix has duplicate entries */ /** @return True if the matrix has duplicate entries */
bool CSRHasDuplicate(CSRMatrix csr); bool CSRHasDuplicate(CSRMatrix csr);
/*! /**
* @brief Sort the column index at each row in ascending order in-place. * @brief Sort the column index at each row in ascending order in-place.
* *
* Only the indices and data arrays (if available) will be mutated. The indptr array * Only the indices and data arrays (if available) will be mutated. The indptr
* stays the same. * array stays the same.
* *
* Examples: * Examples:
* num_rows = 4 * num_rows = 4
...@@ -363,39 +375,39 @@ bool CSRHasDuplicate(CSRMatrix csr); ...@@ -363,39 +375,39 @@ bool CSRHasDuplicate(CSRMatrix csr);
*/ */
void CSRSort_(CSRMatrix* csr); void CSRSort_(CSRMatrix* csr);
/*! /**
* @brief Sort the column index at each row in ascending order. * @brief Sort the column index at each row in ascending order.
* *
* Return a new CSR matrix with sorted column indices and data arrays. * Return a new CSR matrix with sorted column indices and data arrays.
*/ */
inline CSRMatrix CSRSort(CSRMatrix csr) { inline CSRMatrix CSRSort(CSRMatrix csr) {
if (csr.sorted) if (csr.sorted) return csr;
return csr; CSRMatrix ret(
CSRMatrix ret(csr.num_rows, csr.num_cols, csr.num_rows, csr.num_cols, csr.indptr, csr.indices.Clone(),
csr.indptr, csr.indices.Clone(), CSRHasData(csr) ? csr.data.Clone() : csr.data, csr.sorted);
CSRHasData(csr)? csr.data.Clone() : csr.data,
csr.sorted);
CSRSort_(&ret); CSRSort_(&ret);
return ret; return ret;
} }
/*! /**
* @brief Reorder the rows and colmns according to the new row and column order. * @brief Reorder the rows and colmns according to the new row and column order.
* @param csr The input csr matrix. * @param csr The input csr matrix.
* @param new_row_ids the new row Ids (the index is the old row Id) * @param new_row_ids the new row Ids (the index is the old row Id)
* @param new_col_ids the new column Ids (the index is the old col Id). * @param new_col_ids the new column Ids (the index is the old col Id).
*/ */
CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids); CSRMatrix CSRReorder(
CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
/*! /**
* @brief Remove entries from CSR matrix by entry indices (data indices) * @brief Remove entries from CSR matrix by entry indices (data indices)
* @return A new CSR matrix as well as a mapping from the new CSR entries to the old CSR * @return A new CSR matrix as well as a mapping from the new CSR entries to the
* entries. * old CSR entries.
*/ */
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries); CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries);
/*! /**
* @brief Randomly select a fixed number of non-zero entries along each given row independently. * @brief Randomly select a fixed number of non-zero entries along each given
* row independently.
* *
* The function performs random choices along each row independently. * The function performs random choices along each row independently.
* The picked indices are returned in the form of a COO matrix. * The picked indices are returned in the form of a COO matrix.
...@@ -431,13 +443,10 @@ CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries); ...@@ -431,13 +443,10 @@ CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries);
* @note The edges of the entire graph must be ordered by their edge types. * @note The edges of the entire graph must be ordered by their edge types.
*/ */
COOMatrix CSRRowWiseSampling( COOMatrix CSRRowWiseSampling(
CSRMatrix mat, CSRMatrix mat, IdArray rows, int64_t num_samples,
IdArray rows, NDArray prob_or_mask = NDArray(), bool replace = true);
int64_t num_samples,
NDArray prob_or_mask = NDArray(),
bool replace = true);
/*! /**
* @brief Randomly select a fixed number of non-zero entries for each edge type * @brief Randomly select a fixed number of non-zero entries for each edge type
* along each given row independently. * along each given row independently.
* *
...@@ -460,8 +469,8 @@ COOMatrix CSRRowWiseSampling( ...@@ -460,8 +469,8 @@ COOMatrix CSRRowWiseSampling(
* CSRMatrix csr = ...; * CSRMatrix csr = ...;
* IdArray rows = ... ; // [0, 3] * IdArray rows = ... ; // [0, 3]
* std::vector<int64_t> num_samples = {2, 2, 2}; * std::vector<int64_t> num_samples = {2, 2, 2};
* COOMatrix sampled = CSRRowWisePerEtypeSampling(csr, rows, eid2etype_offset, num_samples, * COOMatrix sampled = CSRRowWisePerEtypeSampling(csr, rows, eid2etype_offset,
* FloatArray(), false); * num_samples, FloatArray(), false);
* // possible sampled coo matrix: * // possible sampled coo matrix:
* // sampled.num_rows = 4 * // sampled.num_rows = 4
* // sampled.num_cols = 4 * // sampled.num_cols = 4
...@@ -477,21 +486,20 @@ COOMatrix CSRRowWiseSampling( ...@@ -477,21 +486,20 @@ COOMatrix CSRRowWiseSampling(
* Should be of the same length as the data array. * Should be of the same length as the data array.
* If an empty array is provided, assume uniform. * If an empty array is provided, assume uniform.
* @param replace True if sample with replacement * @param replace True if sample with replacement
* @param rowwise_etype_sorted whether the CSR column indices per row are ordered by edge type. * @param rowwise_etype_sorted whether the CSR column indices per row are
* ordered by edge type.
* @return A COOMatrix storing the picked row, col and data indices. * @return A COOMatrix storing the picked row, col and data indices.
* @note The edges must be ordered by their edge types. * @note The edges must be ordered by their edge types.
*/ */
COOMatrix CSRRowWisePerEtypeSampling( COOMatrix CSRRowWisePerEtypeSampling(
CSRMatrix mat, CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
IdArray rows,
const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples, const std::vector<int64_t>& num_samples,
const std::vector<NDArray>& prob_or_mask, const std::vector<NDArray>& prob_or_mask, bool replace = true,
bool replace = true,
bool rowwise_etype_sorted = false); bool rowwise_etype_sorted = false);
/*! /**
* @brief Select K non-zero entries with the largest weights along each given row. * @brief Select K non-zero entries with the largest weights along each given
* row.
* *
* The function performs top-k selection along each row independently. * The function performs top-k selection along each row independently.
* The picked indices are returned in the form of a COO matrix. * The picked indices are returned in the form of a COO matrix.
...@@ -520,35 +528,32 @@ COOMatrix CSRRowWisePerEtypeSampling( ...@@ -520,35 +528,32 @@ COOMatrix CSRRowWisePerEtypeSampling(
* @param mat Input CSR matrix. * @param mat Input CSR matrix.
* @param rows Rows to sample from. * @param rows Rows to sample from.
* @param k The K value. * @param k The K value.
* @param weight Weight associated with each entry. Should be of the same length as the * @param weight Weight associated with each entry. Should be of the same length
* data array. If an empty array is provided, assume uniform. * as the data array. If an empty array is provided, assume uniform.
* @param ascending If true, elements are sorted by ascending order, equivalent to find * @param ascending If true, elements are sorted by ascending order, equivalent
* the K smallest values. Otherwise, find K largest values. * to find the K smallest values. Otherwise, find K largest values.
* @return A COOMatrix storing the picked row and col indices. Its data field stores the * @return A COOMatrix storing the picked row and col indices. Its data field
* the index of the picked elements in the value array. * stores the the index of the picked elements in the value array.
*/ */
COOMatrix CSRRowWiseTopk( COOMatrix CSRRowWiseTopk(
CSRMatrix mat, CSRMatrix mat, IdArray rows, int64_t k, FloatArray weight,
IdArray rows,
int64_t k,
FloatArray weight,
bool ascending = false); bool ascending = false);
/**
* @brief Randomly select a fixed number of non-zero entries along each given
/*! * row independently, where the probability of columns to be picked can be
* @brief Randomly select a fixed number of non-zero entries along each given row independently, * biased according to its tag.
* where the probability of columns to be picked can be biased according to its tag.
* *
* Each column is assigned an integer tag which determines its probability to be sampled. * Each column is assigned an integer tag which determines its probability to be
* Users can assign different probability to different tags. * sampled. Users can assign different probability to different tags.
* *
* This function only works with a CSR matrix sorted according to the tag so that entries with * This function only works with a CSR matrix sorted according to the tag so
* the same column tag are arranged in a consecutive range, and the input `tag_offset` represents * that entries with the same column tag are arranged in a consecutive range,
* the boundaries of these ranges. However, the function itself will not check if the input matrix * and the input `tag_offset` represents the boundaries of these ranges.
* has been sorted. It's the caller's responsibility to ensure the input matrix has been sorted * However, the function itself will not check if the input matrix has been
* by `CSRSortByTag` (it will also return a NDArray `tag_offset` which should be used as an input * sorted. It's the caller's responsibility to ensure the input matrix has been
* of this function). * sorted by `CSRSortByTag` (it will also return a NDArray `tag_offset` which
* should be used as an input of this function).
* *
* The picked indices are returned in the form of a COO matrix. * The picked indices are returned in the form of a COO matrix.
* *
...@@ -576,53 +581,48 @@ COOMatrix CSRRowWiseTopk( ...@@ -576,53 +581,48 @@ COOMatrix CSRRowWiseTopk(
* // sampled.rows = [0, 1] * // sampled.rows = [0, 1]
* // sampled.cols = [1, 2] * // sampled.cols = [1, 2]
* // sampled.data = [2, 0] * // sampled.data = [2, 0]
* // Note that in this case, for row 1, the column 3 will never be picked as it has tag 1 and the * // Note that in this case, for row 1, the column 3 will never be picked as it
* has tag 1 and the
* // probability of tag 1 is 0. * // probability of tag 1 is 0.
* *
* *
* @param mat Input CSR matrix. * @param mat Input CSR matrix.
* @param rows Rows to sample from. * @param rows Rows to sample from.
* @param num_samples Number of samples. * @param num_samples Number of samples.
* @param tag_offset The boundaries of tags. Should be of the shape [num_row, num_tags+1] * @param tag_offset The boundaries of tags. Should be of the shape [num_row,
* num_tags+1]
* @param bias Unnormalized probability array. Should be of length num_tags * @param bias Unnormalized probability array. Should be of length num_tags
* @param replace True if sample with replacement * @param replace True if sample with replacement
* @return A COOMatrix storing the picked row and col indices. Its data field stores the * @return A COOMatrix storing the picked row and col indices. Its data field
* the index of the picked elements in the value array. * stores the the index of the picked elements in the value array.
* *
*/ */
COOMatrix CSRRowWiseSamplingBiased( COOMatrix CSRRowWiseSamplingBiased(
CSRMatrix mat, CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray tag_offset,
IdArray rows, FloatArray bias, bool replace = true);
int64_t num_samples,
NDArray tag_offset,
FloatArray bias,
bool replace = true
);
/*! /**
* @brief Uniformly sample row-column pairs whose entries do not exist in the given * @brief Uniformly sample row-column pairs whose entries do not exist in the
* sparse matrix using rejection sampling. * given sparse matrix using rejection sampling.
* *
* @note The number of samples returned may not necessarily be the number of samples * @note The number of samples returned may not necessarily be the number of
* given. * samples given.
* *
* @param csr The CSR matrix. * @param csr The CSR matrix.
* @param num_samples The number of samples. * @param num_samples The number of samples.
* @param num_trials The number of trials. * @param num_trials The number of trials.
* @param exclude_self_loops Do not include the examples where the row equals the column. * @param exclude_self_loops Do not include the examples where the row equals
* the column.
* @param replace Whether to sample with replacement. * @param replace Whether to sample with replacement.
* @param redundancy How much redundant negative examples to take in case of duplicate examples. * @param redundancy How much redundant negative examples to take in case of
* duplicate examples.
* @return A pair of row and column tensors. * @return A pair of row and column tensors.
*/ */
std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling( std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
const CSRMatrix& csr, const CSRMatrix& csr, int64_t num_samples, int num_trials,
int64_t num_samples, bool exclude_self_loops, bool replace, double redundancy);
int num_trials,
bool exclude_self_loops, /**
bool replace,
double redundancy);
/*!
* @brief Sort the column index according to the tag of each column. * @brief Sort the column index according to the tag of each column.
* *
* Example: * Example:
...@@ -647,14 +647,13 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling( ...@@ -647,14 +647,13 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
* @param tag_array Tag of each column. IdArray with length num_cols * @param tag_array Tag of each column. IdArray with length num_cols
* @param num_tags Number of tags. It should be equal to max(tag_array)+1. * @param num_tags Number of tags. It should be equal to max(tag_array)+1.
* @return 1. A sorted copy of the given CSR matrix * @return 1. A sorted copy of the given CSR matrix
* 2. The split positions of different tags. NDArray of shape (num_rows, num_tags + 1) * 2. The split positions of different tags. NDArray of shape (num_rows,
* num_tags + 1)
*/ */
std::pair<CSRMatrix, NDArray> CSRSortByTag( std::pair<CSRMatrix, NDArray> CSRSortByTag(
const CSRMatrix &csr, const CSRMatrix& csr, const IdArray tag_array, int64_t num_tags);
const IdArray tag_array,
int64_t num_tags);
/* /**
* @brief Union two CSRMatrix into one CSRMatrix. * @brief Union two CSRMatrix into one CSRMatrix.
* *
* Two Matrix must have the same shape. * Two Matrix must have the same shape.
...@@ -683,10 +682,9 @@ std::pair<CSRMatrix, NDArray> CSRSortByTag( ...@@ -683,10 +682,9 @@ std::pair<CSRMatrix, NDArray> CSRSortByTag(
* CSRMatrix_C.num_rows : 3 * CSRMatrix_C.num_rows : 3
* CSRMatrix_C.num_cols : 4 * CSRMatrix_C.num_cols : 4
*/ */
CSRMatrix UnionCsr( CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs);
const std::vector<CSRMatrix>& csrs);
/*! /**
* @brief Union a list CSRMatrix into one CSRMatrix. * @brief Union a list CSRMatrix into one CSRMatrix.
* *
* Examples: * Examples:
...@@ -714,14 +712,15 @@ CSRMatrix UnionCsr( ...@@ -714,14 +712,15 @@ CSRMatrix UnionCsr(
* CSRMatrix_C.num_cols : 5 * CSRMatrix_C.num_cols : 5
* *
* @param csrs The input list of csr matrix. * @param csrs The input list of csr matrix.
* @param src_offset A list of integers recording src vertix id offset of each Matrix in csrs * @param src_offset A list of integers recording src vertix id offset of each
* @param src_offset A list of integers recording dst vertix id offset of each Matrix in csrs * Matrix in csrs
* @param src_offset A list of integers recording dst vertix id offset of each
* Matrix in csrs
* @return The combined CSRMatrix. * @return The combined CSRMatrix.
*/ */
CSRMatrix DisjointUnionCsr( CSRMatrix DisjointUnionCsr(const std::vector<CSRMatrix>& csrs);
const std::vector<CSRMatrix>& csrs);
/*! /**
* @brief CSRMatrix toSimple. * @brief CSRMatrix toSimple.
* *
* A = [[0, 0, 0], * A = [[0, 0, 0],
...@@ -739,13 +738,13 @@ CSRMatrix DisjointUnionCsr( ...@@ -739,13 +738,13 @@ CSRMatrix DisjointUnionCsr(
* edge_map = [0, 0, 0, 1, 1, 2, 3, 4, 4, 4, 4] * edge_map = [0, 0, 0, 1, 1, 2, 3, 4, 4, 4, 4]
* *
* @return The simplified CSRMatrix * @return The simplified CSRMatrix
* The count recording the number of duplicated edges from the original graph. * The count recording the number of duplicated edges from the original
* The edge mapping from the edge IDs of original graph to those of the * graph. The edge mapping from the edge IDs of original graph to those of the
* returned graph. * returned graph.
*/ */
std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(const CSRMatrix& csr); std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(const CSRMatrix& csr);
/*! /**
* @brief Split a CSRMatrix into multiple disjoint components. * @brief Split a CSRMatrix into multiple disjoint components.
* *
* Examples: * Examples:
...@@ -790,13 +789,12 @@ std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(const CSRMatrix& csr); ...@@ -790,13 +789,12 @@ std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(const CSRMatrix& csr);
* @return A list of CSRMatrixes representing each disjoint components. * @return A list of CSRMatrixes representing each disjoint components.
*/ */
std::vector<CSRMatrix> DisjointPartitionCsrBySizes( std::vector<CSRMatrix> DisjointPartitionCsrBySizes(
const CSRMatrix &csrs, const CSRMatrix& csrs, const uint64_t batch_size,
const uint64_t batch_size, const std::vector<uint64_t>& edge_cumsum,
const std::vector<uint64_t> &edge_cumsum, const std::vector<uint64_t>& src_vertex_cumsum,
const std::vector<uint64_t> &src_vertex_cumsum, const std::vector<uint64_t>& dst_vertex_cumsum);
const std::vector<uint64_t> &dst_vertex_cumsum);
/*! /**
* @brief Slice a contiguous chunk from a CSRMatrix * @brief Slice a contiguous chunk from a CSRMatrix
* *
* Examples: * Examples:
...@@ -832,10 +830,9 @@ std::vector<CSRMatrix> DisjointPartitionCsrBySizes( ...@@ -832,10 +830,9 @@ std::vector<CSRMatrix> DisjointPartitionCsrBySizes(
* @return CSRMatrix representing the chunk. * @return CSRMatrix representing the chunk.
*/ */
CSRMatrix CSRSliceContiguousChunk( CSRMatrix CSRSliceContiguousChunk(
const CSRMatrix &csr, const CSRMatrix& csr, const std::vector<uint64_t>& edge_range,
const std::vector<uint64_t> &edge_range, const std::vector<uint64_t>& src_vertex_range,
const std::vector<uint64_t> &src_vertex_range, const std::vector<uint64_t>& dst_vertex_range);
const std::vector<uint64_t> &dst_vertex_range);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
......
/*! /**
* Copyright (c) 2020 by Contributors * Copyright (c) 2020 by Contributors
* @file dgl/aten/macro.h * @file dgl/aten/macro.h
* @brief Common macros for aten package. * @brief Common macros for aten package.
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
///////////////////////// Dispatchers ////////////////////////// ///////////////////////// Dispatchers //////////////////////////
/* /**
* Dispatch according to device: * Dispatch according to device:
* *
* ATEN_XPU_SWITCH(array->ctx.device_type, XPU, { * ATEN_XPU_SWITCH(array->ctx.device_type, XPU, {
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
} \ } \
} while (0) } while (0)
/* /**
* Dispatch according to device: * Dispatch according to device:
* *
* XXX(minjie): temporary macro that allows CUDA operator * XXX(minjie): temporary macro that allows CUDA operator
...@@ -59,7 +59,7 @@ ...@@ -59,7 +59,7 @@
#define ATEN_XPU_SWITCH_CUDA ATEN_XPU_SWITCH #define ATEN_XPU_SWITCH_CUDA ATEN_XPU_SWITCH
#endif // DGL_USE_CUDA #endif // DGL_USE_CUDA
/* /**
* Dispatch according to integral type (either int32 or int64): * Dispatch according to integral type (either int32 or int64):
* *
* ATEN_ID_TYPE_SWITCH(array->dtype, IdType, { * ATEN_ID_TYPE_SWITCH(array->dtype, IdType, {
...@@ -81,7 +81,7 @@ ...@@ -81,7 +81,7 @@
} \ } \
} while (0) } while (0)
/* /**
* Dispatch according to bits (either int32 or int64): * Dispatch according to bits (either int32 or int64):
* *
* ATEN_ID_BITS_SWITCH(bits, IdType, { * ATEN_ID_BITS_SWITCH(bits, IdType, {
...@@ -104,7 +104,7 @@ ...@@ -104,7 +104,7 @@
} \ } \
} while (0) } while (0)
/* /**
* Dispatch according to float type (either float32 or float64): * Dispatch according to float type (either float32 or float64):
* *
* ATEN_FLOAT_TYPE_SWITCH(array->dtype, FloatType, { * ATEN_FLOAT_TYPE_SWITCH(array->dtype, FloatType, {
...@@ -128,7 +128,7 @@ ...@@ -128,7 +128,7 @@
} \ } \
} while (0) } while (0)
/* /**
* Dispatch according to float type, including 16bits (float16/bfloat16/float32/float64). * Dispatch according to float type, including 16bits (float16/bfloat16/float32/float64).
*/ */
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
...@@ -185,7 +185,7 @@ ...@@ -185,7 +185,7 @@
ATEN_FLOAT_TYPE_SWITCH(val, FloatType, val_name, {__VA_ARGS__}) ATEN_FLOAT_TYPE_SWITCH(val, FloatType, val_name, {__VA_ARGS__})
#endif // DGL_USE_CUDA #endif // DGL_USE_CUDA
/* /**
* Dispatch according to data type (int32, int64, float32 or float64): * Dispatch according to data type (int32, int64, float32 or float64):
* *
* ATEN_DTYPE_SWITCH(array->dtype, DType, { * ATEN_DTYPE_SWITCH(array->dtype, DType, {
...@@ -212,7 +212,7 @@ ...@@ -212,7 +212,7 @@
} \ } \
} while (0) } while (0)
/* /**
* Dispatch according to data type (int8, uint8, float32 or float64): * Dispatch according to data type (int8, uint8, float32 or float64):
* *
* ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(array->dtype, DType, { * ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(array->dtype, DType, {
...@@ -239,7 +239,7 @@ ...@@ -239,7 +239,7 @@
} \ } \
} while (0) } while (0)
/* /**
* Dispatch data type only based on bit-width (8-bit, 16-bit, 32-bit, 64-bit): * Dispatch data type only based on bit-width (8-bit, 16-bit, 32-bit, 64-bit):
* *
* ATEN_DTYPE_BITS_ONLY_SWITCH(array->dtype, DType, { * ATEN_DTYPE_BITS_ONLY_SWITCH(array->dtype, DType, {
...@@ -268,7 +268,7 @@ ...@@ -268,7 +268,7 @@
} \ } \
} while (0) } while (0)
/* /**
* Dispatch according to integral type of CSR graphs. * Dispatch according to integral type of CSR graphs.
* Identical to ATEN_ID_TYPE_SWITCH except for a different error message. * Identical to ATEN_ID_TYPE_SWITCH except for a different error message.
*/ */
...@@ -306,7 +306,7 @@ ...@@ -306,7 +306,7 @@
<< "context as " << (#VAR1) << "(" << (VAR1)->ctx << "). " \ << "context as " << (#VAR1) << "(" << (VAR1)->ctx << "). " \
<< "Or " << (#VAR1) << "(" << (VAR1)->ctx << ")" << " is pinned"; << "Or " << (#VAR1) << "(" << (VAR1)->ctx << ")" << " is pinned";
/* /**
* Macro to dispatch according to the context of array and dtype of csr * Macro to dispatch according to the context of array and dtype of csr
* to enable CUDA UVA ops. * to enable CUDA UVA ops.
* Context check is covered here to avoid confusion with CHECK_SAME_CONTEXT. * Context check is covered here to avoid confusion with CHECK_SAME_CONTEXT.
......
/*! /**
* Copyright (c) 2020 by Contributors * Copyright (c) 2020 by Contributors
* @file dgl/aten/spmat.h * @file dgl/aten/spmat.h
* @brief Sparse matrix definitions * @brief Sparse matrix definitions
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
namespace dgl { namespace dgl {
/*! /**
* @brief Sparse format. * @brief Sparse format.
*/ */
enum class SparseFormat { enum class SparseFormat {
...@@ -23,7 +23,7 @@ enum class SparseFormat { ...@@ -23,7 +23,7 @@ enum class SparseFormat {
kCSC = 3, kCSC = 3,
}; };
/*! /**
* @brief Sparse format codes * @brief Sparse format codes
*/ */
const dgl_format_code_t ALL_CODE = 0x7; const dgl_format_code_t ALL_CODE = 0x7;
......
/*! /**
* Copyright (c) 2020 by Contributors * Copyright (c) 2020 by Contributors
* @file dgl/aten/types.h * @file dgl/aten/types.h
* @brief Array and ID types * @brief Array and ID types
...@@ -14,7 +14,7 @@ namespace dgl { ...@@ -14,7 +14,7 @@ namespace dgl {
typedef uint64_t dgl_id_t; typedef uint64_t dgl_id_t;
typedef uint64_t dgl_type_t; typedef uint64_t dgl_type_t;
/*! @brief Type for dgl fomrat code, whose binary representation indices /** @brief Type for dgl fomrat code, whose binary representation indices
* which sparse format is in use and which is not. * which sparse format is in use and which is not.
* *
* Suppose the binary representation is xyz, then * Suppose the binary representation is xyz, then
......
/*! /**
* Copyright (c) 2019 by Contributors * Copyright (c) 2019 by Contributors
* @file dgl/heterograph_interface.h * @file dgl/heterograph_interface.h
* @brief DGL heterogeneous graph index class. * @brief DGL heterogeneous graph index class.
...@@ -31,13 +31,13 @@ typedef std::shared_ptr<FlattenedHeteroGraph> FlattenedHeteroGraphPtr; ...@@ -31,13 +31,13 @@ typedef std::shared_ptr<FlattenedHeteroGraph> FlattenedHeteroGraphPtr;
struct HeteroSubgraph; struct HeteroSubgraph;
/*! @brief Enum class for edge direction */ /** @brief Enum class for edge direction */
enum class EdgeDir { enum class EdgeDir {
kIn, // in edge direction kIn, // in edge direction
kOut // out edge direction kOut // out edge direction
}; };
/*! /**
* @brief Base heterogenous graph. * @brief Base heterogenous graph.
* *
* In heterograph, nodes represent entities and edges represent relations. * In heterograph, nodes represent entities and edges represent relations.
...@@ -58,22 +58,22 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -58,22 +58,22 @@ class BaseHeteroGraph : public runtime::Object {
////////////////////// query/operations on meta graph /////////////////////// ////////////////////// query/operations on meta graph ///////////////////////
/*! @return the number of vertex types */ /** @return the number of vertex types */
virtual uint64_t NumVertexTypes() const { return meta_graph_->NumVertices(); } virtual uint64_t NumVertexTypes() const { return meta_graph_->NumVertices(); }
/*! @return the number of edge types */ /** @return the number of edge types */
virtual uint64_t NumEdgeTypes() const { return meta_graph_->NumEdges(); } virtual uint64_t NumEdgeTypes() const { return meta_graph_->NumEdges(); }
/*! @return given the edge type, find the source type */ /** @return given the edge type, find the source type */
virtual std::pair<dgl_type_t, dgl_type_t> GetEndpointTypes( virtual std::pair<dgl_type_t, dgl_type_t> GetEndpointTypes(
dgl_type_t etype) const { dgl_type_t etype) const {
return meta_graph_->FindEdge(etype); return meta_graph_->FindEdge(etype);
} }
/*! @return the meta graph */ /** @return the meta graph */
virtual GraphPtr meta_graph() const { return meta_graph_; } virtual GraphPtr meta_graph() const { return meta_graph_; }
/*! /**
* @brief Return the bipartite graph of the given edge type. * @brief Return the bipartite graph of the given edge type.
* @param etype The edge type. * @param etype The edge type.
* @return The bipartite graph. * @return The bipartite graph.
...@@ -82,90 +82,90 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -82,90 +82,90 @@ class BaseHeteroGraph : public runtime::Object {
///////////////////// query/operations on realized graph ///////////////////// ///////////////////// query/operations on realized graph /////////////////////
/*! @brief Add vertices to the given vertex type */ /** @brief Add vertices to the given vertex type */
virtual void AddVertices(dgl_type_t vtype, uint64_t num_vertices) = 0; virtual void AddVertices(dgl_type_t vtype, uint64_t num_vertices) = 0;
/*! @brief Add one edge to the given edge type */ /** @brief Add one edge to the given edge type */
virtual void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) = 0; virtual void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) = 0;
/*! @brief Add edges to the given edge type */ /** @brief Add edges to the given edge type */
virtual void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) = 0; virtual void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) = 0;
/*! /**
* @brief Clear the graph. Remove all vertices/edges. * @brief Clear the graph. Remove all vertices/edges.
*/ */
virtual void Clear() = 0; virtual void Clear() = 0;
/*! /**
* @brief Get the data type of node and edge IDs of this graph. * @brief Get the data type of node and edge IDs of this graph.
*/ */
virtual DGLDataType DataType() const = 0; virtual DGLDataType DataType() const = 0;
/*! /**
* @brief Get the device context of this graph. * @brief Get the device context of this graph.
*/ */
virtual DGLContext Context() const = 0; virtual DGLContext Context() const = 0;
/*! /**
* @brief Pin graph. * @brief Pin graph.
*/ */
virtual void PinMemory_() = 0; virtual void PinMemory_() = 0;
/*! /**
* @brief Check if this graph is pinned. * @brief Check if this graph is pinned.
*/ */
virtual bool IsPinned() const = 0; virtual bool IsPinned() const = 0;
/*! /**
* @brief Record stream for this graph. * @brief Record stream for this graph.
* @param stream The stream that is using the graph * @param stream The stream that is using the graph
*/ */
virtual void RecordStream(DGLStreamHandle stream) = 0; virtual void RecordStream(DGLStreamHandle stream) = 0;
/*! /**
* @brief Get the number of integer bits used to store node/edge ids (32 or * @brief Get the number of integer bits used to store node/edge ids (32 or
* 64). * 64).
*/ */
// TODO(BarclayII) replace NumBits() calls to DataType() calls // TODO(BarclayII) replace NumBits() calls to DataType() calls
virtual uint8_t NumBits() const = 0; virtual uint8_t NumBits() const = 0;
/*! /**
* @return whether the graph is a multigraph * @return whether the graph is a multigraph
*/ */
virtual bool IsMultigraph() const = 0; virtual bool IsMultigraph() const = 0;
/*! @return whether the graph is read-only */ /** @return whether the graph is read-only */
virtual bool IsReadonly() const = 0; virtual bool IsReadonly() const = 0;
/*! @return the number of vertices in the graph.*/ /** @return the number of vertices in the graph.*/
virtual uint64_t NumVertices(dgl_type_t vtype) const = 0; virtual uint64_t NumVertices(dgl_type_t vtype) const = 0;
/*! @return the number of vertices for each type in the graph as a vector */ /** @return the number of vertices for each type in the graph as a vector */
inline virtual std::vector<int64_t> NumVerticesPerType() const { inline virtual std::vector<int64_t> NumVerticesPerType() const {
LOG(FATAL) << "[BUG] NumVerticesPerType() not supported on this object."; LOG(FATAL) << "[BUG] NumVerticesPerType() not supported on this object.";
return {}; return {};
} }
/*! @return the number of edges in the graph.*/ /** @return the number of edges in the graph.*/
virtual uint64_t NumEdges(dgl_type_t etype) const = 0; virtual uint64_t NumEdges(dgl_type_t etype) const = 0;
/*! @return true if the given vertex is in the graph.*/ /** @return true if the given vertex is in the graph.*/
virtual bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const = 0; virtual bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const = 0;
/*! @return a 0-1 array indicating whether the given vertices are in the /** @return a 0-1 array indicating whether the given vertices are in the
* graph. * graph.
*/ */
virtual BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const = 0; virtual BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const = 0;
/*! @return true if the given edge is in the graph.*/ /** @return true if the given edge is in the graph.*/
virtual bool HasEdgeBetween( virtual bool HasEdgeBetween(
dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const = 0; dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const = 0;
/*! @return a 0-1 array indicating whether the given edges are in the graph.*/ /** @return a 0-1 array indicating whether the given edges are in the graph.*/
virtual BoolArray HasEdgesBetween( virtual BoolArray HasEdgesBetween(
dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const = 0; dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const = 0;
/*! /**
* @brief Find the predecessors of a vertex. * @brief Find the predecessors of a vertex.
* @note The given vertex should belong to the source vertex type * @note The given vertex should belong to the source vertex type
* of the given edge type. * of the given edge type.
...@@ -175,7 +175,7 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -175,7 +175,7 @@ class BaseHeteroGraph : public runtime::Object {
*/ */
virtual IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const = 0; virtual IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const = 0;
/*! /**
* @brief Find the successors of a vertex. * @brief Find the successors of a vertex.
* @note The given vertex should belong to the dest vertex type * @note The given vertex should belong to the dest vertex type
* of the given edge type. * of the given edge type.
...@@ -185,7 +185,7 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -185,7 +185,7 @@ class BaseHeteroGraph : public runtime::Object {
*/ */
virtual IdArray Successors(dgl_type_t etype, dgl_id_t src) const = 0; virtual IdArray Successors(dgl_type_t etype, dgl_id_t src) const = 0;
/*! /**
* @brief Get all edge ids between the two given endpoints * @brief Get all edge ids between the two given endpoints
* @note The given src and dst vertices should belong to the source vertex * @note The given src and dst vertices should belong to the source vertex
* type and the dest vertex type of the given edge type, respectively. \param * type and the dest vertex type of the given edge type, respectively. \param
...@@ -195,7 +195,7 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -195,7 +195,7 @@ class BaseHeteroGraph : public runtime::Object {
virtual IdArray EdgeId( virtual IdArray EdgeId(
dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const = 0; dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const = 0;
/*! /**
* @brief Get all edge ids between the given endpoint pairs. * @brief Get all edge ids between the given endpoint pairs.
* *
* @param etype The edge type * @param etype The edge type
...@@ -206,7 +206,7 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -206,7 +206,7 @@ class BaseHeteroGraph : public runtime::Object {
virtual EdgeArray EdgeIdsAll( virtual EdgeArray EdgeIdsAll(
dgl_type_t etype, IdArray src, IdArray dst) const = 0; dgl_type_t etype, IdArray src, IdArray dst) const = 0;
/*! /**
* @brief Get edge ids between the given endpoint pairs. * @brief Get edge ids between the given endpoint pairs.
* *
* Only find one matched edge Ids even if there are multiple matches due to * Only find one matched edge Ids even if there are multiple matches due to
...@@ -221,7 +221,7 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -221,7 +221,7 @@ class BaseHeteroGraph : public runtime::Object {
virtual IdArray EdgeIdsOne( virtual IdArray EdgeIdsOne(
dgl_type_t etype, IdArray src, IdArray dst) const = 0; dgl_type_t etype, IdArray src, IdArray dst) const = 0;
/*! /**
* @brief Find the edge ID and return the pair of endpoints * @brief Find the edge ID and return the pair of endpoints
* @param etype The edge type * @param etype The edge type
* @param eid The edge ID * @param eid The edge ID
...@@ -231,7 +231,7 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -231,7 +231,7 @@ class BaseHeteroGraph : public runtime::Object {
virtual std::pair<dgl_id_t, dgl_id_t> FindEdge( virtual std::pair<dgl_id_t, dgl_id_t> FindEdge(
dgl_type_t etype, dgl_id_t eid) const = 0; dgl_type_t etype, dgl_id_t eid) const = 0;
/*! /**
* @brief Find the edge IDs and return their source and target node IDs. * @brief Find the edge IDs and return their source and target node IDs.
* @param etype The edge type * @param etype The edge type
* @param eids The edge ID array. * @param eids The edge ID array.
...@@ -240,7 +240,7 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -240,7 +240,7 @@ class BaseHeteroGraph : public runtime::Object {
*/ */
virtual EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const = 0; virtual EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const = 0;
/*! /**
* @brief Get the in edges of the vertex. * @brief Get the in edges of the vertex.
* @note The given vertex should belong to the dest vertex type * @note The given vertex should belong to the dest vertex type
* of the given edge type. * of the given edge type.
...@@ -250,7 +250,7 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -250,7 +250,7 @@ class BaseHeteroGraph : public runtime::Object {
*/ */
virtual EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const = 0; virtual EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const = 0;
/*! /**
* @brief Get the in edges of the vertices. * @brief Get the in edges of the vertices.
* @note The given vertex should belong to the dest vertex type * @note The given vertex should belong to the dest vertex type
* of the given edge type. * of the given edge type.
...@@ -260,7 +260,7 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -260,7 +260,7 @@ class BaseHeteroGraph : public runtime::Object {
*/ */
virtual EdgeArray InEdges(dgl_type_t etype, IdArray vids) const = 0; virtual EdgeArray InEdges(dgl_type_t etype, IdArray vids) const = 0;
/*! /**
* @brief Get the out edges of the vertex. * @brief Get the out edges of the vertex.
* @note The given vertex should belong to the source vertex type * @note The given vertex should belong to the source vertex type
* of the given edge type. * of the given edge type.
...@@ -270,7 +270,7 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -270,7 +270,7 @@ class BaseHeteroGraph : public runtime::Object {
*/ */
virtual EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const = 0; virtual EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const = 0;
/*! /**
* @brief Get the out edges of the vertices. * @brief Get the out edges of the vertices.
* @note The given vertex should belong to the source vertex type * @note The given vertex should belong to the source vertex type
* of the given edge type. * of the given edge type.
...@@ -280,7 +280,7 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -280,7 +280,7 @@ class BaseHeteroGraph : public runtime::Object {
*/ */
virtual EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const = 0; virtual EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const = 0;
/*! /**
* @brief Get all the edges in the graph. * @brief Get all the edges in the graph.
* @note If order is "srcdst", the returned edges list is sorted by their src * @note If order is "srcdst", the returned edges list is sorted by their src
* and dst ids. If order is "eid", they are in their edge id order. Otherwise, * and dst ids. If order is "eid", they are in their edge id order. Otherwise,
...@@ -291,7 +291,7 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -291,7 +291,7 @@ class BaseHeteroGraph : public runtime::Object {
virtual EdgeArray Edges( virtual EdgeArray Edges(
dgl_type_t etype, const std::string& order = "") const = 0; dgl_type_t etype, const std::string& order = "") const = 0;
/*! /**
* @brief Get the in degree of the given vertex. * @brief Get the in degree of the given vertex.
* @note The given vertex should belong to the dest vertex type * @note The given vertex should belong to the dest vertex type
* of the given edge type. * of the given edge type.
...@@ -301,7 +301,7 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -301,7 +301,7 @@ class BaseHeteroGraph : public runtime::Object {
*/ */
virtual uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const = 0; virtual uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const = 0;
/*! /**
* @brief Get the in degrees of the given vertices. * @brief Get the in degrees of the given vertices.
* @note The given vertex should belong to the dest vertex type * @note The given vertex should belong to the dest vertex type
* of the given edge type. * of the given edge type.
...@@ -311,7 +311,7 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -311,7 +311,7 @@ class BaseHeteroGraph : public runtime::Object {
*/ */
virtual DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const = 0; virtual DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const = 0;
/*! /**
* @brief Get the out degree of the given vertex. * @brief Get the out degree of the given vertex.
* @note The given vertex should belong to the source vertex type * @note The given vertex should belong to the source vertex type
* of the given edge type. * of the given edge type.
...@@ -321,7 +321,7 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -321,7 +321,7 @@ class BaseHeteroGraph : public runtime::Object {
*/ */
virtual uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const = 0; virtual uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const = 0;
/*! /**
* @brief Get the out degrees of the given vertices. * @brief Get the out degrees of the given vertices.
* @note The given vertex should belong to the source vertex type * @note The given vertex should belong to the source vertex type
* of the given edge type. * of the given edge type.
...@@ -331,7 +331,7 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -331,7 +331,7 @@ class BaseHeteroGraph : public runtime::Object {
*/ */
virtual DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const = 0; virtual DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const = 0;
/*! /**
* @brief Return the successor vector * @brief Return the successor vector
* @note The given vertex should belong to the source vertex type * @note The given vertex should belong to the source vertex type
* of the given edge type. * of the given edge type.
...@@ -340,7 +340,7 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -340,7 +340,7 @@ class BaseHeteroGraph : public runtime::Object {
*/ */
virtual DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const = 0; virtual DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const = 0;
/*! /**
* @brief Return the out edge id vector * @brief Return the out edge id vector
* @note The given vertex should belong to the source vertex type * @note The given vertex should belong to the source vertex type
* of the given edge type. * of the given edge type.
...@@ -349,7 +349,7 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -349,7 +349,7 @@ class BaseHeteroGraph : public runtime::Object {
*/ */
virtual DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const = 0; virtual DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const = 0;
/*! /**
* @brief Return the predecessor vector * @brief Return the predecessor vector
* @note The given vertex should belong to the dest vertex type * @note The given vertex should belong to the dest vertex type
* of the given edge type. * of the given edge type.
...@@ -358,7 +358,7 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -358,7 +358,7 @@ class BaseHeteroGraph : public runtime::Object {
*/ */
virtual DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const = 0; virtual DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const = 0;
/*! /**
* @brief Return the in edge id vector * @brief Return the in edge id vector
* @note The given vertex should belong to the dest vertex type * @note The given vertex should belong to the dest vertex type
* of the given edge type. * of the given edge type.
...@@ -367,7 +367,7 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -367,7 +367,7 @@ class BaseHeteroGraph : public runtime::Object {
*/ */
virtual DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const = 0; virtual DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const = 0;
/*! /**
* @brief Get the adjacency matrix of the graph. * @brief Get the adjacency matrix of the graph.
* *
* TODO(minjie): deprecate this interface; replace it with GetXXXMatrix. * TODO(minjie): deprecate this interface; replace it with GetXXXMatrix.
...@@ -388,7 +388,7 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -388,7 +388,7 @@ class BaseHeteroGraph : public runtime::Object {
virtual std::vector<IdArray> GetAdj( virtual std::vector<IdArray> GetAdj(
dgl_type_t etype, bool transpose, const std::string& fmt) const = 0; dgl_type_t etype, bool transpose, const std::string& fmt) const = 0;
/*! /**
* @brief Determine which format to use with a preference. * @brief Determine which format to use with a preference.
* *
...@@ -402,35 +402,35 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -402,35 +402,35 @@ class BaseHeteroGraph : public runtime::Object {
virtual SparseFormat SelectFormat( virtual SparseFormat SelectFormat(
dgl_type_t etype, dgl_format_code_t preferred_formats) const = 0; dgl_type_t etype, dgl_format_code_t preferred_formats) const = 0;
/*! /**
* @brief Return sparse formats already created for the graph. * @brief Return sparse formats already created for the graph.
* *
* @return a number of type dgl_format_code_t. * @return a number of type dgl_format_code_t.
*/ */
virtual dgl_format_code_t GetCreatedFormats() const = 0; virtual dgl_format_code_t GetCreatedFormats() const = 0;
/*! /**
* @brief Return allowed sparse formats for the graph. * @brief Return allowed sparse formats for the graph.
* *
* @return a number of type dgl_format_code_t. * @return a number of type dgl_format_code_t.
*/ */
virtual dgl_format_code_t GetAllowedFormats() const = 0; virtual dgl_format_code_t GetAllowedFormats() const = 0;
/*! /**
* @brief Return the graph in specified available formats. * @brief Return the graph in specified available formats.
* *
* @return The new graph. * @return The new graph.
*/ */
virtual HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const = 0; virtual HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const = 0;
/*! /**
* @brief Get adjacency matrix in COO format. * @brief Get adjacency matrix in COO format.
* @param etype Edge type. * @param etype Edge type.
* @return COO matrix. * @return COO matrix.
*/ */
virtual aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const = 0; virtual aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const = 0;
/*! /**
* @brief Get adjacency matrix in CSR format. * @brief Get adjacency matrix in CSR format.
* *
* The row and column sizes are equal to the number of dsttype and srctype * The row and column sizes are equal to the number of dsttype and srctype
...@@ -441,7 +441,7 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -441,7 +441,7 @@ class BaseHeteroGraph : public runtime::Object {
*/ */
virtual aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const = 0; virtual aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const = 0;
/*! /**
* @brief Get adjacency matrix in CSC format. * @brief Get adjacency matrix in CSC format.
* *
* A CSC matrix is equivalent to the transpose of a CSR matrix. * A CSC matrix is equivalent to the transpose of a CSR matrix.
...@@ -453,7 +453,7 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -453,7 +453,7 @@ class BaseHeteroGraph : public runtime::Object {
*/ */
virtual aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const = 0; virtual aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const = 0;
/*! /**
* @brief Extract the induced subgraph by the given vertices. * @brief Extract the induced subgraph by the given vertices.
* *
* The length of the given vector should be equal to the number of vertex * The length of the given vector should be equal to the number of vertex
...@@ -467,7 +467,7 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -467,7 +467,7 @@ class BaseHeteroGraph : public runtime::Object {
virtual HeteroSubgraph VertexSubgraph( virtual HeteroSubgraph VertexSubgraph(
const std::vector<IdArray>& vids) const = 0; const std::vector<IdArray>& vids) const = 0;
/*! /**
* @brief Extract the induced subgraph by the given edges. * @brief Extract the induced subgraph by the given edges.
* *
* The length of the given vector should be equal to the number of edge types. * The length of the given vector should be equal to the number of edge types.
...@@ -482,7 +482,7 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -482,7 +482,7 @@ class BaseHeteroGraph : public runtime::Object {
virtual HeteroSubgraph EdgeSubgraph( virtual HeteroSubgraph EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes = false) const = 0; const std::vector<IdArray>& eids, bool preserve_nodes = false) const = 0;
/*! /**
* @brief Convert the list of requested unitgraph graphs into a single * @brief Convert the list of requested unitgraph graphs into a single
* unitgraph graph. * unitgraph graph.
* *
...@@ -496,7 +496,7 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -496,7 +496,7 @@ class BaseHeteroGraph : public runtime::Object {
return nullptr; return nullptr;
} }
/*! @brief Cast this graph to immutable graph */ /** @brief Cast this graph to immutable graph */
virtual GraphPtr AsImmutableGraph() const { virtual GraphPtr AsImmutableGraph() const {
LOG(FATAL) << "AsImmutableGraph not supported."; LOG(FATAL) << "AsImmutableGraph not supported.";
return nullptr; return nullptr;
...@@ -506,7 +506,7 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -506,7 +506,7 @@ class BaseHeteroGraph : public runtime::Object {
DGL_DECLARE_OBJECT_TYPE_INFO(BaseHeteroGraph, runtime::Object); DGL_DECLARE_OBJECT_TYPE_INFO(BaseHeteroGraph, runtime::Object);
protected: protected:
/*! @brief meta graph */ /** @brief meta graph */
GraphPtr meta_graph_; GraphPtr meta_graph_;
// empty constructor // empty constructor
...@@ -516,7 +516,7 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -516,7 +516,7 @@ class BaseHeteroGraph : public runtime::Object {
// Define HeteroGraphRef // Define HeteroGraphRef
DGL_DEFINE_OBJECT_REF(HeteroGraphRef, BaseHeteroGraph); DGL_DEFINE_OBJECT_REF(HeteroGraphRef, BaseHeteroGraph);
/*! /**
* @brief Hetero-subgraph data structure. * @brief Hetero-subgraph data structure.
* *
* This class can be used as arguments and return values of a C API. * This class can be used as arguments and return values of a C API.
...@@ -531,16 +531,16 @@ DGL_DEFINE_OBJECT_REF(HeteroGraphRef, BaseHeteroGraph); ...@@ -531,16 +531,16 @@ DGL_DEFINE_OBJECT_REF(HeteroGraphRef, BaseHeteroGraph);
* </code> * </code>
*/ */
struct HeteroSubgraph : public runtime::Object { struct HeteroSubgraph : public runtime::Object {
/*! @brief The heterograph. */ /** @brief The heterograph. */
HeteroGraphPtr graph; HeteroGraphPtr graph;
/*! /**
* @brief The induced vertex ids of each entity type. * @brief The induced vertex ids of each entity type.
* The vector length is equal to the number of vertex types in the parent * The vector length is equal to the number of vertex types in the parent
* graph. Each array i has the same length as the number of vertices in type * graph. Each array i has the same length as the number of vertices in type
* i. Empty array is allowed if the mapping is identity. * i. Empty array is allowed if the mapping is identity.
*/ */
std::vector<IdArray> induced_vertices; std::vector<IdArray> induced_vertices;
/*! /**
* @brief The induced edge ids of each relation type. * @brief The induced edge ids of each relation type.
* The vector length is equal to the number of edge types in the parent graph. * The vector length is equal to the number of edge types in the parent graph.
* Each array i has the same length as the number of edges in type i. * Each array i has the same length as the number of edges in type i.
...@@ -555,46 +555,46 @@ struct HeteroSubgraph : public runtime::Object { ...@@ -555,46 +555,46 @@ struct HeteroSubgraph : public runtime::Object {
// Define HeteroSubgraphRef // Define HeteroSubgraphRef
DGL_DEFINE_OBJECT_REF(HeteroSubgraphRef, HeteroSubgraph); DGL_DEFINE_OBJECT_REF(HeteroSubgraphRef, HeteroSubgraph);
/*! @brief The flattened heterograph */ /** @brief The flattened heterograph */
struct FlattenedHeteroGraph : public runtime::Object { struct FlattenedHeteroGraph : public runtime::Object {
/*! @brief The graph */ /** @brief The graph */
HeteroGraphRef graph; HeteroGraphRef graph;
/*! /**
* @brief Mapping from source node ID to node type in parent graph * @brief Mapping from source node ID to node type in parent graph
* @note The induced type array guarantees that the same type always appear * @note The induced type array guarantees that the same type always appear
* contiguously. * contiguously.
*/ */
IdArray induced_srctype; IdArray induced_srctype;
/*! /**
* @brief The set of node types in parent graph appearing in source nodes. * @brief The set of node types in parent graph appearing in source nodes.
*/ */
IdArray induced_srctype_set; IdArray induced_srctype_set;
/*! @brief Mapping from source node ID to local node ID in parent graph */ /** @brief Mapping from source node ID to local node ID in parent graph */
IdArray induced_srcid; IdArray induced_srcid;
/*! /**
* @brief Mapping from edge ID to edge type in parent graph * @brief Mapping from edge ID to edge type in parent graph
* @note The induced type array guarantees that the same type always appear * @note The induced type array guarantees that the same type always appear
* contiguously. * contiguously.
*/ */
IdArray induced_etype; IdArray induced_etype;
/*! /**
* @brief The set of edge types in parent graph appearing in edges. * @brief The set of edge types in parent graph appearing in edges.
*/ */
IdArray induced_etype_set; IdArray induced_etype_set;
/*! @brief Mapping from edge ID to local edge ID in parent graph */ /** @brief Mapping from edge ID to local edge ID in parent graph */
IdArray induced_eid; IdArray induced_eid;
/*! /**
* @brief Mapping from destination node ID to node type in parent graph * @brief Mapping from destination node ID to node type in parent graph
* @note The induced type array guarantees that the same type always appear * @note The induced type array guarantees that the same type always appear
* contiguously. * contiguously.
*/ */
IdArray induced_dsttype; IdArray induced_dsttype;
/*! /**
* @brief The set of node types in parent graph appearing in destination * @brief The set of node types in parent graph appearing in destination
* nodes. * nodes.
*/ */
IdArray induced_dsttype_set; IdArray induced_dsttype_set;
/*! @brief Mapping from destination node ID to local node ID in parent graph /** @brief Mapping from destination node ID to local node ID in parent graph
*/ */
IdArray induced_dstid; IdArray induced_dstid;
...@@ -618,7 +618,7 @@ DGL_DEFINE_OBJECT_REF(FlattenedHeteroGraphRef, FlattenedHeteroGraph); ...@@ -618,7 +618,7 @@ DGL_DEFINE_OBJECT_REF(FlattenedHeteroGraphRef, FlattenedHeteroGraph);
// Declarations of functions and algorithms // Declarations of functions and algorithms
/*! /**
* @brief Create a heterograph from meta graph and a list of bipartite graph, * @brief Create a heterograph from meta graph and a list of bipartite graph,
* additionally specifying number of nodes per type. * additionally specifying number of nodes per type.
*/ */
...@@ -626,7 +626,7 @@ HeteroGraphPtr CreateHeteroGraph( ...@@ -626,7 +626,7 @@ HeteroGraphPtr CreateHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs, GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs,
const std::vector<int64_t>& num_nodes_per_type = {}); const std::vector<int64_t>& num_nodes_per_type = {});
/*! /**
* @brief Create a heterograph from COO input. * @brief Create a heterograph from COO input.
* @param num_vtypes Number of vertex types. Must be 1 or 2. * @param num_vtypes Number of vertex types. Must be 1 or 2.
* @param num_src Number of nodes in the source type. * @param num_src Number of nodes in the source type.
...@@ -645,7 +645,7 @@ HeteroGraphPtr CreateFromCOO( ...@@ -645,7 +645,7 @@ HeteroGraphPtr CreateFromCOO(
IdArray col, bool row_sorted = false, bool col_sorted = false, IdArray col, bool row_sorted = false, bool col_sorted = false,
dgl_format_code_t formats = ALL_CODE); dgl_format_code_t formats = ALL_CODE);
/*! /**
* @brief Create a heterograph from COO input. * @brief Create a heterograph from COO input.
* @param num_vtypes Number of vertex types. Must be 1 or 2. * @param num_vtypes Number of vertex types. Must be 1 or 2.
* @param mat The COO matrix * @param mat The COO matrix
...@@ -656,7 +656,7 @@ HeteroGraphPtr CreateFromCOO( ...@@ -656,7 +656,7 @@ HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, const aten::COOMatrix& mat, int64_t num_vtypes, const aten::COOMatrix& mat,
dgl_format_code_t formats = ALL_CODE); dgl_format_code_t formats = ALL_CODE);
/*! /**
* @brief Create a heterograph from CSR input. * @brief Create a heterograph from CSR input.
* @param num_vtypes Number of vertex types. Must be 1 or 2. * @param num_vtypes Number of vertex types. Must be 1 or 2.
* @param num_src Number of nodes in the source type. * @param num_src Number of nodes in the source type.
...@@ -671,7 +671,7 @@ HeteroGraphPtr CreateFromCSR( ...@@ -671,7 +671,7 @@ HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr, int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,
IdArray indices, IdArray edge_ids, dgl_format_code_t formats = ALL_CODE); IdArray indices, IdArray edge_ids, dgl_format_code_t formats = ALL_CODE);
/*! /**
* @brief Create a heterograph from CSR input. * @brief Create a heterograph from CSR input.
* @param num_vtypes Number of vertex types. Must be 1 or 2. * @param num_vtypes Number of vertex types. Must be 1 or 2.
* @param mat The CSR matrix * @param mat The CSR matrix
...@@ -682,7 +682,7 @@ HeteroGraphPtr CreateFromCSR( ...@@ -682,7 +682,7 @@ HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, const aten::CSRMatrix& mat, int64_t num_vtypes, const aten::CSRMatrix& mat,
dgl_format_code_t formats = ALL_CODE); dgl_format_code_t formats = ALL_CODE);
/*! /**
* @brief Create a heterograph from CSC input. * @brief Create a heterograph from CSC input.
* @param num_vtypes Number of vertex types. Must be 1 or 2. * @param num_vtypes Number of vertex types. Must be 1 or 2.
* @param num_src Number of nodes in the source type. * @param num_src Number of nodes in the source type.
...@@ -697,7 +697,7 @@ HeteroGraphPtr CreateFromCSC( ...@@ -697,7 +697,7 @@ HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr, int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,
IdArray indices, IdArray edge_ids, dgl_format_code_t formats = ALL_CODE); IdArray indices, IdArray edge_ids, dgl_format_code_t formats = ALL_CODE);
/*! /**
* @brief Create a heterograph from CSC input. * @brief Create a heterograph from CSC input.
* @param num_vtypes Number of vertex types. Must be 1 or 2. * @param num_vtypes Number of vertex types. Must be 1 or 2.
* @param mat The CSC matrix * @param mat The CSC matrix
...@@ -708,7 +708,7 @@ HeteroGraphPtr CreateFromCSC( ...@@ -708,7 +708,7 @@ HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, const aten::CSRMatrix& mat, int64_t num_vtypes, const aten::CSRMatrix& mat,
dgl_format_code_t formats = ALL_CODE); dgl_format_code_t formats = ALL_CODE);
/*! /**
* @brief Extract the subgraph of the in edges of the given nodes. * @brief Extract the subgraph of the in edges of the given nodes.
* @param graph Graph * @param graph Graph
* @param nodes Node IDs of each type * @param nodes Node IDs of each type
...@@ -720,7 +720,7 @@ HeteroSubgraph InEdgeGraph( ...@@ -720,7 +720,7 @@ HeteroSubgraph InEdgeGraph(
const HeteroGraphPtr graph, const std::vector<IdArray>& nodes, const HeteroGraphPtr graph, const std::vector<IdArray>& nodes,
bool relabel_nodes = false); bool relabel_nodes = false);
/*! /**
* @brief Extract the subgraph of the out edges of the given nodes. * @brief Extract the subgraph of the out edges of the given nodes.
* @param graph Graph * @param graph Graph
* @param nodes Node IDs of each type * @param nodes Node IDs of each type
...@@ -732,7 +732,7 @@ HeteroSubgraph OutEdgeGraph( ...@@ -732,7 +732,7 @@ HeteroSubgraph OutEdgeGraph(
const HeteroGraphPtr graph, const std::vector<IdArray>& nodes, const HeteroGraphPtr graph, const std::vector<IdArray>& nodes,
bool relabel_nodes = false); bool relabel_nodes = false);
/*! /**
* @brief Joint union multiple graphs into one graph. * @brief Joint union multiple graphs into one graph.
* *
* All input graphs should have the same metagraph. * All input graphs should have the same metagraph.
...@@ -746,7 +746,7 @@ HeteroSubgraph OutEdgeGraph( ...@@ -746,7 +746,7 @@ HeteroSubgraph OutEdgeGraph(
HeteroGraphPtr JointUnionHeteroGraph( HeteroGraphPtr JointUnionHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs); GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs);
/*! /**
* @brief Union multiple graphs into one with each input graph as one disjoint * @brief Union multiple graphs into one with each input graph as one disjoint
* component. * component.
* *
...@@ -766,7 +766,7 @@ HeteroGraphPtr DisjointUnionHeteroGraph( ...@@ -766,7 +766,7 @@ HeteroGraphPtr DisjointUnionHeteroGraph(
HeteroGraphPtr DisjointUnionHeteroGraph2( HeteroGraphPtr DisjointUnionHeteroGraph2(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs); GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs);
/*! /**
* @brief Slice a contiguous subgraph, e.g. retrieve a component graph from a * @brief Slice a contiguous subgraph, e.g. retrieve a component graph from a
* batched graph. * batched graph.
* *
...@@ -785,7 +785,7 @@ HeteroGraphPtr SliceHeteroGraph( ...@@ -785,7 +785,7 @@ HeteroGraphPtr SliceHeteroGraph(
IdArray num_nodes_per_type, IdArray start_nid_per_type, IdArray num_nodes_per_type, IdArray start_nid_per_type,
IdArray num_edges_per_type, IdArray start_eid_per_type); IdArray num_edges_per_type, IdArray start_eid_per_type);
/*! /**
* @brief Split a graph into multiple disjoin components. * @brief Split a graph into multiple disjoin components.
* *
* Edges across different components are ignored. All the result graphs have the * Edges across different components are ignored. All the result graphs have the
...@@ -814,7 +814,7 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2( ...@@ -814,7 +814,7 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes, GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes,
IdArray edge_sizes); IdArray edge_sizes);
/*! /**
* @brief Structure for pickle/unpickle. * @brief Structure for pickle/unpickle.
* *
* The design principle is to leverage the NDArray class as much as possible so * The design principle is to leverage the NDArray class as much as possible so
...@@ -827,29 +827,29 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2( ...@@ -827,29 +827,29 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(
* This class can be used as arguments and return values of a C API. * This class can be used as arguments and return values of a C API.
*/ */
struct HeteroPickleStates : public runtime::Object { struct HeteroPickleStates : public runtime::Object {
/*! @brief version number */ /** @brief version number */
int64_t version = 0; int64_t version = 0;
/*! @brief Metainformation /** @brief Metainformation
* *
* metagraph, number of nodes per type, format, flags * metagraph, number of nodes per type, format, flags
*/ */
std::string meta; std::string meta;
/*! @brief Arrays representing graph structure (coo or csr) */ /** @brief Arrays representing graph structure (coo or csr) */
std::vector<IdArray> arrays; std::vector<IdArray> arrays;
/* To support backward compatibility, we have to retain fields in the old /* To support backward compatibility, we have to retain fields in the old
* version of HeteroPickleStates * version of HeteroPickleStates
*/ */
/*! @brief Metagraph(64bits ImmutableGraph) */ /** @brief Metagraph(64bits ImmutableGraph) */
GraphPtr metagraph; GraphPtr metagraph;
/*! @brief Number of nodes per type */ /** @brief Number of nodes per type */
std::vector<int64_t> num_nodes_per_type; std::vector<int64_t> num_nodes_per_type;
/*! @brief adjacency matrices of each relation graph */ /** @brief adjacency matrices of each relation graph */
std::vector<std::shared_ptr<SparseMatrix> > adjs; std::vector<std::shared_ptr<SparseMatrix> > adjs;
static constexpr const char* _type_key = "graph.HeteroPickleStates"; static constexpr const char* _type_key = "graph.HeteroPickleStates";
...@@ -859,7 +859,7 @@ struct HeteroPickleStates : public runtime::Object { ...@@ -859,7 +859,7 @@ struct HeteroPickleStates : public runtime::Object {
// Define HeteroPickleStatesRef // Define HeteroPickleStatesRef
DGL_DEFINE_OBJECT_REF(HeteroPickleStatesRef, HeteroPickleStates); DGL_DEFINE_OBJECT_REF(HeteroPickleStatesRef, HeteroPickleStates);
/*! /**
* @brief Create a heterograph from pickling states. * @brief Create a heterograph from pickling states.
* *
* @param states Pickle states * @param states Pickle states
...@@ -867,7 +867,7 @@ DGL_DEFINE_OBJECT_REF(HeteroPickleStatesRef, HeteroPickleStates); ...@@ -867,7 +867,7 @@ DGL_DEFINE_OBJECT_REF(HeteroPickleStatesRef, HeteroPickleStates);
*/ */
HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states); HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states);
/*! /**
* @brief Get the pickling state of the relation graph structure in backend * @brief Get the pickling state of the relation graph structure in backend
* tensors. * tensors.
* *
...@@ -875,7 +875,7 @@ HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states); ...@@ -875,7 +875,7 @@ HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states);
*/ */
HeteroPickleStates HeteroPickle(HeteroGraphPtr graph); HeteroPickleStates HeteroPickle(HeteroGraphPtr graph);
/*! /**
* @brief Old version of HeteroUnpickle, for backward compatibility * @brief Old version of HeteroUnpickle, for backward compatibility
* *
* @param states Pickle states * @param states Pickle states
...@@ -883,7 +883,7 @@ HeteroPickleStates HeteroPickle(HeteroGraphPtr graph); ...@@ -883,7 +883,7 @@ HeteroPickleStates HeteroPickle(HeteroGraphPtr graph);
*/ */
HeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates& states); HeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates& states);
/*! /**
* @brief Create heterograph from pickling states pickled by ForkingPickler. * @brief Create heterograph from pickling states pickled by ForkingPickler.
* *
* This is different from HeteroUnpickle where * This is different from HeteroUnpickle where
...@@ -892,7 +892,7 @@ HeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates& states); ...@@ -892,7 +892,7 @@ HeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates& states);
*/ */
HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates& states); HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates& states);
/*! /**
* @brief Get the pickling states of the relation graph structure in backend * @brief Get the pickling states of the relation graph structure in backend
* tensors for ForkingPickler. * tensors for ForkingPickler.
* *
......
/*! /**
* Copyright (c) 2020 by Contributors * Copyright (c) 2020 by Contributors
* @file dgl/aten/bcast.h * @file dgl/aten/bcast.h
* @brief Broadcast related function C++ header. * @brief Broadcast related function C++ header.
...@@ -14,11 +14,11 @@ ...@@ -14,11 +14,11 @@
using namespace dgl::runtime; using namespace dgl::runtime;
namespace dgl { namespace dgl {
/*! /**
* @brief Broadcast offsets and auxiliary information. * @brief Broadcast offsets and auxiliary information.
*/ */
struct BcastOff { struct BcastOff {
/*! /**
* @brief offset vector of lhs operand and rhs operand. * @brief offset vector of lhs operand and rhs operand.
* @note lhs_offset[i] indicates the start position of the scalar * @note lhs_offset[i] indicates the start position of the scalar
* in lhs operand that required to compute the i-th element * in lhs operand that required to compute the i-th element
...@@ -36,9 +36,9 @@ struct BcastOff { ...@@ -36,9 +36,9 @@ struct BcastOff {
* rhs array. * rhs array.
*/ */
std::vector<int64_t> lhs_offset, rhs_offset; std::vector<int64_t> lhs_offset, rhs_offset;
/*! @brief Whether broadcast is required or not. */ /** @brief Whether broadcast is required or not. */
bool use_bcast; bool use_bcast;
/*! /**
* @brief Auxiliary information for kernel computation * @brief Auxiliary information for kernel computation
* @note lhs_len refers to the left hand side operand length. * @note lhs_len refers to the left hand side operand length.
* e.g. 15 for shape (1, 3, 5) * e.g. 15 for shape (1, 3, 5)
...@@ -52,7 +52,7 @@ struct BcastOff { ...@@ -52,7 +52,7 @@ struct BcastOff {
int64_t lhs_len, rhs_len, out_len, reduce_size; int64_t lhs_len, rhs_len, out_len, reduce_size;
}; };
/*! /**
* @brief: Compute broadcast and auxiliary information given operator * @brief: Compute broadcast and auxiliary information given operator
* and operands for kernel computation. * and operands for kernel computation.
* @param op: a string indicates the operator, could be `add`, `sub`, * @param op: a string indicates the operator, could be `add`, `sub`,
......
/*! /**
* Copyright (c) 2018 by Contributors * Copyright (c) 2018 by Contributors
* @file dgl/graph.h * @file dgl/graph.h
* @brief DGL graph index class. * @brief DGL graph index class.
...@@ -21,20 +21,20 @@ class Graph; ...@@ -21,20 +21,20 @@ class Graph;
class GraphOp; class GraphOp;
typedef std::shared_ptr<Graph> MutableGraphPtr; typedef std::shared_ptr<Graph> MutableGraphPtr;
/*! @brief Mutable graph based on adjacency list. */ /** @brief Mutable graph based on adjacency list. */
class Graph : public GraphInterface { class Graph : public GraphInterface {
public: public:
/*! @brief default constructor */ /** @brief default constructor */
Graph() {} Graph() {}
/*! @brief construct a graph from the coo format. */ /** @brief construct a graph from the coo format. */
Graph(IdArray src_ids, IdArray dst_ids, size_t num_nodes); Graph(IdArray src_ids, IdArray dst_ids, size_t num_nodes);
/*! @brief default copy constructor */ /** @brief default copy constructor */
Graph(const Graph& other) = default; Graph(const Graph& other) = default;
#ifndef _MSC_VER #ifndef _MSC_VER
/*! @brief default move constructor */ /** @brief default move constructor */
Graph(Graph&& other) = default; Graph(Graph&& other) = default;
#else #else
Graph(Graph&& other) { Graph(Graph&& other) {
...@@ -48,13 +48,13 @@ class Graph : public GraphInterface { ...@@ -48,13 +48,13 @@ class Graph : public GraphInterface {
} }
#endif // _MSC_VER #endif // _MSC_VER
/*! @brief default assign constructor */ /** @brief default assign constructor */
Graph& operator=(const Graph& other) = default; Graph& operator=(const Graph& other) = default;
/*! @brief default destructor */ /** @brief default destructor */
~Graph() = default; ~Graph() = default;
/*! /**
* @brief Add vertices to the graph. * @brief Add vertices to the graph.
* @note Since vertices are integers enumerated from zero, only the number of * @note Since vertices are integers enumerated from zero, only the number of
* vertices to be added needs to be specified. * vertices to be added needs to be specified.
...@@ -62,21 +62,21 @@ class Graph : public GraphInterface { ...@@ -62,21 +62,21 @@ class Graph : public GraphInterface {
*/ */
void AddVertices(uint64_t num_vertices) override; void AddVertices(uint64_t num_vertices) override;
/*! /**
* @brief Add one edge to the graph. * @brief Add one edge to the graph.
* @param src The source vertex. * @param src The source vertex.
* @param dst The destination vertex. * @param dst The destination vertex.
*/ */
void AddEdge(dgl_id_t src, dgl_id_t dst) override; void AddEdge(dgl_id_t src, dgl_id_t dst) override;
/*! /**
* @brief Add edges to the graph. * @brief Add edges to the graph.
* @param src_ids The source vertex id array. * @param src_ids The source vertex id array.
* @param dst_ids The destination vertex id array. * @param dst_ids The destination vertex id array.
*/ */
void AddEdges(IdArray src_ids, IdArray dst_ids) override; void AddEdges(IdArray src_ids, IdArray dst_ids) override;
/*! /**
* @brief Clear the graph. Remove all vertices/edges. * @brief Clear the graph. Remove all vertices/edges.
*/ */
void Clear() override { void Clear() override {
...@@ -92,35 +92,35 @@ class Graph : public GraphInterface { ...@@ -92,35 +92,35 @@ class Graph : public GraphInterface {
uint8_t NumBits() const override { return 64; } uint8_t NumBits() const override { return 64; }
/*! /**
* @note not const since we have caches * @note not const since we have caches
* @return whether the graph is a multigraph * @return whether the graph is a multigraph
*/ */
bool IsMultigraph() const override; bool IsMultigraph() const override;
/*! /**
* @return whether the graph is read-only * @return whether the graph is read-only
*/ */
bool IsReadonly() const override { return false; } bool IsReadonly() const override { return false; }
/*! @return the number of vertices in the graph.*/ /** @return the number of vertices in the graph.*/
uint64_t NumVertices() const override { return adjlist_.size(); } uint64_t NumVertices() const override { return adjlist_.size(); }
/*! @return the number of edges in the graph.*/ /** @return the number of edges in the graph.*/
uint64_t NumEdges() const override { return num_edges_; } uint64_t NumEdges() const override { return num_edges_; }
/*! @return a 0-1 array indicating whether the given vertices are in the /** @return a 0-1 array indicating whether the given vertices are in the
* graph. * graph.
*/ */
BoolArray HasVertices(IdArray vids) const override; BoolArray HasVertices(IdArray vids) const override;
/*! @return true if the given edge is in the graph.*/ /** @return true if the given edge is in the graph.*/
bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override; bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override;
/*! @return a 0-1 array indicating whether the given edges are in the graph.*/ /** @return a 0-1 array indicating whether the given edges are in the graph.*/
BoolArray HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const override; BoolArray HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const override;
/*! /**
* @brief Find the predecessors of a vertex. * @brief Find the predecessors of a vertex.
* @param vid The vertex id. * @param vid The vertex id.
* @param radius The radius of the neighborhood. Default is immediate neighbor * @param radius The radius of the neighborhood. Default is immediate neighbor
...@@ -129,7 +129,7 @@ class Graph : public GraphInterface { ...@@ -129,7 +129,7 @@ class Graph : public GraphInterface {
*/ */
IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override; IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override;
/*! /**
* @brief Find the successors of a vertex. * @brief Find the successors of a vertex.
* @param vid The vertex id. * @param vid The vertex id.
* @param radius The radius of the neighborhood. Default is immediate neighbor * @param radius The radius of the neighborhood. Default is immediate neighbor
...@@ -138,7 +138,7 @@ class Graph : public GraphInterface { ...@@ -138,7 +138,7 @@ class Graph : public GraphInterface {
*/ */
IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const override; IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const override;
/*! /**
* @brief Get all edge ids between the two given endpoints * @brief Get all edge ids between the two given endpoints
* @note Edges are associated with an integer id start from zero. * @note Edges are associated with an integer id start from zero.
* The id is assigned when the edge is being added to the graph. * The id is assigned when the edge is being added to the graph.
...@@ -148,7 +148,7 @@ class Graph : public GraphInterface { ...@@ -148,7 +148,7 @@ class Graph : public GraphInterface {
*/ */
IdArray EdgeId(dgl_id_t src, dgl_id_t dst) const override; IdArray EdgeId(dgl_id_t src, dgl_id_t dst) const override;
/*! /**
* @brief Get all edge ids between the given endpoint pairs. * @brief Get all edge ids between the given endpoint pairs.
* @note Edges are associated with an integer id start from zero. * @note Edges are associated with an integer id start from zero.
* The id is assigned when the edge is being added to the graph. * The id is assigned when the edge is being added to the graph.
...@@ -159,7 +159,7 @@ class Graph : public GraphInterface { ...@@ -159,7 +159,7 @@ class Graph : public GraphInterface {
*/ */
EdgeArray EdgeIds(IdArray src, IdArray dst) const override; EdgeArray EdgeIds(IdArray src, IdArray dst) const override;
/*! /**
* @brief Find the edge ID and return the pair of endpoints * @brief Find the edge ID and return the pair of endpoints
* @param eid The edge ID * @param eid The edge ID
* @return a pair whose first element is the source and the second the * @return a pair whose first element is the source and the second the
...@@ -169,7 +169,7 @@ class Graph : public GraphInterface { ...@@ -169,7 +169,7 @@ class Graph : public GraphInterface {
return std::make_pair(all_edges_src_[eid], all_edges_dst_[eid]); return std::make_pair(all_edges_src_[eid], all_edges_dst_[eid]);
} }
/*! /**
* @brief Find the edge IDs and return their source and target node IDs. * @brief Find the edge IDs and return their source and target node IDs.
* @param eids The edge ID array. * @param eids The edge ID array.
* @return EdgeArray containing all edges with id in eid. The order is * @return EdgeArray containing all edges with id in eid. The order is
...@@ -177,7 +177,7 @@ class Graph : public GraphInterface { ...@@ -177,7 +177,7 @@ class Graph : public GraphInterface {
*/ */
EdgeArray FindEdges(IdArray eids) const override; EdgeArray FindEdges(IdArray eids) const override;
/*! /**
* @brief Get the in edges of the vertex. * @brief Get the in edges of the vertex.
* @note The returned dst id array is filled with vid. * @note The returned dst id array is filled with vid.
* @param vid The vertex id. * @param vid The vertex id.
...@@ -185,14 +185,14 @@ class Graph : public GraphInterface { ...@@ -185,14 +185,14 @@ class Graph : public GraphInterface {
*/ */
EdgeArray InEdges(dgl_id_t vid) const override; EdgeArray InEdges(dgl_id_t vid) const override;
/*! /**
* @brief Get the in edges of the vertices. * @brief Get the in edges of the vertices.
* @param vids The vertex id array. * @param vids The vertex id array.
* @return the id arrays of the two endpoints of the edges. * @return the id arrays of the two endpoints of the edges.
*/ */
EdgeArray InEdges(IdArray vids) const override; EdgeArray InEdges(IdArray vids) const override;
/*! /**
* @brief Get the out edges of the vertex. * @brief Get the out edges of the vertex.
* @note The returned src id array is filled with vid. * @note The returned src id array is filled with vid.
* @param vid The vertex id. * @param vid The vertex id.
...@@ -200,14 +200,14 @@ class Graph : public GraphInterface { ...@@ -200,14 +200,14 @@ class Graph : public GraphInterface {
*/ */
EdgeArray OutEdges(dgl_id_t vid) const override; EdgeArray OutEdges(dgl_id_t vid) const override;
/*! /**
* @brief Get the out edges of the vertices. * @brief Get the out edges of the vertices.
* @param vids The vertex id array. * @param vids The vertex id array.
* @return the id arrays of the two endpoints of the edges. * @return the id arrays of the two endpoints of the edges.
*/ */
EdgeArray OutEdges(IdArray vids) const override; EdgeArray OutEdges(IdArray vids) const override;
/*! /**
* @brief Get all the edges in the graph. * @brief Get all the edges in the graph.
* @note If sorted is true, the returned edges list is sorted by their src and * @note If sorted is true, the returned edges list is sorted by their src and
* dst ids. Otherwise, they are in their edge id order. * dst ids. Otherwise, they are in their edge id order.
...@@ -217,7 +217,7 @@ class Graph : public GraphInterface { ...@@ -217,7 +217,7 @@ class Graph : public GraphInterface {
*/ */
EdgeArray Edges(const std::string& order = "") const override; EdgeArray Edges(const std::string& order = "") const override;
/*! /**
* @brief Get the in degree of the given vertex. * @brief Get the in degree of the given vertex.
* @param vid The vertex id. * @param vid The vertex id.
* @return the in degree * @return the in degree
...@@ -227,14 +227,14 @@ class Graph : public GraphInterface { ...@@ -227,14 +227,14 @@ class Graph : public GraphInterface {
return reverse_adjlist_[vid].succ.size(); return reverse_adjlist_[vid].succ.size();
} }
/*! /**
* @brief Get the in degrees of the given vertices. * @brief Get the in degrees of the given vertices.
* @param vid The vertex id array. * @param vid The vertex id array.
* @return the in degree array * @return the in degree array
*/ */
DegreeArray InDegrees(IdArray vids) const override; DegreeArray InDegrees(IdArray vids) const override;
/*! /**
* @brief Get the out degree of the given vertex. * @brief Get the out degree of the given vertex.
* @param vid The vertex id. * @param vid The vertex id.
* @return the out degree * @return the out degree
...@@ -244,14 +244,14 @@ class Graph : public GraphInterface { ...@@ -244,14 +244,14 @@ class Graph : public GraphInterface {
return adjlist_[vid].succ.size(); return adjlist_[vid].succ.size();
} }
/*! /**
* @brief Get the out degrees of the given vertices. * @brief Get the out degrees of the given vertices.
* @param vid The vertex id array. * @param vid The vertex id array.
* @return the out degree array * @return the out degree array
*/ */
DegreeArray OutDegrees(IdArray vids) const override; DegreeArray OutDegrees(IdArray vids) const override;
/*! /**
* @brief Construct the induced subgraph of the given vertices. * @brief Construct the induced subgraph of the given vertices.
* *
* The induced subgraph is a subgraph formed by specifying a set of vertices * The induced subgraph is a subgraph formed by specifying a set of vertices
...@@ -270,7 +270,7 @@ class Graph : public GraphInterface { ...@@ -270,7 +270,7 @@ class Graph : public GraphInterface {
*/ */
Subgraph VertexSubgraph(IdArray vids) const override; Subgraph VertexSubgraph(IdArray vids) const override;
/*! /**
* @brief Construct the induced edge subgraph of the given edges. * @brief Construct the induced edge subgraph of the given edges.
* *
* The induced edges subgraph is a subgraph formed by specifying a set of * The induced edges subgraph is a subgraph formed by specifying a set of
...@@ -290,7 +290,7 @@ class Graph : public GraphInterface { ...@@ -290,7 +290,7 @@ class Graph : public GraphInterface {
Subgraph EdgeSubgraph( Subgraph EdgeSubgraph(
IdArray eids, bool preserve_nodes = false) const override; IdArray eids, bool preserve_nodes = false) const override;
/*! /**
* @brief Return the successor vector * @brief Return the successor vector
* @param vid The vertex id. * @param vid The vertex id.
* @return the successor vector * @return the successor vector
...@@ -301,7 +301,7 @@ class Graph : public GraphInterface { ...@@ -301,7 +301,7 @@ class Graph : public GraphInterface {
return DGLIdIters(data, data + size); return DGLIdIters(data, data + size);
} }
/*! /**
* @brief Return the out edge id vector * @brief Return the out edge id vector
* @param vid The vertex id. * @param vid The vertex id.
* @return the out edge id vector * @return the out edge id vector
...@@ -312,7 +312,7 @@ class Graph : public GraphInterface { ...@@ -312,7 +312,7 @@ class Graph : public GraphInterface {
return DGLIdIters(data, data + size); return DGLIdIters(data, data + size);
} }
/*! /**
* @brief Return the predecessor vector * @brief Return the predecessor vector
* @param vid The vertex id. * @param vid The vertex id.
* @return the predecessor vector * @return the predecessor vector
...@@ -323,7 +323,7 @@ class Graph : public GraphInterface { ...@@ -323,7 +323,7 @@ class Graph : public GraphInterface {
return DGLIdIters(data, data + size); return DGLIdIters(data, data + size);
} }
/*! /**
* @brief Return the in edge id vector * @brief Return the in edge id vector
* @param vid The vertex id. * @param vid The vertex id.
* @return the in edge id vector * @return the in edge id vector
...@@ -334,7 +334,7 @@ class Graph : public GraphInterface { ...@@ -334,7 +334,7 @@ class Graph : public GraphInterface {
return DGLIdIters(data, data + size); return DGLIdIters(data, data + size);
} }
/*! /**
* @brief Get the adjacency matrix of the graph. * @brief Get the adjacency matrix of the graph.
* *
* By default, a row of returned adjacency matrix represents the destination * By default, a row of returned adjacency matrix represents the destination
...@@ -346,10 +346,10 @@ class Graph : public GraphInterface { ...@@ -346,10 +346,10 @@ class Graph : public GraphInterface {
std::vector<IdArray> GetAdj( std::vector<IdArray> GetAdj(
bool transpose, const std::string& fmt) const override; bool transpose, const std::string& fmt) const override;
/*! @brief Create an empty graph */ /** @brief Create an empty graph */
static MutableGraphPtr Create() { return std::make_shared<Graph>(); } static MutableGraphPtr Create() { return std::make_shared<Graph>(); }
/*! @brief Create from coo */ /** @brief Create from coo */
static MutableGraphPtr CreateFromCOO( static MutableGraphPtr CreateFromCOO(
int64_t num_nodes, IdArray src_ids, IdArray dst_ids) { int64_t num_nodes, IdArray src_ids, IdArray dst_ids) {
return std::make_shared<Graph>(src_ids, dst_ids, num_nodes); return std::make_shared<Graph>(src_ids, dst_ids, num_nodes);
...@@ -357,29 +357,29 @@ class Graph : public GraphInterface { ...@@ -357,29 +357,29 @@ class Graph : public GraphInterface {
protected: protected:
friend class GraphOp; friend class GraphOp;
/*! @brief Internal edge list type */ /** @brief Internal edge list type */
struct EdgeList { struct EdgeList {
/*! @brief successor vertex list */ /** @brief successor vertex list */
std::vector<dgl_id_t> succ; std::vector<dgl_id_t> succ;
/*! @brief out edge list */ /** @brief out edge list */
std::vector<dgl_id_t> edge_id; std::vector<dgl_id_t> edge_id;
}; };
typedef std::vector<EdgeList> AdjacencyList; typedef std::vector<EdgeList> AdjacencyList;
/*! @brief adjacency list using vector storage */ /** @brief adjacency list using vector storage */
AdjacencyList adjlist_; AdjacencyList adjlist_;
/*! @brief reverse adjacency list using vector storage */ /** @brief reverse adjacency list using vector storage */
AdjacencyList reverse_adjlist_; AdjacencyList reverse_adjlist_;
/*! @brief all edges' src endpoints in their edge id order */ /** @brief all edges' src endpoints in their edge id order */
std::vector<dgl_id_t> all_edges_src_; std::vector<dgl_id_t> all_edges_src_;
/*! @brief all edges' dst endpoints in their edge id order */ /** @brief all edges' dst endpoints in their edge id order */
std::vector<dgl_id_t> all_edges_dst_; std::vector<dgl_id_t> all_edges_dst_;
/*! @brief read only flag */ /** @brief read only flag */
bool read_only_ = false; bool read_only_ = false;
/*! @brief number of edges */ /** @brief number of edges */
uint64_t num_edges_ = 0; uint64_t num_edges_ = 0;
}; };
......
/*! /**
* Copyright (c) 2018 by Contributors * Copyright (c) 2018 by Contributors
* @file dgl/graph_interface.h * @file dgl/graph_interface.h
* @brief DGL graph index class. * @brief DGL graph index class.
...@@ -19,7 +19,7 @@ namespace dgl { ...@@ -19,7 +19,7 @@ namespace dgl {
const dgl_id_t DGL_INVALID_ID = static_cast<dgl_id_t>(-1); const dgl_id_t DGL_INVALID_ID = static_cast<dgl_id_t>(-1);
/*! /**
* @brief This class references data in std::vector. * @brief This class references data in std::vector.
* *
* This isn't a STL-style iterator. It provides a STL data container interface. * This isn't a STL-style iterator. It provides a STL data container interface.
...@@ -28,9 +28,9 @@ const dgl_id_t DGL_INVALID_ID = static_cast<dgl_id_t>(-1); ...@@ -28,9 +28,9 @@ const dgl_id_t DGL_INVALID_ID = static_cast<dgl_id_t>(-1);
*/ */
class DGLIdIters { class DGLIdIters {
public: public:
/* !\brief default constructor to create an empty range */ /** @brief default constructor to create an empty range */
DGLIdIters() {} DGLIdIters() {}
/* !\brief constructor with given begin and end */ /** @brief constructor with given begin and end */
DGLIdIters(const dgl_id_t *begin, const dgl_id_t *end) { DGLIdIters(const dgl_id_t *begin, const dgl_id_t *end) {
this->begin_ = begin; this->begin_ = begin;
this->end_ = end; this->end_ = end;
...@@ -44,15 +44,15 @@ class DGLIdIters { ...@@ -44,15 +44,15 @@ class DGLIdIters {
const dgl_id_t *begin_{nullptr}, *end_{nullptr}; const dgl_id_t *begin_{nullptr}, *end_{nullptr};
}; };
/*! /**
* @brief int32 version for DGLIdIters * @brief int32 version for DGLIdIters
* *
*/ */
class DGLIdIters32 { class DGLIdIters32 {
public: public:
/* !\brief default constructor to create an empty range */ /** @brief default constructor to create an empty range */
DGLIdIters32() {} DGLIdIters32() {}
/* !\brief constructor with given begin and end */ /** @brief constructor with given begin and end */
DGLIdIters32(const int32_t *begin, const int32_t *end) { DGLIdIters32(const int32_t *begin, const int32_t *end) {
this->begin_ = begin; this->begin_ = begin;
this->end_ = end; this->end_ = end;
...@@ -78,7 +78,7 @@ class GraphRef; ...@@ -78,7 +78,7 @@ class GraphRef;
class GraphInterface; class GraphInterface;
typedef std::shared_ptr<GraphInterface> GraphPtr; typedef std::shared_ptr<GraphInterface> GraphPtr;
/*! /**
* @brief dgl graph index interface. * @brief dgl graph index interface.
* *
* DGL's graph is directed. Vertices are integers enumerated from zero. * DGL's graph is directed. Vertices are integers enumerated from zero.
...@@ -93,7 +93,7 @@ class GraphInterface : public runtime::Object { ...@@ -93,7 +93,7 @@ class GraphInterface : public runtime::Object {
public: public:
virtual ~GraphInterface() = default; virtual ~GraphInterface() = default;
/*! /**
* @brief Add vertices to the graph. * @brief Add vertices to the graph.
* @note Since vertices are integers enumerated from zero, only the number of * @note Since vertices are integers enumerated from zero, only the number of
* vertices to be added needs to be specified. * vertices to be added needs to be specified.
...@@ -101,42 +101,42 @@ class GraphInterface : public runtime::Object { ...@@ -101,42 +101,42 @@ class GraphInterface : public runtime::Object {
*/ */
virtual void AddVertices(uint64_t num_vertices) = 0; virtual void AddVertices(uint64_t num_vertices) = 0;
/*! /**
* @brief Add one edge to the graph. * @brief Add one edge to the graph.
* @param src The source vertex. * @param src The source vertex.
* @param dst The destination vertex. * @param dst The destination vertex.
*/ */
virtual void AddEdge(dgl_id_t src, dgl_id_t dst) = 0; virtual void AddEdge(dgl_id_t src, dgl_id_t dst) = 0;
/*! /**
* @brief Add edges to the graph. * @brief Add edges to the graph.
* @param src_ids The source vertex id array. * @param src_ids The source vertex id array.
* @param dst_ids The destination vertex id array. * @param dst_ids The destination vertex id array.
*/ */
virtual void AddEdges(IdArray src_ids, IdArray dst_ids) = 0; virtual void AddEdges(IdArray src_ids, IdArray dst_ids) = 0;
/*! /**
* @brief Clear the graph. Remove all vertices/edges. * @brief Clear the graph. Remove all vertices/edges.
*/ */
virtual void Clear() = 0; virtual void Clear() = 0;
/*! /**
* @brief Get the device context of this graph. * @brief Get the device context of this graph.
*/ */
virtual DGLContext Context() const = 0; virtual DGLContext Context() const = 0;
/*! /**
* @brief Get the number of integer bits used to store node/edge ids * @brief Get the number of integer bits used to store node/edge ids
* (32 or 64). * (32 or 64).
*/ */
virtual uint8_t NumBits() const = 0; virtual uint8_t NumBits() const = 0;
/*! /**
* @return whether the graph is a multigraph * @return whether the graph is a multigraph
*/ */
virtual bool IsMultigraph() const = 0; virtual bool IsMultigraph() const = 0;
/*! /**
* @return whether the graph is unibipartite * @return whether the graph is unibipartite
*/ */
virtual bool IsUniBipartite() const { virtual bool IsUniBipartite() const {
...@@ -167,32 +167,32 @@ class GraphInterface : public runtime::Object { ...@@ -167,32 +167,32 @@ class GraphInterface : public runtime::Object {
return is_unibipartite; return is_unibipartite;
} }
/*! /**
* @return whether the graph is read-only * @return whether the graph is read-only
*/ */
virtual bool IsReadonly() const = 0; virtual bool IsReadonly() const = 0;
/*! @return the number of vertices in the graph.*/ /** @return the number of vertices in the graph.*/
virtual uint64_t NumVertices() const = 0; virtual uint64_t NumVertices() const = 0;
/*! @return the number of edges in the graph.*/ /** @return the number of edges in the graph.*/
virtual uint64_t NumEdges() const = 0; virtual uint64_t NumEdges() const = 0;
/*! @return true if the given vertex is in the graph.*/ /** @return true if the given vertex is in the graph.*/
virtual bool HasVertex(dgl_id_t vid) const { return vid < NumVertices(); } virtual bool HasVertex(dgl_id_t vid) const { return vid < NumVertices(); }
/*! @return a 0-1 array indicating whether the given vertices are in the /** @return a 0-1 array indicating whether the given vertices are in the
* graph. * graph.
*/ */
virtual BoolArray HasVertices(IdArray vids) const = 0; virtual BoolArray HasVertices(IdArray vids) const = 0;
/*! @return true if the given edge is in the graph.*/ /** @return true if the given edge is in the graph.*/
virtual bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const = 0; virtual bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const = 0;
/*! @return a 0-1 array indicating whether the given edges are in the graph.*/ /** @return a 0-1 array indicating whether the given edges are in the graph.*/
virtual BoolArray HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const = 0; virtual BoolArray HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const = 0;
/*! /**
* @brief Find the predecessors of a vertex. * @brief Find the predecessors of a vertex.
* @param vid The vertex id. * @param vid The vertex id.
* @param radius The radius of the neighborhood. Default is immediate neighbor * @param radius The radius of the neighborhood. Default is immediate neighbor
...@@ -201,7 +201,7 @@ class GraphInterface : public runtime::Object { ...@@ -201,7 +201,7 @@ class GraphInterface : public runtime::Object {
*/ */
virtual IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const = 0; virtual IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const = 0;
/*! /**
* @brief Find the successors of a vertex. * @brief Find the successors of a vertex.
* @param vid The vertex id. * @param vid The vertex id.
* @param radius The radius of the neighborhood. Default is immediate neighbor * @param radius The radius of the neighborhood. Default is immediate neighbor
...@@ -210,7 +210,7 @@ class GraphInterface : public runtime::Object { ...@@ -210,7 +210,7 @@ class GraphInterface : public runtime::Object {
*/ */
virtual IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const = 0; virtual IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const = 0;
/*! /**
* @brief Get all edge ids between the two given endpoints * @brief Get all edge ids between the two given endpoints
* @note Edges are associated with an integer id start from zero. * @note Edges are associated with an integer id start from zero.
* The id is assigned when the edge is being added to the graph. * The id is assigned when the edge is being added to the graph.
...@@ -220,7 +220,7 @@ class GraphInterface : public runtime::Object { ...@@ -220,7 +220,7 @@ class GraphInterface : public runtime::Object {
*/ */
virtual IdArray EdgeId(dgl_id_t src, dgl_id_t dst) const = 0; virtual IdArray EdgeId(dgl_id_t src, dgl_id_t dst) const = 0;
/*! /**
* @brief Get all edge ids between the given endpoint pairs. * @brief Get all edge ids between the given endpoint pairs.
* @note Edges are associated with an integer id start from zero. * @note Edges are associated with an integer id start from zero.
* The id is assigned when the edge is being added to the graph. * The id is assigned when the edge is being added to the graph.
...@@ -231,7 +231,7 @@ class GraphInterface : public runtime::Object { ...@@ -231,7 +231,7 @@ class GraphInterface : public runtime::Object {
*/ */
virtual EdgeArray EdgeIds(IdArray src, IdArray dst) const = 0; virtual EdgeArray EdgeIds(IdArray src, IdArray dst) const = 0;
/*! /**
* @brief Find the edge ID and return the pair of endpoints * @brief Find the edge ID and return the pair of endpoints
* @param eid The edge ID * @param eid The edge ID
* @return a pair whose first element is the source and the second the * @return a pair whose first element is the source and the second the
...@@ -239,7 +239,7 @@ class GraphInterface : public runtime::Object { ...@@ -239,7 +239,7 @@ class GraphInterface : public runtime::Object {
*/ */
virtual std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const = 0; virtual std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const = 0;
/*! /**
* @brief Find the edge IDs and return their source and target node IDs. * @brief Find the edge IDs and return their source and target node IDs.
* @param eids The edge ID array. * @param eids The edge ID array.
* @return EdgeArray containing all edges with id in eid. The order is * @return EdgeArray containing all edges with id in eid. The order is
...@@ -247,7 +247,7 @@ class GraphInterface : public runtime::Object { ...@@ -247,7 +247,7 @@ class GraphInterface : public runtime::Object {
*/ */
virtual EdgeArray FindEdges(IdArray eids) const = 0; virtual EdgeArray FindEdges(IdArray eids) const = 0;
/*! /**
* @brief Get the in edges of the vertex. * @brief Get the in edges of the vertex.
* @note The returned dst id array is filled with vid. * @note The returned dst id array is filled with vid.
* @param vid The vertex id. * @param vid The vertex id.
...@@ -255,14 +255,14 @@ class GraphInterface : public runtime::Object { ...@@ -255,14 +255,14 @@ class GraphInterface : public runtime::Object {
*/ */
virtual EdgeArray InEdges(dgl_id_t vid) const = 0; virtual EdgeArray InEdges(dgl_id_t vid) const = 0;
/*! /**
* @brief Get the in edges of the vertices. * @brief Get the in edges of the vertices.
* @param vids The vertex id array. * @param vids The vertex id array.
* @return the id arrays of the two endpoints of the edges. * @return the id arrays of the two endpoints of the edges.
*/ */
virtual EdgeArray InEdges(IdArray vids) const = 0; virtual EdgeArray InEdges(IdArray vids) const = 0;
/*! /**
* @brief Get the out edges of the vertex. * @brief Get the out edges of the vertex.
* @note The returned src id array is filled with vid. * @note The returned src id array is filled with vid.
* @param vid The vertex id. * @param vid The vertex id.
...@@ -270,14 +270,14 @@ class GraphInterface : public runtime::Object { ...@@ -270,14 +270,14 @@ class GraphInterface : public runtime::Object {
*/ */
virtual EdgeArray OutEdges(dgl_id_t vid) const = 0; virtual EdgeArray OutEdges(dgl_id_t vid) const = 0;
/*! /**
* @brief Get the out edges of the vertices. * @brief Get the out edges of the vertices.
* @param vids The vertex id array. * @param vids The vertex id array.
* @return the id arrays of the two endpoints of the edges. * @return the id arrays of the two endpoints of the edges.
*/ */
virtual EdgeArray OutEdges(IdArray vids) const = 0; virtual EdgeArray OutEdges(IdArray vids) const = 0;
/*! /**
* @brief Get all the edges in the graph. * @brief Get all the edges in the graph.
* @note If order is "srcdst", the returned edges list is sorted by their src * @note If order is "srcdst", the returned edges list is sorted by their src
* and dst ids. If order is "eid", they are in their edge id order. * and dst ids. If order is "eid", they are in their edge id order.
...@@ -287,35 +287,35 @@ class GraphInterface : public runtime::Object { ...@@ -287,35 +287,35 @@ class GraphInterface : public runtime::Object {
*/ */
virtual EdgeArray Edges(const std::string &order = "") const = 0; virtual EdgeArray Edges(const std::string &order = "") const = 0;
/*! /**
* @brief Get the in degree of the given vertex. * @brief Get the in degree of the given vertex.
* @param vid The vertex id. * @param vid The vertex id.
* @return the in degree * @return the in degree
*/ */
virtual uint64_t InDegree(dgl_id_t vid) const = 0; virtual uint64_t InDegree(dgl_id_t vid) const = 0;
/*! /**
* @brief Get the in degrees of the given vertices. * @brief Get the in degrees of the given vertices.
* @param vid The vertex id array. * @param vid The vertex id array.
* @return the in degree array * @return the in degree array
*/ */
virtual DegreeArray InDegrees(IdArray vids) const = 0; virtual DegreeArray InDegrees(IdArray vids) const = 0;
/*! /**
* @brief Get the out degree of the given vertex. * @brief Get the out degree of the given vertex.
* @param vid The vertex id. * @param vid The vertex id.
* @return the out degree * @return the out degree
*/ */
virtual uint64_t OutDegree(dgl_id_t vid) const = 0; virtual uint64_t OutDegree(dgl_id_t vid) const = 0;
/*! /**
* @brief Get the out degrees of the given vertices. * @brief Get the out degrees of the given vertices.
* @param vid The vertex id array. * @param vid The vertex id array.
* @return the out degree array * @return the out degree array
*/ */
virtual DegreeArray OutDegrees(IdArray vids) const = 0; virtual DegreeArray OutDegrees(IdArray vids) const = 0;
/*! /**
* @brief Construct the induced subgraph of the given vertices. * @brief Construct the induced subgraph of the given vertices.
* *
* The induced subgraph is a subgraph formed by specifying a set of vertices * The induced subgraph is a subgraph formed by specifying a set of vertices
...@@ -334,7 +334,7 @@ class GraphInterface : public runtime::Object { ...@@ -334,7 +334,7 @@ class GraphInterface : public runtime::Object {
*/ */
virtual Subgraph VertexSubgraph(IdArray vids) const = 0; virtual Subgraph VertexSubgraph(IdArray vids) const = 0;
/*! /**
* @brief Construct the induced edge subgraph of the given edges. * @brief Construct the induced edge subgraph of the given edges.
* *
* The induced edges subgraph is a subgraph formed by specifying a set of * The induced edges subgraph is a subgraph formed by specifying a set of
...@@ -356,35 +356,35 @@ class GraphInterface : public runtime::Object { ...@@ -356,35 +356,35 @@ class GraphInterface : public runtime::Object {
virtual Subgraph EdgeSubgraph( virtual Subgraph EdgeSubgraph(
IdArray eids, bool preserve_nodes = false) const = 0; IdArray eids, bool preserve_nodes = false) const = 0;
/*! /**
* @brief Return the successor vector * @brief Return the successor vector
* @param vid The vertex id. * @param vid The vertex id.
* @return the successor vector iterator pair. * @return the successor vector iterator pair.
*/ */
virtual DGLIdIters SuccVec(dgl_id_t vid) const = 0; virtual DGLIdIters SuccVec(dgl_id_t vid) const = 0;
/*! /**
* @brief Return the out edge id vector * @brief Return the out edge id vector
* @param vid The vertex id. * @param vid The vertex id.
* @return the out edge id vector iterator pair. * @return the out edge id vector iterator pair.
*/ */
virtual DGLIdIters OutEdgeVec(dgl_id_t vid) const = 0; virtual DGLIdIters OutEdgeVec(dgl_id_t vid) const = 0;
/*! /**
* @brief Return the predecessor vector * @brief Return the predecessor vector
* @param vid The vertex id. * @param vid The vertex id.
* @return the predecessor vector iterator pair. * @return the predecessor vector iterator pair.
*/ */
virtual DGLIdIters PredVec(dgl_id_t vid) const = 0; virtual DGLIdIters PredVec(dgl_id_t vid) const = 0;
/*! /**
* @brief Return the in edge id vector * @brief Return the in edge id vector
* @param vid The vertex id. * @param vid The vertex id.
* @return the in edge id vector iterator pair. * @return the in edge id vector iterator pair.
*/ */
virtual DGLIdIters InEdgeVec(dgl_id_t vid) const = 0; virtual DGLIdIters InEdgeVec(dgl_id_t vid) const = 0;
/*! /**
* @brief Get the adjacency matrix of the graph. * @brief Get the adjacency matrix of the graph.
* *
* By default, a row of returned adjacency matrix represents the destination * By default, a row of returned adjacency matrix represents the destination
...@@ -403,7 +403,7 @@ class GraphInterface : public runtime::Object { ...@@ -403,7 +403,7 @@ class GraphInterface : public runtime::Object {
virtual std::vector<IdArray> GetAdj( virtual std::vector<IdArray> GetAdj(
bool transpose, const std::string &fmt) const = 0; bool transpose, const std::string &fmt) const = 0;
/*! /**
* @brief Sort the columns in CSR. * @brief Sort the columns in CSR.
* *
* This sorts the columns in each row based on the column Ids. * This sorts the columns in each row based on the column Ids.
...@@ -418,17 +418,17 @@ class GraphInterface : public runtime::Object { ...@@ -418,17 +418,17 @@ class GraphInterface : public runtime::Object {
// Define GraphRef // Define GraphRef
DGL_DEFINE_OBJECT_REF(GraphRef, GraphInterface); DGL_DEFINE_OBJECT_REF(GraphRef, GraphInterface);
/*! @brief Subgraph data structure */ /** @brief Subgraph data structure */
struct Subgraph : public runtime::Object { struct Subgraph : public runtime::Object {
/*! @brief The graph. */ /** @brief The graph. */
GraphPtr graph; GraphPtr graph;
/*! /**
* @brief The induced vertex ids. * @brief The induced vertex ids.
* @note This is also a map from the new vertex id to the vertex id in the * @note This is also a map from the new vertex id to the vertex id in the
* parent graph. * parent graph.
*/ */
IdArray induced_vertices; IdArray induced_vertices;
/*! /**
* @brief The induced edge ids. * @brief The induced edge ids.
* @note This is also a map from the new edge id to the edge id in the parent * @note This is also a map from the new edge id to the edge id in the parent
* graph. * graph.
...@@ -439,21 +439,21 @@ struct Subgraph : public runtime::Object { ...@@ -439,21 +439,21 @@ struct Subgraph : public runtime::Object {
DGL_DECLARE_OBJECT_TYPE_INFO(Subgraph, runtime::Object); DGL_DECLARE_OBJECT_TYPE_INFO(Subgraph, runtime::Object);
}; };
/*! @brief Subgraph data structure for negative subgraph */ /** @brief Subgraph data structure for negative subgraph */
struct NegSubgraph : public Subgraph { struct NegSubgraph : public Subgraph {
/*! @brief The existence of the negative edges in the parent graph. */ /** @brief The existence of the negative edges in the parent graph. */
IdArray exist; IdArray exist;
/*! @brief The Ids of head nodes */ /** @brief The Ids of head nodes */
IdArray head_nid; IdArray head_nid;
/*! @brief The Ids of tail nodes */ /** @brief The Ids of tail nodes */
IdArray tail_nid; IdArray tail_nid;
}; };
/*! @brief Subgraph data structure for halo subgraph */ /** @brief Subgraph data structure for halo subgraph */
struct HaloSubgraph : public Subgraph { struct HaloSubgraph : public Subgraph {
/*! @brief Indicate if a node belongs to the partition. */ /** @brief Indicate if a node belongs to the partition. */
IdArray inner_nodes; IdArray inner_nodes;
}; };
......
/*! /**
* Copyright (c) 2018 by Contributors * Copyright (c) 2018 by Contributors
* @file dgl/graph_op.h * @file dgl/graph_op.h
* @brief Operations on graph index. * @brief Operations on graph index.
...@@ -15,7 +15,7 @@ namespace dgl { ...@@ -15,7 +15,7 @@ namespace dgl {
class GraphOp { class GraphOp {
public: public:
/*! /**
* @brief Return a new graph with all the edges reversed. * @brief Return a new graph with all the edges reversed.
* *
* The returned graph preserves the vertex and edge index in the original * The returned graph preserves the vertex and edge index in the original
...@@ -25,7 +25,7 @@ class GraphOp { ...@@ -25,7 +25,7 @@ class GraphOp {
*/ */
static GraphPtr Reverse(GraphPtr graph); static GraphPtr Reverse(GraphPtr graph);
/*! /**
* @brief Return the line graph. * @brief Return the line graph.
* *
* If i~j and j~i are two edges in original graph G, then * If i~j and j~i are two edges in original graph G, then
...@@ -38,7 +38,7 @@ class GraphOp { ...@@ -38,7 +38,7 @@ class GraphOp {
*/ */
static GraphPtr LineGraph(GraphPtr graph, bool backtracking); static GraphPtr LineGraph(GraphPtr graph, bool backtracking);
/*! /**
* @brief Return a disjoint union of the input graphs. * @brief Return a disjoint union of the input graphs.
* *
* The new graph will include all the nodes/edges in the given graphs. * The new graph will include all the nodes/edges in the given graphs.
...@@ -55,7 +55,7 @@ class GraphOp { ...@@ -55,7 +55,7 @@ class GraphOp {
*/ */
static GraphPtr DisjointUnion(std::vector<GraphPtr> graphs); static GraphPtr DisjointUnion(std::vector<GraphPtr> graphs);
/*! /**
* @brief Partition the graph into several subgraphs. * @brief Partition the graph into several subgraphs.
* *
* This is a reverse operation of DisjointUnion. The graph will be partitioned * This is a reverse operation of DisjointUnion. The graph will be partitioned
...@@ -72,7 +72,7 @@ class GraphOp { ...@@ -72,7 +72,7 @@ class GraphOp {
static std::vector<GraphPtr> DisjointPartitionByNum( static std::vector<GraphPtr> DisjointPartitionByNum(
GraphPtr graph, int64_t num); GraphPtr graph, int64_t num);
/*! /**
* @brief Partition the graph into several subgraphs. * @brief Partition the graph into several subgraphs.
* *
* This is a reverse operation of DisjointUnion. The graph will be partitioned * This is a reverse operation of DisjointUnion. The graph will be partitioned
...@@ -89,7 +89,7 @@ class GraphOp { ...@@ -89,7 +89,7 @@ class GraphOp {
static std::vector<GraphPtr> DisjointPartitionBySizes( static std::vector<GraphPtr> DisjointPartitionBySizes(
GraphPtr graph, IdArray sizes); GraphPtr graph, IdArray sizes);
/*! /**
* @brief Map vids in the parent graph to the vids in the subgraph. * @brief Map vids in the parent graph to the vids in the subgraph.
* *
* If the Id doesn't exist in the subgraph, -1 will be used. * If the Id doesn't exist in the subgraph, -1 will be used.
...@@ -102,7 +102,7 @@ class GraphOp { ...@@ -102,7 +102,7 @@ class GraphOp {
*/ */
static IdArray MapParentIdToSubgraphId(IdArray parent_vid_map, IdArray query); static IdArray MapParentIdToSubgraphId(IdArray parent_vid_map, IdArray query);
/*! /**
* @brief Expand an Id array based on the offset array. * @brief Expand an Id array based on the offset array.
* *
* For example, * For example,
...@@ -118,14 +118,14 @@ class GraphOp { ...@@ -118,14 +118,14 @@ class GraphOp {
*/ */
static IdArray ExpandIds(IdArray ids, IdArray offset); static IdArray ExpandIds(IdArray ids, IdArray offset);
/*! /**
* @brief Convert the graph to a simple graph. * @brief Convert the graph to a simple graph.
* @param graph The input graph. * @param graph The input graph.
* @return a new immutable simple graph with no multi-edge. * @return a new immutable simple graph with no multi-edge.
*/ */
static GraphPtr ToSimpleGraph(GraphPtr graph); static GraphPtr ToSimpleGraph(GraphPtr graph);
/*! /**
* @brief Convert the graph to a mutable bidirected graph. * @brief Convert the graph to a mutable bidirected graph.
* *
* If the original graph has m edges for i -> j and n edges for * If the original graph has m edges for i -> j and n edges for
...@@ -137,7 +137,7 @@ class GraphOp { ...@@ -137,7 +137,7 @@ class GraphOp {
*/ */
static GraphPtr ToBidirectedMutableGraph(GraphPtr graph); static GraphPtr ToBidirectedMutableGraph(GraphPtr graph);
/*! /**
* @brief Same as BidirectedMutableGraph except that the returned graph is * @brief Same as BidirectedMutableGraph except that the returned graph is
* immutable. * immutable.
* @param graph The input graph. * @param graph The input graph.
...@@ -145,7 +145,7 @@ class GraphOp { ...@@ -145,7 +145,7 @@ class GraphOp {
* graph. * graph.
*/ */
static GraphPtr ToBidirectedImmutableGraph(GraphPtr graph); static GraphPtr ToBidirectedImmutableGraph(GraphPtr graph);
/*! /**
* @brief Same as BidirectedMutableGraph except that the returned graph is * @brief Same as BidirectedMutableGraph except that the returned graph is
* immutable and call gk_csr_MakeSymmetric in GKlib. This is more efficient * immutable and call gk_csr_MakeSymmetric in GKlib. This is more efficient
* than ToBidirectedImmutableGraph. It return a null pointer if the conversion * than ToBidirectedImmutableGraph. It return a null pointer if the conversion
...@@ -156,7 +156,7 @@ class GraphOp { ...@@ -156,7 +156,7 @@ class GraphOp {
*/ */
static GraphPtr ToBidirectedSimpleImmutableGraph(ImmutableGraphPtr ig); static GraphPtr ToBidirectedSimpleImmutableGraph(ImmutableGraphPtr ig);
/*! /**
* @brief Get a induced subgraph with HALO nodes. * @brief Get a induced subgraph with HALO nodes.
* The HALO nodes are the ones that can be reached from `nodes` within * The HALO nodes are the ones that can be reached from `nodes` within
* `num_hops`. * `num_hops`.
...@@ -168,7 +168,7 @@ class GraphOp { ...@@ -168,7 +168,7 @@ class GraphOp {
static HaloSubgraph GetSubgraphWithHalo( static HaloSubgraph GetSubgraphWithHalo(
GraphPtr graph, IdArray nodes, int num_hops); GraphPtr graph, IdArray nodes, int num_hops);
/*! /**
* @brief Reorder the nodes in the immutable graph. * @brief Reorder the nodes in the immutable graph.
* @param graph The input graph. * @param graph The input graph.
* @param new_order The node Ids in the new graph. The index in `new_order` is * @param new_order The node Ids in the new graph. The index in `new_order` is
......
/*! /**
* Copyright (c) 2018 by Contributors * Copyright (c) 2018 by Contributors
* @file graph/graph_serializer.cc * @file graph/graph_serializer.cc
* @brief DGL serializer APIs * @brief DGL serializer APIs
......
/*! /**
* Copyright (c) 2020 by Contributors * Copyright (c) 2020 by Contributors
* @file dgl/graph_traversal.h * @file dgl/graph_traversal.h
* @brief common graph traversal operations * @brief common graph traversal operations
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
namespace dgl { namespace dgl {
///////////////////////// Graph Traverse routines ////////////////////////// ///////////////////////// Graph Traverse routines //////////////////////////
/*! /**
* @brief Class for representing frontiers. * @brief Class for representing frontiers.
* *
* Each frontier is a list of nodes/edges (specified by their ids). * Each frontier is a list of nodes/edges (specified by their ids).
...@@ -20,22 +20,22 @@ namespace dgl { ...@@ -20,22 +20,22 @@ namespace dgl {
* value). * value).
*/ */
struct Frontiers { struct Frontiers {
/*!\brief a vector store for the nodes/edges in all the frontiers */ /** @brief a vector store for the nodes/edges in all the frontiers */
IdArray ids; IdArray ids;
/*! /**
* @brief a vector store for node/edge tags. Dtype is int64. * @brief a vector store for node/edge tags. Dtype is int64.
* Empty if no tags are requested * Empty if no tags are requested
*/ */
IdArray tags; IdArray tags;
/*!\brief a section vector to indicate each frontier Dtype is int64. */ /** @brief a section vector to indicate each frontier Dtype is int64. */
IdArray sections; IdArray sections;
}; };
namespace aten { namespace aten {
/*! /**
* @brief Traverse the graph in a breadth-first-search (BFS) order. * @brief Traverse the graph in a breadth-first-search (BFS) order.
* *
* @param csr The input csr matrix. * @param csr The input csr matrix.
...@@ -44,7 +44,7 @@ namespace aten { ...@@ -44,7 +44,7 @@ namespace aten {
*/ */
Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source); Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source);
/*! /**
* @brief Traverse the graph in a breadth-first-search (BFS) order, returning * @brief Traverse the graph in a breadth-first-search (BFS) order, returning
* the edges of the BFS tree. * the edges of the BFS tree.
* *
...@@ -54,7 +54,7 @@ Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source); ...@@ -54,7 +54,7 @@ Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source);
*/ */
Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source); Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source);
/*! /**
* @brief Traverse the graph in topological order. * @brief Traverse the graph in topological order.
* *
* @param csr The input csr matrix. * @param csr The input csr matrix.
...@@ -62,7 +62,7 @@ Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source); ...@@ -62,7 +62,7 @@ Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source);
*/ */
Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr); Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr);
/*! /**
* @brief Traverse the graph in a depth-first-search (DFS) order. * @brief Traverse the graph in a depth-first-search (DFS) order.
* *
* @param csr The input csr matrix. * @param csr The input csr matrix.
...@@ -71,7 +71,7 @@ Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr); ...@@ -71,7 +71,7 @@ Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr);
*/ */
Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source); Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source);
/*! /**
* @brief Traverse the graph in a depth-first-search (DFS) order and return the * @brief Traverse the graph in a depth-first-search (DFS) order and return the
* recorded edge tag if return_labels is specified. * recorded edge tag if return_labels is specified.
* *
......
/*! /**
* Copyright (c) 2018 by Contributors * Copyright (c) 2018 by Contributors
* @file dgl/immutable_graph.h * @file dgl/immutable_graph.h
* @brief DGL immutable graph index class. * @brief DGL immutable graph index class.
...@@ -29,7 +29,7 @@ typedef std::shared_ptr<COO> COOPtr; ...@@ -29,7 +29,7 @@ typedef std::shared_ptr<COO> COOPtr;
class ImmutableGraph; class ImmutableGraph;
typedef std::shared_ptr<ImmutableGraph> ImmutableGraphPtr; typedef std::shared_ptr<ImmutableGraph> ImmutableGraphPtr;
/*! /**
* @brief Graph class stored using CSR structure. * @brief Graph class stored using CSR structure.
*/ */
class CSR : public GraphInterface { class CSR : public GraphInterface {
...@@ -180,37 +180,37 @@ class CSR : public GraphInterface { ...@@ -180,37 +180,37 @@ class CSR : public GraphInterface {
return {adj_.indptr, adj_.indices, adj_.data}; return {adj_.indptr, adj_.indices, adj_.data};
} }
/*! @brief Indicate whether this uses shared memory. */ /** @brief Indicate whether this uses shared memory. */
bool IsSharedMem() const { return !shared_mem_name_.empty(); } bool IsSharedMem() const { return !shared_mem_name_.empty(); }
/*! @brief Return the reverse of this CSR graph (i.e, a CSC graph) */ /** @brief Return the reverse of this CSR graph (i.e, a CSC graph) */
CSRPtr Transpose() const; CSRPtr Transpose() const;
/*! @brief Convert this CSR to COO */ /** @brief Convert this CSR to COO */
COOPtr ToCOO() const; COOPtr ToCOO() const;
/*! /**
* @return the csr matrix that represents this graph. * @return the csr matrix that represents this graph.
* @note The csr matrix shares the storage with this graph. * @note The csr matrix shares the storage with this graph.
* The data field of the CSR matrix stores the edge ids. * The data field of the CSR matrix stores the edge ids.
*/ */
aten::CSRMatrix ToCSRMatrix() const { return adj_; } aten::CSRMatrix ToCSRMatrix() const { return adj_; }
/*! /**
* @brief Copy the data to another context. * @brief Copy the data to another context.
* @param ctx The target context. * @param ctx The target context.
* @return The graph under another context. * @return The graph under another context.
*/ */
CSR CopyTo(const DGLContext &ctx) const; CSR CopyTo(const DGLContext &ctx) const;
/*! /**
* @brief Copy data to shared memory. * @brief Copy data to shared memory.
* @param name The name of the shared memory. * @param name The name of the shared memory.
* @return The graph in the shared memory * @return The graph in the shared memory
*/ */
CSR CopyToSharedMem(const std::string &name) const; CSR CopyToSharedMem(const std::string &name) const;
/*! /**
* @brief Convert the graph to use the given number of bits for storage. * @brief Convert the graph to use the given number of bits for storage.
* @param bits The new number of integer bits (32 or 64). * @param bits The new number of integer bits (32 or 64).
* @return The graph with new bit size storage. * @return The graph with new bit size storage.
...@@ -225,10 +225,10 @@ class CSR : public GraphInterface { ...@@ -225,10 +225,10 @@ class CSR : public GraphInterface {
IdArray edge_ids() const { return adj_.data; } IdArray edge_ids() const { return adj_.data; }
/*! @return Load CSR from stream */ /** @return Load CSR from stream */
bool Load(dmlc::Stream *fs); bool Load(dmlc::Stream *fs);
/*! @return Save CSR to stream */ /** @return Save CSR to stream */
void Save(dmlc::Stream *fs) const; void Save(dmlc::Stream *fs) const;
void SortCSR() override { void SortCSR() override {
...@@ -239,7 +239,7 @@ class CSR : public GraphInterface { ...@@ -239,7 +239,7 @@ class CSR : public GraphInterface {
private: private:
friend class Serializer; friend class Serializer;
/*! @brief private default constructor */ /** @brief private default constructor */
CSR() { adj_.sorted = false; } CSR() { adj_.sorted = false; }
// The internal CSR adjacency matrix. // The internal CSR adjacency matrix.
// The data field stores edge ids. // The data field stores edge ids.
...@@ -424,43 +424,43 @@ class COO : public GraphInterface { ...@@ -424,43 +424,43 @@ class COO : public GraphInterface {
} }
} }
/*! @brief Return the transpose of this COO */ /** @brief Return the transpose of this COO */
COOPtr Transpose() const { COOPtr Transpose() const {
return COOPtr(new COO(adj_.num_rows, adj_.col, adj_.row)); return COOPtr(new COO(adj_.num_rows, adj_.col, adj_.row));
} }
/*! @brief Convert this COO to CSR */ /** @brief Convert this COO to CSR */
CSRPtr ToCSR() const; CSRPtr ToCSR() const;
/*! /**
* @brief Get the coo matrix that represents this graph. * @brief Get the coo matrix that represents this graph.
* @note The coo matrix shares the storage with this graph. * @note The coo matrix shares the storage with this graph.
* The data field of the coo matrix is none. * The data field of the coo matrix is none.
*/ */
aten::COOMatrix ToCOOMatrix() const { return adj_; } aten::COOMatrix ToCOOMatrix() const { return adj_; }
/*! /**
* @brief Copy the data to another context. * @brief Copy the data to another context.
* @param ctx The target context. * @param ctx The target context.
* @return The graph under another context. * @return The graph under another context.
*/ */
COO CopyTo(const DGLContext &ctx) const; COO CopyTo(const DGLContext &ctx) const;
/*! /**
* @brief Copy data to shared memory. * @brief Copy data to shared memory.
* @param name The name of the shared memory. * @param name The name of the shared memory.
* @return The graph in the shared memory * @return The graph in the shared memory
*/ */
COO CopyToSharedMem(const std::string &name) const; COO CopyToSharedMem(const std::string &name) const;
/*! /**
* @brief Convert the graph to use the given number of bits for storage. * @brief Convert the graph to use the given number of bits for storage.
* @param bits The new number of integer bits (32 or 64). * @param bits The new number of integer bits (32 or 64).
* @return The graph with new bit size storage. * @return The graph with new bit size storage.
*/ */
COO AsNumBits(uint8_t bits) const; COO AsNumBits(uint8_t bits) const;
/*! @brief Indicate whether this uses shared memory. */ /** @brief Indicate whether this uses shared memory. */
bool IsSharedMem() const { return false; } bool IsSharedMem() const { return false; }
// member getters // member getters
...@@ -470,7 +470,7 @@ class COO : public GraphInterface { ...@@ -470,7 +470,7 @@ class COO : public GraphInterface {
IdArray dst() const { return adj_.col; } IdArray dst() const { return adj_.col; }
private: private:
/* !\brief private default constructor */ /** @brief private default constructor */
COO() {} COO() {}
// The internal COO adjacency matrix. // The internal COO adjacency matrix.
...@@ -478,17 +478,17 @@ class COO : public GraphInterface { ...@@ -478,17 +478,17 @@ class COO : public GraphInterface {
aten::COOMatrix adj_; aten::COOMatrix adj_;
}; };
/*! /**
* @brief DGL immutable graph index class. * @brief DGL immutable graph index class.
* *
* DGL's graph is directed. Vertices are integers enumerated from zero. * DGL's graph is directed. Vertices are integers enumerated from zero.
*/ */
class ImmutableGraph : public GraphInterface { class ImmutableGraph : public GraphInterface {
public: public:
/*! @brief Construct an immutable graph from the COO format. */ /** @brief Construct an immutable graph from the COO format. */
explicit ImmutableGraph(COOPtr coo) : coo_(coo) {} explicit ImmutableGraph(COOPtr coo) : coo_(coo) {}
/*! /**
* @brief Construct an immutable graph from the CSR format. * @brief Construct an immutable graph from the CSR format.
* *
* For a single graph, we need two CSRs, one stores the in-edges of vertices * For a single graph, we need two CSRs, one stores the in-edges of vertices
...@@ -506,14 +506,14 @@ class ImmutableGraph : public GraphInterface { ...@@ -506,14 +506,14 @@ class ImmutableGraph : public GraphInterface {
CHECK(in_csr_ || out_csr_) << "Both CSR are missing."; CHECK(in_csr_ || out_csr_) << "Both CSR are missing.";
} }
/*! @brief Construct an immutable graph from one CSR. */ /** @brief Construct an immutable graph from one CSR. */
explicit ImmutableGraph(CSRPtr csr) : out_csr_(csr) {} explicit ImmutableGraph(CSRPtr csr) : out_csr_(csr) {}
/*! @brief default copy constructor */ /** @brief default copy constructor */
ImmutableGraph(const ImmutableGraph &other) = default; ImmutableGraph(const ImmutableGraph &other) = default;
#ifndef _MSC_VER #ifndef _MSC_VER
/*! @brief default move constructor */ /** @brief default move constructor */
ImmutableGraph(ImmutableGraph &&other) = default; ImmutableGraph(ImmutableGraph &&other) = default;
#else #else
ImmutableGraph(ImmutableGraph &&other) { ImmutableGraph(ImmutableGraph &&other) {
...@@ -526,10 +526,10 @@ class ImmutableGraph : public GraphInterface { ...@@ -526,10 +526,10 @@ class ImmutableGraph : public GraphInterface {
} }
#endif // _MSC_VER #endif // _MSC_VER
/*! @brief default assign constructor */ /** @brief default assign constructor */
ImmutableGraph &operator=(const ImmutableGraph &other) = default; ImmutableGraph &operator=(const ImmutableGraph &other) = default;
/*! @brief default destructor */ /** @brief default destructor */
~ImmutableGraph() = default; ~ImmutableGraph() = default;
void AddVertices(uint64_t num_vertices) override { void AddVertices(uint64_t num_vertices) override {
...@@ -552,13 +552,13 @@ class ImmutableGraph : public GraphInterface { ...@@ -552,13 +552,13 @@ class ImmutableGraph : public GraphInterface {
uint8_t NumBits() const override { return AnyGraph()->NumBits(); } uint8_t NumBits() const override { return AnyGraph()->NumBits(); }
/*! /**
* @note not const since we have caches * @note not const since we have caches
* @return whether the graph is a multigraph * @return whether the graph is a multigraph
*/ */
bool IsMultigraph() const override { return AnyGraph()->IsMultigraph(); } bool IsMultigraph() const override { return AnyGraph()->IsMultigraph(); }
/*! /**
* @return whether the graph is read-only * @return whether the graph is read-only
*/ */
bool IsReadonly() const override { return true; } bool IsReadonly() const override { return true; }
...@@ -577,18 +577,18 @@ class ImmutableGraph : public GraphInterface { ...@@ -577,18 +577,18 @@ class ImmutableGraph : public GraphInterface {
return is_unibipartite_; return is_unibipartite_;
} }
/*! @return the number of vertices in the graph.*/ /** @return the number of vertices in the graph.*/
uint64_t NumVertices() const override { return AnyGraph()->NumVertices(); } uint64_t NumVertices() const override { return AnyGraph()->NumVertices(); }
/*! @return the number of edges in the graph.*/ /** @return the number of edges in the graph.*/
uint64_t NumEdges() const override { return AnyGraph()->NumEdges(); } uint64_t NumEdges() const override { return AnyGraph()->NumEdges(); }
/*! @return true if the given vertex is in the graph.*/ /** @return true if the given vertex is in the graph.*/
bool HasVertex(dgl_id_t vid) const override { return vid < NumVertices(); } bool HasVertex(dgl_id_t vid) const override { return vid < NumVertices(); }
BoolArray HasVertices(IdArray vids) const override; BoolArray HasVertices(IdArray vids) const override;
/*! @return true if the given edge is in the graph.*/ /** @return true if the given edge is in the graph.*/
bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override { bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override {
if (in_csr_) { if (in_csr_) {
return in_csr_->HasEdgeBetween(dst, src); return in_csr_->HasEdgeBetween(dst, src);
...@@ -605,7 +605,7 @@ class ImmutableGraph : public GraphInterface { ...@@ -605,7 +605,7 @@ class ImmutableGraph : public GraphInterface {
} }
} }
/*! /**
* @brief Find the predecessors of a vertex. * @brief Find the predecessors of a vertex.
* @param vid The vertex id. * @param vid The vertex id.
* @param radius The radius of the neighborhood. Default is immediate neighbor * @param radius The radius of the neighborhood. Default is immediate neighbor
...@@ -616,7 +616,7 @@ class ImmutableGraph : public GraphInterface { ...@@ -616,7 +616,7 @@ class ImmutableGraph : public GraphInterface {
return GetInCSR()->Successors(vid, radius); return GetInCSR()->Successors(vid, radius);
} }
/*! /**
* @brief Find the successors of a vertex. * @brief Find the successors of a vertex.
* @param vid The vertex id. * @param vid The vertex id.
* @param radius The radius of the neighborhood. Default is immediate neighbor * @param radius The radius of the neighborhood. Default is immediate neighbor
...@@ -627,7 +627,7 @@ class ImmutableGraph : public GraphInterface { ...@@ -627,7 +627,7 @@ class ImmutableGraph : public GraphInterface {
return GetOutCSR()->Successors(vid, radius); return GetOutCSR()->Successors(vid, radius);
} }
/*! /**
* @brief Get all edge ids between the two given endpoints * @brief Get all edge ids between the two given endpoints
* @note Edges are associated with an integer id start from zero. * @note Edges are associated with an integer id start from zero.
* The id is assigned when the edge is being added to the graph. * The id is assigned when the edge is being added to the graph.
...@@ -643,7 +643,7 @@ class ImmutableGraph : public GraphInterface { ...@@ -643,7 +643,7 @@ class ImmutableGraph : public GraphInterface {
} }
} }
/*! /**
* @brief Get all edge ids between the given endpoint pairs. * @brief Get all edge ids between the given endpoint pairs.
* @note Edges are associated with an integer id start from zero. * @note Edges are associated with an integer id start from zero.
* The id is assigned when the edge is being added to the graph. * The id is assigned when the edge is being added to the graph.
...@@ -661,7 +661,7 @@ class ImmutableGraph : public GraphInterface { ...@@ -661,7 +661,7 @@ class ImmutableGraph : public GraphInterface {
} }
} }
/*! /**
* @brief Find the edge ID and return the pair of endpoints * @brief Find the edge ID and return the pair of endpoints
* @param eid The edge ID * @param eid The edge ID
* @return a pair whose first element is the source and the second the * @return a pair whose first element is the source and the second the
...@@ -671,7 +671,7 @@ class ImmutableGraph : public GraphInterface { ...@@ -671,7 +671,7 @@ class ImmutableGraph : public GraphInterface {
return GetCOO()->FindEdge(eid); return GetCOO()->FindEdge(eid);
} }
/*! /**
* @brief Find the edge IDs and return their source and target node IDs. * @brief Find the edge IDs and return their source and target node IDs.
* @param eids The edge ID array. * @param eids The edge ID array.
* @return EdgeArray containing all edges with id in eid. The order is * @return EdgeArray containing all edges with id in eid. The order is
...@@ -681,7 +681,7 @@ class ImmutableGraph : public GraphInterface { ...@@ -681,7 +681,7 @@ class ImmutableGraph : public GraphInterface {
return GetCOO()->FindEdges(eids); return GetCOO()->FindEdges(eids);
} }
/*! /**
* @brief Get the in edges of the vertex. * @brief Get the in edges of the vertex.
* @note The returned dst id array is filled with vid. * @note The returned dst id array is filled with vid.
* @param vid The vertex id. * @param vid The vertex id.
...@@ -692,7 +692,7 @@ class ImmutableGraph : public GraphInterface { ...@@ -692,7 +692,7 @@ class ImmutableGraph : public GraphInterface {
return {ret.dst, ret.src, ret.id}; return {ret.dst, ret.src, ret.id};
} }
/*! /**
* @brief Get the in edges of the vertices. * @brief Get the in edges of the vertices.
* @param vids The vertex id array. * @param vids The vertex id array.
* @return the id arrays of the two endpoints of the edges. * @return the id arrays of the two endpoints of the edges.
...@@ -702,7 +702,7 @@ class ImmutableGraph : public GraphInterface { ...@@ -702,7 +702,7 @@ class ImmutableGraph : public GraphInterface {
return {ret.dst, ret.src, ret.id}; return {ret.dst, ret.src, ret.id};
} }
/*! /**
* @brief Get the out edges of the vertex. * @brief Get the out edges of the vertex.
* @note The returned src id array is filled with vid. * @note The returned src id array is filled with vid.
* @param vid The vertex id. * @param vid The vertex id.
...@@ -712,7 +712,7 @@ class ImmutableGraph : public GraphInterface { ...@@ -712,7 +712,7 @@ class ImmutableGraph : public GraphInterface {
return GetOutCSR()->OutEdges(vid); return GetOutCSR()->OutEdges(vid);
} }
/*! /**
* @brief Get the out edges of the vertices. * @brief Get the out edges of the vertices.
* @param vids The vertex id array. * @param vids The vertex id array.
* @return the id arrays of the two endpoints of the edges. * @return the id arrays of the two endpoints of the edges.
...@@ -721,7 +721,7 @@ class ImmutableGraph : public GraphInterface { ...@@ -721,7 +721,7 @@ class ImmutableGraph : public GraphInterface {
return GetOutCSR()->OutEdges(vids); return GetOutCSR()->OutEdges(vids);
} }
/*! /**
* @brief Get all the edges in the graph. * @brief Get all the edges in the graph.
* @note If sorted is true, the returned edges list is sorted by their src and * @note If sorted is true, the returned edges list is sorted by their src and
* dst ids. Otherwise, they are in their edge id order. * dst ids. Otherwise, they are in their edge id order.
...@@ -731,7 +731,7 @@ class ImmutableGraph : public GraphInterface { ...@@ -731,7 +731,7 @@ class ImmutableGraph : public GraphInterface {
*/ */
EdgeArray Edges(const std::string &order = "") const override; EdgeArray Edges(const std::string &order = "") const override;
/*! /**
* @brief Get the in degree of the given vertex. * @brief Get the in degree of the given vertex.
* @param vid The vertex id. * @param vid The vertex id.
* @return the in degree * @return the in degree
...@@ -740,7 +740,7 @@ class ImmutableGraph : public GraphInterface { ...@@ -740,7 +740,7 @@ class ImmutableGraph : public GraphInterface {
return GetInCSR()->OutDegree(vid); return GetInCSR()->OutDegree(vid);
} }
/*! /**
* @brief Get the in degrees of the given vertices. * @brief Get the in degrees of the given vertices.
* @param vid The vertex id array. * @param vid The vertex id array.
* @return the in degree array * @return the in degree array
...@@ -749,7 +749,7 @@ class ImmutableGraph : public GraphInterface { ...@@ -749,7 +749,7 @@ class ImmutableGraph : public GraphInterface {
return GetInCSR()->OutDegrees(vids); return GetInCSR()->OutDegrees(vids);
} }
/*! /**
* @brief Get the out degree of the given vertex. * @brief Get the out degree of the given vertex.
* @param vid The vertex id. * @param vid The vertex id.
* @return the out degree * @return the out degree
...@@ -758,7 +758,7 @@ class ImmutableGraph : public GraphInterface { ...@@ -758,7 +758,7 @@ class ImmutableGraph : public GraphInterface {
return GetOutCSR()->OutDegree(vid); return GetOutCSR()->OutDegree(vid);
} }
/*! /**
* @brief Get the out degrees of the given vertices. * @brief Get the out degrees of the given vertices.
* @param vid The vertex id array. * @param vid The vertex id array.
* @return the out degree array * @return the out degree array
...@@ -767,7 +767,7 @@ class ImmutableGraph : public GraphInterface { ...@@ -767,7 +767,7 @@ class ImmutableGraph : public GraphInterface {
return GetOutCSR()->OutDegrees(vids); return GetOutCSR()->OutDegrees(vids);
} }
/*! /**
* @brief Construct the induced subgraph of the given vertices. * @brief Construct the induced subgraph of the given vertices.
* *
* The induced subgraph is a subgraph formed by specifying a set of vertices * The induced subgraph is a subgraph formed by specifying a set of vertices
...@@ -786,7 +786,7 @@ class ImmutableGraph : public GraphInterface { ...@@ -786,7 +786,7 @@ class ImmutableGraph : public GraphInterface {
*/ */
Subgraph VertexSubgraph(IdArray vids) const override; Subgraph VertexSubgraph(IdArray vids) const override;
/*! /**
* @brief Construct the induced edge subgraph of the given edges. * @brief Construct the induced edge subgraph of the given edges.
* *
* The induced edges subgraph is a subgraph formed by specifying a set of * The induced edges subgraph is a subgraph formed by specifying a set of
...@@ -806,7 +806,7 @@ class ImmutableGraph : public GraphInterface { ...@@ -806,7 +806,7 @@ class ImmutableGraph : public GraphInterface {
Subgraph EdgeSubgraph( Subgraph EdgeSubgraph(
IdArray eids, bool preserve_nodes = false) const override; IdArray eids, bool preserve_nodes = false) const override;
/*! /**
* @brief Return the successor vector * @brief Return the successor vector
* @param vid The vertex id. * @param vid The vertex id.
* @return the successor vector * @return the successor vector
...@@ -815,7 +815,7 @@ class ImmutableGraph : public GraphInterface { ...@@ -815,7 +815,7 @@ class ImmutableGraph : public GraphInterface {
return GetOutCSR()->SuccVec(vid); return GetOutCSR()->SuccVec(vid);
} }
/*! /**
* @brief Return the out edge id vector * @brief Return the out edge id vector
* @param vid The vertex id. * @param vid The vertex id.
* @return the out edge id vector * @return the out edge id vector
...@@ -824,7 +824,7 @@ class ImmutableGraph : public GraphInterface { ...@@ -824,7 +824,7 @@ class ImmutableGraph : public GraphInterface {
return GetOutCSR()->OutEdgeVec(vid); return GetOutCSR()->OutEdgeVec(vid);
} }
/*! /**
* @brief Return the predecessor vector * @brief Return the predecessor vector
* @param vid The vertex id. * @param vid The vertex id.
* @return the predecessor vector * @return the predecessor vector
...@@ -833,7 +833,7 @@ class ImmutableGraph : public GraphInterface { ...@@ -833,7 +833,7 @@ class ImmutableGraph : public GraphInterface {
return GetInCSR()->SuccVec(vid); return GetInCSR()->SuccVec(vid);
} }
/*! /**
* @brief Return the in edge id vector * @brief Return the in edge id vector
* @param vid The vertex id. * @param vid The vertex id.
* @return the in edge id vector * @return the in edge id vector
...@@ -842,7 +842,7 @@ class ImmutableGraph : public GraphInterface { ...@@ -842,7 +842,7 @@ class ImmutableGraph : public GraphInterface {
return GetInCSR()->OutEdgeVec(vid); return GetInCSR()->OutEdgeVec(vid);
} }
/*! /**
* @brief Get the adjacency matrix of the graph. * @brief Get the adjacency matrix of the graph.
* *
* By default, a row of returned adjacency matrix represents the destination * By default, a row of returned adjacency matrix represents the destination
...@@ -854,28 +854,28 @@ class ImmutableGraph : public GraphInterface { ...@@ -854,28 +854,28 @@ class ImmutableGraph : public GraphInterface {
std::vector<IdArray> GetAdj( std::vector<IdArray> GetAdj(
bool transpose, const std::string &fmt) const override; bool transpose, const std::string &fmt) const override;
/* !\brief Return in csr. If not exist, transpose the other one.*/ /** @brief Return in csr. If not exist, transpose the other one.*/
CSRPtr GetInCSR() const; CSRPtr GetInCSR() const;
/* !\brief Return out csr. If not exist, transpose the other one.*/ /** @brief Return out csr. If not exist, transpose the other one.*/
CSRPtr GetOutCSR() const; CSRPtr GetOutCSR() const;
/* !\brief Return coo. If not exist, create from csr.*/ /** @brief Return coo. If not exist, create from csr.*/
COOPtr GetCOO() const; COOPtr GetCOO() const;
/*! @brief Create an immutable graph from CSR. */ /** @brief Create an immutable graph from CSR. */
static ImmutableGraphPtr CreateFromCSR( static ImmutableGraphPtr CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids, IdArray indptr, IdArray indices, IdArray edge_ids,
const std::string &edge_dir); const std::string &edge_dir);
static ImmutableGraphPtr CreateFromCSR(const std::string &shared_mem_name); static ImmutableGraphPtr CreateFromCSR(const std::string &shared_mem_name);
/*! @brief Create an immutable graph from COO. */ /** @brief Create an immutable graph from COO. */
static ImmutableGraphPtr CreateFromCOO( static ImmutableGraphPtr CreateFromCOO(
int64_t num_vertices, IdArray src, IdArray dst, bool row_osrted = false, int64_t num_vertices, IdArray src, IdArray dst, bool row_osrted = false,
bool col_sorted = false); bool col_sorted = false);
/*! /**
* @brief Convert the given graph to an immutable graph. * @brief Convert the given graph to an immutable graph.
* *
* If the graph is already an immutable graph. The result graph will share * If the graph is already an immutable graph. The result graph will share
...@@ -886,14 +886,14 @@ class ImmutableGraph : public GraphInterface { ...@@ -886,14 +886,14 @@ class ImmutableGraph : public GraphInterface {
*/ */
static ImmutableGraphPtr ToImmutable(GraphPtr graph); static ImmutableGraphPtr ToImmutable(GraphPtr graph);
/*! /**
* @brief Copy the data to another context. * @brief Copy the data to another context.
* @param ctx The target context. * @param ctx The target context.
* @return The graph under another context. * @return The graph under another context.
*/ */
static ImmutableGraphPtr CopyTo(ImmutableGraphPtr g, const DGLContext &ctx); static ImmutableGraphPtr CopyTo(ImmutableGraphPtr g, const DGLContext &ctx);
/*! /**
* @brief Copy data to shared memory. * @brief Copy data to shared memory.
* @param name The name of the shared memory. * @param name The name of the shared memory.
* @return The graph in the shared memory * @return The graph in the shared memory
...@@ -901,14 +901,14 @@ class ImmutableGraph : public GraphInterface { ...@@ -901,14 +901,14 @@ class ImmutableGraph : public GraphInterface {
static ImmutableGraphPtr CopyToSharedMem( static ImmutableGraphPtr CopyToSharedMem(
ImmutableGraphPtr g, const std::string &name); ImmutableGraphPtr g, const std::string &name);
/*! /**
* @brief Convert the graph to use the given number of bits for storage. * @brief Convert the graph to use the given number of bits for storage.
* @param bits The new number of integer bits (32 or 64). * @param bits The new number of integer bits (32 or 64).
* @return The graph with new bit size storage. * @return The graph with new bit size storage.
*/ */
static ImmutableGraphPtr AsNumBits(ImmutableGraphPtr g, uint8_t bits); static ImmutableGraphPtr AsNumBits(ImmutableGraphPtr g, uint8_t bits);
/*! /**
* @brief Return a new graph with all the edges reversed. * @brief Return a new graph with all the edges reversed.
* *
* The returned graph preserves the vertex and edge index in the original * The returned graph preserves the vertex and edge index in the original
...@@ -918,10 +918,10 @@ class ImmutableGraph : public GraphInterface { ...@@ -918,10 +918,10 @@ class ImmutableGraph : public GraphInterface {
*/ */
ImmutableGraphPtr Reverse() const; ImmutableGraphPtr Reverse() const;
/*! @return Load ImmutableGraph from stream, using out csr */ /** @return Load ImmutableGraph from stream, using out csr */
bool Load(dmlc::Stream *fs); bool Load(dmlc::Stream *fs);
/*! @return Save ImmutableGraph to stream, using out csr */ /** @return Save ImmutableGraph to stream, using out csr */
void Save(dmlc::Stream *fs) const; void Save(dmlc::Stream *fs) const;
void SortCSR() override { void SortCSR() override {
...@@ -933,17 +933,17 @@ class ImmutableGraph : public GraphInterface { ...@@ -933,17 +933,17 @@ class ImmutableGraph : public GraphInterface {
bool HasOutCSR() const { return out_csr_ != NULL; } bool HasOutCSR() const { return out_csr_ != NULL; }
/*! @brief Cast this graph to a heterograph */ /** @brief Cast this graph to a heterograph */
HeteroGraphPtr AsHeteroGraph() const; HeteroGraphPtr AsHeteroGraph() const;
protected: protected:
friend class Serializer; friend class Serializer;
friend class UnitGraph; friend class UnitGraph;
/* !\brief internal default constructor */ /** @brief internal default constructor */
ImmutableGraph() {} ImmutableGraph() {}
/* !\brief internal constructor for all the members */ /** @brief internal constructor for all the members */
ImmutableGraph(CSRPtr in_csr, CSRPtr out_csr, COOPtr coo) ImmutableGraph(CSRPtr in_csr, CSRPtr out_csr, COOPtr coo)
: in_csr_(in_csr), out_csr_(out_csr), coo_(coo) { : in_csr_(in_csr), out_csr_(out_csr), coo_(coo) {
CHECK(AnyGraph()) << "At least one graph structure should exist."; CHECK(AnyGraph()) << "At least one graph structure should exist.";
...@@ -956,7 +956,7 @@ class ImmutableGraph : public GraphInterface { ...@@ -956,7 +956,7 @@ class ImmutableGraph : public GraphInterface {
this->shared_mem_name_ = shared_mem_name; this->shared_mem_name_ = shared_mem_name;
} }
/* !\brief return pointer to any available graph structure */ /** @brief return pointer to any available graph structure */
GraphPtr AnyGraph() const { GraphPtr AnyGraph() const {
if (in_csr_) { if (in_csr_) {
return in_csr_; return in_csr_;
......
/*! /**
* Copyright (c) 2020 by Contributors * Copyright (c) 2020 by Contributors
* @file dgl/aten/kernel.h * @file dgl/aten/kernel.h
* @brief Sparse matrix operators. * @brief Sparse matrix operators.
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
namespace dgl { namespace dgl {
namespace aten { namespace aten {
/*! /**
* @brief Generalized Sparse Matrix-Matrix Multiplication. * @brief Generalized Sparse Matrix-Matrix Multiplication.
* @param op The binary operator, could be `add`, `sub', `mul`, 'div', * @param op The binary operator, could be `add`, `sub', `mul`, 'div',
* `copy_u`, `copy_e'. * `copy_u`, `copy_e'.
...@@ -34,7 +34,7 @@ void SpMM( ...@@ -34,7 +34,7 @@ void SpMM(
const std::string& op, const std::string& reduce, HeteroGraphPtr graph, const std::string& op, const std::string& reduce, HeteroGraphPtr graph,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
/*! /**
* @brief Generalized Sampled Dense-Dense Matrix Multiplication. * @brief Generalized Sampled Dense-Dense Matrix Multiplication.
* @param op The binary operator, could be `add`, `sub', `mul`, 'div', * @param op The binary operator, could be `add`, `sub', `mul`, 'div',
* `dot`, `copy_u`, `copy_e'. * `dot`, `copy_u`, `copy_e'.
...@@ -47,7 +47,7 @@ void SDDMM( ...@@ -47,7 +47,7 @@ void SDDMM(
const std::string& op, HeteroGraphPtr graph, NDArray ufeat, NDArray efeat, const std::string& op, HeteroGraphPtr graph, NDArray ufeat, NDArray efeat,
NDArray out); NDArray out);
/*! /**
* @brief Sparse-sparse matrix multiplication. * @brief Sparse-sparse matrix multiplication.
* *
* The sparse matrices must have scalar weights (i.e. \a A_weights and \a * The sparse matrices must have scalar weights (i.e. \a A_weights and \a
...@@ -56,7 +56,7 @@ void SDDMM( ...@@ -56,7 +56,7 @@ void SDDMM(
std::pair<CSRMatrix, NDArray> CSRMM( std::pair<CSRMatrix, NDArray> CSRMM(
CSRMatrix A, NDArray A_weights, CSRMatrix B, NDArray B_weights); CSRMatrix A, NDArray A_weights, CSRMatrix B, NDArray B_weights);
/*! /**
* @brief Summing up a list of sparse matrices. * @brief Summing up a list of sparse matrices.
* *
* The sparse matrices must have scalar weights (i.e. the arrays in \a A_weights * The sparse matrices must have scalar weights (i.e. the arrays in \a A_weights
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment