"...text-generation-inference.git" did not exist on "dea9c0dc741875fde9225e6c2a51d7bb8fb052e4"
Unverified Commit 8ac27dad authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4824)



* [Misc] clang-format auto fix.

* blabla

* ablabla

* blabla
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent bcd37684
...@@ -10,10 +10,11 @@ ...@@ -10,10 +10,11 @@
#define DGL_ATEN_ARRAY_OPS_H_ #define DGL_ATEN_ARRAY_OPS_H_
#include <algorithm> #include <algorithm>
#include <string>
#include <tuple>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <tuple>
#include <string>
#include "./types.h" #include "./types.h"
namespace dgl { namespace dgl {
...@@ -24,17 +25,16 @@ namespace aten { ...@@ -24,17 +25,16 @@ namespace aten {
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
/** @return A special array to represent null. */ /** @return A special array to represent null. */
inline NDArray NullArray(const DGLDataType& dtype = DGLDataType{kDGLInt, 64, 1}, inline NDArray NullArray(
const DGLContext& ctx = DGLContext{kDGLCPU, 0}) { const DGLDataType& dtype = DGLDataType{kDGLInt, 64, 1},
const DGLContext& ctx = DGLContext{kDGLCPU, 0}) {
return NDArray::Empty({0}, dtype, ctx); return NDArray::Empty({0}, dtype, ctx);
} }
/** /**
* @return Whether the input array is a null array. * @return Whether the input array is a null array.
*/ */
inline bool IsNullArray(NDArray array) { inline bool IsNullArray(NDArray array) { return array->shape[0] == 0; }
return array->shape[0] == 0;
}
/** /**
* @brief Create a new id array with given length * @brief Create a new id array with given length
...@@ -43,9 +43,9 @@ inline bool IsNullArray(NDArray array) { ...@@ -43,9 +43,9 @@ inline bool IsNullArray(NDArray array) {
* @param nbits The number of integer bits * @param nbits The number of integer bits
* @return id array * @return id array
*/ */
IdArray NewIdArray(int64_t length, IdArray NewIdArray(
DGLContext ctx = DGLContext{kDGLCPU, 0}, int64_t length, DGLContext ctx = DGLContext{kDGLCPU, 0},
uint8_t nbits = 64); uint8_t nbits = 64);
/** /**
* @brief Create a new id array using the given vector data * @brief Create a new id array using the given vector data
...@@ -55,9 +55,9 @@ IdArray NewIdArray(int64_t length, ...@@ -55,9 +55,9 @@ IdArray NewIdArray(int64_t length,
* @return the id array * @return the id array
*/ */
template <typename T> template <typename T>
IdArray VecToIdArray(const std::vector<T>& vec, IdArray VecToIdArray(
uint8_t nbits = 64, const std::vector<T>& vec, uint8_t nbits = 64,
DGLContext ctx = DGLContext{kDGLCPU, 0}); DGLContext ctx = DGLContext{kDGLCPU, 0});
/** /**
* @brief Return an array representing a 1D range. * @brief Return an array representing a 1D range.
...@@ -148,7 +148,7 @@ IdArray NonZero(BoolArray bool_arr); ...@@ -148,7 +148,7 @@ IdArray NonZero(BoolArray bool_arr);
* @brief Return the data under the index. In numpy notation, A[I] * @brief Return the data under the index. In numpy notation, A[I]
* @tparam ValueType The type of return value. * @tparam ValueType The type of return value.
*/ */
template<typename ValueType> template <typename ValueType>
ValueType IndexSelect(NDArray array, int64_t index); ValueType IndexSelect(NDArray array, int64_t index);
/** /**
...@@ -187,10 +187,11 @@ NDArray Scatter(NDArray array, IdArray indices); ...@@ -187,10 +187,11 @@ NDArray Scatter(NDArray array, IdArray indices);
void Scatter_(IdArray index, NDArray value, NDArray out); void Scatter_(IdArray index, NDArray value, NDArray out);
/** /**
* @brief Repeat each element a number of times. Equivalent to np.repeat(array, repeats) * @brief Repeat each element a number of times. Equivalent to np.repeat(array,
* repeats)
* @param array A 1D vector * @param array A 1D vector
* @param repeats A 1D integer vector for number of times to repeat for each element in * @param repeats A 1D integer vector for number of times to repeat for each
* \c array. Must have the same shape as \c array. * element in \c array. Must have the same shape as \c array.
*/ */
NDArray Repeat(NDArray array, IdArray repeats); NDArray Repeat(NDArray array, IdArray repeats);
...@@ -253,12 +254,13 @@ inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) { ...@@ -253,12 +254,13 @@ inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) {
* *
* The packed tensor would be [1, 2, 3, 4, 5]. * The packed tensor would be [1, 2, 3, 4, 5].
* *
* The length tensor would be [2, 3], i.e. the length of each sequence before padding. * The length tensor would be [2, 3], i.e. the length of each sequence before
* padding.
* *
* The offset tensor would be [0, 2], i.e. the offset to the packed tensor for each * The offset tensor would be [0, 2], i.e. the offset to the packed tensor for
* sequence (before padding) * each sequence (before padding)
*/ */
template<typename ValueType> template <typename ValueType>
std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, ValueType pad_value); std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, ValueType pad_value);
/** /**
...@@ -295,8 +297,9 @@ std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths); ...@@ -295,8 +297,9 @@ std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths);
* @brief Return the cumulative summation (or inclusive sum) of the input array. * @brief Return the cumulative summation (or inclusive sum) of the input array.
* *
* The first element out[0] is equal to the first element of the input array * The first element out[0] is equal to the first element of the input array
* array[0]. The rest elements are defined recursively, out[i] = out[i-1] + array[i]. * array[0]. The rest elements are defined recursively, out[i] = out[i-1] +
* Hence, the result array length is the same as the input array length. * array[i]. Hence, the result array length is the same as the input array
* length.
* *
* If prepend_zero is true, then the first element is zero and the result array * If prepend_zero is true, then the first element is zero and the result array
* length is the input array length plus one. This is useful for creating * length is the input array length plus one. This is useful for creating
...@@ -320,17 +323,18 @@ IdArray NonZero(NDArray array); ...@@ -320,17 +323,18 @@ IdArray NonZero(NDArray array);
/** /**
* @brief Sort the ID vector in ascending order. * @brief Sort the ID vector in ascending order.
* *
* It performs both sort and arg_sort (returning the sorted index). The sorted index * It performs both sort and arg_sort (returning the sorted index). The sorted
* is always in int64. * index is always in int64.
* *
* @param array Input array. * @param array Input array.
* @param num_bits The number of bits used in key comparison. For example, if the data type * @param num_bits The number of bits used in key comparison. For example, if
* of the input array is int32_t and `num_bits = 8`, it only uses bits in index * the data type of the input array is int32_t and `num_bits = 8`, it only uses
* range [0, 8) for sorting. Setting it to a small value could * bits in index range [0, 8) for sorting. Setting it to a small value could
* speed up the sorting if the underlying sorting algorithm is radix sort (e.g., on GPU). * speed up the sorting if the underlying sorting algorithm is
* Setting it to zero (default value) means using all the bits for comparison. * radix sort (e.g., on GPU). Setting it to zero (default value) means using all
* On CPU, it currently has no effect. * the bits for comparison. On CPU, it currently has no effect.
* @return A pair of arrays: sorted values and sorted index to the original position. * @return A pair of arrays: sorted values and sorted index to the original
* position.
*/ */
std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits = 0); std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits = 0);
...@@ -341,9 +345,7 @@ std::string ToDebugString(NDArray array); ...@@ -341,9 +345,7 @@ std::string ToDebugString(NDArray array);
// inline implementations // inline implementations
template <typename T> template <typename T>
IdArray VecToIdArray(const std::vector<T>& vec, IdArray VecToIdArray(const std::vector<T>& vec, uint8_t nbits, DGLContext ctx) {
uint8_t nbits,
DGLContext ctx) {
IdArray ret = NewIdArray(vec.size(), DGLContext{kDGLCPU, 0}, nbits); IdArray ret = NewIdArray(vec.size(), DGLContext{kDGLCPU, 0}, nbits);
if (nbits == 32) { if (nbits == 32) {
std::copy(vec.begin(), vec.end(), static_cast<int32_t*>(ret->data)); std::copy(vec.begin(), vec.end(), static_cast<int32_t*>(ret->data));
...@@ -367,7 +369,8 @@ inline DGLContext GetContextOf(const std::vector<IdArray>& arrays) { ...@@ -367,7 +369,8 @@ inline DGLContext GetContextOf(const std::vector<IdArray>& arrays) {
first = false; first = false;
result = array->ctx; result = array->ctx;
} else { } else {
CHECK_EQ(array->ctx, result) << "Context of the input arrays are different"; CHECK_EQ(array->ctx, result)
<< "Context of the input arrays are different";
} }
} }
return result; return result;
......
...@@ -9,14 +9,16 @@ ...@@ -9,14 +9,16 @@
#include <dmlc/io.h> #include <dmlc/io.h>
#include <dmlc/serializer.h> #include <dmlc/serializer.h>
#include <vector>
#include <utility>
#include <tuple>
#include <string> #include <string>
#include "./types.h" #include <tuple>
#include <utility>
#include <vector>
#include "./array_ops.h" #include "./array_ops.h"
#include "./spmat.h"
#include "./macro.h" #include "./macro.h"
#include "./spmat.h"
#include "./types.h"
namespace dgl { namespace dgl {
namespace aten { namespace aten {
...@@ -52,9 +54,9 @@ struct COOMatrix { ...@@ -52,9 +54,9 @@ struct COOMatrix {
/** @brief default constructor */ /** @brief default constructor */
COOMatrix() = default; COOMatrix() = default;
/** @brief constructor */ /** @brief constructor */
COOMatrix(int64_t nrows, int64_t ncols, IdArray rarr, IdArray carr, COOMatrix(
IdArray darr = NullArray(), bool rsorted = false, int64_t nrows, int64_t ncols, IdArray rarr, IdArray carr,
bool csorted = false) IdArray darr = NullArray(), bool rsorted = false, bool csorted = false)
: num_rows(nrows), : num_rows(nrows),
num_cols(ncols), num_cols(ncols),
row(rarr), row(rarr),
...@@ -79,8 +81,9 @@ struct COOMatrix { ...@@ -79,8 +81,9 @@ struct COOMatrix {
// Convert to a SparseMatrix object that can return to python. // Convert to a SparseMatrix object that can return to python.
SparseMatrix ToSparseMatrix() const { SparseMatrix ToSparseMatrix() const {
return SparseMatrix(static_cast<int32_t>(SparseFormat::kCOO), num_rows, return SparseMatrix(
num_cols, {row, col, data}, {row_sorted, col_sorted}); static_cast<int32_t>(SparseFormat::kCOO), num_rows, num_cols,
{row, col, data}, {row_sorted, col_sorted});
} }
bool Load(dmlc::Stream* fs) { bool Load(dmlc::Stream* fs) {
...@@ -122,25 +125,24 @@ struct COOMatrix { ...@@ -122,25 +125,24 @@ struct COOMatrix {
} }
/** @brief Return a copy of this matrix on the give device context. */ /** @brief Return a copy of this matrix on the give device context. */
inline COOMatrix CopyTo(const DGLContext &ctx) const { inline COOMatrix CopyTo(const DGLContext& ctx) const {
if (ctx == row->ctx) if (ctx == row->ctx) return *this;
return *this; return COOMatrix(
return COOMatrix(num_rows, num_cols, row.CopyTo(ctx), col.CopyTo(ctx), num_rows, num_cols, row.CopyTo(ctx), col.CopyTo(ctx),
aten::IsNullArray(data) ? data : data.CopyTo(ctx), aten::IsNullArray(data) ? data : data.CopyTo(ctx), row_sorted,
row_sorted, col_sorted); col_sorted);
} }
/** /**
* @brief Pin the row, col and data (if not Null) of the matrix. * @brief Pin the row, col and data (if not Null) of the matrix.
* @note This is an in-place method. Behavior depends on the current context, * @note This is an in-place method. Behavior depends on the current context,
* kDGLCPU: will be pinned; * kDGLCPU: will be pinned;
* IsPinned: directly return; * IsPinned: directly return;
* kDGLCUDA: invalid, will throw an error. * kDGLCUDA: invalid, will throw an error.
* The context check is deferred to pinning the NDArray. * The context check is deferred to pinning the NDArray.
*/ */
inline void PinMemory_() { inline void PinMemory_() {
if (is_pinned) if (is_pinned) return;
return;
row.PinMemory_(); row.PinMemory_();
col.PinMemory_(); col.PinMemory_();
if (!aten::IsNullArray(data)) { if (!aten::IsNullArray(data)) {
...@@ -150,15 +152,14 @@ struct COOMatrix { ...@@ -150,15 +152,14 @@ struct COOMatrix {
} }
/** /**
* @brief Unpin the row, col and data (if not Null) of the matrix. * @brief Unpin the row, col and data (if not Null) of the matrix.
* @note This is an in-place method. Behavior depends on the current context, * @note This is an in-place method. Behavior depends on the current context,
* IsPinned: will be unpinned; * IsPinned: will be unpinned;
* others: directly return. * others: directly return.
* The context check is deferred to unpinning the NDArray. * The context check is deferred to unpinning the NDArray.
*/ */
inline void UnpinMemory_() { inline void UnpinMemory_() {
if (!is_pinned) if (!is_pinned) return;
return;
row.UnpinMemory_(); row.UnpinMemory_();
col.UnpinMemory_(); col.UnpinMemory_();
if (!aten::IsNullArray(data)) { if (!aten::IsNullArray(data)) {
...@@ -183,25 +184,25 @@ struct COOMatrix { ...@@ -183,25 +184,25 @@ struct COOMatrix {
///////////////////////// COO routines ////////////////////////// ///////////////////////// COO routines //////////////////////////
/** @brief Return true if the value (row, col) is non-zero */ /** @brief Return true if the value (row, col) is non-zero */
bool COOIsNonZero(COOMatrix , int64_t row, int64_t col); bool COOIsNonZero(COOMatrix, int64_t row, int64_t col);
/** /**
* @brief Batched implementation of COOIsNonZero. * @brief Batched implementation of COOIsNonZero.
* @note This operator allows broadcasting (i.e, either row or col can be of length 1). * @note This operator allows broadcasting (i.e, either row or col can be of
* length 1).
*/ */
runtime::NDArray COOIsNonZero(COOMatrix , runtime::NDArray row, runtime::NDArray col); runtime::NDArray COOIsNonZero(
COOMatrix, runtime::NDArray row, runtime::NDArray col);
/** @brief Return the nnz of the given row */ /** @brief Return the nnz of the given row */
int64_t COOGetRowNNZ(COOMatrix , int64_t row); int64_t COOGetRowNNZ(COOMatrix, int64_t row);
runtime::NDArray COOGetRowNNZ(COOMatrix , runtime::NDArray row); runtime::NDArray COOGetRowNNZ(COOMatrix, runtime::NDArray row);
/** @brief Return the data array of the given row */ /** @brief Return the data array of the given row */
std::pair<runtime::NDArray, runtime::NDArray> std::pair<runtime::NDArray, runtime::NDArray> COOGetRowDataAndIndices(
COOGetRowDataAndIndices(COOMatrix , int64_t row); COOMatrix, int64_t row);
/** @brief Whether the COO matrix contains data */ /** @brief Whether the COO matrix contains data */
inline bool COOHasData(COOMatrix csr) { inline bool COOHasData(COOMatrix csr) { return !IsNullArray(csr.data); }
return !IsNullArray(csr.data);
}
/** /**
* @brief Check whether the COO is sorted. * @brief Check whether the COO is sorted.
...@@ -217,11 +218,12 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo); ...@@ -217,11 +218,12 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo);
/** /**
* @brief Get the data and the row,col indices for each returned entries. * @brief Get the data and the row,col indices for each returned entries.
* *
* The operator supports matrix with duplicate entries and all the matched entries * The operator supports matrix with duplicate entries and all the matched
* will be returned. The operator assumes there is NO duplicate (row, col) pair * entries will be returned. The operator assumes there is NO duplicate (row,
* in the given input. Otherwise, the returned result is undefined. * col) pair in the given input. Otherwise, the returned result is undefined.
* *
* @note This operator allows broadcasting (i.e, either row or col can be of length 1). * @note This operator allows broadcasting (i.e, either row or col can be of
* length 1).
* @param mat Sparse matrix * @param mat Sparse matrix
* @param rows Row index * @param rows Row index
* @param cols Column index * @param cols Column index
...@@ -230,10 +232,13 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo); ...@@ -230,10 +232,13 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo);
std::vector<runtime::NDArray> COOGetDataAndIndices( std::vector<runtime::NDArray> COOGetDataAndIndices(
COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols); COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols);
/** @brief Get data. The return type is an ndarray due to possible duplicate entries. */ /** @brief Get data. The return type is an ndarray due to possible duplicate
* entries. */
inline runtime::NDArray COOGetAllData(COOMatrix mat, int64_t row, int64_t col) { inline runtime::NDArray COOGetAllData(COOMatrix mat, int64_t row, int64_t col) {
IdArray rows = VecToIdArray<int64_t>({row}, mat.row->dtype.bits, mat.row->ctx); IdArray rows =
IdArray cols = VecToIdArray<int64_t>({col}, mat.row->dtype.bits, mat.row->ctx); VecToIdArray<int64_t>({row}, mat.row->dtype.bits, mat.row->ctx);
IdArray cols =
VecToIdArray<int64_t>({col}, mat.row->dtype.bits, mat.row->ctx);
const auto& rst = COOGetDataAndIndices(mat, rows, cols); const auto& rst = COOGetDataAndIndices(mat, rows, cols);
return rst[2]; return rst[2];
} }
...@@ -241,18 +246,20 @@ inline runtime::NDArray COOGetAllData(COOMatrix mat, int64_t row, int64_t col) { ...@@ -241,18 +246,20 @@ inline runtime::NDArray COOGetAllData(COOMatrix mat, int64_t row, int64_t col) {
/** /**
* @brief Get the data for each (row, col) pair. * @brief Get the data for each (row, col) pair.
* *
* The operator supports matrix with duplicate entries but only one matched entry * The operator supports matrix with duplicate entries but only one matched
* will be returned for each (row, col) pair. Support duplicate input (row, col) * entry will be returned for each (row, col) pair. Support duplicate input
* pairs. * (row, col) pairs.
* *
* @note This operator allows broadcasting (i.e, either row or col can be of length 1). * @note This operator allows broadcasting (i.e, either row or col can be of
* length 1).
* *
* @param mat Sparse matrix. * @param mat Sparse matrix.
* @param rows Row index. * @param rows Row index.
* @param cols Column index. * @param cols Column index.
* @return Data array. The i^th element is the data of (rows[i], cols[i]) * @return Data array. The i^th element is the data of (rows[i], cols[i])
*/ */
runtime::NDArray COOGetData(COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols); runtime::NDArray COOGetData(
COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols);
/** @brief Return a transposed COO matrix */ /** @brief Return a transposed COO matrix */
COOMatrix COOTranspose(COOMatrix coo); COOMatrix COOTranspose(COOMatrix coo);
...@@ -301,14 +308,15 @@ COOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows); ...@@ -301,14 +308,15 @@ COOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows);
* @param cols The col index to select * @param cols The col index to select
* @return submatrix * @return submatrix
*/ */
COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols); COOMatrix COOSliceMatrix(
COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
/** @return True if the matrix has duplicate entries */ /** @return True if the matrix has duplicate entries */
bool COOHasDuplicate(COOMatrix coo); bool COOHasDuplicate(COOMatrix coo);
/** /**
* @brief Deduplicate the entries of a sorted COO matrix, replacing the data with the * @brief Deduplicate the entries of a sorted COO matrix, replacing the data
* number of occurrences of the row-col coordinates. * with the number of occurrences of the row-col coordinates.
*/ */
std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo); std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo);
...@@ -316,11 +324,13 @@ std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo); ...@@ -316,11 +324,13 @@ std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo);
* @brief Sort the indices of a COO matrix in-place. * @brief Sort the indices of a COO matrix in-place.
* *
* The function sorts row indices in ascending order. If sort_column is true, * The function sorts row indices in ascending order. If sort_column is true,
* col indices are sorted in ascending order too. The data array of the returned COOMatrix * col indices are sorted in ascending order too. The data array of the returned
* stores the shuffled index which could be used to fetch edge data. * COOMatrix stores the shuffled index which could be used to fetch edge data.
* *
* Complexity: O(N*log(N)) time and O(1) space, where N is the number of nonzeros. * Complexity: O(N*log(N)) time and O(1) space, where N is the number of
* TODO(minjie): The time complexity could be improved to O(N) by using a O(N) space. * nonzeros.
* TODO(minjie): The time complexity could be improved to O(N) by using a O(N)
* space.
* *
* @param mat The coo matrix to sort. * @param mat The coo matrix to sort.
* @param sort_column True if column index should be sorted too. * @param sort_column True if column index should be sorted too.
...@@ -331,31 +341,32 @@ void COOSort_(COOMatrix* mat, bool sort_column = false); ...@@ -331,31 +341,32 @@ void COOSort_(COOMatrix* mat, bool sort_column = false);
* @brief Sort the indices of a COO matrix. * @brief Sort the indices of a COO matrix.
* *
* The function sorts row indices in ascending order. If sort_column is true, * The function sorts row indices in ascending order. If sort_column is true,
* col indices are sorted in ascending order too. The data array of the returned COOMatrix * col indices are sorted in ascending order too. The data array of the returned
* stores the shuffled index which could be used to fetch edge data. * COOMatrix stores the shuffled index which could be used to fetch edge data.
* *
* Complexity: O(N*log(N)) time and O(1) space, where N is the number of nonzeros. * Complexity: O(N*log(N)) time and O(1) space, where N is the number of
* TODO(minjie): The time complexity could be improved to O(N) by using a O(N) space. * nonzeros.
* TODO(minjie): The time complexity could be improved to O(N) by using a O(N)
* space.
* *
* @param mat The input coo matrix * @param mat The input coo matrix
* @param sort_column True if column index should be sorted too. * @param sort_column True if column index should be sorted too.
* @return COO matrix with index sorted. * @return COO matrix with index sorted.
*/ */
inline COOMatrix COOSort(COOMatrix mat, bool sort_column = false) { inline COOMatrix COOSort(COOMatrix mat, bool sort_column = false) {
if ((mat.row_sorted && !sort_column) || mat.col_sorted) if ((mat.row_sorted && !sort_column) || mat.col_sorted) return mat;
return mat; COOMatrix ret(
COOMatrix ret(mat.num_rows, mat.num_cols, mat.num_rows, mat.num_cols, mat.row.Clone(), mat.col.Clone(),
mat.row.Clone(), mat.col.Clone(), COOHasData(mat) ? mat.data.Clone() : mat.data, mat.row_sorted,
COOHasData(mat)? mat.data.Clone() : mat.data, mat.col_sorted);
mat.row_sorted, mat.col_sorted);
COOSort_(&ret, sort_column); COOSort_(&ret, sort_column);
return ret; return ret;
} }
/** /**
* @brief Remove entries from COO matrix by entry indices (data indices) * @brief Remove entries from COO matrix by entry indices (data indices)
* @return A new COO matrix as well as a mapping from the new COO entries to the old COO * @return A new COO matrix as well as a mapping from the new COO entries to the
* entries. * old COO entries.
*/ */
COOMatrix COORemove(COOMatrix coo, IdArray entries); COOMatrix COORemove(COOMatrix coo, IdArray entries);
...@@ -365,10 +376,12 @@ COOMatrix COORemove(COOMatrix coo, IdArray entries); ...@@ -365,10 +376,12 @@ COOMatrix COORemove(COOMatrix coo, IdArray entries);
* @param new_row_ids the new row Ids (the index is the old row Id) * @param new_row_ids the new row Ids (the index is the old row Id)
* @param new_col_ids the new column Ids (the index is the old col Id). * @param new_col_ids the new column Ids (the index is the old col Id).
*/ */
COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids); COOMatrix COOReorder(
COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
/** /**
* @brief Randomly select a fixed number of non-zero entries along each given row independently. * @brief Randomly select a fixed number of non-zero entries along each given
* row independently.
* *
* The function performs random choices along each row independently. * The function performs random choices along each row independently.
* The picked indices are returned in the form of a COO matrix. * The picked indices are returned in the form of a COO matrix.
...@@ -400,15 +413,12 @@ COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArr ...@@ -400,15 +413,12 @@ COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArr
* Should be of the same length as the data array. * Should be of the same length as the data array.
* If an empty array is provided, assume uniform. * If an empty array is provided, assume uniform.
* @param replace True if sample with replacement * @param replace True if sample with replacement
* @return A COOMatrix storing the picked row and col indices. Its data field stores the * @return A COOMatrix storing the picked row and col indices. Its data field
* the index of the picked elements in the value array. * stores the the index of the picked elements in the value array.
*/ */
COOMatrix COORowWiseSampling( COOMatrix COORowWiseSampling(
COOMatrix mat, COOMatrix mat, IdArray rows, int64_t num_samples,
IdArray rows, NDArray prob_or_mask = NDArray(), bool replace = true);
int64_t num_samples,
NDArray prob_or_mask = NDArray(),
bool replace = true);
/** /**
* @brief Randomly select a fixed number of non-zero entries for each edge type * @brief Randomly select a fixed number of non-zero entries for each edge type
...@@ -433,8 +443,8 @@ COOMatrix COORowWiseSampling( ...@@ -433,8 +443,8 @@ COOMatrix COORowWiseSampling(
* COOMatrix coo = ...; * COOMatrix coo = ...;
* IdArray rows = ... ; // [0, 3] * IdArray rows = ... ; // [0, 3]
* std::vector<int64_t> num_samples = {2, 2, 2}; * std::vector<int64_t> num_samples = {2, 2, 2};
* COOMatrix sampled = COORowWisePerEtypeSampling(coo, rows, eid2etype_offset, num_samples, * COOMatrix sampled = COORowWisePerEtypeSampling(coo, rows, eid2etype_offset,
* FloatArray(), false); * num_samples, FloatArray(), false);
* // possible sampled coo matrix: * // possible sampled coo matrix:
* // sampled.num_rows = 4 * // sampled.num_rows = 4
* // sampled.num_cols = 4 * // sampled.num_cols = 4
...@@ -450,20 +460,18 @@ COOMatrix COORowWiseSampling( ...@@ -450,20 +460,18 @@ COOMatrix COORowWiseSampling(
* Should be of the same length as the data array. * Should be of the same length as the data array.
* If an empty array is provided, assume uniform. * If an empty array is provided, assume uniform.
* @param replace True if sample with replacement * @param replace True if sample with replacement
* @return A COOMatrix storing the picked row and col indices. Its data field stores the * @return A COOMatrix storing the picked row and col indices. Its data field
* the index of the picked elements in the value array. * stores the the index of the picked elements in the value array.
* @note The edges of the entire graph must be ordered by their edge types. * @note The edges of the entire graph must be ordered by their edge types.
*/ */
COOMatrix COORowWisePerEtypeSampling( COOMatrix COORowWisePerEtypeSampling(
COOMatrix mat, COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
IdArray rows,
const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples, const std::vector<int64_t>& num_samples,
const std::vector<NDArray>& prob_or_mask, const std::vector<NDArray>& prob_or_mask, bool replace = true);
bool replace = true);
/** /**
* @brief Select K non-zero entries with the largest weights along each given row. * @brief Select K non-zero entries with the largest weights along each given
* row.
* *
* The function performs top-k selection along each row independently. * The function performs top-k selection along each row independently.
* The picked indices are returned in the form of a COO matrix. * The picked indices are returned in the form of a COO matrix.
...@@ -492,18 +500,15 @@ COOMatrix COORowWisePerEtypeSampling( ...@@ -492,18 +500,15 @@ COOMatrix COORowWisePerEtypeSampling(
* @param mat Input COO matrix. * @param mat Input COO matrix.
* @param rows Rows to sample from. * @param rows Rows to sample from.
* @param k The K value. * @param k The K value.
* @param weight Weight associated with each entry. Should be of the same length as the * @param weight Weight associated with each entry. Should be of the same length
* data array. If an empty array is provided, assume uniform. * as the data array. If an empty array is provided, assume uniform.
* @param ascending If true, elements are sorted by ascending order, equivalent to find * @param ascending If true, elements are sorted by ascending order, equivalent
* the K smallest values. Otherwise, find K largest values. * to find the K smallest values. Otherwise, find K largest values.
* @return A COOMatrix storing the picked row and col indices. Its data field stores the * @return A COOMatrix storing the picked row and col indices. Its data field
* the index of the picked elements in the value array. * stores the the index of the picked elements in the value array.
*/ */
COOMatrix COORowWiseTopk( COOMatrix COORowWiseTopk(
COOMatrix mat, COOMatrix mat, IdArray rows, int64_t k, NDArray weight,
IdArray rows,
int64_t k,
NDArray weight,
bool ascending = false); bool ascending = false);
/** /**
...@@ -535,8 +540,7 @@ COOMatrix COORowWiseTopk( ...@@ -535,8 +540,7 @@ COOMatrix COORowWiseTopk(
* COOMatrix_C.num_rows : 3 * COOMatrix_C.num_rows : 3
* COOMatrix_C.num_cols : 4 * COOMatrix_C.num_cols : 4
*/ */
COOMatrix UnionCoo( COOMatrix UnionCoo(const std::vector<COOMatrix>& coos);
const std::vector<COOMatrix>& coos);
/** /**
* @brief DisjointUnion a list COOMatrix into one COOMatrix. * @brief DisjointUnion a list COOMatrix into one COOMatrix.
...@@ -566,12 +570,13 @@ COOMatrix UnionCoo( ...@@ -566,12 +570,13 @@ COOMatrix UnionCoo(
* COOMatrix_C.num_cols : 5 * COOMatrix_C.num_cols : 5
* *
* @param coos The input list of coo matrix. * @param coos The input list of coo matrix.
* @param src_offset A list of integers recording src vertix id offset of each Matrix in coos * @param src_offset A list of integers recording src vertix id offset of each
* @param src_offset A list of integers recording dst vertix id offset of each Matrix in coos * Matrix in coos
* @param src_offset A list of integers recording dst vertix id offset of each
* Matrix in coos
* @return The combined COOMatrix. * @return The combined COOMatrix.
*/ */
COOMatrix DisjointUnionCoo( COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos);
const std::vector<COOMatrix>& coos);
/** /**
* @brief COOMatrix toSimple. * @brief COOMatrix toSimple.
...@@ -591,8 +596,8 @@ COOMatrix DisjointUnionCoo( ...@@ -591,8 +596,8 @@ COOMatrix DisjointUnionCoo(
* edge_map = [0, 0, 0, 1, 1, 2, 3, 4, 4, 4, 4] * edge_map = [0, 0, 0, 1, 1, 2, 3, 4, 4, 4, 4]
* *
* @return The simplified COOMatrix * @return The simplified COOMatrix
* The count recording the number of duplicated edges from the original graph. * The count recording the number of duplicated edges from the original
* The edge mapping from the edge IDs of original graph to those of the * graph. The edge mapping from the edge IDs of original graph to those of the
* returned graph. * returned graph.
*/ */
std::tuple<COOMatrix, IdArray, IdArray> COOToSimple(const COOMatrix& coo); std::tuple<COOMatrix, IdArray, IdArray> COOToSimple(const COOMatrix& coo);
...@@ -642,11 +647,10 @@ std::tuple<COOMatrix, IdArray, IdArray> COOToSimple(const COOMatrix& coo); ...@@ -642,11 +647,10 @@ std::tuple<COOMatrix, IdArray, IdArray> COOToSimple(const COOMatrix& coo);
* @return A list of COOMatrixes representing each disjoint components. * @return A list of COOMatrixes representing each disjoint components.
*/ */
std::vector<COOMatrix> DisjointPartitionCooBySizes( std::vector<COOMatrix> DisjointPartitionCooBySizes(
const COOMatrix &coo, const COOMatrix& coo, const uint64_t batch_size,
const uint64_t batch_size, const std::vector<uint64_t>& edge_cumsum,
const std::vector<uint64_t> &edge_cumsum, const std::vector<uint64_t>& src_vertex_cumsum,
const std::vector<uint64_t> &src_vertex_cumsum, const std::vector<uint64_t>& dst_vertex_cumsum);
const std::vector<uint64_t> &dst_vertex_cumsum);
/** /**
* @brief Slice a contiguous chunk from a COOMatrix * @brief Slice a contiguous chunk from a COOMatrix
...@@ -684,10 +688,9 @@ std::vector<COOMatrix> DisjointPartitionCooBySizes( ...@@ -684,10 +688,9 @@ std::vector<COOMatrix> DisjointPartitionCooBySizes(
* @return COOMatrix representing the chunk. * @return COOMatrix representing the chunk.
*/ */
COOMatrix COOSliceContiguousChunk( COOMatrix COOSliceContiguousChunk(
const COOMatrix &coo, const COOMatrix& coo, const std::vector<uint64_t>& edge_range,
const std::vector<uint64_t> &edge_range, const std::vector<uint64_t>& src_vertex_range,
const std::vector<uint64_t> &src_vertex_range, const std::vector<uint64_t>& dst_vertex_range);
const std::vector<uint64_t> &dst_vertex_range);
/** /**
* @brief Create a LineGraph of input coo * @brief Create a LineGraph of input coo
...@@ -716,10 +719,11 @@ COOMatrix COOSliceContiguousChunk( ...@@ -716,10 +719,11 @@ COOMatrix COOSliceContiguousChunk(
* [0, 1, 1, 0, 0]] * [0, 1, 1, 0, 0]]
* *
* @param coo COOMatrix to create the LineGraph * @param coo COOMatrix to create the LineGraph
* @param backtracking whether the pair of (v, u) (u, v) edges are treated as linked * @param backtracking whether the pair of (v, u) (u, v) edges are treated as
* linked
* @return LineGraph in COO format * @return LineGraph in COO format
*/ */
COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking); COOMatrix COOLineGraph(const COOMatrix& coo, bool backtracking);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
......
...@@ -223,8 +223,10 @@ bool CSRIsSorted(CSRMatrix csr); ...@@ -223,8 +223,10 @@ bool CSRIsSorted(CSRMatrix csr);
std::vector<runtime::NDArray> CSRGetDataAndIndices( std::vector<runtime::NDArray> CSRGetDataAndIndices(
CSRMatrix, runtime::NDArray rows, runtime::NDArray cols); CSRMatrix, runtime::NDArray rows, runtime::NDArray cols);
/* @brief Get data. The return type is an ndarray due to possible duplicate /**
* entries. */ * @brief Get data. The return type is an ndarray due to possible duplicate
* entries.
*/
inline runtime::NDArray CSRGetAllData(CSRMatrix mat, int64_t row, int64_t col) { inline runtime::NDArray CSRGetAllData(CSRMatrix mat, int64_t row, int64_t col) {
const auto& nbits = mat.indptr->dtype.bits; const auto& nbits = mat.indptr->dtype.bits;
const auto& ctx = mat.indptr->ctx; const auto& ctx = mat.indptr->ctx;
......
...@@ -17,16 +17,16 @@ ...@@ -17,16 +17,16 @@
* DeviceSpecificImplementation<XPU>(...); * DeviceSpecificImplementation<XPU>(...);
* }); * });
*/ */
#define ATEN_XPU_SWITCH(val, XPU, op, ...) do { \ #define ATEN_XPU_SWITCH(val, XPU, op, ...) \
if ((val) == kDGLCPU) { \ do { \
constexpr auto XPU = kDGLCPU; \ if ((val) == kDGLCPU) { \
{__VA_ARGS__} \ constexpr auto XPU = kDGLCPU; \
} else { \ { __VA_ARGS__ } \
LOG(FATAL) << "Operator " << (op) << " does not support " \ } else { \
<< dgl::runtime::DeviceTypeCode2Str(val) \ LOG(FATAL) << "Operator " << (op) << " does not support " \
<< " device."; \ << dgl::runtime::DeviceTypeCode2Str(val) << " device."; \
} \ } \
} while (0) } while (0)
/** /**
* Dispatch according to device: * Dispatch according to device:
...@@ -37,24 +37,24 @@ ...@@ -37,24 +37,24 @@
* // Now XPU is a placeholder for array->ctx.device_type * // Now XPU is a placeholder for array->ctx.device_type
* DeviceSpecificImplementation<XPU>(...); * DeviceSpecificImplementation<XPU>(...);
* }); * });
* *
* We treat pinned memory as normal host memory if we don't want * We treat pinned memory as normal host memory if we don't want
* to enable CUDA UVA access for this operator * to enable CUDA UVA access for this operator
*/ */
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
#define ATEN_XPU_SWITCH_CUDA(val, XPU, op, ...) do { \ #define ATEN_XPU_SWITCH_CUDA(val, XPU, op, ...) \
if ((val) == kDGLCPU) { \ do { \
constexpr auto XPU = kDGLCPU; \ if ((val) == kDGLCPU) { \
{__VA_ARGS__} \ constexpr auto XPU = kDGLCPU; \
} else if ((val) == kDGLCUDA) { \ { __VA_ARGS__ } \
constexpr auto XPU = kDGLCUDA; \ } else if ((val) == kDGLCUDA) { \
{__VA_ARGS__} \ constexpr auto XPU = kDGLCUDA; \
} else { \ { __VA_ARGS__ } \
LOG(FATAL) << "Operator " << (op) << " does not support " \ } else { \
<< dgl::runtime::DeviceTypeCode2Str(val) \ LOG(FATAL) << "Operator " << (op) << " does not support " \
<< " device."; \ << dgl::runtime::DeviceTypeCode2Str(val) << " device."; \
} \ } \
} while (0) } while (0)
#else // DGL_USE_CUDA #else // DGL_USE_CUDA
#define ATEN_XPU_SWITCH_CUDA ATEN_XPU_SWITCH #define ATEN_XPU_SWITCH_CUDA ATEN_XPU_SWITCH
#endif // DGL_USE_CUDA #endif // DGL_USE_CUDA
...@@ -68,18 +68,19 @@ ...@@ -68,18 +68,19 @@
* DType *data = static_cast<DType *>(array->data); * DType *data = static_cast<DType *>(array->data);
* }); * });
*/ */
#define ATEN_ID_TYPE_SWITCH(val, IdType, ...) do { \ #define ATEN_ID_TYPE_SWITCH(val, IdType, ...) \
CHECK_EQ((val).code, kDGLInt) << "ID must be integer type"; \ do { \
if ((val).bits == 32) { \ CHECK_EQ((val).code, kDGLInt) << "ID must be integer type"; \
typedef int32_t IdType; \ if ((val).bits == 32) { \
{__VA_ARGS__} \ typedef int32_t IdType; \
} else if ((val).bits == 64) { \ { __VA_ARGS__ } \
typedef int64_t IdType; \ } else if ((val).bits == 64) { \
{__VA_ARGS__} \ typedef int64_t IdType; \
} else { \ { __VA_ARGS__ } \
LOG(FATAL) << "ID can only be int32 or int64"; \ } else { \
} \ LOG(FATAL) << "ID can only be int32 or int64"; \
} while (0) } \
} while (0)
/** /**
* Dispatch according to bits (either int32 or int64): * Dispatch according to bits (either int32 or int64):
...@@ -90,18 +91,18 @@ ...@@ -90,18 +91,18 @@
* DType *data = static_cast<DType *>(array->data); * DType *data = static_cast<DType *>(array->data);
* }); * });
*/ */
#define ATEN_ID_BITS_SWITCH(bits, IdType, ...) \ #define ATEN_ID_BITS_SWITCH(bits, IdType, ...) \
do { \ do { \
CHECK((bits) == 32 || (bits) == 64) << "bits must be 32 or 64"; \ CHECK((bits) == 32 || (bits) == 64) << "bits must be 32 or 64"; \
if ((bits) == 32) { \ if ((bits) == 32) { \
typedef int32_t IdType; \ typedef int32_t IdType; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else if ((bits) == 64) { \ } else if ((bits) == 64) { \
typedef int64_t IdType; \ typedef int64_t IdType; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else { \ } else { \
LOG(FATAL) << "ID can only be int32 or int64"; \ LOG(FATAL) << "ID can only be int32 or int64"; \
} \ } \
} while (0) } while (0)
/** /**
...@@ -113,75 +114,79 @@ ...@@ -113,75 +114,79 @@
* FloatType *data = static_cast<FloatType *>(array->data); * FloatType *data = static_cast<FloatType *>(array->data);
* }); * });
*/ */
#define ATEN_FLOAT_TYPE_SWITCH(val, FloatType, val_name, ...) do { \ #define ATEN_FLOAT_TYPE_SWITCH(val, FloatType, val_name, ...) \
CHECK_EQ((val).code, kDGLFloat) \ do { \
<< (val_name) << " must be float type"; \ CHECK_EQ((val).code, kDGLFloat) << (val_name) << " must be float type"; \
if ((val).bits == 32) { \ if ((val).bits == 32) { \
typedef float FloatType; \ typedef float FloatType; \
{__VA_ARGS__} \ { __VA_ARGS__ } \
} else if ((val).bits == 64) { \ } else if ((val).bits == 64) { \
typedef double FloatType; \ typedef double FloatType; \
{__VA_ARGS__} \ { __VA_ARGS__ } \
} else { \ } else { \
LOG(FATAL) << (val_name) \ LOG(FATAL) << (val_name) << " can only be float32 or float64"; \
<< " can only be float32 or float64"; \ } \
} \ } while (0)
} while (0)
/** /**
* Dispatch according to float type, including 16bits (float16/bfloat16/float32/float64). * Dispatch according to float type, including 16bits
* (float16/bfloat16/float32/float64).
*/ */
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
#if BF16_ENABLED #if BF16_ENABLED
#define ATEN_FLOAT_TYPE_SWITCH_16BITS(val, FloatType, XPU, val_name, ...) do { \ #define ATEN_FLOAT_TYPE_SWITCH_16BITS(val, FloatType, XPU, val_name, ...) \
CHECK((val).code == kDGLFloat || (val.code == kDGLBfloat)) \ do { \
<< (val_name) << " must be float type"; \ CHECK((val).code == kDGLFloat || (val.code == kDGLBfloat)) \
if ((val).bits == 32) { \ << (val_name) << " must be float type"; \
typedef float FloatType; \ if ((val).bits == 32) { \
{__VA_ARGS__} \ typedef float FloatType; \
} else if ((val).bits == 64) { \ { __VA_ARGS__ } \
typedef double FloatType; \ } else if ((val).bits == 64) { \
{__VA_ARGS__} \ typedef double FloatType; \
} else if (XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLFloat) { \ { __VA_ARGS__ } \
typedef __half FloatType; \ } else if ( \
{__VA_ARGS__} \ XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLFloat) { \
} else if (XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLBfloat) { \ typedef __half FloatType; \
typedef __nv_bfloat16 FloatType; \ { __VA_ARGS__ } \
{__VA_ARGS__} \ } else if ( \
} else if (XPU == kDGLCPU) { \ XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLBfloat) { \
LOG(FATAL) << (val_name) \ typedef __nv_bfloat16 FloatType; \
<< " can only be float32 or float64 on CPU"; \ { __VA_ARGS__ } \
} else { \ } else if (XPU == kDGLCPU) { \
LOG(FATAL) << (val_name) \ LOG(FATAL) << (val_name) << " can only be float32 or float64 on CPU"; \
<< " can only be float16/bfloat16/float32/float64 on GPU"; \ } else { \
} \ LOG(FATAL) << (val_name) \
} while (0) << " can only be float16/bfloat16/float32/float64 on GPU"; \
} \
} while (0)
#else // BF16_ENABLED #else // BF16_ENABLED
#define ATEN_FLOAT_TYPE_SWITCH_16BITS(val, FloatType, XPU, val_name, ...) do { \ #define ATEN_FLOAT_TYPE_SWITCH_16BITS(val, FloatType, XPU, val_name, ...) \
CHECK((val).code == kDGLFloat || (val.code == kDGLBfloat)) \ do { \
<< (val_name) << " must be float type"; \ CHECK((val).code == kDGLFloat || (val.code == kDGLBfloat)) \
if ((val).bits == 32) { \ << (val_name) << " must be float type"; \
typedef float FloatType; \ if ((val).bits == 32) { \
{__VA_ARGS__} \ typedef float FloatType; \
} else if ((val).bits == 64) { \ { __VA_ARGS__ } \
typedef double FloatType; \ } else if ((val).bits == 64) { \
{__VA_ARGS__} \ typedef double FloatType; \
} else if (XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLFloat) { \ { __VA_ARGS__ } \
typedef __half FloatType; \ } else if ( \
{__VA_ARGS__} \ XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLFloat) { \
} else if (XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLBfloat) { \ typedef __half FloatType; \
LOG(FATAL) << "bfloat16 requires CUDA >= 11.0"; \ { __VA_ARGS__ } \
} else if (XPU == kDGLCPU) { \ } else if ( \
LOG(FATAL) << (val_name) \ XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLBfloat) { \
<< " can only be float32 or float64 on CPU"; \ LOG(FATAL) << "bfloat16 requires CUDA >= 11.0"; \
} else { \ } else if (XPU == kDGLCPU) { \
LOG(FATAL) << (val_name) \ LOG(FATAL) << (val_name) << " can only be float32 or float64 on CPU"; \
<< " can only be float16/float32/float64 on GPU"; \ } else { \
} \ LOG(FATAL) << (val_name) \
} while (0) << " can only be float16/float32/float64 on GPU"; \
} \
} while (0)
#endif // BF16_ENABLED #endif // BF16_ENABLED
#else // DGL_USE_CUDA #else // DGL_USE_CUDA
#define ATEN_FLOAT_TYPE_SWITCH_16BITS(val, FloatType, XPU, val_name, ...) \ #define ATEN_FLOAT_TYPE_SWITCH_16BITS(val, FloatType, XPU, val_name, ...) \
ATEN_FLOAT_TYPE_SWITCH(val, FloatType, val_name, {__VA_ARGS__}) ATEN_FLOAT_TYPE_SWITCH(val, FloatType, val_name, {__VA_ARGS__})
#endif // DGL_USE_CUDA #endif // DGL_USE_CUDA
...@@ -194,23 +199,25 @@ ...@@ -194,23 +199,25 @@
* DType *data = static_cast<DType *>(array->data); * DType *data = static_cast<DType *>(array->data);
* }); * });
*/ */
#define ATEN_DTYPE_SWITCH(val, DType, val_name, ...) do { \ #define ATEN_DTYPE_SWITCH(val, DType, val_name, ...) \
if ((val).code == kDGLInt && (val).bits == 32) { \ do { \
typedef int32_t DType; \ if ((val).code == kDGLInt && (val).bits == 32) { \
{__VA_ARGS__} \ typedef int32_t DType; \
} else if ((val).code == kDGLInt && (val).bits == 64) { \ { __VA_ARGS__ } \
typedef int64_t DType; \ } else if ((val).code == kDGLInt && (val).bits == 64) { \
{__VA_ARGS__} \ typedef int64_t DType; \
} else if ((val).code == kDGLFloat && (val).bits == 32) { \ { __VA_ARGS__ } \
typedef float DType; \ } else if ((val).code == kDGLFloat && (val).bits == 32) { \
{__VA_ARGS__} \ typedef float DType; \
} else if ((val).code == kDGLFloat && (val).bits == 64) { \ { __VA_ARGS__ } \
typedef double DType; \ } else if ((val).code == kDGLFloat && (val).bits == 64) { \
{__VA_ARGS__} \ typedef double DType; \
} else { \ { __VA_ARGS__ } \
LOG(FATAL) << (val_name) << " can only be int32, int64, float32 or float64"; \ } else { \
} \ LOG(FATAL) << (val_name) \
} while (0) << " can only be int32, int64, float32 or float64"; \
} \
} while (0)
/** /**
* Dispatch according to data type (int8, uint8, float32 or float64): * Dispatch according to data type (int8, uint8, float32 or float64):
...@@ -221,23 +228,25 @@ ...@@ -221,23 +228,25 @@
* DType *data = static_cast<DType *>(array->data); * DType *data = static_cast<DType *>(array->data);
* }); * });
*/ */
#define ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(val, DType, val_name, ...) do { \ #define ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(val, DType, val_name, ...) \
if ((val).code == kDGLInt && (val).bits == 8) { \ do { \
typedef int8_t DType; \ if ((val).code == kDGLInt && (val).bits == 8) { \
{__VA_ARGS__} \ typedef int8_t DType; \
} else if ((val).code == kDGLUInt && (val).bits == 8) { \ { __VA_ARGS__ } \
typedef uint8_t DType; \ } else if ((val).code == kDGLUInt && (val).bits == 8) { \
{__VA_ARGS__} \ typedef uint8_t DType; \
} else if ((val).code == kDGLFloat && (val).bits == 32) { \ { __VA_ARGS__ } \
typedef float DType; \ } else if ((val).code == kDGLFloat && (val).bits == 32) { \
{__VA_ARGS__} \ typedef float DType; \
} else if ((val).code == kDGLFloat && (val).bits == 64) { \ { __VA_ARGS__ } \
typedef double DType; \ } else if ((val).code == kDGLFloat && (val).bits == 64) { \
{__VA_ARGS__} \ typedef double DType; \
} else { \ { __VA_ARGS__ } \
LOG(FATAL) << (val_name) << " can only be int8, uint8, float32 or float64"; \ } else { \
} \ LOG(FATAL) << (val_name) \
} while (0) << " can only be int8, uint8, float32 or float64"; \
} \
} while (0)
/** /**
* Dispatch data type only based on bit-width (8-bit, 16-bit, 32-bit, 64-bit): * Dispatch data type only based on bit-width (8-bit, 16-bit, 32-bit, 64-bit):
...@@ -250,61 +259,62 @@ ...@@ -250,61 +259,62 @@
* DType *data = static_cast<DType *>(array->data); * DType *data = static_cast<DType *>(array->data);
* }); * });
*/ */
#define ATEN_DTYPE_BITS_ONLY_SWITCH(val, DType, val_name, ...) do { \ #define ATEN_DTYPE_BITS_ONLY_SWITCH(val, DType, val_name, ...) \
if ((val).bits == 8) { \ do { \
typedef int8_t DType; \ if ((val).bits == 8) { \
{__VA_ARGS__} \ typedef int8_t DType; \
} else if ((val).bits == 16) { \ { __VA_ARGS__ } \
typedef int16_t DType; \ } else if ((val).bits == 16) { \
{__VA_ARGS__} \ typedef int16_t DType; \
} else if ((val).bits == 32) { \ { __VA_ARGS__ } \
typedef int32_t DType; \ } else if ((val).bits == 32) { \
{__VA_ARGS__} \ typedef int32_t DType; \
} else if ((val).bits == 64) { \ { __VA_ARGS__ } \
typedef int64_t DType; \ } else if ((val).bits == 64) { \
{__VA_ARGS__} \ typedef int64_t DType; \
} else { \ { __VA_ARGS__ } \
LOG(FATAL) << (val_name) << " can only be 8-bit, 16-bit, 32-bit, or 64-bit"; \ } else { \
} \ LOG(FATAL) << (val_name) \
} while (0) << " can only be 8-bit, 16-bit, 32-bit, or 64-bit"; \
} \
} while (0)
/** /**
* Dispatch according to integral type of CSR graphs. * Dispatch according to integral type of CSR graphs.
* Identical to ATEN_ID_TYPE_SWITCH except for a different error message. * Identical to ATEN_ID_TYPE_SWITCH except for a different error message.
*/ */
#define ATEN_CSR_DTYPE_SWITCH(val, DType, ...) do { \ #define ATEN_CSR_DTYPE_SWITCH(val, DType, ...) \
if ((val).code == kDGLInt && (val).bits == 32) { \ do { \
typedef int32_t DType; \ if ((val).code == kDGLInt && (val).bits == 32) { \
{__VA_ARGS__} \ typedef int32_t DType; \
} else if ((val).code == kDGLInt && (val).bits == 64) { \ { __VA_ARGS__ } \
typedef int64_t DType; \ } else if ((val).code == kDGLInt && (val).bits == 64) { \
{__VA_ARGS__} \ typedef int64_t DType; \
} else { \ { __VA_ARGS__ } \
LOG(FATAL) << "CSR matrix data can only be int32 or int64"; \ } else { \
} \ LOG(FATAL) << "CSR matrix data can only be int32 or int64"; \
} while (0) } \
} while (0)
// Macro to dispatch according to device context and index type. // Macro to dispatch according to device context and index type.
#define ATEN_CSR_SWITCH(csr, XPU, IdType, op, ...) \ #define ATEN_CSR_SWITCH(csr, XPU, IdType, op, ...) \
ATEN_XPU_SWITCH((csr).indptr->ctx.device_type, XPU, op, { \ ATEN_XPU_SWITCH((csr).indptr->ctx.device_type, XPU, op, { \
ATEN_ID_TYPE_SWITCH((csr).indptr->dtype, IdType, { \ ATEN_ID_TYPE_SWITCH((csr).indptr->dtype, IdType, {{__VA_ARGS__}}); \
{__VA_ARGS__} \
}); \
}); });
// Macro to dispatch according to device context and index type. // Macro to dispatch according to device context and index type.
#define ATEN_COO_SWITCH(coo, XPU, IdType, op, ...) \ #define ATEN_COO_SWITCH(coo, XPU, IdType, op, ...) \
ATEN_XPU_SWITCH((coo).row->ctx.device_type, XPU, op, { \ ATEN_XPU_SWITCH((coo).row->ctx.device_type, XPU, op, { \
ATEN_ID_TYPE_SWITCH((coo).row->dtype, IdType, { \ ATEN_ID_TYPE_SWITCH((coo).row->dtype, IdType, {{__VA_ARGS__}}); \
{__VA_ARGS__} \
}); \
}); });
#define CHECK_VALID_CONTEXT(VAR1, VAR2) \ #define CHECK_VALID_CONTEXT(VAR1, VAR2) \
CHECK(((VAR1)->ctx == (VAR2)->ctx) || (VAR1).IsPinned()) \ CHECK(((VAR1)->ctx == (VAR2)->ctx) || (VAR1).IsPinned()) \
<< "Expected " << (#VAR2) << "(" << (VAR2)->ctx << ")" << " to have the same device " \ << "Expected " << (#VAR2) << "(" << (VAR2)->ctx << ")" \
<< "context as " << (#VAR1) << "(" << (VAR1)->ctx << "). " \ << " to have the same device " \
<< "Or " << (#VAR1) << "(" << (VAR1)->ctx << ")" << " is pinned"; << "context as " << (#VAR1) << "(" << (VAR1)->ctx << "). " \
<< "Or " << (#VAR1) << "(" << (VAR1)->ctx << ")" \
<< " is pinned";
/** /**
* Macro to dispatch according to the context of array and dtype of csr * Macro to dispatch according to the context of array and dtype of csr
...@@ -313,30 +323,25 @@ ...@@ -313,30 +323,25 @@
* If csr has the same context with array, same behivor as ATEN_CSR_SWITCH_CUDA. * If csr has the same context with array, same behivor as ATEN_CSR_SWITCH_CUDA.
* If csr is pinned, array's context will conduct the actual operation. * If csr is pinned, array's context will conduct the actual operation.
*/ */
#define ATEN_CSR_SWITCH_CUDA_UVA(csr, array, XPU, IdType, op, ...) do { \ #define ATEN_CSR_SWITCH_CUDA_UVA(csr, array, XPU, IdType, op, ...) \
CHECK_VALID_CONTEXT(csr.indices, array); \ do { \
ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, op, { \ CHECK_VALID_CONTEXT(csr.indices, array); \
ATEN_ID_TYPE_SWITCH((csr).indptr->dtype, IdType, { \ ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, op, { \
{__VA_ARGS__} \ ATEN_ID_TYPE_SWITCH((csr).indptr->dtype, IdType, {{__VA_ARGS__}}); \
}); \ }); \
}); \ } while (0)
} while (0)
// Macro to dispatch according to device context (allowing cuda) // Macro to dispatch according to device context (allowing cuda)
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
#define ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, op, ...) \ #define ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, op, ...) \
ATEN_XPU_SWITCH_CUDA((csr).indptr->ctx.device_type, XPU, op, { \ ATEN_XPU_SWITCH_CUDA((csr).indptr->ctx.device_type, XPU, op, { \
ATEN_ID_TYPE_SWITCH((csr).indptr->dtype, IdType, { \ ATEN_ID_TYPE_SWITCH((csr).indptr->dtype, IdType, {{__VA_ARGS__}}); \
{__VA_ARGS__} \
}); \
}); });
// Macro to dispatch according to device context and index type. // Macro to dispatch according to device context and index type.
#define ATEN_COO_SWITCH_CUDA(coo, XPU, IdType, op, ...) \ #define ATEN_COO_SWITCH_CUDA(coo, XPU, IdType, op, ...) \
ATEN_XPU_SWITCH_CUDA((coo).row->ctx.device_type, XPU, op, { \ ATEN_XPU_SWITCH_CUDA((coo).row->ctx.device_type, XPU, op, { \
ATEN_ID_TYPE_SWITCH((coo).row->dtype, IdType, { \ ATEN_ID_TYPE_SWITCH((coo).row->dtype, IdType, {{__VA_ARGS__}}); \
{__VA_ARGS__} \
}); \
}); });
#else // DGL_USE_CUDA #else // DGL_USE_CUDA
#define ATEN_CSR_SWITCH_CUDA ATEN_CSR_SWITCH #define ATEN_CSR_SWITCH_CUDA ATEN_CSR_SWITCH
...@@ -345,54 +350,56 @@ ...@@ -345,54 +350,56 @@
///////////////////////// Array checks ////////////////////////// ///////////////////////// Array checks //////////////////////////
#define IS_INT32(a) \ #define IS_INT32(a) ((a)->dtype.code == kDGLInt && (a)->dtype.bits == 32)
((a)->dtype.code == kDGLInt && (a)->dtype.bits == 32) #define IS_INT64(a) ((a)->dtype.code == kDGLInt && (a)->dtype.bits == 64)
#define IS_INT64(a) \ #define IS_FLOAT32(a) ((a)->dtype.code == kDGLFloat && (a)->dtype.bits == 32)
((a)->dtype.code == kDGLInt && (a)->dtype.bits == 64) #define IS_FLOAT64(a) ((a)->dtype.code == kDGLFloat && (a)->dtype.bits == 64)
#define IS_FLOAT32(a) \
((a)->dtype.code == kDGLFloat && (a)->dtype.bits == 32)
#define IS_FLOAT64(a) \
((a)->dtype.code == kDGLFloat && (a)->dtype.bits == 64)
#define CHECK_IF(cond, prop, value_name, dtype_name) \ #define CHECK_IF(cond, prop, value_name, dtype_name) \
CHECK(cond) << "Expecting " << (prop) << " of " << (value_name) << " to be " << (dtype_name) CHECK(cond) << "Expecting " << (prop) << " of " << (value_name) << " to be " \
<< (dtype_name)
#define CHECK_INT32(value, value_name) \ #define CHECK_INT32(value, value_name) \
CHECK_IF(IS_INT32(value), "dtype", value_name, "int32") CHECK_IF(IS_INT32(value), "dtype", value_name, "int32")
#define CHECK_INT64(value, value_name) \ #define CHECK_INT64(value, value_name) \
CHECK_IF(IS_INT64(value), "dtype", value_name, "int64") CHECK_IF(IS_INT64(value), "dtype", value_name, "int64")
#define CHECK_INT(value, value_name) \ #define CHECK_INT(value, value_name) \
CHECK_IF(IS_INT32(value) || IS_INT64(value), "dtype", value_name, "int32 or int64") CHECK_IF( \
IS_INT32(value) || IS_INT64(value), "dtype", value_name, \
"int32 or int64")
#define CHECK_FLOAT32(value, value_name) \ #define CHECK_FLOAT32(value, value_name) \
CHECK_IF(IS_FLOAT32(value), "dtype", value_name, "float32") CHECK_IF(IS_FLOAT32(value), "dtype", value_name, "float32")
#define CHECK_FLOAT64(value, value_name) \ #define CHECK_FLOAT64(value, value_name) \
CHECK_IF(IS_FLOAT64(value), "dtype", value_name, "float64") CHECK_IF(IS_FLOAT64(value), "dtype", value_name, "float64")
#define CHECK_FLOAT(value, value_name) \ #define CHECK_FLOAT(value, value_name) \
CHECK_IF(IS_FLOAT32(value) || IS_FLOAT64(value), "dtype", value_name, "float32 or float64") CHECK_IF( \
IS_FLOAT32(value) || IS_FLOAT64(value), "dtype", value_name, \
"float32 or float64")
#define CHECK_NDIM(value, _ndim, value_name) \ #define CHECK_NDIM(value, _ndim, value_name) \
CHECK_IF((value)->ndim == (_ndim), "ndim", value_name, _ndim) CHECK_IF((value)->ndim == (_ndim), "ndim", value_name, _ndim)
#define CHECK_SAME_DTYPE(VAR1, VAR2) \ #define CHECK_SAME_DTYPE(VAR1, VAR2) \
CHECK((VAR1)->dtype == (VAR2)->dtype) \ CHECK((VAR1)->dtype == (VAR2)->dtype) \
<< "Expected " << (#VAR2) << " to be the same type as " << (#VAR1) << "(" \ << "Expected " << (#VAR2) << " to be the same type as " << (#VAR1) \
<< (VAR1)->dtype << ")" \ << "(" << (VAR1)->dtype << ")" \
<< ". But got " << (VAR2)->dtype << "."; << ". But got " << (VAR2)->dtype << ".";
#define CHECK_SAME_CONTEXT(VAR1, VAR2) \ #define CHECK_SAME_CONTEXT(VAR1, VAR2) \
CHECK((VAR1)->ctx == (VAR2)->ctx) \ CHECK((VAR1)->ctx == (VAR2)->ctx) \
<< "Expected " << (#VAR2) << " to have the same device context as " << (#VAR1) << "(" \ << "Expected " << (#VAR2) << " to have the same device context as " \
<< (VAR1)->ctx << ")" \ << (#VAR1) << "(" << (VAR1)->ctx << ")" \
<< ". But got " << (VAR2)->ctx << "."; << ". But got " << (VAR2)->ctx << ".";
#define CHECK_NO_OVERFLOW(dtype, val) \ #define CHECK_NO_OVERFLOW(dtype, val) \
do { \ do { \
if (sizeof(val) == 8 && (dtype).bits == 32) \ if (sizeof(val) == 8 && (dtype).bits == 32) \
CHECK_LE((val), 0x7FFFFFFFL) << "int32 overflow for argument " << (#val) << "."; \ CHECK_LE((val), 0x7FFFFFFFL) \
<< "int32 overflow for argument " << (#val) << "."; \
} while (0); } while (0);
#define CHECK_IS_ID_ARRAY(VAR) \ #define CHECK_IS_ID_ARRAY(VAR) \
CHECK((VAR)->ndim == 1 && (IS_INT32(VAR) || IS_INT64(VAR))) \ CHECK((VAR)->ndim == 1 && (IS_INT32(VAR) || IS_INT64(VAR))) \
<< "Expected argument " << (#VAR) << " to be an 1D integer array."; << "Expected argument " << (#VAR) << " to be an 1D integer array.";
#endif // DGL_ATEN_MACRO_H_ #endif // DGL_ATEN_MACRO_H_
...@@ -10,7 +10,8 @@ ...@@ -10,7 +10,8 @@
#include <memory> #include <memory>
namespace dgl { namespace dgl {
// Util class to call the private/public empty constructor, which is needed for serialization // Util class to call the private/public empty constructor, which is needed for
// serialization
class Serializer { class Serializer {
public: public:
template <typename T> template <typename T>
......
...@@ -6,12 +6,11 @@ ...@@ -6,12 +6,11 @@
#ifndef DGL_RUNTIME_NDARRAY_H_ #ifndef DGL_RUNTIME_NDARRAY_H_
#define DGL_RUNTIME_NDARRAY_H_ #define DGL_RUNTIME_NDARRAY_H_
#include <atomic> #include <atomic>
#include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <memory>
#include "c_runtime_api.h" #include "c_runtime_api.h"
#include "serializer.h" #include "serializer.h"
...@@ -29,7 +28,7 @@ ...@@ -29,7 +28,7 @@
#endif // DGL_USE_CUDA #endif // DGL_USE_CUDA
// forward declaration // forward declaration
inline std::ostream& operator << (std::ostream& os, DGLDataType t); inline std::ostream& operator<<(std::ostream& os, DGLDataType t);
namespace dgl { namespace dgl {
...@@ -39,13 +38,13 @@ namespace dgl { ...@@ -39,13 +38,13 @@ namespace dgl {
* Usage: * Usage:
* DGLDataTypeTraits<int>::dtype == dtype * DGLDataTypeTraits<int>::dtype == dtype
*/ */
template<typename T> template <typename T>
struct DGLDataTypeTraits { struct DGLDataTypeTraits {
static constexpr DGLDataType dtype{0, 0, 0}; // dummy static constexpr DGLDataType dtype{0, 0, 0}; // dummy
}; };
#define GEN_DGLDATATYPETRAITS_FOR(T, code, bits) \ #define GEN_DGLDATATYPETRAITS_FOR(T, code, bits) \
template<> \ template <> \
struct DGLDataTypeTraits<T> { \ struct DGLDataTypeTraits<T> { \
static constexpr DGLDataType dtype{code, bits, 1}; \ static constexpr DGLDataType dtype{code, bits, 1}; \
} }
GEN_DGLDATATYPETRAITS_FOR(int8_t, kDGLInt, 8); GEN_DGLDATATYPETRAITS_FOR(int8_t, kDGLInt, 8);
...@@ -53,8 +52,8 @@ GEN_DGLDATATYPETRAITS_FOR(uint8_t, kDGLUInt, 8); ...@@ -53,8 +52,8 @@ GEN_DGLDATATYPETRAITS_FOR(uint8_t, kDGLUInt, 8);
GEN_DGLDATATYPETRAITS_FOR(int16_t, kDGLInt, 16); GEN_DGLDATATYPETRAITS_FOR(int16_t, kDGLInt, 16);
GEN_DGLDATATYPETRAITS_FOR(int32_t, kDGLInt, 32); GEN_DGLDATATYPETRAITS_FOR(int32_t, kDGLInt, 32);
GEN_DGLDATATYPETRAITS_FOR(int64_t, kDGLInt, 64); GEN_DGLDATATYPETRAITS_FOR(int64_t, kDGLInt, 64);
// XXX(BarclayII) most DL frameworks do not support unsigned int and long arrays, so I'm just // XXX(BarclayII) most DL frameworks do not support unsigned int and long
// converting uints to signed DTypes. // arrays, so I'm just converting uints to signed DTypes.
GEN_DGLDATATYPETRAITS_FOR(uint32_t, kDGLInt, 32); GEN_DGLDATATYPETRAITS_FOR(uint32_t, kDGLInt, 32);
GEN_DGLDATATYPETRAITS_FOR(uint64_t, kDGLInt, 64); GEN_DGLDATATYPETRAITS_FOR(uint64_t, kDGLInt, 64);
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
...@@ -98,14 +97,12 @@ class NDArray { ...@@ -98,14 +97,12 @@ class NDArray {
* @brief move constructor * @brief move constructor
* @param other The value to be moved * @param other The value to be moved
*/ */
NDArray(NDArray&& other) // NOLINT(*) NDArray(NDArray&& other) // NOLINT(*)
: data_(other.data_) { : data_(other.data_) {
other.data_ = nullptr; other.data_ = nullptr;
} }
/** @brief destructor */ /** @brief destructor */
~NDArray() { ~NDArray() { this->reset(); }
this->reset();
}
/** /**
* @brief Swap this array with another NDArray * @brief Swap this array with another NDArray
* @param other The other NDArray * @param other The other NDArray
...@@ -130,17 +127,13 @@ class NDArray { ...@@ -130,17 +127,13 @@ class NDArray {
*/ */
NDArray& operator=(NDArray&& other) { // NOLINT(*) NDArray& operator=(NDArray&& other) { // NOLINT(*)
// copy-and-swap idiom // copy-and-swap idiom
NDArray(std::move(other)).swap(*this); // NOLINT(*) NDArray(std::move(other)).swap(*this); // NOLINT(*)
return *this; return *this;
} }
/** @return If NDArray is defined */ /** @return If NDArray is defined */
bool defined() const { bool defined() const { return data_ != nullptr; }
return data_ != nullptr;
}
/** @return If both NDArray reference the same container */ /** @return If both NDArray reference the same container */
bool same_as(const NDArray& other) const { bool same_as(const NDArray& other) const { return data_ == other.data_; }
return data_ == other.data_;
}
/** @brief reset the content of NDArray to be nullptr */ /** @brief reset the content of NDArray to be nullptr */
inline void reset(); inline void reset();
/** /**
...@@ -160,22 +153,23 @@ class NDArray { ...@@ -160,22 +153,23 @@ class NDArray {
else else
return static_cast<T*>(operator->()->data); return static_cast<T*>(operator->()->data);
} }
/** /**
* @brief Copy data content from/into another array. * @brief Copy data content from/into another array.
* @param other The source array to be copied from. * @param other The source array to be copied from.
* @note The copy runs on the dgl internal stream if it involves a GPU context. * @note The copy runs on the dgl internal stream if it involves a GPU
* context.
*/ */
inline void CopyFrom(DGLArray* other); inline void CopyFrom(DGLArray* other);
inline void CopyFrom(const NDArray& other); inline void CopyFrom(const NDArray& other);
inline void CopyTo(DGLArray *other) const; inline void CopyTo(DGLArray* other) const;
inline void CopyTo(const NDArray &other) const; inline void CopyTo(const NDArray& other) const;
/** /**
* @brief Copy the data to another context. * @brief Copy the data to another context.
* @param ctx The target context. * @param ctx The target context.
* @return The array under another context. * @return The array under another context.
*/ */
inline NDArray CopyTo(const DGLContext &ctx) const; inline NDArray CopyTo(const DGLContext& ctx) const;
/** /**
* @brief Return a new array with a copy of the content. * @brief Return a new array with a copy of the content.
*/ */
...@@ -224,8 +218,8 @@ class NDArray { ...@@ -224,8 +218,8 @@ class NDArray {
* @param offset The offset (in bytes) of the starting pointer. * @param offset The offset (in bytes) of the starting pointer.
* @note The memory size of new array must be smaller than the current one. * @note The memory size of new array must be smaller than the current one.
*/ */
DGL_DLL NDArray CreateView( DGL_DLL NDArray
std::vector<int64_t> shape, DGLDataType dtype, int64_t offset = 0); CreateView(std::vector<int64_t> shape, DGLDataType dtype, int64_t offset = 0);
/** /**
* @brief Create an empty NDArray. * @brief Create an empty NDArray.
* @param shape The shape of the new array. * @param shape The shape of the new array.
...@@ -233,9 +227,8 @@ class NDArray { ...@@ -233,9 +227,8 @@ class NDArray {
* @param ctx The context of the Array. * @param ctx The context of the Array.
* @return The created Array * @return The created Array
*/ */
DGL_DLL static NDArray Empty(std::vector<int64_t> shape, DGL_DLL static NDArray Empty(
DGLDataType dtype, std::vector<int64_t> shape, DGLDataType dtype, DGLContext ctx);
DGLContext ctx);
/** /**
* @brief Create an empty NDArray with shared memory. * @brief Create an empty NDArray with shared memory.
* @param name The name of shared memory. * @param name The name of shared memory.
...@@ -245,11 +238,9 @@ class NDArray { ...@@ -245,11 +238,9 @@ class NDArray {
* @param is_create whether to create shared memory. * @param is_create whether to create shared memory.
* @return The created Array * @return The created Array
*/ */
DGL_DLL static NDArray EmptyShared(const std::string &name, DGL_DLL static NDArray EmptyShared(
std::vector<int64_t> shape, const std::string& name, std::vector<int64_t> shape, DGLDataType dtype,
DGLDataType dtype, DGLContext ctx, bool is_create);
DGLContext ctx,
bool is_create);
/** /**
* @brief Get the size of the array in the number of bytes. * @brief Get the size of the array in the number of bytes.
*/ */
...@@ -264,23 +255,24 @@ class NDArray { ...@@ -264,23 +255,24 @@ class NDArray {
* @brief Create a NDArray by copying from std::vector. * @brief Create a NDArray by copying from std::vector.
* @tparam T Type of vector data. Determines the dtype of returned array. * @tparam T Type of vector data. Determines the dtype of returned array.
*/ */
template<typename T> template <typename T>
DGL_DLL static NDArray FromVector( DGL_DLL static NDArray FromVector(
const std::vector<T>& vec, DGLContext ctx = DGLContext{kDGLCPU, 0}); const std::vector<T>& vec, DGLContext ctx = DGLContext{kDGLCPU, 0});
/** /**
* @brief Create a NDArray from a raw pointer. * @brief Create a NDArray from a raw pointer.
*/ */
DGL_DLL static NDArray CreateFromRaw(const std::vector<int64_t>& shape, DGL_DLL static NDArray CreateFromRaw(
DGLDataType dtype, DGLContext ctx, void* raw, bool auto_free); const std::vector<int64_t>& shape, DGLDataType dtype, DGLContext ctx,
void* raw, bool auto_free);
/** /**
* @brief Create a std::vector from a 1D NDArray. * @brief Create a std::vector from a 1D NDArray.
* @tparam T Type of vector data. * @tparam T Type of vector data.
* @note Type casting is NOT performed. The caller has to make sure that the vector * @note Type casting is NOT performed. The caller has to make sure that the
* type matches the dtype of NDArray. * vector type matches the dtype of NDArray.
*/ */
template<typename T> template <typename T>
std::vector<T> ToVector() const; std::vector<T> ToVector() const;
std::shared_ptr<SharedMemory> GetSharedMem() const; std::shared_ptr<SharedMemory> GetSharedMem() const;
...@@ -291,8 +283,7 @@ class NDArray { ...@@ -291,8 +283,7 @@ class NDArray {
* @param to The target array. * @param to The target array.
* @param (optional) stream The stream used in copy. * @param (optional) stream The stream used in copy.
*/ */
DGL_DLL static void CopyFromTo( DGL_DLL static void CopyFromTo(DGLArray* from, DGLArray* to);
DGLArray* from, DGLArray* to);
DGL_DLL static void CopyFromTo( DGL_DLL static void CopyFromTo(
DGLArray* from, DGLArray* to, DGLStreamHandle stream); DGLArray* from, DGLArray* to, DGLStreamHandle stream);
...@@ -337,8 +328,8 @@ class NDArray { ...@@ -337,8 +328,8 @@ class NDArray {
static void DefaultDeleter(NDArray::Container* ptr); static void DefaultDeleter(NDArray::Container* ptr);
// Local create function which allocates tensor metadata // Local create function which allocates tensor metadata
// but does not allocate space for the data. // but does not allocate space for the data.
static NDArray Create(std::vector<int64_t> shape, static NDArray Create(
DGLDataType dtype, DGLContext ctx); std::vector<int64_t> shape, DGLDataType dtype, DGLContext ctx);
// Implementation of API function // Implementation of API function
static DGLArray* MoveAsDGLArray(NDArray arr); static DGLArray* MoveAsDGLArray(NDArray arr);
}; };
...@@ -360,7 +351,6 @@ class NDArray { ...@@ -360,7 +351,6 @@ class NDArray {
*/ */
inline bool SaveDGLArray(dmlc::Stream* strm, const DGLArray* tensor); inline bool SaveDGLArray(dmlc::Stream* strm, const DGLArray* tensor);
/** /**
* @brief Reference counted Container object used to back NDArray. * @brief Reference counted Container object used to back NDArray.
* *
...@@ -409,9 +399,7 @@ struct NDArray::Container { ...@@ -409,9 +399,7 @@ struct NDArray::Container {
/** @brief pointer to shared memory */ /** @brief pointer to shared memory */
std::shared_ptr<SharedMemory> mem; std::shared_ptr<SharedMemory> mem;
/** @brief developer function, increases reference counter */ /** @brief developer function, increases reference counter */
void IncRef() { void IncRef() { ref_counter_.fetch_add(1, std::memory_order_relaxed); }
ref_counter_.fetch_add(1, std::memory_order_relaxed);
}
/** @brief developer function, decrease reference counter */ /** @brief developer function, decrease reference counter */
void DecRef() { void DecRef() {
if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) { if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) {
...@@ -444,16 +432,12 @@ struct NDArray::Container { ...@@ -444,16 +432,12 @@ struct NDArray::Container {
// implementations of inline functions // implementations of inline functions
// the usages of functions are documented in place. // the usages of functions are documented in place.
inline NDArray::NDArray(Container* data) inline NDArray::NDArray(Container* data) : data_(data) {
: data_(data) { if (data_) data_->IncRef();
if (data_)
data_->IncRef();
} }
inline NDArray::NDArray(const NDArray& other) inline NDArray::NDArray(const NDArray& other) : data_(other.data_) {
: data_(other.data_) { if (data_) data_->IncRef();
if (data_)
data_->IncRef();
} }
inline void NDArray::reset() { inline void NDArray::reset() {
...@@ -473,21 +457,22 @@ inline void NDArray::CopyFrom(const NDArray& other) { ...@@ -473,21 +457,22 @@ inline void NDArray::CopyFrom(const NDArray& other) {
CopyFrom(&(other.data_->dl_tensor)); CopyFrom(&(other.data_->dl_tensor));
} }
inline void NDArray::CopyTo(DGLArray *other) const { inline void NDArray::CopyTo(DGLArray* other) const {
CHECK(data_ != nullptr); CHECK(data_ != nullptr);
CopyFromTo(&(data_->dl_tensor), other); CopyFromTo(&(data_->dl_tensor), other);
} }
inline void NDArray::CopyTo(const NDArray &other) const { inline void NDArray::CopyTo(const NDArray& other) const {
CHECK(other.data_ != nullptr); CHECK(other.data_ != nullptr);
CopyTo(&(other.data_->dl_tensor)); CopyTo(&(other.data_->dl_tensor));
} }
inline NDArray NDArray::CopyTo(const DGLContext &ctx) const { inline NDArray NDArray::CopyTo(const DGLContext& ctx) const {
CHECK(data_ != nullptr); CHECK(data_ != nullptr);
const DGLArray* dptr = operator->(); const DGLArray* dptr = operator->();
NDArray ret = Empty(std::vector<int64_t>(dptr->shape, dptr->shape + dptr->ndim), NDArray ret = Empty(
dptr->dtype, ctx); std::vector<int64_t>(dptr->shape, dptr->shape + dptr->ndim), dptr->dtype,
ctx);
this->CopyTo(ret); this->CopyTo(ret);
return ret; return ret;
} }
...@@ -530,8 +515,7 @@ inline const DGLArray* NDArray::operator->() const { ...@@ -530,8 +515,7 @@ inline const DGLArray* NDArray::operator->() const {
/** @brief Magic number for NDArray file */ /** @brief Magic number for NDArray file */
constexpr uint64_t kDGLNDArrayMagic = 0xDD5E40F096B4A13F; constexpr uint64_t kDGLNDArrayMagic = 0xDD5E40F096B4A13F;
inline bool SaveDGLArray(dmlc::Stream* strm, inline bool SaveDGLArray(dmlc::Stream* strm, DGLArray* tensor) {
DGLArray* tensor) {
uint64_t header = kDGLNDArrayMagic, reserved = 0; uint64_t header = kDGLNDArrayMagic, reserved = 0;
strm->Write(header); strm->Write(header);
strm->Write(reserved); strm->Write(reserved);
...@@ -560,16 +544,14 @@ inline bool SaveDGLArray(dmlc::Stream* strm, ...@@ -560,16 +544,14 @@ inline bool SaveDGLArray(dmlc::Stream* strm,
int64_t data_byte_size = type_bytes * num_elems; int64_t data_byte_size = type_bytes * num_elems;
strm->Write(data_byte_size); strm->Write(data_byte_size);
if (DMLC_IO_NO_ENDIAN_SWAP && if (DMLC_IO_NO_ENDIAN_SWAP && tensor->ctx.device_type == kDGLCPU &&
tensor->ctx.device_type == kDGLCPU && tensor->strides == nullptr && tensor->byte_offset == 0) {
tensor->strides == nullptr &&
tensor->byte_offset == 0) {
// quick path // quick path
strm->Write(tensor->data, data_byte_size); strm->Write(tensor->data, data_byte_size);
} else { } else {
std::vector<uint8_t> bytes(data_byte_size); std::vector<uint8_t> bytes(data_byte_size);
CHECK_EQ(DGLArrayCopyToBytes( CHECK_EQ(
tensor, dmlc::BeginPtr(bytes), data_byte_size), 0) DGLArrayCopyToBytes(tensor, dmlc::BeginPtr(bytes), data_byte_size), 0)
<< DGLGetLastError(); << DGLGetLastError();
if (!DMLC_IO_NO_ENDIAN_SWAP) { if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(dmlc::BeginPtr(bytes), type_bytes, num_elems); dmlc::ByteSwap(dmlc::BeginPtr(bytes), type_bytes, num_elems);
...@@ -586,22 +568,37 @@ inline bool SaveDGLArray(dmlc::Stream* strm, ...@@ -586,22 +568,37 @@ inline bool SaveDGLArray(dmlc::Stream* strm,
*/ */
inline const char* TypeCode2Str(int type_code) { inline const char* TypeCode2Str(int type_code) {
switch (type_code) { switch (type_code) {
case kDGLInt: return "int"; case kDGLInt:
case kDGLUInt: return "uint"; return "int";
case kDGLFloat: return "float"; case kDGLUInt:
case kStr: return "str"; return "uint";
case kBytes: return "bytes"; case kDGLFloat:
case kHandle: return "handle"; return "float";
case kNull: return "NULL"; case kStr:
case kObjectHandle: return "ObjectHandle"; return "str";
case kArrayHandle: return "ArrayHandle"; case kBytes:
case kDGLDataType: return "DGLDataType"; return "bytes";
case kDGLContext: return "DGLContext"; case kHandle:
case kFuncHandle: return "FunctionHandle"; return "handle";
case kModuleHandle: return "ModuleHandle"; case kNull:
case kNDArrayContainer: return "NDArrayContainer"; return "NULL";
default: LOG(FATAL) << "unknown type_code=" case kObjectHandle:
<< static_cast<int>(type_code); return ""; return "ObjectHandle";
case kArrayHandle:
return "ArrayHandle";
case kDGLDataType:
return "DGLDataType";
case kDGLContext:
return "DGLContext";
case kFuncHandle:
return "FunctionHandle";
case kModuleHandle:
return "ModuleHandle";
case kNDArrayContainer:
return "NDArrayContainer";
default:
LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
return "";
} }
} }
...@@ -612,10 +609,14 @@ inline const char* TypeCode2Str(int type_code) { ...@@ -612,10 +609,14 @@ inline const char* TypeCode2Str(int type_code) {
*/ */
inline const char* DeviceTypeCode2Str(DGLDeviceType device_type) { inline const char* DeviceTypeCode2Str(DGLDeviceType device_type) {
switch (device_type) { switch (device_type) {
case kDGLCPU: return "cpu"; case kDGLCPU:
case kDGLCUDA: return "cuda"; return "cpu";
default: LOG(FATAL) << "Unsupported device type code=" case kDGLCUDA:
<< static_cast<int>(device_type); return ""; return "cuda";
default:
LOG(FATAL) << "Unsupported device type code="
<< static_cast<int>(device_type);
return "";
} }
} }
...@@ -626,14 +627,18 @@ inline const char* DeviceTypeCode2Str(DGLDeviceType device_type) { ...@@ -626,14 +627,18 @@ inline const char* DeviceTypeCode2Str(DGLDeviceType device_type) {
*/ */
inline DGLDataType String2DGLDataType(std::string s) { inline DGLDataType String2DGLDataType(std::string s) {
DGLDataType t; DGLDataType t;
t.bits = 32; t.lanes = 1; t.bits = 32;
t.lanes = 1;
const char* scan; const char* scan;
if (s.substr(0, 3) == "int") { if (s.substr(0, 3) == "int") {
t.code = kDGLInt; scan = s.c_str() + 3; t.code = kDGLInt;
scan = s.c_str() + 3;
} else if (s.substr(0, 4) == "uint") { } else if (s.substr(0, 4) == "uint") {
t.code = kDGLUInt; scan = s.c_str() + 4; t.code = kDGLUInt;
scan = s.c_str() + 4;
} else if (s.substr(0, 5) == "float") { } else if (s.substr(0, 5) == "float") {
t.code = kDGLFloat; scan = s.c_str() + 5; t.code = kDGLFloat;
scan = s.c_str() + 5;
} else if (s.substr(0, 6) == "handle") { } else if (s.substr(0, 6) == "handle") {
t.code = kHandle; t.code = kHandle;
t.bits = 64; // handle uses 64 bit by default. t.bits = 64; // handle uses 64 bit by default.
...@@ -674,9 +679,9 @@ inline std::string DGLDataType2String(DGLDataType t) { ...@@ -674,9 +679,9 @@ inline std::string DGLDataType2String(DGLDataType t) {
} }
// macro to check type code. // macro to check type code.
#define DGL_CHECK_TYPE_CODE(CODE, T) \ #define DGL_CHECK_TYPE_CODE(CODE, T) \
CHECK_EQ(CODE, T) << " expected " \ CHECK_EQ(CODE, T) << " expected " << TypeCode2Str(T) << " but get " \
<< TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) \ << TypeCode2Str(CODE)
} // namespace runtime } // namespace runtime
} // namespace dgl } // namespace dgl
...@@ -686,69 +691,69 @@ DMLC_DECLARE_TRAITS(has_saveload, dgl::runtime::NDArray, true); ...@@ -686,69 +691,69 @@ DMLC_DECLARE_TRAITS(has_saveload, dgl::runtime::NDArray, true);
} // namespace dmlc } // namespace dmlc
///////////////// Operator overloading for NDArray ///////////////// ///////////////// Operator overloading for NDArray /////////////////
dgl::runtime::NDArray operator + (const dgl::runtime::NDArray& a1, dgl::runtime::NDArray operator+(
const dgl::runtime::NDArray& a2); const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator - (const dgl::runtime::NDArray& a1, dgl::runtime::NDArray operator-(
const dgl::runtime::NDArray& a2); const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator * (const dgl::runtime::NDArray& a1, dgl::runtime::NDArray operator*(
const dgl::runtime::NDArray& a2); const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator / (const dgl::runtime::NDArray& a1, dgl::runtime::NDArray operator/(
const dgl::runtime::NDArray& a2); const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator % (const dgl::runtime::NDArray& a1, dgl::runtime::NDArray operator%(
const dgl::runtime::NDArray& a2); const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator + (const dgl::runtime::NDArray& a1, int64_t rhs); dgl::runtime::NDArray operator+(const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator - (const dgl::runtime::NDArray& a1, int64_t rhs); dgl::runtime::NDArray operator-(const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator * (const dgl::runtime::NDArray& a1, int64_t rhs); dgl::runtime::NDArray operator*(const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator / (const dgl::runtime::NDArray& a1, int64_t rhs); dgl::runtime::NDArray operator/(const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator % (const dgl::runtime::NDArray& a1, int64_t rhs); dgl::runtime::NDArray operator%(const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator + (int64_t lhs, const dgl::runtime::NDArray& a2); dgl::runtime::NDArray operator+(int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator - (int64_t lhs, const dgl::runtime::NDArray& a2); dgl::runtime::NDArray operator-(int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator * (int64_t lhs, const dgl::runtime::NDArray& a2); dgl::runtime::NDArray operator*(int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator / (int64_t lhs, const dgl::runtime::NDArray& a2); dgl::runtime::NDArray operator/(int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator % (int64_t lhs, const dgl::runtime::NDArray& a2); dgl::runtime::NDArray operator%(int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator - (const dgl::runtime::NDArray& array); dgl::runtime::NDArray operator-(const dgl::runtime::NDArray& array);
dgl::runtime::NDArray operator > (const dgl::runtime::NDArray& a1, dgl::runtime::NDArray operator>(
const dgl::runtime::NDArray& a2); const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator < (const dgl::runtime::NDArray& a1, dgl::runtime::NDArray operator<(
const dgl::runtime::NDArray& a2); const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator >= (const dgl::runtime::NDArray& a1, dgl::runtime::NDArray operator>=(
const dgl::runtime::NDArray& a2); const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator <= (const dgl::runtime::NDArray& a1, dgl::runtime::NDArray operator<=(
const dgl::runtime::NDArray& a2); const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator == (const dgl::runtime::NDArray& a1, dgl::runtime::NDArray operator==(
const dgl::runtime::NDArray& a2); const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator != (const dgl::runtime::NDArray& a1, dgl::runtime::NDArray operator!=(
const dgl::runtime::NDArray& a2); const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator > (const dgl::runtime::NDArray& a1, int64_t rhs); dgl::runtime::NDArray operator>(const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator < (const dgl::runtime::NDArray& a1, int64_t rhs); dgl::runtime::NDArray operator<(const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator >= (const dgl::runtime::NDArray& a1, int64_t rhs); dgl::runtime::NDArray operator>=(const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator <= (const dgl::runtime::NDArray& a1, int64_t rhs); dgl::runtime::NDArray operator<=(const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator == (const dgl::runtime::NDArray& a1, int64_t rhs); dgl::runtime::NDArray operator==(const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator != (const dgl::runtime::NDArray& a1, int64_t rhs); dgl::runtime::NDArray operator!=(const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator > (int64_t lhs, const dgl::runtime::NDArray& a2); dgl::runtime::NDArray operator>(int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator < (int64_t lhs, const dgl::runtime::NDArray& a2); dgl::runtime::NDArray operator<(int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator >= (int64_t lhs, const dgl::runtime::NDArray& a2); dgl::runtime::NDArray operator>=(int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator <= (int64_t lhs, const dgl::runtime::NDArray& a2); dgl::runtime::NDArray operator<=(int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator == (int64_t lhs, const dgl::runtime::NDArray& a2); dgl::runtime::NDArray operator==(int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator != (int64_t lhs, const dgl::runtime::NDArray& a2); dgl::runtime::NDArray operator!=(int64_t lhs, const dgl::runtime::NDArray& a2);
std::ostream& operator << (std::ostream& os, dgl::runtime::NDArray array); std::ostream& operator<<(std::ostream& os, dgl::runtime::NDArray array);
///////////////// Operator overloading for DGLDataType ///////////////// ///////////////// Operator overloading for DGLDataType /////////////////
/** @brief Check whether two data types are the same.*/ /** @brief Check whether two data types are the same.*/
inline bool operator == (const DGLDataType& ty1, const DGLDataType& ty2) { inline bool operator==(const DGLDataType& ty1, const DGLDataType& ty2) {
return ty1.code == ty2.code && ty1.bits == ty2.bits && ty1.lanes == ty2.lanes; return ty1.code == ty2.code && ty1.bits == ty2.bits && ty1.lanes == ty2.lanes;
} }
/** @brief Check whether two data types are different.*/ /** @brief Check whether two data types are different.*/
inline bool operator != (const DGLDataType& ty1, const DGLDataType& ty2) { inline bool operator!=(const DGLDataType& ty1, const DGLDataType& ty2) {
return !(ty1 == ty2); return !(ty1 == ty2);
} }
#ifndef _LIBCPP_SGX_NO_IOSTREAMS #ifndef _LIBCPP_SGX_NO_IOSTREAMS
inline std::ostream& operator << (std::ostream& os, DGLDataType t) { inline std::ostream& operator<<(std::ostream& os, DGLDataType t) {
os << dgl::runtime::TypeCode2Str(t.code); os << dgl::runtime::TypeCode2Str(t.code);
if (t.code == kHandle) return os; if (t.code == kHandle) return os;
os << static_cast<int>(t.bits); os << static_cast<int>(t.bits);
...@@ -762,18 +767,20 @@ inline std::ostream& operator << (std::ostream& os, DGLDataType t) { ...@@ -762,18 +767,20 @@ inline std::ostream& operator << (std::ostream& os, DGLDataType t) {
///////////////// Operator overloading for DGLContext ///////////////// ///////////////// Operator overloading for DGLContext /////////////////
/** @brief Check whether two device contexts are the same.*/ /** @brief Check whether two device contexts are the same.*/
inline bool operator == (const DGLContext& ctx1, const DGLContext& ctx2) { inline bool operator==(const DGLContext& ctx1, const DGLContext& ctx2) {
return ctx1.device_type == ctx2.device_type && ctx1.device_id == ctx2.device_id; return ctx1.device_type == ctx2.device_type &&
ctx1.device_id == ctx2.device_id;
} }
/** @brief Check whether two device contexts are different.*/ /** @brief Check whether two device contexts are different.*/
inline bool operator != (const DGLContext& ctx1, const DGLContext& ctx2) { inline bool operator!=(const DGLContext& ctx1, const DGLContext& ctx2) {
return !(ctx1 == ctx2); return !(ctx1 == ctx2);
} }
#ifndef _LIBCPP_SGX_NO_IOSTREAMS #ifndef _LIBCPP_SGX_NO_IOSTREAMS
inline std::ostream& operator << (std::ostream& os, const DGLContext& ctx) { inline std::ostream& operator<<(std::ostream& os, const DGLContext& ctx) {
return os << dgl::runtime::DeviceTypeCode2Str(ctx.device_type) << ":" << ctx.device_id; return os << dgl::runtime::DeviceTypeCode2Str(ctx.device_type) << ":"
<< ctx.device_id;
} }
#endif #endif
......
...@@ -7,19 +7,18 @@ ...@@ -7,19 +7,18 @@
#define DGL_RUNTIME_PARALLEL_FOR_H_ #define DGL_RUNTIME_PARALLEL_FOR_H_
#include <dmlc/omp.h> #include <dmlc/omp.h>
#include <algorithm> #include <algorithm>
#include <string> #include <atomic>
#include <cstdlib> #include <cstdlib>
#include <exception> #include <exception>
#include <vector> #include <string>
#include <atomic>
#include <utility> #include <utility>
#include <vector>
namespace { namespace {
int64_t divup(int64_t x, int64_t y) { int64_t divup(int64_t x, int64_t y) { return (x + y - 1) / y; }
return (x + y - 1) / y; } // namespace
}
}
namespace dgl { namespace dgl {
namespace runtime { namespace runtime {
...@@ -37,9 +36,7 @@ struct DefaultGrainSizeT { ...@@ -37,9 +36,7 @@ struct DefaultGrainSizeT {
} }
} }
size_t operator()() { size_t operator()() { return grain_size; }
return grain_size;
}
}; };
} // namespace } // namespace
...@@ -48,7 +45,9 @@ inline size_t compute_num_threads(size_t begin, size_t end, size_t grain_size) { ...@@ -48,7 +45,9 @@ inline size_t compute_num_threads(size_t begin, size_t end, size_t grain_size) {
if (omp_in_parallel() || end - begin <= grain_size || end - begin == 1) if (omp_in_parallel() || end - begin <= grain_size || end - begin == 1)
return 1; return 1;
return std::min(static_cast<int64_t>(omp_get_max_threads()), divup(end - begin, grain_size)); return std::min(
static_cast<int64_t>(omp_get_max_threads()),
divup(end - begin, grain_size));
#else #else
return 1; return 1;
#endif #endif
...@@ -60,15 +59,13 @@ static DefaultGrainSizeT default_grain_size; ...@@ -60,15 +59,13 @@ static DefaultGrainSizeT default_grain_size;
* @brief OpenMP-based parallel for loop. * @brief OpenMP-based parallel for loop.
* *
* It requires each thread's workload to have at least \a grain_size elements. * It requires each thread's workload to have at least \a grain_size elements.
* The loop body will be a function that takes in two arguments \a begin and \a end, which * The loop body will be a function that takes in two arguments \a begin and \a
* stands for the starting (inclusive) and ending index (exclusive) of the workload. * end, which stands for the starting (inclusive) and ending index (exclusive)
* of the workload.
*/ */
template <typename F> template <typename F>
void parallel_for( void parallel_for(
const size_t begin, const size_t begin, const size_t end, const size_t grain_size, F&& f) {
const size_t end,
const size_t grain_size,
F&& f) {
if (begin >= end) { if (begin >= end) {
return; return;
} }
...@@ -89,13 +86,11 @@ void parallel_for( ...@@ -89,13 +86,11 @@ void parallel_for(
try { try {
f(begin_tid, end_tid); f(begin_tid, end_tid);
} catch (...) { } catch (...) {
if (!err_flag.test_and_set()) if (!err_flag.test_and_set()) eptr = std::current_exception();
eptr = std::current_exception();
} }
} }
} }
if (eptr) if (eptr) std::rethrow_exception(eptr);
std::rethrow_exception(eptr);
#else #else
f(begin, end); f(begin, end);
#endif #endif
...@@ -110,22 +105,21 @@ void parallel_for( ...@@ -110,22 +105,21 @@ void parallel_for(
* parallel for pragma with static scheduling. * parallel for pragma with static scheduling.
*/ */
template <typename F> template <typename F>
void parallel_for( void parallel_for(const size_t begin, const size_t end, F&& f) {
const size_t begin,
const size_t end,
F&& f) {
parallel_for(begin, end, default_grain_size(), std::forward<F>(f)); parallel_for(begin, end, default_grain_size(), std::forward<F>(f));
} }
/** /**
* @brief OpenMP-based two-stage parallel reduction. * @brief OpenMP-based two-stage parallel reduction.
* *
* The first-stage reduction function \a f works in parallel. Each thread's workload has * The first-stage reduction function \a f works in parallel. Each thread's
* at least \a grain_size elements. The loop body will be a function that takes in * workload has at least \a grain_size elements. The loop body will be a
* the starting index (inclusive), the ending index (exclusive), and the reduction identity. * function that takes in the starting index (inclusive), the ending index
* (exclusive), and the reduction identity.
* *
* The second-stage reduction function \a sf is a binary function working in the main * The second-stage reduction function \a sf is a binary function working in the
* thread. It aggregates the partially reduced result computed from each thread. * main thread. It aggregates the partially reduced result computed from each
* thread.
* *
* Example to compute a parallelized max reduction of an array \c a: * Example to compute a parallelized max reduction of an array \c a:
* *
...@@ -134,11 +128,9 @@ void parallel_for( ...@@ -134,11 +128,9 @@ void parallel_for(
* 100, // ending index * 100, // ending index
* 1, // grain size * 1, // grain size
* -std::numeric_limits<float>::infinity, // identity * -std::numeric_limits<float>::infinity, // identity
* [&a] (int begin, int end, float ident) { // first-stage partial reducer * [&a] (int begin, int end, float ident) { // first-stage partial
* float result = ident; * reducer float result = ident; for (int i = begin; i < end; ++i) result =
* for (int i = begin; i < end; ++i) * std::max(result, a[i]); return result;
* result = std::max(result, a[i]);
* return result;
* }, * },
* [] (float result, float partial_result) { * [] (float result, float partial_result) {
* return std::max(result, partial_result); * return std::max(result, partial_result);
...@@ -146,12 +138,8 @@ void parallel_for( ...@@ -146,12 +138,8 @@ void parallel_for(
*/ */
template <typename DType, typename F, typename SF> template <typename DType, typename F, typename SF>
DType parallel_reduce( DType parallel_reduce(
const size_t begin, const size_t begin, const size_t end, const size_t grain_size,
const size_t end, const DType ident, const F& f, const SF& sf) {
const size_t grain_size,
const DType ident,
const F& f,
const SF& sf) {
if (begin >= end) { if (begin >= end) {
return ident; return ident;
} }
...@@ -174,17 +162,14 @@ DType parallel_reduce( ...@@ -174,17 +162,14 @@ DType parallel_reduce(
try { try {
results[tid] = f(begin_tid, end_tid, ident); results[tid] = f(begin_tid, end_tid, ident);
} catch (...) { } catch (...) {
if (!err_flag.test_and_set()) if (!err_flag.test_and_set()) eptr = std::current_exception();
eptr = std::current_exception();
} }
} }
} }
if (eptr) if (eptr) std::rethrow_exception(eptr);
std::rethrow_exception(eptr);
DType out = ident; DType out = ident;
for (int64_t i = 0; i < num_threads; ++i) for (int64_t i = 0; i < num_threads; ++i) out = sf(out, results[i]);
out = sf(out, results[i]);
return out; return out;
} }
......
...@@ -7,12 +7,14 @@ ...@@ -7,12 +7,14 @@
#include <dgl/graph_traversal.h> #include <dgl/graph_traversal.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include <dgl/runtime/shared_mem.h>
#include <dgl/runtime/device_api.h> #include <dgl/runtime/device_api.h>
#include <dgl/runtime/shared_mem.h>
#include <sstream> #include <sstream>
#include "../c_api_common.h" #include "../c_api_common.h"
#include "./array_op.h"
#include "./arith.h" #include "./arith.h"
#include "./array_op.h"
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -73,17 +75,14 @@ template NDArray Full<double>(double val, int64_t length, DGLContext ctx); ...@@ -73,17 +75,14 @@ template NDArray Full<double>(double val, int64_t length, DGLContext ctx);
IdArray AsNumBits(IdArray arr, uint8_t bits) { IdArray AsNumBits(IdArray arr, uint8_t bits) {
CHECK(bits == 32 || bits == 64) CHECK(bits == 32 || bits == 64)
<< "Invalid ID type. Must be int32 or int64, but got int" << "Invalid ID type. Must be int32 or int64, but got int"
<< static_cast<int>(bits) << "."; << static_cast<int>(bits) << ".";
if (arr->dtype.bits == bits) if (arr->dtype.bits == bits) return arr;
return arr; if (arr.NumElements() == 0) return NewIdArray(arr->shape[0], arr->ctx, bits);
if (arr.NumElements() == 0)
return NewIdArray(arr->shape[0], arr->ctx, bits);
IdArray ret; IdArray ret;
ATEN_XPU_SWITCH_CUDA(arr->ctx.device_type, XPU, "AsNumBits", { ATEN_XPU_SWITCH_CUDA(arr->ctx.device_type, XPU, "AsNumBits", {
ATEN_ID_TYPE_SWITCH(arr->dtype, IdType, { ATEN_ID_TYPE_SWITCH(
ret = impl::AsNumBits<XPU, IdType>(arr, bits); arr->dtype, IdType, { ret = impl::AsNumBits<XPU, IdType>(arr, bits); });
});
}); });
return ret; return ret;
} }
...@@ -98,14 +97,12 @@ IdArray HStack(IdArray lhs, IdArray rhs) { ...@@ -98,14 +97,12 @@ IdArray HStack(IdArray lhs, IdArray rhs) {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, { ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
const int64_t len = lhs->shape[0]; const int64_t len = lhs->shape[0];
ret = NewIdArray(2 * len, lhs->ctx, lhs->dtype.bits); ret = NewIdArray(2 * len, lhs->ctx, lhs->dtype.bits);
device->CopyDataFromTo(lhs.Ptr<IdType>(), 0, device->CopyDataFromTo(
ret.Ptr<IdType>(), 0, lhs.Ptr<IdType>(), 0, ret.Ptr<IdType>(), 0, len * sizeof(IdType), ctx,
len * sizeof(IdType), ctx, lhs->dtype);
ctx, ctx, lhs->dtype); device->CopyDataFromTo(
device->CopyDataFromTo(rhs.Ptr<IdType>(), 0, rhs.Ptr<IdType>(), 0, ret.Ptr<IdType>(), len * sizeof(IdType),
ret.Ptr<IdType>(), len * sizeof(IdType), len * sizeof(IdType), ctx, ctx, lhs->dtype);
len * sizeof(IdType),
ctx, ctx, lhs->dtype);
}); });
return ret; return ret;
} }
...@@ -127,11 +124,11 @@ NDArray IndexSelect(NDArray array, IdArray index) { ...@@ -127,11 +124,11 @@ NDArray IndexSelect(NDArray array, IdArray index) {
return ret; return ret;
} }
template<typename ValueType> template <typename ValueType>
ValueType IndexSelect(NDArray array, int64_t index) { ValueType IndexSelect(NDArray array, int64_t index) {
CHECK_EQ(array->ndim, 1) << "Only support select values from 1D array."; CHECK_EQ(array->ndim, 1) << "Only support select values from 1D array.";
CHECK(index >= 0 && index < array.NumElements()) CHECK(index >= 0 && index < array.NumElements())
<< "Index " << index << " is out of bound."; << "Index " << index << " is out of bound.";
ValueType ret = 0; ValueType ret = 0;
ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "IndexSelect", { ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "IndexSelect", {
ATEN_DTYPE_SWITCH(array->dtype, DType, "values", { ATEN_DTYPE_SWITCH(array->dtype, DType, "values", {
...@@ -150,17 +147,17 @@ template double IndexSelect<double>(NDArray array, int64_t index); ...@@ -150,17 +147,17 @@ template double IndexSelect<double>(NDArray array, int64_t index);
NDArray IndexSelect(NDArray array, int64_t start, int64_t end) { NDArray IndexSelect(NDArray array, int64_t start, int64_t end) {
CHECK_EQ(array->ndim, 1) << "Only support select values from 1D array."; CHECK_EQ(array->ndim, 1) << "Only support select values from 1D array.";
CHECK(start >= 0 && start < array.NumElements()) CHECK(start >= 0 && start < array.NumElements())
<< "Index " << start << " is out of bound."; << "Index " << start << " is out of bound.";
CHECK(end >= 0 && end <= array.NumElements()) CHECK(end >= 0 && end <= array.NumElements())
<< "Index " << end << " is out of bound."; << "Index " << end << " is out of bound.";
CHECK_LE(start, end); CHECK_LE(start, end);
auto device = runtime::DeviceAPI::Get(array->ctx); auto device = runtime::DeviceAPI::Get(array->ctx);
const int64_t len = end - start; const int64_t len = end - start;
NDArray ret = NDArray::Empty({len}, array->dtype, array->ctx); NDArray ret = NDArray::Empty({len}, array->dtype, array->ctx);
ATEN_DTYPE_SWITCH(array->dtype, DType, "values", { ATEN_DTYPE_SWITCH(array->dtype, DType, "values", {
device->CopyDataFromTo(array->data, start * sizeof(DType), device->CopyDataFromTo(
ret->data, 0, len * sizeof(DType), array->data, start * sizeof(DType), ret->data, 0, len * sizeof(DType),
array->ctx, ret->ctx, array->dtype); array->ctx, ret->ctx, array->dtype);
}); });
return ret; return ret;
} }
...@@ -182,8 +179,7 @@ void Scatter_(IdArray index, NDArray value, NDArray out) { ...@@ -182,8 +179,7 @@ void Scatter_(IdArray index, NDArray value, NDArray out) {
CHECK_SAME_CONTEXT(index, value); CHECK_SAME_CONTEXT(index, value);
CHECK_SAME_CONTEXT(index, out); CHECK_SAME_CONTEXT(index, out);
CHECK_EQ(value->shape[0], index->shape[0]); CHECK_EQ(value->shape[0], index->shape[0]);
if (index->shape[0] == 0) if (index->shape[0] == 0) return;
return;
ATEN_XPU_SWITCH_CUDA(value->ctx.device_type, XPU, "Scatter_", { ATEN_XPU_SWITCH_CUDA(value->ctx.device_type, XPU, "Scatter_", {
ATEN_DTYPE_SWITCH(value->dtype, DType, "values", { ATEN_DTYPE_SWITCH(value->dtype, DType, "values", {
ATEN_ID_TYPE_SWITCH(index->dtype, IdType, { ATEN_ID_TYPE_SWITCH(index->dtype, IdType, {
...@@ -225,31 +221,25 @@ NDArray Concat(const std::vector<IdArray>& arrays) { ...@@ -225,31 +221,25 @@ NDArray Concat(const std::vector<IdArray>& arrays) {
CHECK_SAME_CONTEXT(arrays[0], arrays[i]); CHECK_SAME_CONTEXT(arrays[0], arrays[i]);
} }
NDArray ret_arr = NDArray::Empty({len}, NDArray ret_arr = NDArray::Empty({len}, arrays[0]->dtype, arrays[0]->ctx);
arrays[0]->dtype,
arrays[0]->ctx);
auto device = runtime::DeviceAPI::Get(arrays[0]->ctx); auto device = runtime::DeviceAPI::Get(arrays[0]->ctx);
for (size_t i = 0; i < arrays.size(); ++i) { for (size_t i = 0; i < arrays.size(); ++i) {
ATEN_DTYPE_SWITCH(arrays[i]->dtype, DType, "array", { ATEN_DTYPE_SWITCH(arrays[i]->dtype, DType, "array", {
device->CopyDataFromTo( device->CopyDataFromTo(
static_cast<DType*>(arrays[i]->data), static_cast<DType*>(arrays[i]->data), 0,
0, static_cast<DType*>(ret_arr->data), offset,
static_cast<DType*>(ret_arr->data), arrays[i]->shape[0] * sizeof(DType), arrays[i]->ctx, ret_arr->ctx,
offset, arrays[i]->dtype);
arrays[i]->shape[0] * sizeof(DType),
arrays[i]->ctx, offset += arrays[i]->shape[0] * sizeof(DType);
ret_arr->ctx,
arrays[i]->dtype);
offset += arrays[i]->shape[0] * sizeof(DType);
}); });
} }
return ret_arr; return ret_arr;
} }
template<typename ValueType> template <typename ValueType>
std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, ValueType pad_value) { std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, ValueType pad_value) {
std::tuple<NDArray, IdArray, IdArray> ret; std::tuple<NDArray, IdArray, IdArray> ret;
ATEN_XPU_SWITCH(array->ctx.device_type, XPU, "Pack", { ATEN_XPU_SWITCH(array->ctx.device_type, XPU, "Pack", {
...@@ -262,8 +252,10 @@ std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, ValueType pad_value) { ...@@ -262,8 +252,10 @@ std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, ValueType pad_value) {
template std::tuple<NDArray, IdArray, IdArray> Pack<int32_t>(NDArray, int32_t); template std::tuple<NDArray, IdArray, IdArray> Pack<int32_t>(NDArray, int32_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<int64_t>(NDArray, int64_t); template std::tuple<NDArray, IdArray, IdArray> Pack<int64_t>(NDArray, int64_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<uint32_t>(NDArray, uint32_t); template std::tuple<NDArray, IdArray, IdArray> Pack<uint32_t>(
template std::tuple<NDArray, IdArray, IdArray> Pack<uint64_t>(NDArray, uint64_t); NDArray, uint32_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<uint64_t>(
NDArray, uint64_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<float>(NDArray, float); template std::tuple<NDArray, IdArray, IdArray> Pack<float>(NDArray, float);
template std::tuple<NDArray, IdArray, IdArray> Pack<double>(NDArray, double); template std::tuple<NDArray, IdArray, IdArray> Pack<double>(NDArray, double);
...@@ -292,9 +284,8 @@ IdArray CumSum(IdArray array, bool prepend_zero) { ...@@ -292,9 +284,8 @@ IdArray CumSum(IdArray array, bool prepend_zero) {
IdArray NonZero(NDArray array) { IdArray NonZero(NDArray array) {
IdArray ret; IdArray ret;
ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "NonZero", { ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "NonZero", {
ATEN_ID_TYPE_SWITCH(array->dtype, DType, { ATEN_ID_TYPE_SWITCH(
ret = impl::NonZero<XPU, DType>(array); array->dtype, DType, { ret = impl::NonZero<XPU, DType>(array); });
});
}); });
return ret; return ret;
} }
...@@ -322,8 +313,7 @@ std::string ToDebugString(NDArray array) { ...@@ -322,8 +313,7 @@ std::string ToDebugString(NDArray array) {
oss << a.Ptr<DType>()[i] << ", "; oss << a.Ptr<DType>()[i] << ", ";
} }
}); });
if (a.NumElements() > 10) if (a.NumElements() > 10) oss << "...";
oss << "...";
oss << "], dtype=" << array->dtype << ", ctx=" << array->ctx << ")"; oss << "], dtype=" << array->dtype << ", ctx=" << array->ctx << ")";
return oss.str(); return oss.str();
} }
...@@ -396,8 +386,7 @@ NDArray CSRGetRowData(CSRMatrix csr, int64_t row) { ...@@ -396,8 +386,7 @@ NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
} }
bool CSRIsSorted(CSRMatrix csr) { bool CSRIsSorted(CSRMatrix csr) {
if (csr.indices->shape[0] <= 1) if (csr.indices->shape[0] <= 1) return true;
return true;
bool ret = false; bool ret = false;
ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRIsSorted", { ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRIsSorted", {
ret = impl::CSRIsSorted<XPU, IdType>(csr); ret = impl::CSRIsSorted<XPU, IdType>(csr);
...@@ -417,14 +406,16 @@ NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) { ...@@ -417,14 +406,16 @@ NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
} }
template <typename DType> template <typename DType>
NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, DType filler) { NDArray CSRGetData(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, DType filler) {
NDArray ret; NDArray ret;
CHECK_SAME_DTYPE(csr.indices, rows); CHECK_SAME_DTYPE(csr.indices, rows);
CHECK_SAME_DTYPE(csr.indices, cols); CHECK_SAME_DTYPE(csr.indices, cols);
CHECK_SAME_CONTEXT(rows, cols); CHECK_SAME_CONTEXT(rows, cols);
CHECK_SAME_CONTEXT(rows, weights); CHECK_SAME_CONTEXT(rows, weights);
ATEN_CSR_SWITCH_CUDA_UVA(csr, rows, XPU, IdType, "CSRGetData", { ATEN_CSR_SWITCH_CUDA_UVA(csr, rows, XPU, IdType, "CSRGetData", {
ret = impl::CSRGetData<XPU, IdType, DType>(csr, rows, cols, weights, filler); ret =
impl::CSRGetData<XPU, IdType, DType>(csr, rows, cols, weights, filler);
}); });
return ret; return ret;
} }
...@@ -459,11 +450,12 @@ CSRMatrix CSRTranspose(CSRMatrix csr) { ...@@ -459,11 +450,12 @@ CSRMatrix CSRTranspose(CSRMatrix csr) {
COOMatrix CSRToCOO(CSRMatrix csr, bool data_as_order) { COOMatrix CSRToCOO(CSRMatrix csr, bool data_as_order) {
COOMatrix ret; COOMatrix ret;
if (data_as_order) { if (data_as_order) {
ATEN_XPU_SWITCH_CUDA(csr.indptr->ctx.device_type, XPU, "CSRToCOODataAsOrder", { ATEN_XPU_SWITCH_CUDA(
ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, { csr.indptr->ctx.device_type, XPU, "CSRToCOODataAsOrder", {
ret = impl::CSRToCOODataAsOrder<XPU, IdType>(csr); ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, {
}); ret = impl::CSRToCOODataAsOrder<XPU, IdType>(csr);
}); });
});
} else { } else {
ATEN_XPU_SWITCH_CUDA(csr.indptr->ctx.device_type, XPU, "CSRToCOO", { ATEN_XPU_SWITCH_CUDA(csr.indptr->ctx.device_type, XPU, "CSRToCOO", {
ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, { ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, {
...@@ -506,17 +498,16 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, NDArray rows, NDArray cols) { ...@@ -506,17 +498,16 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, NDArray rows, NDArray cols) {
} }
void CSRSort_(CSRMatrix* csr) { void CSRSort_(CSRMatrix* csr) {
if (csr->sorted) if (csr->sorted) return;
return; ATEN_CSR_SWITCH_CUDA(
ATEN_CSR_SWITCH_CUDA(*csr, XPU, IdType, "CSRSort_", { *csr, XPU, IdType, "CSRSort_", { impl::CSRSort_<XPU, IdType>(csr); });
impl::CSRSort_<XPU, IdType>(csr);
});
} }
std::pair<CSRMatrix, NDArray> CSRSortByTag( std::pair<CSRMatrix, NDArray> CSRSortByTag(
const CSRMatrix &csr, IdArray tag, int64_t num_tags) { const CSRMatrix& csr, IdArray tag, int64_t num_tags) {
CHECK_EQ(csr.indices->shape[0], tag->shape[0]) CHECK_EQ(csr.indices->shape[0], tag->shape[0])
<< "The length of the tag array should be equal to the number of non-zero data."; << "The length of the tag array should be equal to the number of "
"non-zero data.";
CHECK_SAME_CONTEXT(csr.indices, tag); CHECK_SAME_CONTEXT(csr.indices, tag);
CHECK_INT(tag, "tag"); CHECK_INT(tag, "tag");
std::pair<CSRMatrix, NDArray> ret; std::pair<CSRMatrix, NDArray> ret;
...@@ -528,7 +519,8 @@ std::pair<CSRMatrix, NDArray> CSRSortByTag( ...@@ -528,7 +519,8 @@ std::pair<CSRMatrix, NDArray> CSRSortByTag(
return ret; return ret;
} }
CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids) { CSRMatrix CSRReorder(
CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids) {
CSRMatrix ret; CSRMatrix ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRReorder", { ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRReorder", {
ret = impl::CSRReorder<XPU, IdType>(csr, new_row_ids, new_col_ids); ret = impl::CSRReorder<XPU, IdType>(csr, new_row_ids, new_col_ids);
...@@ -545,23 +537,26 @@ CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) { ...@@ -545,23 +537,26 @@ CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) {
} }
COOMatrix CSRRowWiseSampling( COOMatrix CSRRowWiseSampling(
CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask, bool replace) { CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
bool replace) {
COOMatrix ret; COOMatrix ret;
if (IsNullArray(prob_or_mask)) { if (IsNullArray(prob_or_mask)) {
ATEN_CSR_SWITCH_CUDA_UVA(mat, rows, XPU, IdType, "CSRRowWiseSamplingUniform", { ATEN_CSR_SWITCH_CUDA_UVA(
ret = impl::CSRRowWiseSamplingUniform<XPU, IdType>(mat, rows, num_samples, replace); mat, rows, XPU, IdType, "CSRRowWiseSamplingUniform", {
}); ret = impl::CSRRowWiseSamplingUniform<XPU, IdType>(
mat, rows, num_samples, replace);
});
} else { } else {
// prob_or_mask is pinned and rows on GPU is valid // prob_or_mask is pinned and rows on GPU is valid
CHECK_VALID_CONTEXT(prob_or_mask, rows); CHECK_VALID_CONTEXT(prob_or_mask, rows);
ATEN_CSR_SWITCH_CUDA_UVA(mat, rows, XPU, IdType, "CSRRowWiseSampling", { ATEN_CSR_SWITCH_CUDA_UVA(mat, rows, XPU, IdType, "CSRRowWiseSampling", {
CHECK(!(prob_or_mask->dtype.bits == 8 && XPU == kDGLCUDA)) << CHECK(!(prob_or_mask->dtype.bits == 8 && XPU == kDGLCUDA))
"GPU sampling with masks is currently not supported yet."; << "GPU sampling with masks is currently not supported yet.";
ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH( ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(
prob_or_mask->dtype, FloatType, "probability or mask", { prob_or_mask->dtype, FloatType, "probability or mask", {
ret = impl::CSRRowWiseSampling<XPU, IdType, FloatType>( ret = impl::CSRRowWiseSampling<XPU, IdType, FloatType>(
mat, rows, num_samples, prob_or_mask, replace); mat, rows, num_samples, prob_or_mask, replace);
}); });
}); });
} }
return ret; return ret;
...@@ -569,26 +564,28 @@ COOMatrix CSRRowWiseSampling( ...@@ -569,26 +564,28 @@ COOMatrix CSRRowWiseSampling(
COOMatrix CSRRowWisePerEtypeSampling( COOMatrix CSRRowWisePerEtypeSampling(
CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset, CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples, const std::vector<NDArray>& prob_or_mask, const std::vector<int64_t>& num_samples,
bool replace, bool rowwise_etype_sorted) { const std::vector<NDArray>& prob_or_mask, bool replace,
bool rowwise_etype_sorted) {
COOMatrix ret; COOMatrix ret;
CHECK(prob_or_mask.size() > 0) << "probability or mask array is empty"; CHECK(prob_or_mask.size() > 0) << "probability or mask array is empty";
ATEN_CSR_SWITCH(mat, XPU, IdType, "CSRRowWisePerEtypeSampling", { ATEN_CSR_SWITCH(mat, XPU, IdType, "CSRRowWisePerEtypeSampling", {
if (std::all_of(prob_or_mask.begin(), prob_or_mask.end(), IsNullArray)) { if (std::all_of(prob_or_mask.begin(), prob_or_mask.end(), IsNullArray)) {
ret = impl::CSRRowWisePerEtypeSamplingUniform<XPU, IdType>( ret = impl::CSRRowWisePerEtypeSamplingUniform<XPU, IdType>(
mat, rows, eid2etype_offset, num_samples, replace, rowwise_etype_sorted); mat, rows, eid2etype_offset, num_samples, replace,
rowwise_etype_sorted);
} else { } else {
ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH( ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(
prob_or_mask[0]->dtype, DType, "probability or mask", { prob_or_mask[0]->dtype, DType, "probability or mask", {
ret = impl::CSRRowWisePerEtypeSampling<XPU, IdType, DType>( ret = impl::CSRRowWisePerEtypeSampling<XPU, IdType, DType>(
mat, rows, eid2etype_offset, num_samples, prob_or_mask, replace, rowwise_etype_sorted); mat, rows, eid2etype_offset, num_samples, prob_or_mask, replace,
}); rowwise_etype_sorted);
});
} }
}); });
return ret; return ret;
} }
COOMatrix CSRRowWiseTopk( COOMatrix CSRRowWiseTopk(
CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending) { CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending) {
COOMatrix ret; COOMatrix ret;
...@@ -602,16 +599,12 @@ COOMatrix CSRRowWiseTopk( ...@@ -602,16 +599,12 @@ COOMatrix CSRRowWiseTopk(
} }
COOMatrix CSRRowWiseSamplingBiased( COOMatrix CSRRowWiseSamplingBiased(
CSRMatrix mat, CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray tag_offset,
IdArray rows, FloatArray bias, bool replace) {
int64_t num_samples,
NDArray tag_offset,
FloatArray bias,
bool replace) {
COOMatrix ret; COOMatrix ret;
ATEN_CSR_SWITCH(mat, XPU, IdType, "CSRRowWiseSamplingBiased", { ATEN_CSR_SWITCH(mat, XPU, IdType, "CSRRowWiseSamplingBiased", {
ATEN_FLOAT_TYPE_SWITCH(bias->dtype, FloatType, "bias", { ATEN_FLOAT_TYPE_SWITCH(bias->dtype, FloatType, "bias", {
ret = impl::CSRRowWiseSamplingBiased<XPU, IdType, FloatType>( ret = impl::CSRRowWiseSamplingBiased<XPU, IdType, FloatType>(
mat, rows, num_samples, tag_offset, bias, replace); mat, rows, num_samples, tag_offset, bias, replace);
}); });
}); });
...@@ -619,12 +612,8 @@ COOMatrix CSRRowWiseSamplingBiased( ...@@ -619,12 +612,8 @@ COOMatrix CSRRowWiseSamplingBiased(
} }
std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling( std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
const CSRMatrix& csr, const CSRMatrix& csr, int64_t num_samples, int num_trials,
int64_t num_samples, bool exclude_self_loops, bool replace, double redundancy) {
int num_trials,
bool exclude_self_loops,
bool replace,
double redundancy) {
CHECK_GT(num_samples, 0) << "Number of samples must be positive"; CHECK_GT(num_samples, 0) << "Number of samples must be positive";
CHECK_GT(num_trials, 0) << "Number of sampling trials must be positive"; CHECK_GT(num_trials, 0) << "Number of sampling trials must be positive";
std::pair<IdArray, IdArray> result; std::pair<IdArray, IdArray> result;
...@@ -635,16 +624,16 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling( ...@@ -635,16 +624,16 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
return result; return result;
} }
CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs) { CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs) {
CSRMatrix ret; CSRMatrix ret;
CHECK_GT(csrs.size(), 1) << "UnionCsr creates a union of multiple CSRMatrixes"; CHECK_GT(csrs.size(), 1)
<< "UnionCsr creates a union of multiple CSRMatrixes";
// sanity check // sanity check
for (size_t i = 1; i < csrs.size(); ++i) { for (size_t i = 1; i < csrs.size(); ++i) {
CHECK_EQ(csrs[0].num_rows, csrs[i].num_rows) << CHECK_EQ(csrs[0].num_rows, csrs[i].num_rows)
"UnionCsr requires both CSRMatrix have same number of rows"; << "UnionCsr requires both CSRMatrix have same number of rows";
CHECK_EQ(csrs[0].num_cols, csrs[i].num_cols) << CHECK_EQ(csrs[0].num_cols, csrs[i].num_cols)
"UnionCsr requires both CSRMatrix have same number of cols"; << "UnionCsr requires both CSRMatrix have same number of cols";
CHECK_SAME_CONTEXT(csrs[0].indptr, csrs[i].indptr); CHECK_SAME_CONTEXT(csrs[0].indptr, csrs[i].indptr);
CHECK_SAME_DTYPE(csrs[0].indptr, csrs[i].indptr); CHECK_SAME_DTYPE(csrs[0].indptr, csrs[i].indptr);
} }
...@@ -655,9 +644,7 @@ CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs) { ...@@ -655,9 +644,7 @@ CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs) {
return ret; return ret;
} }
std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(const CSRMatrix& csr) {
std::tuple<CSRMatrix, IdArray, IdArray>
CSRToSimple(const CSRMatrix& csr) {
std::tuple<CSRMatrix, IdArray, IdArray> ret; std::tuple<CSRMatrix, IdArray, IdArray> ret;
CSRMatrix sorted_csr = (CSRIsSorted(csr)) ? csr : CSRSort(csr); CSRMatrix sorted_csr = (CSRIsSorted(csr)) ? csr : CSRSort(csr);
...@@ -709,7 +696,8 @@ NDArray COOGetRowNNZ(COOMatrix coo, NDArray row) { ...@@ -709,7 +696,8 @@ NDArray COOGetRowNNZ(COOMatrix coo, NDArray row) {
return ret; return ret;
} }
std::pair<NDArray, NDArray> COOGetRowDataAndIndices(COOMatrix coo, int64_t row) { std::pair<NDArray, NDArray> COOGetRowDataAndIndices(
COOMatrix coo, int64_t row) {
std::pair<NDArray, NDArray> ret; std::pair<NDArray, NDArray> ret;
ATEN_COO_SWITCH(coo, XPU, IdType, "COOGetRowDataAndIndices", { ATEN_COO_SWITCH(coo, XPU, IdType, "COOGetRowDataAndIndices", {
ret = impl::COOGetRowDataAndIndices<XPU, IdType>(coo, row); ret = impl::COOGetRowDataAndIndices<XPU, IdType>(coo, row);
...@@ -741,9 +729,8 @@ COOMatrix COOTranspose(COOMatrix coo) { ...@@ -741,9 +729,8 @@ COOMatrix COOTranspose(COOMatrix coo) {
CSRMatrix COOToCSR(COOMatrix coo) { CSRMatrix COOToCSR(COOMatrix coo) {
CSRMatrix ret; CSRMatrix ret;
ATEN_XPU_SWITCH_CUDA(coo.row->ctx.device_type, XPU, "COOToCSR", { ATEN_XPU_SWITCH_CUDA(coo.row->ctx.device_type, XPU, "COOToCSR", {
ATEN_ID_TYPE_SWITCH(coo.row->dtype, IdType, { ATEN_ID_TYPE_SWITCH(
ret = impl::COOToCSR<XPU, IdType>(coo); coo.row->dtype, IdType, { ret = impl::COOToCSR<XPU, IdType>(coo); });
});
}); });
return ret; return ret;
} }
...@@ -773,8 +760,7 @@ COOMatrix COOSliceMatrix(COOMatrix coo, NDArray rows, NDArray cols) { ...@@ -773,8 +760,7 @@ COOMatrix COOSliceMatrix(COOMatrix coo, NDArray rows, NDArray cols) {
} }
void COOSort_(COOMatrix* mat, bool sort_column) { void COOSort_(COOMatrix* mat, bool sort_column) {
if ((mat->row_sorted && !sort_column) || mat->col_sorted) if ((mat->row_sorted && !sort_column) || mat->col_sorted) return;
return;
ATEN_XPU_SWITCH_CUDA(mat->row->ctx.device_type, XPU, "COOSort_", { ATEN_XPU_SWITCH_CUDA(mat->row->ctx.device_type, XPU, "COOSort_", {
ATEN_ID_TYPE_SWITCH(mat->row->dtype, IdType, { ATEN_ID_TYPE_SWITCH(mat->row->dtype, IdType, {
impl::COOSort_<XPU, IdType>(mat, sort_column); impl::COOSort_<XPU, IdType>(mat, sort_column);
...@@ -783,8 +769,7 @@ void COOSort_(COOMatrix* mat, bool sort_column) { ...@@ -783,8 +769,7 @@ void COOSort_(COOMatrix* mat, bool sort_column) {
} }
std::pair<bool, bool> COOIsSorted(COOMatrix coo) { std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
if (coo.row->shape[0] <= 1) if (coo.row->shape[0] <= 1) return {true, true};
return {true, true};
std::pair<bool, bool> ret; std::pair<bool, bool> ret;
ATEN_COO_SWITCH_CUDA(coo, XPU, IdType, "COOIsSorted", { ATEN_COO_SWITCH_CUDA(coo, XPU, IdType, "COOIsSorted", {
ret = impl::COOIsSorted<XPU, IdType>(coo); ret = impl::COOIsSorted<XPU, IdType>(coo);
...@@ -792,7 +777,8 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo) { ...@@ -792,7 +777,8 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
return ret; return ret;
} }
COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids) { COOMatrix COOReorder(
COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids) {
COOMatrix ret; COOMatrix ret;
ATEN_COO_SWITCH(coo, XPU, IdType, "COOReorder", { ATEN_COO_SWITCH(coo, XPU, IdType, "COOReorder", {
ret = impl::COOReorder<XPU, IdType>(coo, new_row_ids, new_col_ids); ret = impl::COOReorder<XPU, IdType>(coo, new_row_ids, new_col_ids);
...@@ -809,17 +795,19 @@ COOMatrix COORemove(COOMatrix coo, IdArray entries) { ...@@ -809,17 +795,19 @@ COOMatrix COORemove(COOMatrix coo, IdArray entries) {
} }
COOMatrix COORowWiseSampling( COOMatrix COORowWiseSampling(
COOMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask, bool replace) { COOMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
bool replace) {
COOMatrix ret; COOMatrix ret;
ATEN_COO_SWITCH(mat, XPU, IdType, "COORowWiseSampling", { ATEN_COO_SWITCH(mat, XPU, IdType, "COORowWiseSampling", {
if (IsNullArray(prob_or_mask)) { if (IsNullArray(prob_or_mask)) {
ret = impl::COORowWiseSamplingUniform<XPU, IdType>(mat, rows, num_samples, replace); ret = impl::COORowWiseSamplingUniform<XPU, IdType>(
mat, rows, num_samples, replace);
} else { } else {
ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH( ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(
prob_or_mask->dtype, DType, "probability or mask", { prob_or_mask->dtype, DType, "probability or mask", {
ret = impl::COORowWiseSampling<XPU, IdType, DType>( ret = impl::COORowWiseSampling<XPU, IdType, DType>(
mat, rows, num_samples, prob_or_mask, replace); mat, rows, num_samples, prob_or_mask, replace);
}); });
} }
}); });
return ret; return ret;
...@@ -827,20 +815,21 @@ COOMatrix COORowWiseSampling( ...@@ -827,20 +815,21 @@ COOMatrix COORowWiseSampling(
COOMatrix COORowWisePerEtypeSampling( COOMatrix COORowWisePerEtypeSampling(
COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset, COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples, const std::vector<NDArray>& prob_or_mask, const std::vector<int64_t>& num_samples,
bool replace) { const std::vector<NDArray>& prob_or_mask, bool replace) {
COOMatrix ret; COOMatrix ret;
CHECK(prob_or_mask.size() > 0) << "probability or mask array is empty"; CHECK(prob_or_mask.size() > 0) << "probability or mask array is empty";
ATEN_COO_SWITCH(mat, XPU, IdType, "COORowWisePerEtypeSampling", { ATEN_COO_SWITCH(mat, XPU, IdType, "COORowWisePerEtypeSampling", {
if (std::all_of(prob_or_mask.begin(), prob_or_mask.end(), IsNullArray)) { if (std::all_of(prob_or_mask.begin(), prob_or_mask.end(), IsNullArray)) {
ret = impl::COORowWisePerEtypeSamplingUniform<XPU, IdType>( ret = impl::COORowWisePerEtypeSamplingUniform<XPU, IdType>(
mat, rows, eid2etype_offset, num_samples, replace); mat, rows, eid2etype_offset, num_samples, replace);
} else { } else {
ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH( ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(
prob_or_mask[0]->dtype, DType, "probability or mask", { prob_or_mask[0]->dtype, DType, "probability or mask", {
ret = impl::COORowWisePerEtypeSampling<XPU, IdType, DType>( ret = impl::COORowWisePerEtypeSampling<XPU, IdType, DType>(
mat, rows, eid2etype_offset, num_samples, prob_or_mask, replace); mat, rows, eid2etype_offset, num_samples, prob_or_mask,
}); replace);
});
} }
}); });
return ret; return ret;
...@@ -876,7 +865,7 @@ COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) { ...@@ -876,7 +865,7 @@ COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
return ret; return ret;
} }
COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking) { COOMatrix COOLineGraph(const COOMatrix& coo, bool backtracking) {
COOMatrix ret; COOMatrix ret;
ATEN_COO_SWITCH(coo, XPU, IdType, "COOLineGraph", { ATEN_COO_SWITCH(coo, XPU, IdType, "COOLineGraph", {
ret = impl::COOLineGraph<XPU, IdType>(coo, backtracking); ret = impl::COOLineGraph<XPU, IdType>(coo, backtracking);
...@@ -886,13 +875,14 @@ COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking) { ...@@ -886,13 +875,14 @@ COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking) {
COOMatrix UnionCoo(const std::vector<COOMatrix>& coos) { COOMatrix UnionCoo(const std::vector<COOMatrix>& coos) {
COOMatrix ret; COOMatrix ret;
CHECK_GT(coos.size(), 1) << "UnionCoo creates a union of multiple COOMatrixes"; CHECK_GT(coos.size(), 1)
<< "UnionCoo creates a union of multiple COOMatrixes";
// sanity check // sanity check
for (size_t i = 1; i < coos.size(); ++i) { for (size_t i = 1; i < coos.size(); ++i) {
CHECK_EQ(coos[0].num_rows, coos[i].num_rows) << CHECK_EQ(coos[0].num_rows, coos[i].num_rows)
"UnionCoo requires both COOMatrix have same number of rows"; << "UnionCoo requires both COOMatrix have same number of rows";
CHECK_EQ(coos[0].num_cols, coos[i].num_cols) << CHECK_EQ(coos[0].num_cols, coos[i].num_cols)
"UnionCoo requires both COOMatrix have same number of cols"; << "UnionCoo requires both COOMatrix have same number of cols";
CHECK_SAME_CONTEXT(coos[0].row, coos[i].row); CHECK_SAME_CONTEXT(coos[0].row, coos[i].row);
CHECK_SAME_DTYPE(coos[0].row, coos[i].row); CHECK_SAME_DTYPE(coos[0].row, coos[i].row);
} }
...@@ -914,20 +904,19 @@ COOMatrix UnionCoo(const std::vector<COOMatrix>& coos) { ...@@ -914,20 +904,19 @@ COOMatrix UnionCoo(const std::vector<COOMatrix>& coos) {
if (has_data) { if (has_data) {
std::vector<IdArray> eid_data; std::vector<IdArray> eid_data;
eid_data.push_back(COOHasData(coos[0]) ? eid_data.push_back(
coos[0].data : COOHasData(coos[0]) ? coos[0].data
Range(0, : Range(
coos[0].row->shape[0], 0, coos[0].row->shape[0],
coos[0].row->dtype.bits, coos[0].row->dtype.bits, coos[0].row->ctx));
coos[0].row->ctx));
int64_t num_edges = coos[0].row->shape[0]; int64_t num_edges = coos[0].row->shape[0];
for (size_t i = 1; i < coos.size(); ++i) { for (size_t i = 1; i < coos.size(); ++i) {
eid_data.push_back(COOHasData(coos[i]) ? eid_data.push_back(
coos[i].data + num_edges : COOHasData(coos[i])
Range(num_edges, ? coos[i].data + num_edges
num_edges + coos[i].row->shape[0], : Range(
coos[i].row->dtype.bits, num_edges, num_edges + coos[i].row->shape[0],
coos[i].row->ctx)); coos[i].row->dtype.bits, coos[i].row->ctx));
num_edges += coos[i].row->shape[0]; num_edges += coos[i].row->shape[0];
} }
...@@ -935,71 +924,62 @@ COOMatrix UnionCoo(const std::vector<COOMatrix>& coos) { ...@@ -935,71 +924,62 @@ COOMatrix UnionCoo(const std::vector<COOMatrix>& coos) {
} }
return COOMatrix( return COOMatrix(
coos[0].num_rows, coos[0].num_rows, coos[0].num_cols, row, col, data, false, false);
coos[0].num_cols,
row,
col,
data,
false,
false);
} }
std::tuple<COOMatrix, IdArray, IdArray> COOToSimple(const COOMatrix& coo) {
std::tuple<COOMatrix, IdArray, IdArray>
COOToSimple(const COOMatrix& coo) {
// coo column sorted // coo column sorted
const COOMatrix sorted_coo = COOSort(coo, true); const COOMatrix sorted_coo = COOSort(coo, true);
const IdArray eids_shuffled = COOHasData(sorted_coo) ? const IdArray eids_shuffled =
sorted_coo.data : COOHasData(sorted_coo)
Range(0, sorted_coo.row->shape[0], sorted_coo.row->dtype.bits, sorted_coo.row->ctx); ? sorted_coo.data
const auto &coalesced_result = COOCoalesce(sorted_coo); : Range(
const COOMatrix &coalesced_adj = coalesced_result.first; 0, sorted_coo.row->shape[0], sorted_coo.row->dtype.bits,
const IdArray &count = coalesced_result.second; sorted_coo.row->ctx);
const auto& coalesced_result = COOCoalesce(sorted_coo);
const COOMatrix& coalesced_adj = coalesced_result.first;
const IdArray& count = coalesced_result.second;
/** /**
* eids_shuffled actually already contains the mapping from old edge space to the * eids_shuffled actually already contains the mapping from old edge space to
* new one: * the new one:
* *
* * eids_shuffled[0:count[0]] indicates the original edge IDs that coalesced into new * * eids_shuffled[0:count[0]] indicates the original edge IDs that coalesced
* edge #0. * into new edge #0.
* * eids_shuffled[count[0]:count[0] + count[1]] indicates those that coalesced into * * eids_shuffled[count[0]:count[0] + count[1]] indicates those that
* new edge #1. * coalesced into new edge #1.
* * eids_shuffled[count[0] + count[1]:count[0] + count[1] + count[2]] indicates those * * eids_shuffled[count[0] + count[1]:count[0] + count[1] + count[2]]
* that coalesced into new edge #2. * indicates those that coalesced into new edge #2.
* * etc. * * etc.
* *
* Here, we need to translate eids_shuffled to an array "eids_remapped" such that * Here, we need to translate eids_shuffled to an array "eids_remapped" such
* eids_remapped[i] indicates the new edge ID the old edge #i is mapped to. The * that eids_remapped[i] indicates the new edge ID the old edge #i is mapped
* translation can simply be achieved by (in numpy code): * to. The translation can simply be achieved by (in numpy code):
* *
* new_eid_for_eids_shuffled = np.range(len(count)).repeat(count) * new_eid_for_eids_shuffled = np.range(len(count)).repeat(count)
* eids_remapped = np.zeros_like(new_eid_for_eids_shuffled) * eids_remapped = np.zeros_like(new_eid_for_eids_shuffled)
* eids_remapped[eids_shuffled] = new_eid_for_eids_shuffled * eids_remapped[eids_shuffled] = new_eid_for_eids_shuffled
*/ */
const IdArray new_eids = Range( const IdArray new_eids = Range(
0, coalesced_adj.row->shape[0], coalesced_adj.row->dtype.bits, coalesced_adj.row->ctx); 0, coalesced_adj.row->shape[0], coalesced_adj.row->dtype.bits,
coalesced_adj.row->ctx);
const IdArray eids_remapped = Scatter(Repeat(new_eids, count), eids_shuffled); const IdArray eids_remapped = Scatter(Repeat(new_eids, count), eids_shuffled);
COOMatrix ret = COOMatrix( COOMatrix ret = COOMatrix(
coalesced_adj.num_rows, coalesced_adj.num_rows, coalesced_adj.num_cols, coalesced_adj.row,
coalesced_adj.num_cols, coalesced_adj.col, NullArray(), true, true);
coalesced_adj.row,
coalesced_adj.col,
NullArray(),
true,
true);
return std::make_tuple(ret, count, eids_remapped); return std::make_tuple(ret, count, eids_remapped);
} }
///////////////////////// Graph Traverse routines ////////////////////////// ///////////////////////// Graph Traverse routines //////////////////////////
Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) { Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) {
Frontiers ret; Frontiers ret;
CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type) << CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type)
"Graph and source should in the same device context"; << "Graph and source should in the same device context";
CHECK_EQ(csr.indices->dtype, source->dtype) << CHECK_EQ(csr.indices->dtype, source->dtype)
"Graph and source should in the same dtype"; << "Graph and source should in the same dtype";
CHECK_EQ(csr.num_rows, csr.num_cols) << CHECK_EQ(csr.num_rows, csr.num_cols)
"Graph traversal can only work on square-shaped CSR."; << "Graph traversal can only work on square-shaped CSR.";
ATEN_XPU_SWITCH(source->ctx.device_type, XPU, "BFSNodesFrontiers", { ATEN_XPU_SWITCH(source->ctx.device_type, XPU, "BFSNodesFrontiers", {
ATEN_ID_TYPE_SWITCH(source->dtype, IdType, { ATEN_ID_TYPE_SWITCH(source->dtype, IdType, {
ret = impl::BFSNodesFrontiers<XPU, IdType>(csr, source); ret = impl::BFSNodesFrontiers<XPU, IdType>(csr, source);
...@@ -1010,12 +990,12 @@ Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) { ...@@ -1010,12 +990,12 @@ Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) {
Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) { Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) {
Frontiers ret; Frontiers ret;
CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type) << CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type)
"Graph and source should in the same device context"; << "Graph and source should in the same device context";
CHECK_EQ(csr.indices->dtype, source->dtype) << CHECK_EQ(csr.indices->dtype, source->dtype)
"Graph and source should in the same dtype"; << "Graph and source should in the same dtype";
CHECK_EQ(csr.num_rows, csr.num_cols) << CHECK_EQ(csr.num_rows, csr.num_cols)
"Graph traversal can only work on square-shaped CSR."; << "Graph traversal can only work on square-shaped CSR.";
ATEN_XPU_SWITCH(source->ctx.device_type, XPU, "BFSEdgesFrontiers", { ATEN_XPU_SWITCH(source->ctx.device_type, XPU, "BFSEdgesFrontiers", {
ATEN_ID_TYPE_SWITCH(source->dtype, IdType, { ATEN_ID_TYPE_SWITCH(source->dtype, IdType, {
ret = impl::BFSEdgesFrontiers<XPU, IdType>(csr, source); ret = impl::BFSEdgesFrontiers<XPU, IdType>(csr, source);
...@@ -1026,24 +1006,25 @@ Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) { ...@@ -1026,24 +1006,25 @@ Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) {
Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr) { Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr) {
Frontiers ret; Frontiers ret;
CHECK_EQ(csr.num_rows, csr.num_cols) << CHECK_EQ(csr.num_rows, csr.num_cols)
"Graph traversal can only work on square-shaped CSR."; << "Graph traversal can only work on square-shaped CSR.";
ATEN_XPU_SWITCH(csr.indptr->ctx.device_type, XPU, "TopologicalNodesFrontiers", { ATEN_XPU_SWITCH(
ATEN_ID_TYPE_SWITCH(csr.indices->dtype, IdType, { csr.indptr->ctx.device_type, XPU, "TopologicalNodesFrontiers", {
ret = impl::TopologicalNodesFrontiers<XPU, IdType>(csr); ATEN_ID_TYPE_SWITCH(csr.indices->dtype, IdType, {
}); ret = impl::TopologicalNodesFrontiers<XPU, IdType>(csr);
}); });
});
return ret; return ret;
} }
Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) { Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) {
Frontiers ret; Frontiers ret;
CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type) << CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type)
"Graph and source should in the same device context"; << "Graph and source should in the same device context";
CHECK_EQ(csr.indices->dtype, source->dtype) << CHECK_EQ(csr.indices->dtype, source->dtype)
"Graph and source should in the same dtype"; << "Graph and source should in the same dtype";
CHECK_EQ(csr.num_rows, csr.num_cols) << CHECK_EQ(csr.num_rows, csr.num_cols)
"Graph traversal can only work on square-shaped CSR."; << "Graph traversal can only work on square-shaped CSR.";
ATEN_XPU_SWITCH(source->ctx.device_type, XPU, "DGLDFSEdges", { ATEN_XPU_SWITCH(source->ctx.device_type, XPU, "DGLDFSEdges", {
ATEN_ID_TYPE_SWITCH(source->dtype, IdType, { ATEN_ID_TYPE_SWITCH(source->dtype, IdType, {
ret = impl::DGLDFSEdges<XPU, IdType>(csr, source); ret = impl::DGLDFSEdges<XPU, IdType>(csr, source);
...@@ -1052,104 +1033,97 @@ Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) { ...@@ -1052,104 +1033,97 @@ Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) {
return ret; return ret;
} }
Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr, Frontiers DGLDFSLabeledEdges(
IdArray source, const CSRMatrix& csr, IdArray source, const bool has_reverse_edge,
const bool has_reverse_edge, const bool has_nontree_edge, const bool return_labels) {
const bool has_nontree_edge,
const bool return_labels) {
Frontiers ret; Frontiers ret;
CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type) << CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type)
"Graph and source should in the same device context"; << "Graph and source should in the same device context";
CHECK_EQ(csr.indices->dtype, source->dtype) << CHECK_EQ(csr.indices->dtype, source->dtype)
"Graph and source should in the same dtype"; << "Graph and source should in the same dtype";
CHECK_EQ(csr.num_rows, csr.num_cols) << CHECK_EQ(csr.num_rows, csr.num_cols)
"Graph traversal can only work on square-shaped CSR."; << "Graph traversal can only work on square-shaped CSR.";
ATEN_XPU_SWITCH(source->ctx.device_type, XPU, "DGLDFSLabeledEdges", { ATEN_XPU_SWITCH(source->ctx.device_type, XPU, "DGLDFSLabeledEdges", {
ATEN_ID_TYPE_SWITCH(source->dtype, IdType, { ATEN_ID_TYPE_SWITCH(source->dtype, IdType, {
ret = impl::DGLDFSLabeledEdges<XPU, IdType>(csr, ret = impl::DGLDFSLabeledEdges<XPU, IdType>(
source, csr, source, has_reverse_edge, has_nontree_edge, return_labels);
has_reverse_edge,
has_nontree_edge,
return_labels);
}); });
}); });
return ret; return ret;
} }
///////////////////////// C APIs ///////////////////////// ///////////////////////// C APIs /////////////////////////
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetFormat") DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetFormat")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
SparseMatrixRef spmat = args[0]; SparseMatrixRef spmat = args[0];
*rv = spmat->format; *rv = spmat->format;
}); });
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetNumRows") DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetNumRows")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
SparseMatrixRef spmat = args[0]; SparseMatrixRef spmat = args[0];
*rv = spmat->num_rows; *rv = spmat->num_rows;
}); });
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetNumCols") DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetNumCols")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
SparseMatrixRef spmat = args[0]; SparseMatrixRef spmat = args[0];
*rv = spmat->num_cols; *rv = spmat->num_cols;
}); });
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetIndices") DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetIndices")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
SparseMatrixRef spmat = args[0]; SparseMatrixRef spmat = args[0];
const int64_t i = args[1]; const int64_t i = args[1];
*rv = spmat->indices[i]; *rv = spmat->indices[i];
}); });
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetFlags") DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetFlags")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
SparseMatrixRef spmat = args[0]; SparseMatrixRef spmat = args[0];
List<Value> flags; List<Value> flags;
for (bool flg : spmat->flags) { for (bool flg : spmat->flags) {
flags.push_back(Value(MakeValue(flg))); flags.push_back(Value(MakeValue(flg)));
} }
*rv = flags; *rv = flags;
}); });
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLCreateSparseMatrix") DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLCreateSparseMatrix")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t format = args[0]; const int32_t format = args[0];
const int64_t nrows = args[1]; const int64_t nrows = args[1];
const int64_t ncols = args[2]; const int64_t ncols = args[2];
const List<Value> indices = args[3]; const List<Value> indices = args[3];
const List<Value> flags = args[4]; const List<Value> flags = args[4];
std::shared_ptr<SparseMatrix> spmat(new SparseMatrix( std::shared_ptr<SparseMatrix> spmat(new SparseMatrix(
format, nrows, ncols, format, nrows, ncols, ListValueToVector<IdArray>(indices),
ListValueToVector<IdArray>(indices),
ListValueToVector<bool>(flags))); ListValueToVector<bool>(flags)));
*rv = SparseMatrixRef(spmat); *rv = SparseMatrixRef(spmat);
}); });
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLExistSharedMemArray") DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLExistSharedMemArray")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const std::string name = args[0]; const std::string name = args[0];
#ifndef _WIN32 #ifndef _WIN32
*rv = SharedMemory::Exist(name); *rv = SharedMemory::Exist(name);
#else #else
*rv = false; *rv = false;
#endif // _WIN32 #endif // _WIN32
}); });
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLArrayCastToSigned") DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLArrayCastToSigned")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
NDArray array = args[0]; NDArray array = args[0];
CHECK_EQ(array->dtype.code, kDGLUInt); CHECK_EQ(array->dtype.code, kDGLUInt);
std::vector<int64_t> shape(array->shape, array->shape + array->ndim); std::vector<int64_t> shape(array->shape, array->shape + array->ndim);
DGLDataType dtype = array->dtype; DGLDataType dtype = array->dtype;
dtype.code = kDGLInt; dtype.code = kDGLInt;
*rv = array.CreateView(shape, dtype, 0); *rv = array.CreateView(shape, dtype, 0);
}); });
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
std::ostream& operator << (std::ostream& os, dgl::runtime::NDArray array) { std::ostream& operator<<(std::ostream& os, dgl::runtime::NDArray array) {
return os << dgl::aten::ToDebugString(array); return os << dgl::aten::ToDebugString(array);
} }
...@@ -8,9 +8,10 @@ ...@@ -8,9 +8,10 @@
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/graph_traversal.h> #include <dgl/graph_traversal.h>
#include <vector>
#include <tuple> #include <tuple>
#include <utility> #include <utility>
#include <vector>
namespace dgl { namespace dgl {
namespace aten { namespace aten {
...@@ -82,7 +83,8 @@ template <DGLDeviceType XPU, typename IdType> ...@@ -82,7 +83,8 @@ template <DGLDeviceType XPU, typename IdType>
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col); bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
runtime::NDArray CSRIsNonZero(CSRMatrix csr, runtime::NDArray row, runtime::NDArray col); runtime::NDArray CSRIsNonZero(
CSRMatrix csr, runtime::NDArray row, runtime::NDArray col);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
bool CSRHasDuplicate(CSRMatrix csr); bool CSRHasDuplicate(CSRMatrix csr);
...@@ -104,19 +106,21 @@ bool CSRIsSorted(CSRMatrix csr); ...@@ -104,19 +106,21 @@ bool CSRIsSorted(CSRMatrix csr);
template <DGLDeviceType XPU, typename IdType, typename DType> template <DGLDeviceType XPU, typename IdType, typename DType>
runtime::NDArray CSRGetData( runtime::NDArray CSRGetData(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols, bool return_eids, CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols,
runtime::NDArray weights, DType filler); bool return_eids, runtime::NDArray weights, DType filler);
template <DGLDeviceType XPU, typename IdType, typename DType> template <DGLDeviceType XPU, typename IdType, typename DType>
runtime::NDArray CSRGetData( runtime::NDArray CSRGetData(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols, CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols,
runtime::NDArray weights, DType filler) { runtime::NDArray weights, DType filler) {
return CSRGetData<XPU, IdType, DType>(csr, rows, cols, false, weights, filler); return CSRGetData<XPU, IdType, DType>(
csr, rows, cols, false, weights, filler);
} }
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) { NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
return CSRGetData<XPU, IdType, IdType>(csr, rows, cols, true, NullArray(rows->dtype), -1); return CSRGetData<XPU, IdType, IdType>(
csr, rows, cols, true, NullArray(rows->dtype), -1);
} }
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
...@@ -141,20 +145,23 @@ template <DGLDeviceType XPU, typename IdType> ...@@ -141,20 +145,23 @@ template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows); CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols); CSRMatrix CSRSliceMatrix(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
void CSRSort_(CSRMatrix* csr); void CSRSort_(CSRMatrix* csr);
template <DGLDeviceType XPU, typename IdType, typename TagType> template <DGLDeviceType XPU, typename IdType, typename TagType>
std::pair<CSRMatrix, NDArray> CSRSortByTag( std::pair<CSRMatrix, NDArray> CSRSortByTag(
const CSRMatrix &csr, IdArray tag_array, int64_t num_tags); const CSRMatrix& csr, IdArray tag_array, int64_t num_tags);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids); CSRMatrix CSRReorder(
CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids); COOMatrix COOReorder(
COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries); CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries);
...@@ -162,15 +169,16 @@ CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries); ...@@ -162,15 +169,16 @@ CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries);
// FloatType is the type of probability data. // FloatType is the type of probability data.
template <DGLDeviceType XPU, typename IdType, typename DType> template <DGLDeviceType XPU, typename IdType, typename DType>
COOMatrix CSRRowWiseSampling( COOMatrix CSRRowWiseSampling(
CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask, bool replace); CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
bool replace);
// FloatType is the type of probability data. // FloatType is the type of probability data.
template <DGLDeviceType XPU, typename IdType, typename DType> template <DGLDeviceType XPU, typename IdType, typename DType>
COOMatrix CSRRowWisePerEtypeSampling( COOMatrix CSRRowWisePerEtypeSampling(
CSRMatrix mat, IdArray rows, CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples, const std::vector<int64_t>& num_samples,
const std::vector<NDArray>& prob_or_mask, bool replace, bool rowwise_etype_sorted); const std::vector<NDArray>& prob_or_mask, bool replace,
bool rowwise_etype_sorted);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRRowWiseSamplingUniform( COOMatrix CSRRowWiseSamplingUniform(
...@@ -178,9 +186,9 @@ COOMatrix CSRRowWiseSamplingUniform( ...@@ -178,9 +186,9 @@ COOMatrix CSRRowWiseSamplingUniform(
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRRowWisePerEtypeSamplingUniform( COOMatrix CSRRowWisePerEtypeSamplingUniform(
CSRMatrix mat, IdArray rows, CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& eid2etype_offset, const std::vector<int64_t>& num_samples, bool replace,
const std::vector<int64_t>& num_samples, bool replace, bool rowwise_etype_sorted); bool rowwise_etype_sorted);
// FloatType is the type of weight data. // FloatType is the type of weight data.
template <DGLDeviceType XPU, typename IdType, typename DType> template <DGLDeviceType XPU, typename IdType, typename DType>
...@@ -189,21 +197,13 @@ COOMatrix CSRRowWiseTopk( ...@@ -189,21 +197,13 @@ COOMatrix CSRRowWiseTopk(
template <DGLDeviceType XPU, typename IdType, typename FloatType> template <DGLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix CSRRowWiseSamplingBiased( COOMatrix CSRRowWiseSamplingBiased(
CSRMatrix mat, CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray tag_offset,
IdArray rows, FloatArray bias, bool replace);
int64_t num_samples,
NDArray tag_offset,
FloatArray bias,
bool replace);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling( std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
const CSRMatrix& csr, const CSRMatrix& csr, int64_t num_samples, int num_trials,
int64_t num_samples, bool exclude_self_loops, bool replace, double redundancy);
int num_trials,
bool exclude_self_loops,
bool replace,
double redundancy);
// Union CSRMatrixes // Union CSRMatrixes
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
...@@ -218,7 +218,8 @@ template <DGLDeviceType XPU, typename IdType> ...@@ -218,7 +218,8 @@ template <DGLDeviceType XPU, typename IdType>
bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col); bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
runtime::NDArray COOIsNonZero(COOMatrix coo, runtime::NDArray row, runtime::NDArray col); runtime::NDArray COOIsNonZero(
COOMatrix coo, runtime::NDArray row, runtime::NDArray col);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
bool COOHasDuplicate(COOMatrix coo); bool COOHasDuplicate(COOMatrix coo);
...@@ -230,15 +231,16 @@ template <DGLDeviceType XPU, typename IdType> ...@@ -230,15 +231,16 @@ template <DGLDeviceType XPU, typename IdType>
runtime::NDArray COOGetRowNNZ(COOMatrix coo, runtime::NDArray row); runtime::NDArray COOGetRowNNZ(COOMatrix coo, runtime::NDArray row);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<runtime::NDArray, runtime::NDArray> std::pair<runtime::NDArray, runtime::NDArray> COOGetRowDataAndIndices(
COOGetRowDataAndIndices(COOMatrix coo, int64_t row); COOMatrix coo, int64_t row);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::vector<runtime::NDArray> COOGetDataAndIndices( std::vector<runtime::NDArray> COOGetDataAndIndices(
COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols); COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
runtime::NDArray COOGetData(COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols); runtime::NDArray COOGetData(
COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix COOTranspose(COOMatrix coo); COOMatrix COOTranspose(COOMatrix coo);
...@@ -253,7 +255,8 @@ template <DGLDeviceType XPU, typename IdType> ...@@ -253,7 +255,8 @@ template <DGLDeviceType XPU, typename IdType>
COOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows); COOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols); COOMatrix COOSliceMatrix(
COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo); std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo);
...@@ -273,13 +276,13 @@ COOMatrix COORemove(COOMatrix coo, IdArray entries); ...@@ -273,13 +276,13 @@ COOMatrix COORemove(COOMatrix coo, IdArray entries);
// FloatType is the type of probability data. // FloatType is the type of probability data.
template <DGLDeviceType XPU, typename IdType, typename DType> template <DGLDeviceType XPU, typename IdType, typename DType>
COOMatrix COORowWiseSampling( COOMatrix COORowWiseSampling(
COOMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask, bool replace); COOMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
bool replace);
// FloatType is the type of probability data. // FloatType is the type of probability data.
template <DGLDeviceType XPU, typename IdType, typename DType> template <DGLDeviceType XPU, typename IdType, typename DType>
COOMatrix COORowWisePerEtypeSampling( COOMatrix COORowWisePerEtypeSampling(
COOMatrix mat, IdArray rows, COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples, const std::vector<int64_t>& num_samples,
const std::vector<NDArray>& prob_or_mask, bool replace); const std::vector<NDArray>& prob_or_mask, bool replace);
...@@ -289,8 +292,7 @@ COOMatrix COORowWiseSamplingUniform( ...@@ -289,8 +292,7 @@ COOMatrix COORowWiseSamplingUniform(
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix COORowWisePerEtypeSamplingUniform( COOMatrix COORowWisePerEtypeSamplingUniform(
COOMatrix mat, IdArray rows, COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples, bool replace); const std::vector<int64_t>& num_samples, bool replace);
// FloatType is the type of weight data. // FloatType is the type of weight data.
...@@ -313,14 +315,12 @@ template <DGLDeviceType XPU, typename IdType> ...@@ -313,14 +315,12 @@ template <DGLDeviceType XPU, typename IdType>
Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source); Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr, Frontiers DGLDFSLabeledEdges(
IdArray source, const CSRMatrix& csr, IdArray source, const bool has_reverse_edge,
const bool has_reverse_edge, const bool has_nontree_edge, const bool return_labels);
const bool has_nontree_edge,
const bool return_labels);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking); COOMatrix COOLineGraph(const COOMatrix& coo, bool backtracking);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
/** /**
* Copyright (c) 2020 by Contributors * Copyright (c) 2020 by Contributors
* @file kernel/cpu/gaher_mm.cc * @file kernel/cpu/gaher_mm.cc
* @brief GatherMM C APIs and definitions. * @brief GatherMM C APIs and definitions.
*/ */
#include "./gather_mm.h" #include "./gather_mm.h"
#include <dgl/array.h> #include <dgl/array.h>
namespace dgl { namespace dgl {
...@@ -11,81 +12,72 @@ namespace aten { ...@@ -11,81 +12,72 @@ namespace aten {
/** @brief Generalized SegmentMM. */ /** @brief Generalized SegmentMM. */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SegmentMM(const NDArray A, void SegmentMM(
const NDArray B, const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
NDArray C, bool a_trans, bool b_trans) {
const NDArray seglen_A, LOG(FATAL) << "Unsupported CPU kernel for SegmentMM.";
bool a_trans, bool b_trans) {
LOG(FATAL) << "Unsupported CPU kernel for SegmentMM.";
} }
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SegmentMMBackwardB(const NDArray A, void SegmentMMBackwardB(
const NDArray dC, const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen) {
NDArray dB, LOG(FATAL) << "Unsupported CPU kernel for SegmentMMBackwardB.";
const NDArray seglen) {
LOG(FATAL) << "Unsupported CPU kernel for SegmentMMBackwardB.";
} }
/** @brief Generalized GatherMM. */ /** @brief Generalized GatherMM. */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void GatherMM(const NDArray A, void GatherMM(
const NDArray B, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
NDArray C, const NDArray idx_b) {
const NDArray idx_a, LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
const NDArray idx_b) {
LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
} }
/** @brief Generalized GatherMM_scatter. */ /** @brief Generalized GatherMM_scatter. */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void GatherMMScatter(const NDArray A, void GatherMMScatter(
const NDArray B, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
NDArray C, const NDArray idx_b, const NDArray idx_c) {
const NDArray idx_a, LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
const NDArray idx_b,
const NDArray idx_c) {
LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
} }
template void GatherMM<kDGLCPU, int32_t, float>( template void GatherMM<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_b);
template void GatherMM<kDGLCPU, int64_t, float>( template void GatherMM<kDGLCPU, int64_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_b);
template void GatherMM<kDGLCPU, int32_t, double>( template void GatherMM<kDGLCPU, int32_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_b);
template void GatherMM<kDGLCPU, int64_t, double>( template void GatherMM<kDGLCPU, int64_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_b);
template void GatherMMScatter<kDGLCPU, int32_t, float>( template void GatherMMScatter<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCPU, int64_t, float>( template void GatherMMScatter<kDGLCPU, int64_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCPU, int32_t, double>( template void GatherMMScatter<kDGLCPU, int32_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCPU, int64_t, double>( template void GatherMMScatter<kDGLCPU, int64_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_b, const NDArray idx_c);
template void SegmentMM<kDGLCPU, int32_t, float>( template void SegmentMM<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
const NDArray seglen_A, bool a_trans, bool b_trans); bool a_trans, bool b_trans);
template void SegmentMM<kDGLCPU, int64_t, float>( template void SegmentMM<kDGLCPU, int64_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
const NDArray seglen_A, bool a_trans, bool b_trans); bool a_trans, bool b_trans);
template void SegmentMM<kDGLCPU, int32_t, double>( template void SegmentMM<kDGLCPU, int32_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
const NDArray seglen_A, bool a_trans, bool b_trans); bool a_trans, bool b_trans);
template void SegmentMM<kDGLCPU, int64_t, double>( template void SegmentMM<kDGLCPU, int64_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
const NDArray seglen_A, bool a_trans, bool b_trans); bool a_trans, bool b_trans);
template void SegmentMMBackwardB<kDGLCPU, int32_t, float>( template void SegmentMMBackwardB<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
......
...@@ -7,13 +7,14 @@ ...@@ -7,13 +7,14 @@
#define DGL_ARRAY_CPU_ROWWISE_PICK_H_ #define DGL_ARRAY_CPU_ROWWISE_PICK_H_
#include <dgl/array.h> #include <dgl/array.h>
#include <dmlc/omp.h>
#include <dgl/runtime/parallel_for.h> #include <dgl/runtime/parallel_for.h>
#include <functional> #include <dmlc/omp.h>
#include <algorithm> #include <algorithm>
#include <functional>
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory>
namespace dgl { namespace dgl {
namespace aten { namespace aten {
...@@ -41,10 +42,10 @@ namespace impl { ...@@ -41,10 +42,10 @@ namespace impl {
template <typename IdxType> template <typename IdxType>
using PickFn = std::function<void( using PickFn = std::function<void(
IdxType rowid, IdxType off, IdxType len, IdxType num_picks, IdxType rowid, IdxType off, IdxType len, IdxType num_picks,
const IdxType* col, const IdxType* data, const IdxType* col, const IdxType* data, IdxType* out_idx)>;
IdxType* out_idx)>;
// User-defined function for determining the number of elements to pick from one row. // User-defined function for determining the number of elements to pick from one
// row.
// //
// The column indices of the given row are stored in // The column indices of the given row are stored in
// [col + off, col + off + len) // [col + off, col + off + len)
...@@ -63,8 +64,8 @@ using PickFn = std::function<void( ...@@ -63,8 +64,8 @@ using PickFn = std::function<void(
// \param data Pointer of the data indices. // \param data Pointer of the data indices.
template <typename IdxType> template <typename IdxType>
using NumPicksFn = std::function<IdxType( using NumPicksFn = std::function<IdxType(
IdxType rowid, IdxType off, IdxType len, IdxType rowid, IdxType off, IdxType len, const IdxType* col,
const IdxType* col, const IdxType* data)>; const IdxType* data)>;
// User-defined function for picking elements from a range within a row. // User-defined function for picking elements from a range within a row.
// //
...@@ -73,7 +74,8 @@ using NumPicksFn = std::function<IdxType( ...@@ -73,7 +74,8 @@ using NumPicksFn = std::function<IdxType(
// //
// Similarly, the data indices are stored in // Similarly, the data indices are stored in
// data[off+et_idx[et_offset+i])] // data[off+et_idx[et_offset+i])]
// Data index pointer could be NULL, which means data[i] == off+et_idx[et_offset+i]) // Data index pointer could be NULL, which means data[i] ==
// off+et_idx[et_offset+i])
// //
// *ATTENTION*: This function will be invoked concurrently. Please make sure // *ATTENTION*: This function will be invoked concurrently. Please make sure
// it is thread-safe. // it is thread-safe.
...@@ -93,15 +95,17 @@ using EtypeRangePickFn = std::function<void( ...@@ -93,15 +95,17 @@ using EtypeRangePickFn = std::function<void(
const IdxType* eid, IdxType* out_idx)>; const IdxType* eid, IdxType* out_idx)>;
// Template for picking non-zero values row-wise. The implementation utilizes // Template for picking non-zero values row-wise. The implementation utilizes
// OpenMP parallelization on rows because each row performs computation independently. // OpenMP parallelization on rows because each row performs computation
// independently.
template <typename IdxType> template <typename IdxType>
COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows, COOMatrix CSRRowWisePick(
int64_t num_picks, bool replace, PickFn<IdxType> pick_fn, CSRMatrix mat, IdArray rows, int64_t num_picks, bool replace,
NumPicksFn<IdxType> num_picks_fn) { PickFn<IdxType> pick_fn, NumPicksFn<IdxType> num_picks_fn) {
using namespace aten; using namespace aten;
const IdxType* indptr = static_cast<IdxType*>(mat.indptr->data); const IdxType* indptr = static_cast<IdxType*>(mat.indptr->data);
const IdxType* indices = static_cast<IdxType*>(mat.indices->data); const IdxType* indices = static_cast<IdxType*>(mat.indices->data);
const IdxType* data = CSRHasData(mat)? static_cast<IdxType*>(mat.data->data) : nullptr; const IdxType* data =
CSRHasData(mat) ? static_cast<IdxType*>(mat.data->data) : nullptr;
const IdxType* rows_data = static_cast<IdxType*>(rows->data); const IdxType* rows_data = static_cast<IdxType*>(rows->data);
const int64_t num_rows = rows->shape[0]; const int64_t num_rows = rows->shape[0];
const auto& ctx = mat.indptr->ctx; const auto& ctx = mat.indptr->ctx;
...@@ -115,30 +119,36 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows, ...@@ -115,30 +119,36 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
// have at least num_picks number of nnz when replace is false. // have at least num_picks number of nnz when replace is false.
// //
// If the check holds, remove -1 elements by remove_if operation, which simply // If the check holds, remove -1 elements by remove_if operation, which simply
// moves valid elements to the head of arrays and create a view of the original // moves valid elements to the head of arrays and create a view of the
// array. The implementation consumes a little extra memory than the actual requirement. // original array. The implementation consumes a little extra memory than the
// actual requirement.
// //
// Otherwise, directly use the row and col arrays to construct the result COO matrix. // Otherwise, directly use the row and col arrays to construct the result COO
// matrix.
// //
// [02/29/2020 update]: OMP is disabled for now since batch-wise parallelism is more // [02/29/2020 update]: OMP is disabled for now since batch-wise parallelism
// is more
// significant. (minjie) // significant. (minjie)
// Do not use omp_get_max_threads() since that doesn't work for compiling without OpenMP. // Do not use omp_get_max_threads() since that doesn't work for compiling
// without OpenMP.
const int num_threads = runtime::compute_num_threads(0, num_rows, 1); const int num_threads = runtime::compute_num_threads(0, num_rows, 1);
std::vector<int64_t> global_prefix(num_threads + 1, 0); std::vector<int64_t> global_prefix(num_threads + 1, 0);
// TODO(BarclayII) Using OMP parallel directly instead of using runtime::parallel_for // TODO(BarclayII) Using OMP parallel directly instead of using
// does not handle exceptions well (directly aborts when an exception pops up). // runtime::parallel_for does not handle exceptions well (directly aborts when
// It runs faster though because there is less scheduling. Need to handle // an exception pops up). It runs faster though because there is less
// exceptions better. // scheduling. Need to handle exceptions better.
IdArray picked_row, picked_col, picked_idx; IdArray picked_row, picked_col, picked_idx;
#pragma omp parallel num_threads(num_threads) #pragma omp parallel num_threads(num_threads)
{ {
const int thread_id = omp_get_thread_num(); const int thread_id = omp_get_thread_num();
const int64_t start_i = thread_id * (num_rows/num_threads) + const int64_t start_i =
thread_id * (num_rows / num_threads) +
std::min(static_cast<int64_t>(thread_id), num_rows % num_threads); std::min(static_cast<int64_t>(thread_id), num_rows % num_threads);
const int64_t end_i = (thread_id + 1) * (num_rows/num_threads) + const int64_t end_i =
(thread_id + 1) * (num_rows / num_threads) +
std::min(static_cast<int64_t>(thread_id + 1), num_rows % num_threads); std::min(static_cast<int64_t>(thread_id + 1), num_rows % num_threads);
assert(thread_id + 1 < num_threads || end_i == num_rows); assert(thread_id + 1 < num_threads || end_i == num_rows);
...@@ -149,7 +159,7 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows, ...@@ -149,7 +159,7 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
local_prefix[0] = 0; local_prefix[0] = 0;
for (int64_t i = start_i; i < end_i; ++i) { for (int64_t i = start_i; i < end_i; ++i) {
// build prefix-sum // build prefix-sum
const int64_t local_i = i-start_i; const int64_t local_i = i - start_i;
const IdxType rid = rows_data[i]; const IdxType rid = rows_data[i];
IdxType len = num_picks_fn( IdxType len = num_picks_fn(
rid, indptr[rid], indptr[rid + 1] - indptr[rid], indices, data); rid, indptr[rid], indptr[rid + 1] - indptr[rid], indices, data);
...@@ -157,8 +167,8 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows, ...@@ -157,8 +167,8 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
} }
global_prefix[thread_id + 1] = local_prefix[num_local]; global_prefix[thread_id + 1] = local_prefix[num_local];
#pragma omp barrier #pragma omp barrier
#pragma omp master #pragma omp master
{ {
for (int t = 0; t < num_threads; ++t) { for (int t = 0; t < num_threads; ++t) {
global_prefix[t + 1] += global_prefix[t]; global_prefix[t + 1] += global_prefix[t];
...@@ -168,7 +178,7 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows, ...@@ -168,7 +178,7 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
picked_idx = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx); picked_idx = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);
} }
#pragma omp barrier #pragma omp barrier
IdxType* picked_rdata = picked_row.Ptr<IdxType>(); IdxType* picked_rdata = picked_row.Ptr<IdxType>();
IdxType* picked_cdata = picked_col.Ptr<IdxType>(); IdxType* picked_cdata = picked_col.Ptr<IdxType>();
IdxType* picked_idata = picked_idx.Ptr<IdxType>(); IdxType* picked_idata = picked_idx.Ptr<IdxType>();
...@@ -180,14 +190,15 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows, ...@@ -180,14 +190,15 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
const IdxType off = indptr[rid]; const IdxType off = indptr[rid];
const IdxType len = indptr[rid + 1] - off; const IdxType len = indptr[rid + 1] - off;
if (len == 0) if (len == 0) continue;
continue;
const int64_t local_i = i - start_i; const int64_t local_i = i - start_i;
const int64_t row_offset = thread_offset + local_prefix[local_i]; const int64_t row_offset = thread_offset + local_prefix[local_i];
const int64_t num_picks = thread_offset + local_prefix[local_i + 1] - row_offset; const int64_t num_picks =
thread_offset + local_prefix[local_i + 1] - row_offset;
pick_fn(rid, off, len, num_picks, indices, data, picked_idata + row_offset); pick_fn(
rid, off, len, num_picks, indices, data, picked_idata + row_offset);
for (int64_t j = 0; j < num_picks; ++j) { for (int64_t j = 0; j < num_picks; ++j) {
const IdxType picked = picked_idata[row_offset + j]; const IdxType picked = picked_idata[row_offset + j];
picked_rdata[row_offset + j] = rid; picked_rdata[row_offset + j] = rid;
...@@ -200,25 +211,25 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows, ...@@ -200,25 +211,25 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
const int64_t new_len = global_prefix.back(); const int64_t new_len = global_prefix.back();
return COOMatrix( return COOMatrix(
mat.num_rows, mat.num_rows, mat.num_cols,
mat.num_cols,
picked_row.CreateView({new_len}, picked_row->dtype), picked_row.CreateView({new_len}, picked_row->dtype),
picked_col.CreateView({new_len}, picked_row->dtype), picked_col.CreateView({new_len}, picked_row->dtype),
picked_idx.CreateView({new_len}, picked_row->dtype)); picked_idx.CreateView({new_len}, picked_row->dtype));
} }
// Template for picking non-zero values row-wise. The implementation utilizes // Template for picking non-zero values row-wise. The implementation utilizes
// OpenMP parallelization on rows because each row performs computation independently. // OpenMP parallelization on rows because each row performs computation
// independently.
template <typename IdxType, typename DType> template <typename IdxType, typename DType>
COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, COOMatrix CSRRowWisePerEtypePick(
const std::vector<int64_t>& eid2etype_offset, CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_picks, bool replace, const std::vector<int64_t>& num_picks, bool replace,
bool rowwise_etype_sorted, EtypeRangePickFn<IdxType> pick_fn, bool rowwise_etype_sorted, EtypeRangePickFn<IdxType> pick_fn,
const std::vector<NDArray>& prob_or_mask) { const std::vector<NDArray>& prob_or_mask) {
using namespace aten; using namespace aten;
const IdxType* indptr = mat.indptr.Ptr<IdxType>(); const IdxType* indptr = mat.indptr.Ptr<IdxType>();
const IdxType* indices = mat.indices.Ptr<IdxType>(); const IdxType* indices = mat.indices.Ptr<IdxType>();
const IdxType* eid = CSRHasData(mat)? mat.data.Ptr<IdxType>() : nullptr; const IdxType* eid = CSRHasData(mat) ? mat.data.Ptr<IdxType>() : nullptr;
const IdxType* rows_data = rows.Ptr<IdxType>(); const IdxType* rows_data = rows.Ptr<IdxType>();
const int64_t num_rows = rows->shape[0]; const int64_t num_rows = rows->shape[0];
const auto& ctx = mat.indptr->ctx; const auto& ctx = mat.indptr->ctx;
...@@ -229,8 +240,8 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, ...@@ -229,8 +240,8 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows,
std::vector<IdArray> picked_idxs(rows->shape[0]); std::vector<IdArray> picked_idxs(rows->shape[0]);
// Check if the number of picks have the same value. // Check if the number of picks have the same value.
// If so, we can potentially speed up if we have a node with total number of neighbors // If so, we can potentially speed up if we have a node with total number of
// less than the given number of picks with replace=False. // neighbors less than the given number of picks with replace=False.
bool same_num_pick = true; bool same_num_pick = true;
int64_t num_pick_value = num_picks[0]; int64_t num_pick_value = num_picks[0];
for (int64_t num_pick : num_picks) { for (int64_t num_pick : num_picks) {
...@@ -267,9 +278,10 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, ...@@ -267,9 +278,10 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows,
for (int64_t j = 0; j < len; ++j) { for (int64_t j = 0; j < len; ++j) {
const IdxType homogenized_eid = eid ? eid[off + j] : off + j; const IdxType homogenized_eid = eid ? eid[off + j] : off + j;
auto it = std::upper_bound( auto it = std::upper_bound(
eid2etype_offset.begin(), eid2etype_offset.end(), homogenized_eid); eid2etype_offset.begin(), eid2etype_offset.end(),
homogenized_eid);
const IdxType heterogenized_etype = it - eid2etype_offset.begin() - 1; const IdxType heterogenized_etype = it - eid2etype_offset.begin() - 1;
const IdxType heterogenized_eid = \ const IdxType heterogenized_eid =
homogenized_eid - eid2etype_offset[heterogenized_etype]; homogenized_eid - eid2etype_offset[heterogenized_etype];
if (!has_probs || IsNullArray(prob_or_mask[heterogenized_etype])) { if (!has_probs || IsNullArray(prob_or_mask[heterogenized_etype])) {
...@@ -305,17 +317,21 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, ...@@ -305,17 +317,21 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows,
for (int64_t j = 0; j < len; ++j) { for (int64_t j = 0; j < len; ++j) {
const IdxType homogenized_eid = eid ? eid[off + j] : off + j; const IdxType homogenized_eid = eid ? eid[off + j] : off + j;
auto it = std::upper_bound( auto it = std::upper_bound(
eid2etype_offset.begin(), eid2etype_offset.end(), homogenized_eid); eid2etype_offset.begin(), eid2etype_offset.end(),
homogenized_eid);
const IdxType heterogenized_etype = it - eid2etype_offset.begin() - 1; const IdxType heterogenized_etype = it - eid2etype_offset.begin() - 1;
const IdxType heterogenized_eid = \ const IdxType heterogenized_eid =
homogenized_eid - eid2etype_offset[heterogenized_etype]; homogenized_eid - eid2etype_offset[heterogenized_etype];
et[j] = heterogenized_etype; et[j] = heterogenized_etype;
et_eid[j] = heterogenized_eid; et_eid[j] = heterogenized_eid;
} }
if (!rowwise_etype_sorted) // the edge type is sorted, not need to sort it if (!rowwise_etype_sorted) // the edge type is sorted, not need to sort
std::sort(et_idx.begin(), et_idx.end(), // it
[&et](IdxType i1, IdxType i2) {return et[i1] < et[i2];}); std::sort(
CHECK_LT(et[et_idx[len - 1]], num_etypes) << "etype values exceed the number of fanouts"; et_idx.begin(), et_idx.end(),
[&et](IdxType i1, IdxType i2) { return et[i1] < et[i2]; });
CHECK_LT(et[et_idx[len - 1]], num_etypes)
<< "etype values exceed the number of fanouts";
IdxType cur_et = et[et_idx[0]]; IdxType cur_et = et[et_idx[0]];
int64_t et_offset = 0; int64_t et_offset = 0;
...@@ -333,14 +349,18 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, ...@@ -333,14 +349,18 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows,
// fast path, select all // fast path, select all
for (int64_t k = 0; k < et_len; ++k) { for (int64_t k = 0; k < et_len; ++k) {
const IdxType eid_offset = off + et_idx[et_offset + k]; const IdxType eid_offset = off + et_idx[et_offset + k];
const IdxType homogenized_eid = eid ? eid[eid_offset] : eid_offset; const IdxType homogenized_eid =
eid ? eid[eid_offset] : eid_offset;
auto it = std::upper_bound( auto it = std::upper_bound(
eid2etype_offset.begin(), eid2etype_offset.end(), homogenized_eid); eid2etype_offset.begin(), eid2etype_offset.end(),
const IdxType heterogenized_etype = it - eid2etype_offset.begin() - 1; homogenized_eid);
const IdxType heterogenized_eid = \ const IdxType heterogenized_etype =
it - eid2etype_offset.begin() - 1;
const IdxType heterogenized_eid =
homogenized_eid - eid2etype_offset[heterogenized_etype]; homogenized_eid - eid2etype_offset[heterogenized_etype];
if (!has_probs || IsNullArray(prob_or_mask[heterogenized_etype])) { if (!has_probs ||
IsNullArray(prob_or_mask[heterogenized_etype])) {
// No probability, select all // No probability, select all
rows.push_back(rid); rows.push_back(rid);
cols.push_back(indices[eid_offset]); cols.push_back(indices[eid_offset]);
...@@ -357,32 +377,31 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, ...@@ -357,32 +377,31 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows,
} }
} }
} else { } else {
IdArray picked_idx = Full(-1, num_picks[cur_et], sizeof(IdxType) * 8, ctx); IdArray picked_idx =
Full(-1, num_picks[cur_et], sizeof(IdxType) * 8, ctx);
IdxType* picked_idata = picked_idx.Ptr<IdxType>(); IdxType* picked_idata = picked_idx.Ptr<IdxType>();
// need call random pick // need call random pick
pick_fn(off, et_offset, cur_et, pick_fn(
et_len, et_idx, et_eid, off, et_offset, cur_et, et_len, et_idx, et_eid, eid,
eid, picked_idata); picked_idata);
for (int64_t k = 0; k < num_picks[cur_et]; ++k) { for (int64_t k = 0; k < num_picks[cur_et]; ++k) {
const IdxType picked = picked_idata[k]; const IdxType picked = picked_idata[k];
if (picked == -1) if (picked == -1) continue;
continue;
rows.push_back(rid); rows.push_back(rid);
cols.push_back(indices[off+et_idx[et_offset+picked]]); cols.push_back(indices[off + et_idx[et_offset + picked]]);
if (eid) { if (eid) {
idx.push_back(eid[off+et_idx[et_offset+picked]]); idx.push_back(eid[off + et_idx[et_offset + picked]]);
} else { } else {
idx.push_back(off+et_idx[et_offset+picked]); idx.push_back(off + et_idx[et_offset + picked]);
} }
} }
} }
if (j+1 == len) if (j + 1 == len) break;
break;
// next etype // next etype
cur_et = et[et_idx[j+1]]; cur_et = et[et_idx[j + 1]];
et_offset = j+1; et_offset = j + 1;
et_len = 1; et_len = 1;
} else { } else {
et_len++; et_len++;
...@@ -402,31 +421,32 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, ...@@ -402,31 +421,32 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows,
IdArray picked_row = Concat(picked_rows); IdArray picked_row = Concat(picked_rows);
IdArray picked_col = Concat(picked_cols); IdArray picked_col = Concat(picked_cols);
IdArray picked_idx = Concat(picked_idxs); IdArray picked_idx = Concat(picked_idxs);
return COOMatrix(mat.num_rows, mat.num_cols, return COOMatrix(
picked_row, picked_col, picked_idx); mat.num_rows, mat.num_cols, picked_row, picked_col, picked_idx);
} }
// Template for picking non-zero values row-wise. The implementation first slices // Template for picking non-zero values row-wise. The implementation first
// out the corresponding rows and then converts it to CSR format. It then performs // slices out the corresponding rows and then converts it to CSR format. It then
// row-wise pick on the CSR matrix and rectifies the returned results. // performs row-wise pick on the CSR matrix and rectifies the returned results.
template <typename IdxType> template <typename IdxType>
COOMatrix COORowWisePick(COOMatrix mat, IdArray rows, COOMatrix COORowWisePick(
int64_t num_picks, bool replace, PickFn<IdxType> pick_fn, COOMatrix mat, IdArray rows, int64_t num_picks, bool replace,
NumPicksFn<IdxType> num_picks_fn) { PickFn<IdxType> pick_fn, NumPicksFn<IdxType> num_picks_fn) {
using namespace aten; using namespace aten;
const auto& csr = COOToCSR(COOSliceRows(mat, rows)); const auto& csr = COOToCSR(COOSliceRows(mat, rows));
const IdArray new_rows = Range(0, rows->shape[0], rows->dtype.bits, rows->ctx); const IdArray new_rows =
Range(0, rows->shape[0], rows->dtype.bits, rows->ctx);
const auto& picked = CSRRowWisePick<IdxType>( const auto& picked = CSRRowWisePick<IdxType>(
csr, new_rows, num_picks, replace, pick_fn, num_picks_fn); csr, new_rows, num_picks, replace, pick_fn, num_picks_fn);
return COOMatrix(mat.num_rows, mat.num_cols, return COOMatrix(
IndexSelect(rows, picked.row), // map the row index to the correct one mat.num_rows, mat.num_cols,
picked.col, IndexSelect(rows, picked.row), // map the row index to the correct one
picked.data); picked.col, picked.data);
} }
// Template for picking non-zero values row-wise. The implementation first slices // Template for picking non-zero values row-wise. The implementation first
// out the corresponding rows and then converts it to CSR format. It then performs // slices out the corresponding rows and then converts it to CSR format. It then
// row-wise pick on the CSR matrix and rectifies the returned results. // performs row-wise pick on the CSR matrix and rectifies the returned results.
template <typename IdxType, typename DType> template <typename IdxType, typename DType>
COOMatrix COORowWisePerEtypePick( COOMatrix COORowWisePerEtypePick(
COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset, COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
...@@ -435,13 +455,15 @@ COOMatrix COORowWisePerEtypePick( ...@@ -435,13 +455,15 @@ COOMatrix COORowWisePerEtypePick(
const std::vector<NDArray>& prob_or_mask) { const std::vector<NDArray>& prob_or_mask) {
using namespace aten; using namespace aten;
const auto& csr = COOToCSR(COOSliceRows(mat, rows)); const auto& csr = COOToCSR(COOSliceRows(mat, rows));
const IdArray new_rows = Range(0, rows->shape[0], rows->dtype.bits, rows->ctx); const IdArray new_rows =
Range(0, rows->shape[0], rows->dtype.bits, rows->ctx);
const auto& picked = CSRRowWisePerEtypePick<IdxType, DType>( const auto& picked = CSRRowWisePerEtypePick<IdxType, DType>(
csr, new_rows, eid2etype_offset, num_picks, replace, false, pick_fn, prob_or_mask); csr, new_rows, eid2etype_offset, num_picks, replace, false, pick_fn,
return COOMatrix(mat.num_rows, mat.num_cols, prob_or_mask);
IndexSelect(rows, picked.row), // map the row index to the correct one return COOMatrix(
picked.col, mat.num_rows, mat.num_cols,
picked.data); IndexSelect(rows, picked.row), // map the row index to the correct one
picked.col, picked.data);
} }
} // namespace impl } // namespace impl
......
...@@ -4,7 +4,9 @@ ...@@ -4,7 +4,9 @@
* @brief rowwise sampling * @brief rowwise sampling
*/ */
#include <dgl/random.h> #include <dgl/random.h>
#include <numeric> #include <numeric>
#include "./rowwise_pick.h" #include "./rowwise_pick.h"
namespace dgl { namespace dgl {
...@@ -13,8 +15,8 @@ namespace impl { ...@@ -13,8 +15,8 @@ namespace impl {
namespace { namespace {
// Equivalent to numpy expression: array[idx[off:off + len]] // Equivalent to numpy expression: array[idx[off:off + len]]
template <typename IdxType, typename FloatType> template <typename IdxType, typename FloatType>
inline FloatArray DoubleSlice(FloatArray array, const IdxType* idx_data, inline FloatArray DoubleSlice(
IdxType off, IdxType len) { FloatArray array, const IdxType* idx_data, IdxType off, IdxType len) {
const FloatType* array_data = static_cast<FloatType*>(array->data); const FloatType* array_data = static_cast<FloatType*>(array->data);
FloatArray ret = FloatArray::Empty({len}, array->dtype, array->ctx); FloatArray ret = FloatArray::Empty({len}, array->dtype, array->ctx);
FloatType* ret_data = static_cast<FloatType*>(ret->data); FloatType* ret_data = static_cast<FloatType*>(ret->data);
...@@ -30,42 +32,44 @@ inline FloatArray DoubleSlice(FloatArray array, const IdxType* idx_data, ...@@ -30,42 +32,44 @@ inline FloatArray DoubleSlice(FloatArray array, const IdxType* idx_data,
template <typename IdxType, typename DType> template <typename IdxType, typename DType>
inline NumPicksFn<IdxType> GetSamplingNumPicksFn( inline NumPicksFn<IdxType> GetSamplingNumPicksFn(
int64_t num_samples, NDArray prob_or_mask, bool replace) { int64_t num_samples, NDArray prob_or_mask, bool replace) {
NumPicksFn<IdxType> num_picks_fn = [prob_or_mask, num_samples, replace] NumPicksFn<IdxType> num_picks_fn = [prob_or_mask, num_samples, replace](
(IdxType rowid, IdxType off, IdxType len, IdxType rowid, IdxType off,
const IdxType* col, const IdxType* data) { IdxType len, const IdxType* col,
const int64_t max_num_picks = (num_samples == -1) ? len : num_samples; const IdxType* data) {
const DType* prob_or_mask_data = prob_or_mask.Ptr<DType>(); const int64_t max_num_picks = (num_samples == -1) ? len : num_samples;
IdxType nnz = 0; const DType* prob_or_mask_data = prob_or_mask.Ptr<DType>();
for (IdxType i = off; i < off + len; ++i) { IdxType nnz = 0;
const IdxType eid = data ? data[i] : i; for (IdxType i = off; i < off + len; ++i) {
if (prob_or_mask_data[eid] > 0) { const IdxType eid = data ? data[i] : i;
++nnz; if (prob_or_mask_data[eid] > 0) {
} ++nnz;
} }
}
if (replace) { if (replace) {
return static_cast<IdxType>(nnz == 0 ? 0 : max_num_picks); return static_cast<IdxType>(nnz == 0 ? 0 : max_num_picks);
} else { } else {
return std::min(static_cast<IdxType>(max_num_picks), nnz); return std::min(static_cast<IdxType>(max_num_picks), nnz);
} }
}; };
return num_picks_fn; return num_picks_fn;
} }
template <typename IdxType, typename DType> template <typename IdxType, typename DType>
inline PickFn<IdxType> GetSamplingPickFn( inline PickFn<IdxType> GetSamplingPickFn(
int64_t num_samples, NDArray prob_or_mask, bool replace) { int64_t num_samples, NDArray prob_or_mask, bool replace) {
PickFn<IdxType> pick_fn = [prob_or_mask, num_samples, replace] PickFn<IdxType> pick_fn = [prob_or_mask, num_samples, replace](
(IdxType rowid, IdxType off, IdxType len, IdxType num_picks, IdxType rowid, IdxType off, IdxType len,
const IdxType* col, const IdxType* data, IdxType num_picks, const IdxType* col,
IdxType* out_idx) { const IdxType* data, IdxType* out_idx) {
NDArray prob_or_mask_selected = DoubleSlice<IdxType, DType>(prob_or_mask, data, off, len); NDArray prob_or_mask_selected =
RandomEngine::ThreadLocal()->Choice<IdxType, DType>( DoubleSlice<IdxType, DType>(prob_or_mask, data, off, len);
num_picks, prob_or_mask_selected, out_idx, replace); RandomEngine::ThreadLocal()->Choice<IdxType, DType>(
for (int64_t j = 0; j < num_picks; ++j) { num_picks, prob_or_mask_selected, out_idx, replace);
out_idx[j] += off; for (int64_t j = 0; j < num_picks; ++j) {
} out_idx[j] += off;
}; }
};
return pick_fn; return pick_fn;
} }
...@@ -73,108 +77,112 @@ template <typename IdxType, typename FloatType> ...@@ -73,108 +77,112 @@ template <typename IdxType, typename FloatType>
inline EtypeRangePickFn<IdxType> GetSamplingRangePickFn( inline EtypeRangePickFn<IdxType> GetSamplingRangePickFn(
const std::vector<int64_t>& num_samples, const std::vector<int64_t>& num_samples,
const std::vector<FloatArray>& prob, bool replace) { const std::vector<FloatArray>& prob, bool replace) {
EtypeRangePickFn<IdxType> pick_fn = [prob, num_samples, replace] EtypeRangePickFn<IdxType> pick_fn =
(IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len, [prob, num_samples, replace](
const std::vector<IdxType> &et_idx, IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,
const std::vector<IdxType> &et_eid, const std::vector<IdxType>& et_idx,
const IdxType* eid, IdxType* out_idx) { const std::vector<IdxType>& et_eid, const IdxType* eid,
const FloatArray& p = prob[cur_et]; IdxType* out_idx) {
const FloatType* p_data = IsNullArray(p) ? nullptr : p.Ptr<FloatType>(); const FloatArray& p = prob[cur_et];
FloatArray probs = FloatArray::Empty({et_len}, p->dtype, p->ctx); const FloatType* p_data = IsNullArray(p) ? nullptr : p.Ptr<FloatType>();
FloatType* probs_data = probs.Ptr<FloatType>(); FloatArray probs = FloatArray::Empty({et_len}, p->dtype, p->ctx);
for (int64_t j = 0; j < et_len; ++j) { FloatType* probs_data = probs.Ptr<FloatType>();
const IdxType cur_eid = et_eid[et_idx[et_offset + j]]; for (int64_t j = 0; j < et_len; ++j) {
probs_data[j] = p_data ? p_data[cur_eid] : static_cast<FloatType>(1.); const IdxType cur_eid = et_eid[et_idx[et_offset + j]];
} probs_data[j] = p_data ? p_data[cur_eid] : static_cast<FloatType>(1.);
}
RandomEngine::ThreadLocal()->Choice<IdxType, FloatType>( RandomEngine::ThreadLocal()->Choice<IdxType, FloatType>(
num_samples[cur_et], probs, out_idx, replace); num_samples[cur_et], probs, out_idx, replace);
}; };
return pick_fn; return pick_fn;
} }
template <typename IdxType> template <typename IdxType>
inline NumPicksFn<IdxType> GetSamplingUniformNumPicksFn( inline NumPicksFn<IdxType> GetSamplingUniformNumPicksFn(
int64_t num_samples, bool replace) { int64_t num_samples, bool replace) {
NumPicksFn<IdxType> num_picks_fn = [num_samples, replace] NumPicksFn<IdxType> num_picks_fn = [num_samples, replace](
(IdxType rowid, IdxType off, IdxType len, IdxType rowid, IdxType off,
const IdxType* col, const IdxType* data) { IdxType len, const IdxType* col,
const int64_t max_num_picks = (num_samples == -1) ? len : num_samples; const IdxType* data) {
if (replace) { const int64_t max_num_picks = (num_samples == -1) ? len : num_samples;
return static_cast<IdxType>(len == 0 ? 0 : max_num_picks); if (replace) {
} else { return static_cast<IdxType>(len == 0 ? 0 : max_num_picks);
return std::min(static_cast<IdxType>(max_num_picks), len); } else {
} return std::min(static_cast<IdxType>(max_num_picks), len);
}; }
};
return num_picks_fn; return num_picks_fn;
} }
template <typename IdxType> template <typename IdxType>
inline PickFn<IdxType> GetSamplingUniformPickFn( inline PickFn<IdxType> GetSamplingUniformPickFn(
int64_t num_samples, bool replace) { int64_t num_samples, bool replace) {
PickFn<IdxType> pick_fn = [num_samples, replace] PickFn<IdxType> pick_fn = [num_samples, replace](
(IdxType rowid, IdxType off, IdxType len, IdxType num_picks, IdxType rowid, IdxType off, IdxType len,
const IdxType* col, const IdxType* data, IdxType num_picks, const IdxType* col,
IdxType* out_idx) { const IdxType* data, IdxType* out_idx) {
RandomEngine::ThreadLocal()->UniformChoice<IdxType>( RandomEngine::ThreadLocal()->UniformChoice<IdxType>(
num_picks, len, out_idx, replace); num_picks, len, out_idx, replace);
for (int64_t j = 0; j < num_picks; ++j) { for (int64_t j = 0; j < num_picks; ++j) {
out_idx[j] += off; out_idx[j] += off;
} }
}; };
return pick_fn; return pick_fn;
} }
template <typename IdxType> template <typename IdxType>
inline EtypeRangePickFn<IdxType> GetSamplingUniformRangePickFn( inline EtypeRangePickFn<IdxType> GetSamplingUniformRangePickFn(
const std::vector<int64_t>& num_samples, bool replace) { const std::vector<int64_t>& num_samples, bool replace) {
EtypeRangePickFn<IdxType> pick_fn = [num_samples, replace] EtypeRangePickFn<IdxType> pick_fn =
(IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len, [num_samples, replace](
const std::vector<IdxType> &et_idx, IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,
const std::vector<IdxType> &et_eid, const std::vector<IdxType>& et_idx,
const IdxType* data, IdxType* out_idx) { const std::vector<IdxType>& et_eid, const IdxType* data,
RandomEngine::ThreadLocal()->UniformChoice<IdxType>( IdxType* out_idx) {
num_samples[cur_et], et_len, out_idx, replace); RandomEngine::ThreadLocal()->UniformChoice<IdxType>(
}; num_samples[cur_et], et_len, out_idx, replace);
};
return pick_fn; return pick_fn;
} }
template <typename IdxType, typename FloatType> template <typename IdxType, typename FloatType>
inline NumPicksFn<IdxType> GetSamplingBiasedNumPicksFn( inline NumPicksFn<IdxType> GetSamplingBiasedNumPicksFn(
int64_t num_samples, IdArray split, FloatArray bias, bool replace) { int64_t num_samples, IdArray split, FloatArray bias, bool replace) {
NumPicksFn<IdxType> num_picks_fn = [num_samples, split, bias, replace] NumPicksFn<IdxType> num_picks_fn = [num_samples, split, bias, replace](
(IdxType rowid, IdxType off, IdxType len, IdxType rowid, IdxType off,
const IdxType* col, const IdxType* data) { IdxType len, const IdxType* col,
const int64_t max_num_picks = (num_samples == -1) ? len : num_samples; const IdxType* data) {
const int64_t num_tags = split->shape[1] - 1; const int64_t max_num_picks = (num_samples == -1) ? len : num_samples;
const IdxType* tag_offset = split.Ptr<IdxType>() + rowid * split->shape[1]; const int64_t num_tags = split->shape[1] - 1;
const FloatType* bias_data = bias.Ptr<FloatType>(); const IdxType* tag_offset = split.Ptr<IdxType>() + rowid * split->shape[1];
IdxType nnz = 0; const FloatType* bias_data = bias.Ptr<FloatType>();
for (int64_t j = 0; j < num_tags; ++j) { IdxType nnz = 0;
if (bias_data[j] > 0) { for (int64_t j = 0; j < num_tags; ++j) {
nnz += tag_offset[j + 1] - tag_offset[j]; if (bias_data[j] > 0) {
} nnz += tag_offset[j + 1] - tag_offset[j];
} }
}
if (replace) { if (replace) {
return static_cast<IdxType>(nnz == 0 ? 0 : max_num_picks); return static_cast<IdxType>(nnz == 0 ? 0 : max_num_picks);
} else { } else {
return std::min(static_cast<IdxType>(max_num_picks), nnz); return std::min(static_cast<IdxType>(max_num_picks), nnz);
} }
}; };
return num_picks_fn; return num_picks_fn;
} }
template <typename IdxType, typename FloatType> template <typename IdxType, typename FloatType>
inline PickFn<IdxType> GetSamplingBiasedPickFn( inline PickFn<IdxType> GetSamplingBiasedPickFn(
int64_t num_samples, IdArray split, FloatArray bias, bool replace) { int64_t num_samples, IdArray split, FloatArray bias, bool replace) {
PickFn<IdxType> pick_fn = [num_samples, split, bias, replace] PickFn<IdxType> pick_fn = [num_samples, split, bias, replace](
(IdxType rowid, IdxType off, IdxType len, IdxType num_picks, IdxType rowid, IdxType off, IdxType len,
const IdxType* col, const IdxType* data, IdxType num_picks, const IdxType* col,
IdxType* out_idx) { const IdxType* data, IdxType* out_idx) {
const IdxType *tag_offset = split.Ptr<IdxType>() + rowid * split->shape[1]; const IdxType* tag_offset = split.Ptr<IdxType>() + rowid * split->shape[1];
RandomEngine::ThreadLocal()->BiasedChoice<IdxType, FloatType>( RandomEngine::ThreadLocal()->BiasedChoice<IdxType, FloatType>(
num_picks, tag_offset, bias, out_idx, replace); num_picks, tag_offset, bias, out_idx, replace);
for (int64_t j = 0; j < num_picks; ++j) { for (int64_t j = 0; j < num_picks; ++j) {
out_idx[j] += off; out_idx[j] += off;
} }
...@@ -187,15 +195,16 @@ inline PickFn<IdxType> GetSamplingBiasedPickFn( ...@@ -187,15 +195,16 @@ inline PickFn<IdxType> GetSamplingBiasedPickFn(
/////////////////////////////// CSR /////////////////////////////// /////////////////////////////// CSR ///////////////////////////////
template <DGLDeviceType XPU, typename IdxType, typename DType> template <DGLDeviceType XPU, typename IdxType, typename DType>
COOMatrix CSRRowWiseSampling(CSRMatrix mat, IdArray rows, int64_t num_samples, COOMatrix CSRRowWiseSampling(
NDArray prob_or_mask, bool replace) { CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
bool replace) {
// If num_samples is -1, select all neighbors without replacement. // If num_samples is -1, select all neighbors without replacement.
replace = (replace && num_samples != -1); replace = (replace && num_samples != -1);
CHECK(prob_or_mask.defined()); CHECK(prob_or_mask.defined());
auto num_picks_fn = GetSamplingNumPicksFn<IdxType, DType>( auto num_picks_fn =
num_samples, prob_or_mask, replace); GetSamplingNumPicksFn<IdxType, DType>(num_samples, prob_or_mask, replace);
auto pick_fn = GetSamplingPickFn<IdxType, DType>( auto pick_fn =
num_samples, prob_or_mask, replace); GetSamplingPickFn<IdxType, DType>(num_samples, prob_or_mask, replace);
return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn); return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
} }
...@@ -219,49 +228,52 @@ template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, uint8_t>( ...@@ -219,49 +228,52 @@ template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, uint8_t>(
template <DGLDeviceType XPU, typename IdxType, typename DType> template <DGLDeviceType XPU, typename IdxType, typename DType>
COOMatrix CSRRowWisePerEtypeSampling( COOMatrix CSRRowWisePerEtypeSampling(
CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset, CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples, const std::vector<NDArray>& prob_or_mask, const std::vector<int64_t>& num_samples,
bool replace, bool rowwise_etype_sorted) { const std::vector<NDArray>& prob_or_mask, bool replace,
CHECK(prob_or_mask.size() == num_samples.size()) << bool rowwise_etype_sorted) {
"the number of probability tensors does not match the number of edge types."; CHECK(prob_or_mask.size() == num_samples.size())
for (auto& p : prob_or_mask) << "the number of probability tensors does not match the number of edge "
CHECK(p.defined()); "types.";
auto pick_fn = GetSamplingRangePickFn<IdxType, DType>(num_samples, prob_or_mask, replace); for (auto& p : prob_or_mask) CHECK(p.defined());
auto pick_fn = GetSamplingRangePickFn<IdxType, DType>(
num_samples, prob_or_mask, replace);
return CSRRowWisePerEtypePick<IdxType, DType>( return CSRRowWisePerEtypePick<IdxType, DType>(
mat, rows, eid2etype_offset, num_samples, replace, rowwise_etype_sorted, pick_fn, mat, rows, eid2etype_offset, num_samples, replace, rowwise_etype_sorted,
prob_or_mask); pick_fn, prob_or_mask);
} }
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, float>( template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, float>(
CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&, CSRMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool, bool); const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, float>( template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, float>(
CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&, CSRMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool, bool); const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, double>( template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, double>(
CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&, CSRMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool, bool); const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, double>( template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, double>(
CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&, CSRMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool, bool); const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, int8_t>( template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, int8_t>(
CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&, CSRMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool, bool); const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, int8_t>( template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, int8_t>(
CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&, CSRMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool, bool); const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, uint8_t>( template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, uint8_t>(
CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&, CSRMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool, bool); const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, uint8_t>( template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, uint8_t>(
CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&, CSRMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool, bool); const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
template <DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat, IdArray rows, COOMatrix CSRRowWiseSamplingUniform(
int64_t num_samples, bool replace) { CSRMatrix mat, IdArray rows, int64_t num_samples, bool replace) {
// If num_samples is -1, select all neighbors without replacement. // If num_samples is -1, select all neighbors without replacement.
replace = (replace && num_samples != -1); replace = (replace && num_samples != -1);
auto num_picks_fn = GetSamplingUniformNumPicksFn<IdxType>(num_samples, replace); auto num_picks_fn =
GetSamplingUniformNumPicksFn<IdxType>(num_samples, replace);
auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace); auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn); return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
} }
...@@ -274,28 +286,25 @@ template COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int64_t>( ...@@ -274,28 +286,25 @@ template COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int64_t>(
template <DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
COOMatrix CSRRowWisePerEtypeSamplingUniform( COOMatrix CSRRowWisePerEtypeSamplingUniform(
CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset, CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples, bool replace, bool rowwise_etype_sorted) { const std::vector<int64_t>& num_samples, bool replace,
bool rowwise_etype_sorted) {
auto pick_fn = GetSamplingUniformRangePickFn<IdxType>(num_samples, replace); auto pick_fn = GetSamplingUniformRangePickFn<IdxType>(num_samples, replace);
return CSRRowWisePerEtypePick<IdxType, float>( return CSRRowWisePerEtypePick<IdxType, float>(
mat, rows, eid2etype_offset, num_samples, replace, rowwise_etype_sorted, pick_fn, {}); mat, rows, eid2etype_offset, num_samples, replace, rowwise_etype_sorted,
pick_fn, {});
} }
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDGLCPU, int32_t>( template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDGLCPU, int32_t>(
CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&, bool, CSRMatrix, IdArray, const std::vector<int64_t>&,
bool); const std::vector<int64_t>&, bool, bool);
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDGLCPU, int64_t>( template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDGLCPU, int64_t>(
CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&, bool, CSRMatrix, IdArray, const std::vector<int64_t>&,
bool); const std::vector<int64_t>&, bool, bool);
template <DGLDeviceType XPU, typename IdxType, typename FloatType> template <DGLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix CSRRowWiseSamplingBiased( COOMatrix CSRRowWiseSamplingBiased(
CSRMatrix mat, CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray tag_offset,
IdArray rows, FloatArray bias, bool replace) {
int64_t num_samples,
NDArray tag_offset,
FloatArray bias,
bool replace
) {
// If num_samples is -1, select all neighbors without replacement. // If num_samples is -1, select all neighbors without replacement.
replace = (replace && num_samples != -1); replace = (replace && num_samples != -1);
auto num_picks_fn = GetSamplingBiasedNumPicksFn<IdxType, FloatType>( auto num_picks_fn = GetSamplingBiasedNumPicksFn<IdxType, FloatType>(
...@@ -306,30 +315,30 @@ COOMatrix CSRRowWiseSamplingBiased( ...@@ -306,30 +315,30 @@ COOMatrix CSRRowWiseSamplingBiased(
} }
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int32_t, float>( template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int32_t, float>(
CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool); CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int64_t, float>( template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int64_t, float>(
CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool); CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int32_t, double>( template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int32_t, double>(
CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool); CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int64_t, double>( template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int64_t, double>(
CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool); CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
/////////////////////////////// COO /////////////////////////////// /////////////////////////////// COO ///////////////////////////////
template <DGLDeviceType XPU, typename IdxType, typename DType> template <DGLDeviceType XPU, typename IdxType, typename DType>
COOMatrix COORowWiseSampling(COOMatrix mat, IdArray rows, int64_t num_samples, COOMatrix COORowWiseSampling(
NDArray prob_or_mask, bool replace) { COOMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
bool replace) {
// If num_samples is -1, select all neighbors without replacement. // If num_samples is -1, select all neighbors without replacement.
replace = (replace && num_samples != -1); replace = (replace && num_samples != -1);
CHECK(prob_or_mask.defined()); CHECK(prob_or_mask.defined());
auto num_picks_fn = GetSamplingNumPicksFn<IdxType, DType>( auto num_picks_fn =
num_samples, prob_or_mask, replace); GetSamplingNumPicksFn<IdxType, DType>(num_samples, prob_or_mask, replace);
auto pick_fn = GetSamplingPickFn<IdxType, DType>( auto pick_fn =
num_samples, prob_or_mask, replace); GetSamplingPickFn<IdxType, DType>(num_samples, prob_or_mask, replace);
return COORowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn); return COORowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
} }
...@@ -353,48 +362,50 @@ template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, uint8_t>( ...@@ -353,48 +362,50 @@ template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, uint8_t>(
template <DGLDeviceType XPU, typename IdxType, typename DType> template <DGLDeviceType XPU, typename IdxType, typename DType>
COOMatrix COORowWisePerEtypeSampling( COOMatrix COORowWisePerEtypeSampling(
COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset, COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples, const std::vector<NDArray>& prob_or_mask, const std::vector<int64_t>& num_samples,
bool replace) { const std::vector<NDArray>& prob_or_mask, bool replace) {
CHECK(prob_or_mask.size() == num_samples.size()) << CHECK(prob_or_mask.size() == num_samples.size())
"the number of probability tensors do not match the number of edge types."; << "the number of probability tensors do not match the number of edge "
for (auto& p : prob_or_mask) "types.";
CHECK(p.defined()); for (auto& p : prob_or_mask) CHECK(p.defined());
auto pick_fn = GetSamplingRangePickFn<IdxType, DType>(num_samples, prob_or_mask, replace); auto pick_fn = GetSamplingRangePickFn<IdxType, DType>(
num_samples, prob_or_mask, replace);
return COORowWisePerEtypePick<IdxType, DType>( return COORowWisePerEtypePick<IdxType, DType>(
mat, rows, eid2etype_offset, num_samples, replace, pick_fn, prob_or_mask); mat, rows, eid2etype_offset, num_samples, replace, pick_fn, prob_or_mask);
} }
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, float>( template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, float>(
COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&, COOMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool); const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, float>( template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, float>(
COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&, COOMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool); const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, double>( template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, double>(
COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&, COOMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool); const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, double>( template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, double>(
COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&, COOMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool); const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, int8_t>( template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, int8_t>(
COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&, COOMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool); const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, int8_t>( template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, int8_t>(
COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&, COOMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool); const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, uint8_t>( template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, uint8_t>(
COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&, COOMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool); const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, uint8_t>( template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, uint8_t>(
COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&, COOMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool); const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
template <DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
COOMatrix COORowWiseSamplingUniform(COOMatrix mat, IdArray rows, COOMatrix COORowWiseSamplingUniform(
int64_t num_samples, bool replace) { COOMatrix mat, IdArray rows, int64_t num_samples, bool replace) {
// If num_samples is -1, select all neighbors without replacement. // If num_samples is -1, select all neighbors without replacement.
replace = (replace && num_samples != -1); replace = (replace && num_samples != -1);
auto num_picks_fn = GetSamplingUniformNumPicksFn<IdxType>(num_samples, replace); auto num_picks_fn =
GetSamplingUniformNumPicksFn<IdxType>(num_samples, replace);
auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace); auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
return COORowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn); return COORowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
} }
...@@ -414,9 +425,11 @@ COOMatrix COORowWisePerEtypeSamplingUniform( ...@@ -414,9 +425,11 @@ COOMatrix COORowWisePerEtypeSamplingUniform(
} }
template COOMatrix COORowWisePerEtypeSamplingUniform<kDGLCPU, int32_t>( template COOMatrix COORowWisePerEtypeSamplingUniform<kDGLCPU, int32_t>(
COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&, bool); COOMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<int64_t>&, bool);
template COOMatrix COORowWisePerEtypeSamplingUniform<kDGLCPU, int64_t>( template COOMatrix COORowWisePerEtypeSamplingUniform<kDGLCPU, int64_t>(
COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&, bool); COOMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<int64_t>&, bool);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -3,8 +3,9 @@ ...@@ -3,8 +3,9 @@
* @file array/cpu/rowwise_topk.cc * @file array/cpu/rowwise_topk.cc
* @brief rowwise topk * @brief rowwise topk
*/ */
#include <numeric>
#include <algorithm> #include <algorithm>
#include <numeric>
#include "./rowwise_pick.h" #include "./rowwise_pick.h"
namespace dgl { namespace dgl {
...@@ -14,52 +15,52 @@ namespace { ...@@ -14,52 +15,52 @@ namespace {
template <typename IdxType> template <typename IdxType>
inline NumPicksFn<IdxType> GetTopkNumPicksFn(int64_t k) { inline NumPicksFn<IdxType> GetTopkNumPicksFn(int64_t k) {
NumPicksFn<IdxType> num_picks_fn = [k] NumPicksFn<IdxType> num_picks_fn = [k](IdxType rowid, IdxType off,
(IdxType rowid, IdxType off, IdxType len, IdxType len, const IdxType* col,
const IdxType* col, const IdxType* data) { const IdxType* data) {
const int64_t max_num_picks = (k == -1) ? len : k; const int64_t max_num_picks = (k == -1) ? len : k;
return std::min(static_cast<IdxType>(max_num_picks), len); return std::min(static_cast<IdxType>(max_num_picks), len);
}; };
return num_picks_fn; return num_picks_fn;
} }
template <typename IdxType, typename DType> template <typename IdxType, typename DType>
inline PickFn<IdxType> GetTopkPickFn(NDArray weight, bool ascending) { inline PickFn<IdxType> GetTopkPickFn(NDArray weight, bool ascending) {
const DType* wdata = static_cast<DType*>(weight->data); const DType* wdata = static_cast<DType*>(weight->data);
PickFn<IdxType> pick_fn = [ascending, wdata] PickFn<IdxType> pick_fn = [ascending, wdata](
(IdxType rowid, IdxType off, IdxType len, IdxType num_picks, IdxType rowid, IdxType off, IdxType len,
const IdxType* col, const IdxType* data, IdxType num_picks, const IdxType* col,
IdxType* out_idx) { const IdxType* data, IdxType* out_idx) {
std::function<bool(IdxType, IdxType)> compare_fn; std::function<bool(IdxType, IdxType)> compare_fn;
if (ascending) { if (ascending) {
if (data) { if (data) {
compare_fn = [wdata, data] (IdxType i, IdxType j) { compare_fn = [wdata, data](IdxType i, IdxType j) {
return wdata[data[i]] < wdata[data[j]]; return wdata[data[i]] < wdata[data[j]];
}; };
} else {
compare_fn = [wdata] (IdxType i, IdxType j) {
return wdata[i] < wdata[j];
};
}
} else { } else {
if (data) { compare_fn = [wdata](IdxType i, IdxType j) {
compare_fn = [wdata, data] (IdxType i, IdxType j) { return wdata[i] < wdata[j];
return wdata[data[i]] > wdata[data[j]]; };
};
} else {
compare_fn = [wdata] (IdxType i, IdxType j) {
return wdata[i] > wdata[j];
};
}
} }
} else {
std::vector<IdxType> idx(len); if (data) {
std::iota(idx.begin(), idx.end(), off); compare_fn = [wdata, data](IdxType i, IdxType j) {
std::sort(idx.begin(), idx.end(), compare_fn); return wdata[data[i]] > wdata[data[j]];
for (int64_t j = 0; j < num_picks; ++j) { };
out_idx[j] = idx[j]; } else {
compare_fn = [wdata](IdxType i, IdxType j) {
return wdata[i] > wdata[j];
};
} }
}; }
std::vector<IdxType> idx(len);
std::iota(idx.begin(), idx.end(), off);
std::sort(idx.begin(), idx.end(), compare_fn);
for (int64_t j = 0; j < num_picks; ++j) {
out_idx[j] = idx[j];
}
};
return pick_fn; return pick_fn;
} }
......
...@@ -4,73 +4,65 @@ ...@@ -4,73 +4,65 @@
* @brief SDDMM C APIs and definitions. * @brief SDDMM C APIs and definitions.
*/ */
#include "./sddmm.h" #include "./sddmm.h"
#include <dgl/array.h> #include <dgl/array.h>
namespace dgl { namespace dgl {
namespace aten { namespace aten {
#define SWITCH_RHS(rhs_target, RhsTarget, ...) \ #define SWITCH_RHS(rhs_target, RhsTarget, ...) \
do { \ do { \
if ((rhs_target) == 0) { \ if ((rhs_target) == 0) { \
constexpr int RhsTarget = 0; \ constexpr int RhsTarget = 0; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else if ((rhs_target) == 1) { \ } else if ((rhs_target) == 1) { \
constexpr int RhsTarget = 1; \ constexpr int RhsTarget = 1; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else if ((rhs_target) == 2) { \ } else if ((rhs_target) == 2) { \
constexpr int RhsTarget = 2; \ constexpr int RhsTarget = 2; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else { \ } else { \
LOG(INFO) << "Invalid rhs target: " << (rhs_target); \ LOG(INFO) << "Invalid rhs target: " << (rhs_target); \
} \ } \
} while (0) } while (0)
#define SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, ...)\ #define SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, ...) \
do { \ do { \
if ((lhs_target) == 0) { \ if ((lhs_target) == 0) { \
constexpr int LhsTarget = 0; \ constexpr int LhsTarget = 0; \
SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \ SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \
} else if ((lhs_target) == 1) { \ } else if ((lhs_target) == 1) { \
constexpr int LhsTarget = 1; \ constexpr int LhsTarget = 1; \
SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \ SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \
} else if ((lhs_target) == 2) { \ } else if ((lhs_target) == 2) { \
constexpr int LhsTarget = 2; \ constexpr int LhsTarget = 2; \
SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \ SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \
} else { \ } else { \
LOG(INFO) << "Invalid lhs target: " << (lhs_target); \ LOG(INFO) << "Invalid lhs target: " << (lhs_target); \
} \ } \
} while (0) } while (0)
/** @brief Generalized SDDMM on Csr format. */ /** @brief Generalized SDDMM on Csr format. */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SDDMMCsr(const std::string& op, void SDDMMCsr(
const BcastOff& bcast, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
const CSRMatrix& csr, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target) {
NDArray lhs,
NDArray rhs,
NDArray out,
int lhs_target,
int rhs_target) {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, { SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out); cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(
bcast, csr, lhs, rhs, out);
}); });
}); });
} }
/** @brief Generalized SDDMM on Csr format with Heterograph support. */ /** @brief Generalized SDDMM on Csr format with Heterograph support. */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SDDMMCsrHetero(const std::string& op, void SDDMMCsrHetero(
const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& vec_lhs,
const std::vector<NDArray>& vec_lhs, const std::vector<NDArray>& vec_rhs, std::vector<NDArray> vec_out,
const std::vector<NDArray>& vec_rhs, int lhs_target, int rhs_target, const std::vector<dgl_type_t>& lhs_nid,
std::vector<NDArray> vec_out, const std::vector<dgl_type_t>& rhs_nid) {
int lhs_target,
int rhs_target,
const std::vector<dgl_type_t>& lhs_nid,
const std::vector<dgl_type_t>& rhs_nid) {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, { SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
/* Call SDDMM for each relation type */ /* Call SDDMM for each relation type */
...@@ -79,7 +71,8 @@ void SDDMMCsrHetero(const std::string& op, ...@@ -79,7 +71,8 @@ void SDDMMCsrHetero(const std::string& op,
NDArray lhs = vec_lhs[lhs_nid[etype]]; NDArray lhs = vec_lhs[lhs_nid[etype]];
NDArray rhs = vec_rhs[rhs_nid[etype]]; NDArray rhs = vec_rhs[rhs_nid[etype]];
NDArray out = vec_out[etype]; NDArray out = vec_out[etype];
cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out); cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(
bcast, csr, lhs, rhs, out);
} }
}); });
}); });
...@@ -87,79 +80,63 @@ void SDDMMCsrHetero(const std::string& op, ...@@ -87,79 +80,63 @@ void SDDMMCsrHetero(const std::string& op,
template void SDDMMCsr<kDGLCPU, int32_t, float>( template void SDDMMCsr<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCPU, int64_t, float>( template void SDDMMCsr<kDGLCPU, int64_t, float>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCPU, int32_t, double>( template void SDDMMCsr<kDGLCPU, int32_t, double>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCPU, int64_t, double>( template void SDDMMCsr<kDGLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
int lhs_target, int rhs_target);
template void SDDMMCsrHetero<kDGLCPU, int32_t, float>( template void SDDMMCsrHetero<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
std::vector<NDArray> out, int lhs_target, int rhs_target, int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCPU, int64_t, float>( template void SDDMMCsrHetero<kDGLCPU, int64_t, float>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
std::vector<NDArray> out, int lhs_target, int rhs_target, int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCPU, int32_t, double>( template void SDDMMCsrHetero<kDGLCPU, int32_t, double>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
std::vector<NDArray> out, int lhs_target, int rhs_target, int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCPU, int64_t, double>( template void SDDMMCsrHetero<kDGLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
std::vector<NDArray> out, int lhs_target, int rhs_target, int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
/** @brief Generalized SDDMM on Coo format. */ /** @brief Generalized SDDMM on Coo format. */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SDDMMCoo(const std::string& op, void SDDMMCoo(
const BcastOff& bcast, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
const COOMatrix& coo, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target) {
NDArray lhs,
NDArray rhs,
NDArray out,
int lhs_target,
int rhs_target) {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, { SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out); cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(
bcast, coo, lhs, rhs, out);
}); });
}); });
} }
/** @brief Generalized SDDMM on Coo format with Heterograph support. */ /** @brief Generalized SDDMM on Coo format with Heterograph support. */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SDDMMCooHetero(const std::string& op, void SDDMMCooHetero(
const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& vec_lhs,
const std::vector<NDArray>& vec_lhs, const std::vector<NDArray>& vec_rhs, std::vector<NDArray> vec_out,
const std::vector<NDArray>& vec_rhs, int lhs_target, int rhs_target, const std::vector<dgl_type_t>& lhs_nid,
std::vector<NDArray> vec_out, const std::vector<dgl_type_t>& rhs_nid) {
int lhs_target,
int rhs_target,
const std::vector<dgl_type_t>& lhs_nid,
const std::vector<dgl_type_t>& rhs_nid) {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, { SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
/* Call SDDMM for each relation type */ /* Call SDDMM for each relation type */
...@@ -168,7 +145,8 @@ void SDDMMCooHetero(const std::string& op, ...@@ -168,7 +145,8 @@ void SDDMMCooHetero(const std::string& op,
NDArray lhs = vec_lhs[lhs_nid[etype]]; NDArray lhs = vec_lhs[lhs_nid[etype]];
NDArray rhs = vec_rhs[rhs_nid[etype]]; NDArray rhs = vec_rhs[rhs_nid[etype]];
NDArray out = vec_out[etype]; NDArray out = vec_out[etype];
cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out); cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(
bcast, coo, lhs, rhs, out);
} }
}); });
}); });
...@@ -176,48 +154,40 @@ void SDDMMCooHetero(const std::string& op, ...@@ -176,48 +154,40 @@ void SDDMMCooHetero(const std::string& op,
template void SDDMMCoo<kDGLCPU, int32_t, float>( template void SDDMMCoo<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCPU, int64_t, float>( template void SDDMMCoo<kDGLCPU, int64_t, float>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCPU, int32_t, double>( template void SDDMMCoo<kDGLCPU, int32_t, double>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCPU, int64_t, double>( template void SDDMMCoo<kDGLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
int lhs_target, int rhs_target);
template void SDDMMCooHetero<kDGLCPU, int32_t, float>( template void SDDMMCooHetero<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
std::vector<NDArray> out, int lhs_target, int rhs_target, int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCPU, int64_t, float>( template void SDDMMCooHetero<kDGLCPU, int64_t, float>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
std::vector<NDArray> out, int lhs_target, int rhs_target, int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCPU, int32_t, double>( template void SDDMMCooHetero<kDGLCPU, int32_t, double>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
std::vector<NDArray> out, int lhs_target, int rhs_target, int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCPU, int64_t, double>( template void SDDMMCooHetero<kDGLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
std::vector<NDArray> out, int lhs_target, int rhs_target, int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
} // namespace aten } // namespace aten
......
...@@ -4,8 +4,11 @@ ...@@ -4,8 +4,11 @@
* @brief Segment reduce C APIs and definitions. * @brief Segment reduce C APIs and definitions.
*/ */
#include "./segment_reduce.h" #include "./segment_reduce.h"
#include <dgl/array.h> #include <dgl/array.h>
#include <string> #include <string>
#include "./spmm_binary_ops.h" #include "./spmm_binary_ops.h"
namespace dgl { namespace dgl {
...@@ -14,10 +17,7 @@ namespace aten { ...@@ -14,10 +17,7 @@ namespace aten {
/** @brief Segment Reduce operator. */ /** @brief Segment Reduce operator. */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SegmentReduce( void SegmentReduce(
const std::string& op, const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg) { NDArray arg) {
if (op == "sum") { if (op == "sum") {
cpu::SegmentSum<IdType, DType>(feat, offsets, out); cpu::SegmentSum<IdType, DType>(feat, offsets, out);
...@@ -36,73 +36,47 @@ void SegmentReduce( ...@@ -36,73 +36,47 @@ void SegmentReduce(
/** @brief Scatter Add.*/ /** @brief Scatter Add.*/
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void ScatterAdd(NDArray feat, void ScatterAdd(NDArray feat, NDArray idx, NDArray out) {
NDArray idx,
NDArray out) {
cpu::ScatterAdd<IdType, DType>(feat, idx, out); cpu::ScatterAdd<IdType, DType>(feat, idx, out);
} }
/** @brief Update gradients for reduce operator max/min on heterogeneous graph.*/ /** @brief Update gradients for reduce operator max/min on heterogeneous
* graph.*/
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void UpdateGradMinMax_hetero(const HeteroGraphPtr& g, void UpdateGradMinMax_hetero(
const std::string& op, const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx, const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out) {
const std::vector<NDArray>& idx_etype,
std::vector<NDArray>* out) {
cpu::UpdateGradMinMax_hetero<IdType, DType>(g, op, feat, idx, idx_etype, out); cpu::UpdateGradMinMax_hetero<IdType, DType>(g, op, feat, idx, idx_etype, out);
} }
/** @brief Backward function of segment cmp.*/ /** @brief Backward function of segment cmp.*/
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void BackwardSegmentCmp( void BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out) {
NDArray feat,
NDArray arg,
NDArray out) {
cpu::BackwardSegmentCmp<IdType, DType>(feat, arg, out); cpu::BackwardSegmentCmp<IdType, DType>(feat, arg, out);
} }
template void SegmentReduce<kDGLCPU, int32_t, float>( template void SegmentReduce<kDGLCPU, int32_t, float>(
const std::string &op, const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg); NDArray arg);
template void SegmentReduce<kDGLCPU, int64_t, float>( template void SegmentReduce<kDGLCPU, int64_t, float>(
const std::string &op, const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg); NDArray arg);
template void SegmentReduce<kDGLCPU, int32_t, double>( template void SegmentReduce<kDGLCPU, int32_t, double>(
const std::string &op, const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg); NDArray arg);
template void SegmentReduce<kDGLCPU, int64_t, double>( template void SegmentReduce<kDGLCPU, int64_t, double>(
const std::string &op, const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg); NDArray arg);
template void ScatterAdd<kDGLCPU, int32_t, float>( template void ScatterAdd<kDGLCPU, int32_t, float>(
NDArray feat, NDArray feat, NDArray idx, NDArray out);
NDArray idx,
NDArray out);
template void ScatterAdd<kDGLCPU, int64_t, float>( template void ScatterAdd<kDGLCPU, int64_t, float>(
NDArray feat, NDArray feat, NDArray idx, NDArray out);
NDArray idx,
NDArray out);
template void ScatterAdd<kDGLCPU, int32_t, double>( template void ScatterAdd<kDGLCPU, int32_t, double>(
NDArray feat, NDArray feat, NDArray idx, NDArray out);
NDArray idx,
NDArray out);
template void ScatterAdd<kDGLCPU, int64_t, double>( template void ScatterAdd<kDGLCPU, int64_t, double>(
NDArray feat, NDArray feat, NDArray arg, NDArray out);
NDArray arg,
NDArray out);
template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, float>( template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, float>(
const HeteroGraphPtr& g, const std::string& op, const HeteroGraphPtr& g, const std::string& op,
...@@ -122,21 +96,13 @@ template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, double>( ...@@ -122,21 +96,13 @@ template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, double>(
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out); const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void BackwardSegmentCmp<kDGLCPU, int32_t, float>( template void BackwardSegmentCmp<kDGLCPU, int32_t, float>(
NDArray feat, NDArray feat, NDArray arg, NDArray out);
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDGLCPU, int64_t, float>( template void BackwardSegmentCmp<kDGLCPU, int64_t, float>(
NDArray feat, NDArray feat, NDArray arg, NDArray out);
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDGLCPU, int32_t, double>( template void BackwardSegmentCmp<kDGLCPU, int32_t, double>(
NDArray feat, NDArray feat, NDArray arg, NDArray out);
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDGLCPU, int64_t, double>( template void BackwardSegmentCmp<kDGLCPU, int64_t, double>(
NDArray feat, NDArray feat, NDArray arg, NDArray out);
NDArray arg,
NDArray out);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* @brief SPMM C APIs and definitions. * @brief SPMM C APIs and definitions.
*/ */
#include "./spmm.h" #include "./spmm.h"
#include <dgl/array.h> #include <dgl/array.h>
namespace dgl { namespace dgl {
...@@ -11,13 +12,10 @@ namespace aten { ...@@ -11,13 +12,10 @@ namespace aten {
/** @brief Generalized SpMM on Csr format. */ /** @brief Generalized SpMM on Csr format. */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SpMMCsr(const std::string& op, const std::string& reduce, void SpMMCsr(
const BcastOff& bcast, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const CSRMatrix& csr, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, std::vector<NDArray> out_aux) {
NDArray efeat,
NDArray out,
std::vector<NDArray> out_aux) {
const int64_t dim = bcast.out_len; const int64_t dim = bcast.out_len;
if (reduce == "sum") { if (reduce == "sum") {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
...@@ -25,13 +23,15 @@ void SpMMCsr(const std::string& op, const std::string& reduce, ...@@ -25,13 +23,15 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
}); });
} else if (reduce == "max" || reduce == "min") { } else if (reduce == "max" || reduce == "min") {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
DType *out_off = out.Ptr<DType>(); DType* out_off = out.Ptr<DType>();
if (reduce == "max") { if (reduce == "max") {
std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Max<DType>::zero); std::fill(
out_off, out_off + csr.num_rows * dim, cpu::op::Max<DType>::zero);
cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Max<DType>>( cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Max<DType>>(
bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]); bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
} else { } else {
std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Min<DType>::zero); std::fill(
out_off, out_off + csr.num_rows * dim, cpu::op::Min<DType>::zero);
cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Min<DType>>( cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Min<DType>>(
bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]); bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
} }
...@@ -43,15 +43,14 @@ void SpMMCsr(const std::string& op, const std::string& reduce, ...@@ -43,15 +43,14 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
/** @brief Generalized SpMM on Csr format. */ /** @brief Generalized SpMM on Csr format. */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SpMMCsrHetero(const std::string& op, const std::string& reduce, void SpMMCsrHetero(
const BcastOff& bcast, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& vec_ufeat, const std::vector<NDArray>& vec_ufeat,
const std::vector<NDArray>& vec_efeat, const std::vector<NDArray>& vec_efeat, std::vector<NDArray>* vec_out,
std::vector<NDArray>* vec_out, std::vector<std::vector<NDArray>>* out_aux,
std::vector<std::vector<NDArray>>* out_aux, const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& ufeat_node_tids, const std::vector<dgl_type_t>& out_node_tids) {
const std::vector<dgl_type_t>& out_node_tids) {
const int64_t dim = bcast.out_len; const int64_t dim = bcast.out_len;
if (reduce == "sum") { if (reduce == "sum") {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
...@@ -60,8 +59,10 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -60,8 +59,10 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const dgl_type_t src_id = ufeat_node_tids[etype]; const dgl_type_t src_id = ufeat_node_tids[etype];
const dgl_type_t dst_id = out_node_tids[etype]; const dgl_type_t dst_id = out_node_tids[etype];
CSRMatrix csr = vec_csr[etype]; CSRMatrix csr = vec_csr[etype];
NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id]; NDArray ufeat =
NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype]; (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
NDArray efeat =
(vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
NDArray out = (*vec_out)[dst_id]; NDArray out = (*vec_out)[dst_id];
cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out); cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
} }
...@@ -71,21 +72,27 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -71,21 +72,27 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
std::vector<bool> updated((*vec_out).size(), false); std::vector<bool> updated((*vec_out).size(), false);
// TODO(Israt): use vector updated to fill(out...) too // TODO(Israt): use vector updated to fill(out...) too
for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) { for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
DType *out_off = (*vec_out)[out_node_tids[etype]].Ptr<DType>(); DType* out_off = (*vec_out)[out_node_tids[etype]].Ptr<DType>();
if (reduce == "max") if (reduce == "max")
std::fill(out_off, out_off + vec_csr[etype].num_rows * dim, cpu::op::Max<DType>::zero); std::fill(
out_off, out_off + vec_csr[etype].num_rows * dim,
cpu::op::Max<DType>::zero);
else else
std::fill(out_off, out_off + vec_csr[etype].num_rows * dim, cpu::op::Min<DType>::zero); std::fill(
out_off, out_off + vec_csr[etype].num_rows * dim,
cpu::op::Min<DType>::zero);
const dgl_type_t dst_id = out_node_tids[etype]; const dgl_type_t dst_id = out_node_tids[etype];
if (!updated[dst_id]) { if (!updated[dst_id]) {
updated[dst_id] = true; updated[dst_id] = true;
if (Op::use_lhs) { if (Op::use_lhs) {
IdType *argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>(); IdType* argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>();
std::fill(argu_ntype, argu_ntype + vec_csr[etype].num_rows * dim, -1); std::fill(
argu_ntype, argu_ntype + vec_csr[etype].num_rows * dim, -1);
} }
if (Op::use_rhs) { if (Op::use_rhs) {
IdType *arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>(); IdType* arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>();
std::fill(arge_etype, arge_etype + vec_csr[etype].num_rows * dim, -1); std::fill(
arge_etype, arge_etype + vec_csr[etype].num_rows * dim, -1);
} }
} }
} }
...@@ -94,17 +101,21 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -94,17 +101,21 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const dgl_type_t src_id = ufeat_node_tids[etype]; const dgl_type_t src_id = ufeat_node_tids[etype];
const dgl_type_t dst_id = out_node_tids[etype]; const dgl_type_t dst_id = out_node_tids[etype];
CSRMatrix csr = vec_csr[etype]; CSRMatrix csr = vec_csr[etype];
NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id]; NDArray ufeat =
NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype]; (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
NDArray efeat =
(vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
NDArray out = (*vec_out)[dst_id]; NDArray out = (*vec_out)[dst_id];
if (reduce == "max") { if (reduce == "max") {
cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Max<DType>>( cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Max<DType>>(
bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id], (*out_aux)[1][dst_id], bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id],
(*out_aux)[2][dst_id], (*out_aux)[3][dst_id], src_id, etype); (*out_aux)[1][dst_id], (*out_aux)[2][dst_id],
(*out_aux)[3][dst_id], src_id, etype);
} else { } else {
cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Min<DType>>( cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Min<DType>>(
bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id], (*out_aux)[1][dst_id], bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id],
(*out_aux)[2][dst_id], (*out_aux)[3][dst_id], src_id, etype); (*out_aux)[1][dst_id], (*out_aux)[2][dst_id],
(*out_aux)[3][dst_id], src_id, etype);
} }
} }
}); });
...@@ -114,120 +125,105 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -114,120 +125,105 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
} }
template void SpMMCsr<kDGLCPU, int32_t, float>( template void SpMMCsr<kDGLCPU, int32_t, float>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const CSRMatrix& csr, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCPU, int64_t, float>( template void SpMMCsr<kDGLCPU, int64_t, float>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const CSRMatrix& csr, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCPU, int32_t, double>( template void SpMMCsr<kDGLCPU, int32_t, double>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const CSRMatrix& csr, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCPU, int64_t, double>( template void SpMMCsr<kDGLCPU, int64_t, double>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const CSRMatrix& csr, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); std::vector<NDArray> out_aux);
template void SpMMCsrHetero<kDGLCPU, int32_t, float>( template void SpMMCsrHetero<kDGLCPU, int32_t, float>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids, const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids); const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDGLCPU, int64_t, float>( template void SpMMCsrHetero<kDGLCPU, int64_t, float>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids, const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids); const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDGLCPU, int32_t, double>( template void SpMMCsrHetero<kDGLCPU, int32_t, double>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids, const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids); const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDGLCPU, int64_t, double>( template void SpMMCsrHetero<kDGLCPU, int64_t, double>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids, const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids); const std::vector<dgl_type_t>& out_node_tids);
/** @brief Edge_softmax_csr forward op on Csr format. */ /** @brief Edge_softmax_csr forward op on Csr format. */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void Edge_softmax_csr_forward(const std::string& op, void Edge_softmax_csr_forward(
const BcastOff& bcast, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out) {
NDArray ufeat,
NDArray efeat,
NDArray out) {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
cpu::Edge_softmax_csr_forward<IdType, DType, Op>(bcast, csr, ufeat, efeat, out); cpu::Edge_softmax_csr_forward<IdType, DType, Op>(
bcast, csr, ufeat, efeat, out);
}); });
} }
/** @brief Edge_softmax_csr backward op on Csr format. */ /** @brief Edge_softmax_csr backward op on Csr format. */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void Edge_softmax_csr_backward(const std::string& op, void Edge_softmax_csr_backward(
const BcastOff& bcast, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
const CSRMatrix& csr, NDArray out, NDArray sds, NDArray back_out) {
NDArray out,
NDArray sds,
NDArray back_out) {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
cpu::Edge_softmax_csr_backward<IdType, DType, Op>(bcast, csr, out, sds, back_out); cpu::Edge_softmax_csr_backward<IdType, DType, Op>(
bcast, csr, out, sds, back_out);
}); });
} }
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, float>( template void Edge_softmax_csr_forward<kDGLCPU, int32_t, float>(
const std::string& op, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, float>( template void Edge_softmax_csr_forward<kDGLCPU, int64_t, float>(
const std::string& op, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, double>( template void Edge_softmax_csr_forward<kDGLCPU, int32_t, double>(
const std::string& op, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, double>( template void Edge_softmax_csr_forward<kDGLCPU, int64_t, double>(
const std::string& op, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, float>( template void Edge_softmax_csr_backward<kDGLCPU, int32_t, float>(
const std::string& op, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, float>( template void Edge_softmax_csr_backward<kDGLCPU, int64_t, float>(
const std::string& op, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, double>( template void Edge_softmax_csr_backward<kDGLCPU, int32_t, double>(
const std::string& op, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, double>( template void Edge_softmax_csr_backward<kDGLCPU, int64_t, double>(
const std::string& op, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
/** @brief Generalized SpMM on Coo format. */ /** @brief Generalized SpMM on Coo format. */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SpMMCoo(const std::string& op, const std::string& reduce, void SpMMCoo(
const BcastOff& bcast, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const COOMatrix& coo, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, std::vector<NDArray> out_aux) {
NDArray efeat,
NDArray out,
std::vector<NDArray> out_aux) {
if (reduce == "sum") { if (reduce == "sum") {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
cpu::SpMMSumCoo<IdType, DType, Op>(bcast, coo, ufeat, efeat, out); cpu::SpMMSumCoo<IdType, DType, Op>(bcast, coo, ufeat, efeat, out);
...@@ -247,21 +243,21 @@ void SpMMCoo(const std::string& op, const std::string& reduce, ...@@ -247,21 +243,21 @@ void SpMMCoo(const std::string& op, const std::string& reduce,
} }
template void SpMMCoo<kDGLCPU, int32_t, float>( template void SpMMCoo<kDGLCPU, int32_t, float>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const COOMatrix& coo, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCPU, int64_t, float>( template void SpMMCoo<kDGLCPU, int64_t, float>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const COOMatrix& coo, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCPU, int32_t, double>( template void SpMMCoo<kDGLCPU, int32_t, double>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const COOMatrix& coo, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCPU, int64_t, double>( template void SpMMCoo<kDGLCPU, int64_t, double>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const COOMatrix& coo, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); std::vector<NDArray> out_aux);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#define DGL_ARRAY_CPU_SPMM_BINARY_OPS_H_ #define DGL_ARRAY_CPU_SPMM_BINARY_OPS_H_
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/bcast.h> #include <dgl/bcast.h>
#include <limits> #include <limits>
namespace dgl { namespace dgl {
namespace aten { namespace aten {
......
...@@ -63,8 +63,10 @@ template NDArray IndexSelect<kDGLCUDA, int64_t, int64_t>(NDArray, IdArray); ...@@ -63,8 +63,10 @@ template NDArray IndexSelect<kDGLCUDA, int64_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, __half, int32_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, __half, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, __half, int64_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, __half, int64_t>(NDArray, IdArray);
#if BF16_ENABLED #if BF16_ENABLED
template NDArray IndexSelect<kDGLCUDA, __nv_bfloat16, int32_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, __nv_bfloat16, int32_t>(
template NDArray IndexSelect<kDGLCUDA, __nv_bfloat16, int64_t>(NDArray, IdArray); NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, __nv_bfloat16, int64_t>(
NDArray, IdArray);
#endif // BF16_ENABLED #endif // BF16_ENABLED
template NDArray IndexSelect<kDGLCUDA, float, int32_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, float, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, float, int64_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, float, int64_t>(NDArray, IdArray);
...@@ -76,8 +78,8 @@ DType IndexSelect(NDArray array, int64_t index) { ...@@ -76,8 +78,8 @@ DType IndexSelect(NDArray array, int64_t index) {
auto device = runtime::DeviceAPI::Get(array->ctx); auto device = runtime::DeviceAPI::Get(array->ctx);
DType ret = static_cast<DType>(0.0f); DType ret = static_cast<DType>(0.0f);
device->CopyDataFromTo( device->CopyDataFromTo(
static_cast<DType*>(array->data) + index, 0, &ret, 0, static_cast<DType*>(array->data) + index, 0, &ret, 0, sizeof(DType),
sizeof(DType), array->ctx, DGLContext{kDGLCPU, 0}, array->dtype); array->ctx, DGLContext{kDGLCPU, 0}, array->dtype);
return ret; return ret;
} }
...@@ -87,7 +89,8 @@ template uint32_t IndexSelect<kDGLCUDA, uint32_t>(NDArray array, int64_t index); ...@@ -87,7 +89,8 @@ template uint32_t IndexSelect<kDGLCUDA, uint32_t>(NDArray array, int64_t index);
template uint64_t IndexSelect<kDGLCUDA, uint64_t>(NDArray array, int64_t index); template uint64_t IndexSelect<kDGLCUDA, uint64_t>(NDArray array, int64_t index);
template __half IndexSelect<kDGLCUDA, __half>(NDArray array, int64_t index); template __half IndexSelect<kDGLCUDA, __half>(NDArray array, int64_t index);
#if BF16_ENABLED #if BF16_ENABLED
template __nv_bfloat16 IndexSelect<kDGLCUDA, __nv_bfloat16>(NDArray array, int64_t index); template __nv_bfloat16 IndexSelect<kDGLCUDA, __nv_bfloat16>(
NDArray array, int64_t index);
#endif // BF16_ENABLED #endif // BF16_ENABLED
template float IndexSelect<kDGLCUDA, float>(NDArray array, int64_t index); template float IndexSelect<kDGLCUDA, float>(NDArray array, int64_t index);
template double IndexSelect<kDGLCUDA, double>(NDArray array, int64_t index); template double IndexSelect<kDGLCUDA, double>(NDArray array, int64_t index);
......
...@@ -4,10 +4,11 @@ ...@@ -4,10 +4,11 @@
* @brief Array operator GPU implementation * @brief Array operator GPU implementation
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "../../runtime/cuda/cuda_hashtable.cuh" #include "../../runtime/cuda/cuda_hashtable.cuh"
#include "./utils.h"
#include "../arith.h" #include "../arith.h"
#include "./utils.h"
namespace dgl { namespace dgl {
using runtime::NDArray; using runtime::NDArray;
...@@ -38,35 +39,56 @@ IdArray BinaryElewise(IdArray lhs, IdArray rhs) { ...@@ -38,35 +39,56 @@ IdArray BinaryElewise(IdArray lhs, IdArray rhs) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(len); int nt = cuda::FindNumThreads(len);
int nb = (len + nt - 1) / nt; int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL((_BinaryElewiseKernel<IdType, Op>), CUDA_KERNEL_CALL(
nb, nt, 0, stream, (_BinaryElewiseKernel<IdType, Op>), nb, nt, 0, stream, lhs_data, rhs_data,
lhs_data, rhs_data, ret_data, len); ret_data, len);
return ret; return ret;
} }
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Add>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Add>(
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Sub>(IdArray lhs, IdArray rhs); IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mul>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Sub>(
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Div>(IdArray lhs, IdArray rhs); IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mod>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mul>(
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GT>(IdArray lhs, IdArray rhs); IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LT>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Div>(
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GE>(IdArray lhs, IdArray rhs); IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LE>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mod>(
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::EQ>(IdArray lhs, IdArray rhs); IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::NE>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GT>(
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Add>(IdArray lhs, IdArray rhs); IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Sub>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LT>(
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mul>(IdArray lhs, IdArray rhs); IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Div>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GE>(
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mod>(IdArray lhs, IdArray rhs); IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GT>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LE>(
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LT>(IdArray lhs, IdArray rhs); IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GE>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::EQ>(
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LE>(IdArray lhs, IdArray rhs); IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::EQ>(IdArray lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::NE>(
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::NE>(IdArray lhs, IdArray rhs); IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Add>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Sub>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mul>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Div>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mod>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GT>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LT>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GE>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LE>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::EQ>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::NE>(
IdArray lhs, IdArray rhs);
template <typename IdType, typename Op> template <typename IdType, typename Op>
__global__ void _BinaryElewiseKernel( __global__ void _BinaryElewiseKernel(
...@@ -88,36 +110,56 @@ IdArray BinaryElewise(IdArray lhs, IdType rhs) { ...@@ -88,36 +110,56 @@ IdArray BinaryElewise(IdArray lhs, IdType rhs) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(len); int nt = cuda::FindNumThreads(len);
int nb = (len + nt - 1) / nt; int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL((_BinaryElewiseKernel<IdType, Op>), CUDA_KERNEL_CALL(
nb, nt, 0, stream, (_BinaryElewiseKernel<IdType, Op>), nb, nt, 0, stream, lhs_data, rhs,
lhs_data, rhs, ret_data, len); ret_data, len);
return ret; return ret;
} }
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Add>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Add>(
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Sub>(IdArray lhs, int32_t rhs); IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mul>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Sub>(
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Div>(IdArray lhs, int32_t rhs); IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mod>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mul>(
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GT>(IdArray lhs, int32_t rhs); IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LT>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Div>(
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GE>(IdArray lhs, int32_t rhs); IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LE>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mod>(
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::EQ>(IdArray lhs, int32_t rhs); IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::NE>(IdArray lhs, int32_t rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GT>(
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Add>(IdArray lhs, int64_t rhs); IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Sub>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LT>(
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mul>(IdArray lhs, int64_t rhs); IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Div>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GE>(
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mod>(IdArray lhs, int64_t rhs); IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GT>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LE>(
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LT>(IdArray lhs, int64_t rhs); IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GE>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::EQ>(
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LE>(IdArray lhs, int64_t rhs); IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::EQ>(IdArray lhs, int64_t rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::NE>(
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::NE>(IdArray lhs, int64_t rhs); IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Add>(
IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Sub>(
IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mul>(
IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Div>(
IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mod>(
IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GT>(
IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LT>(
IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GE>(
IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LE>(
IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::EQ>(
IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::NE>(
IdArray lhs, int64_t rhs);
template <typename IdType, typename Op> template <typename IdType, typename Op>
__global__ void _BinaryElewiseKernel( __global__ void _BinaryElewiseKernel(
...@@ -139,34 +181,56 @@ IdArray BinaryElewise(IdType lhs, IdArray rhs) { ...@@ -139,34 +181,56 @@ IdArray BinaryElewise(IdType lhs, IdArray rhs) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(len); int nt = cuda::FindNumThreads(len);
int nb = (len + nt - 1) / nt; int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL((_BinaryElewiseKernel<IdType, Op>), CUDA_KERNEL_CALL(
nb, nt, 0, stream, (_BinaryElewiseKernel<IdType, Op>), nb, nt, 0, stream, lhs, rhs_data,
lhs, rhs_data, ret_data, len); ret_data, len);
return ret; return ret;
} }
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Add>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Add>(
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Sub>(int32_t lhs, IdArray rhs); int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mul>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Sub>(
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Div>(int32_t lhs, IdArray rhs); int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mod>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mul>(
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GT>(int32_t lhs, IdArray rhs); int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LT>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Div>(
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GE>(int32_t lhs, IdArray rhs); int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LE>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mod>(
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::EQ>(int32_t lhs, IdArray rhs); int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::NE>(int32_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GT>(
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Add>(int64_t lhs, IdArray rhs); int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Sub>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LT>(
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mul>(int64_t lhs, IdArray rhs); int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Div>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GE>(
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mod>(int64_t lhs, IdArray rhs); int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GT>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LE>(
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LT>(int64_t lhs, IdArray rhs); int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GE>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::EQ>(
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LE>(int64_t lhs, IdArray rhs); int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::EQ>(int64_t lhs, IdArray rhs); template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::NE>(
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::NE>(int64_t lhs, IdArray rhs); int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Add>(
int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Sub>(
int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mul>(
int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Div>(
int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mod>(
int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GT>(
int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LT>(
int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GE>(
int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LE>(
int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::EQ>(
int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::NE>(
int64_t lhs, IdArray rhs);
template <typename IdType, typename Op> template <typename IdType, typename Op>
__global__ void _UnaryElewiseKernel( __global__ void _UnaryElewiseKernel(
...@@ -188,9 +252,9 @@ IdArray UnaryElewise(IdArray lhs) { ...@@ -188,9 +252,9 @@ IdArray UnaryElewise(IdArray lhs) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(len); int nt = cuda::FindNumThreads(len);
int nb = (len + nt - 1) / nt; int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL((_UnaryElewiseKernel<IdType, Op>), CUDA_KERNEL_CALL(
nb, nt, 0, stream, (_UnaryElewiseKernel<IdType, Op>), nb, nt, 0, stream, lhs_data, ret_data,
lhs_data, ret_data, len); len);
return ret; return ret;
} }
...@@ -200,8 +264,7 @@ template IdArray UnaryElewise<kDGLCUDA, int64_t, arith::Neg>(IdArray lhs); ...@@ -200,8 +264,7 @@ template IdArray UnaryElewise<kDGLCUDA, int64_t, arith::Neg>(IdArray lhs);
///////////////////////////// Full ///////////////////////////// ///////////////////////////// Full /////////////////////////////
template <typename DType> template <typename DType>
__global__ void _FullKernel( __global__ void _FullKernel(DType* out, int64_t length, DType val) {
DType* out, int64_t length, DType val) {
int tx = blockIdx.x * blockDim.x + threadIdx.x; int tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = gridDim.x * blockDim.x; int stride_x = gridDim.x * blockDim.x;
while (tx < length) { while (tx < length) {
...@@ -217,20 +280,25 @@ NDArray Full(DType val, int64_t length, DGLContext ctx) { ...@@ -217,20 +280,25 @@ NDArray Full(DType val, int64_t length, DGLContext ctx) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(length); int nt = cuda::FindNumThreads(length);
int nb = (length + nt - 1) / nt; int nb = (length + nt - 1) / nt;
CUDA_KERNEL_CALL((_FullKernel<DType>), nb, nt, 0, stream, CUDA_KERNEL_CALL(
ret_data, length, val); (_FullKernel<DType>), nb, nt, 0, stream, ret_data, length, val);
return ret; return ret;
} }
template IdArray Full<kDGLCUDA, int32_t>(int32_t val, int64_t length, DGLContext ctx); template IdArray Full<kDGLCUDA, int32_t>(
template IdArray Full<kDGLCUDA, int64_t>(int64_t val, int64_t length, DGLContext ctx); int32_t val, int64_t length, DGLContext ctx);
template IdArray Full<kDGLCUDA, __half>(__half val, int64_t length, DGLContext ctx); template IdArray Full<kDGLCUDA, int64_t>(
int64_t val, int64_t length, DGLContext ctx);
template IdArray Full<kDGLCUDA, __half>(
__half val, int64_t length, DGLContext ctx);
#if BF16_ENABLED #if BF16_ENABLED
template IdArray Full<kDGLCUDA, __nv_bfloat16>(__nv_bfloat16 val, int64_t length, DGLContext ctx); template IdArray Full<kDGLCUDA, __nv_bfloat16>(
__nv_bfloat16 val, int64_t length, DGLContext ctx);
#endif // BF16_ENABLED #endif // BF16_ENABLED
template IdArray Full<kDGLCUDA, float>(float val, int64_t length, DGLContext ctx); template IdArray Full<kDGLCUDA, float>(
template IdArray Full<kDGLCUDA, double>(double val, int64_t length, DGLContext ctx); float val, int64_t length, DGLContext ctx);
template IdArray Full<kDGLCUDA, double>(
double val, int64_t length, DGLContext ctx);
///////////////////////////// Range ///////////////////////////// ///////////////////////////// Range /////////////////////////////
...@@ -249,15 +317,13 @@ IdArray Range(IdType low, IdType high, DGLContext ctx) { ...@@ -249,15 +317,13 @@ IdArray Range(IdType low, IdType high, DGLContext ctx) {
CHECK(high >= low) << "high must be bigger than low"; CHECK(high >= low) << "high must be bigger than low";
const IdType length = high - low; const IdType length = high - low;
IdArray ret = NewIdArray(length, ctx, sizeof(IdType) * 8); IdArray ret = NewIdArray(length, ctx, sizeof(IdType) * 8);
if (length == 0) if (length == 0) return ret;
return ret;
IdType* ret_data = static_cast<IdType*>(ret->data); IdType* ret_data = static_cast<IdType*>(ret->data);
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(length); int nt = cuda::FindNumThreads(length);
int nb = (length + nt - 1) / nt; int nb = (length + nt - 1) / nt;
CUDA_KERNEL_CALL((_RangeKernel<IdType>), CUDA_KERNEL_CALL(
nb, nt, 0, stream, (_RangeKernel<IdType>), nb, nt, 0, stream, ret_data, low, length);
ret_data, low, length);
return ret; return ret;
} }
...@@ -269,7 +335,6 @@ template IdArray Range<kDGLCUDA, int64_t>(int64_t, int64_t, DGLContext); ...@@ -269,7 +335,6 @@ template IdArray Range<kDGLCUDA, int64_t>(int64_t, int64_t, DGLContext);
template <typename IdType> template <typename IdType>
__global__ void _RelabelKernel( __global__ void _RelabelKernel(
IdType* out, int64_t length, DeviceOrderedHashTable<IdType> table) { IdType* out, int64_t length, DeviceOrderedHashTable<IdType> table) {
int tx = blockIdx.x * blockDim.x + threadIdx.x; int tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = gridDim.x * blockDim.x; int stride_x = gridDim.x * blockDim.x;
...@@ -295,30 +360,20 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) { ...@@ -295,30 +360,20 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
// build node maps and get the induced nodes // build node maps and get the induced nodes
OrderedHashTable<IdType> node_map(total_length, ctx, stream); OrderedHashTable<IdType> node_map(total_length, ctx, stream);
int64_t num_induced = 0; int64_t num_induced = 0;
int64_t * num_induced_device = static_cast<int64_t*>( int64_t* num_induced_device =
device->AllocWorkspace(ctx, sizeof(int64_t))); static_cast<int64_t*>(device->AllocWorkspace(ctx, sizeof(int64_t)));
IdArray induced_nodes = NewIdArray(total_length, ctx, sizeof(IdType)*8); IdArray induced_nodes = NewIdArray(total_length, ctx, sizeof(IdType) * 8);
CUDA_CALL(cudaMemsetAsync( CUDA_CALL(cudaMemsetAsync(
num_induced_device, num_induced_device, 0, sizeof(*num_induced_device), stream));
0,
sizeof(*num_induced_device),
stream));
node_map.FillWithDuplicates( node_map.FillWithDuplicates(
all_nodes.Ptr<IdType>(), all_nodes.Ptr<IdType>(), all_nodes->shape[0], induced_nodes.Ptr<IdType>(),
all_nodes->shape[0], num_induced_device, stream);
induced_nodes.Ptr<IdType>(),
num_induced_device,
stream);
// copy using the internal current stream // copy using the internal current stream
device->CopyDataFromTo( device->CopyDataFromTo(
num_induced_device, 0, num_induced_device, 0, &num_induced, 0, sizeof(num_induced), ctx,
&num_induced, 0, DGLContext{kDGLCPU, 0}, DGLDataType{kDGLInt, 64, 1});
sizeof(num_induced),
ctx,
DGLContext{kDGLCPU, 0},
DGLDataType{kDGLInt, 64, 1});
device->StreamSync(ctx, stream); device->StreamSync(ctx, stream);
device->FreeWorkspace(ctx, num_induced_device); device->FreeWorkspace(ctx, num_induced_device);
...@@ -331,16 +386,18 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) { ...@@ -331,16 +386,18 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
for (IdArray arr : arrays) { for (IdArray arr : arrays) {
const int64_t length = arr->shape[0]; const int64_t length = arr->shape[0];
int nb = (length + nt - 1) / nt; int nb = (length + nt - 1) / nt;
CUDA_KERNEL_CALL((_RelabelKernel<IdType>), CUDA_KERNEL_CALL(
nb, nt, 0, stream, (_RelabelKernel<IdType>), nb, nt, 0, stream, arr.Ptr<IdType>(), length,
arr.Ptr<IdType>(), length, node_map.DeviceHandle()); node_map.DeviceHandle());
} }
return induced_nodes; return induced_nodes;
} }
template IdArray Relabel_<kDGLCUDA, int32_t>(const std::vector<IdArray>& arrays); template IdArray Relabel_<kDGLCUDA, int32_t>(
template IdArray Relabel_<kDGLCUDA, int64_t>(const std::vector<IdArray>& arrays); const std::vector<IdArray>& arrays);
template IdArray Relabel_<kDGLCUDA, int64_t>(
const std::vector<IdArray>& arrays);
///////////////////////////// AsNumBits ///////////////////////////// ///////////////////////////// AsNumBits /////////////////////////////
...@@ -363,18 +420,19 @@ IdArray AsNumBits(IdArray arr, uint8_t bits) { ...@@ -363,18 +420,19 @@ IdArray AsNumBits(IdArray arr, uint8_t bits) {
int nt = cuda::FindNumThreads(length); int nt = cuda::FindNumThreads(length);
int nb = (length + nt - 1) / nt; int nb = (length + nt - 1) / nt;
if (bits == 32) { if (bits == 32) {
CUDA_KERNEL_CALL((_CastKernel<IdType, int32_t>), CUDA_KERNEL_CALL(
nb, nt, 0, stream, (_CastKernel<IdType, int32_t>), nb, nt, 0, stream,
static_cast<IdType*>(arr->data), static_cast<int32_t*>(ret->data), length); static_cast<IdType*>(arr->data), static_cast<int32_t*>(ret->data),
length);
} else { } else {
CUDA_KERNEL_CALL((_CastKernel<IdType, int64_t>), CUDA_KERNEL_CALL(
nb, nt, 0, stream, (_CastKernel<IdType, int64_t>), nb, nt, 0, stream,
static_cast<IdType*>(arr->data), static_cast<int64_t*>(ret->data), length); static_cast<IdType*>(arr->data), static_cast<int64_t*>(ret->data),
length);
} }
return ret; return ret;
} }
template IdArray AsNumBits<kDGLCUDA, int32_t>(IdArray arr, uint8_t bits); template IdArray AsNumBits<kDGLCUDA, int32_t>(IdArray arr, uint8_t bits);
template IdArray AsNumBits<kDGLCUDA, int64_t>(IdArray arr, uint8_t bits); template IdArray AsNumBits<kDGLCUDA, int64_t>(IdArray arr, uint8_t bits);
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* @brief Array scatter GPU implementation * @brief Array scatter GPU implementation
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./utils.h" #include "./utils.h"
...@@ -13,8 +14,8 @@ namespace aten { ...@@ -13,8 +14,8 @@ namespace aten {
namespace impl { namespace impl {
template <typename DType, typename IdType> template <typename DType, typename IdType>
__global__ void _ScatterKernel(const IdType* index, const DType* value, __global__ void _ScatterKernel(
int64_t length, DType* out) { const IdType* index, const DType* value, int64_t length, DType* out) {
int tx = blockIdx.x * blockDim.x + threadIdx.x; int tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = gridDim.x * blockDim.x; int stride_x = gridDim.x * blockDim.x;
while (tx < length) { while (tx < length) {
...@@ -33,15 +34,15 @@ void Scatter_(IdArray index, NDArray value, NDArray out) { ...@@ -33,15 +34,15 @@ void Scatter_(IdArray index, NDArray value, NDArray out) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const int nt = cuda::FindNumThreads(len); const int nt = cuda::FindNumThreads(len);
const int nb = (len + nt - 1) / nt; const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL(_ScatterKernel, nb, nt, 0, stream, CUDA_KERNEL_CALL(_ScatterKernel, nb, nt, 0, stream, idx, val, len, outd);
idx, val, len, outd);
} }
template void Scatter_<kDGLCUDA, int32_t, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, int32_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, int64_t, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, int64_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, __half, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, __half, int32_t>(IdArray, NDArray, NDArray);
#if BF16_ENABLED #if BF16_ENABLED
template void Scatter_<kDGLCUDA, __nv_bfloat16, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, __nv_bfloat16, int32_t>(
IdArray, NDArray, NDArray);
#endif // BF16_ENABLED #endif // BF16_ENABLED
template void Scatter_<kDGLCUDA, float, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, float, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, double, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, double, int32_t>(IdArray, NDArray, NDArray);
...@@ -49,7 +50,8 @@ template void Scatter_<kDGLCUDA, int32_t, int64_t>(IdArray, NDArray, NDArray); ...@@ -49,7 +50,8 @@ template void Scatter_<kDGLCUDA, int32_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, int64_t, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, int64_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, __half, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, __half, int64_t>(IdArray, NDArray, NDArray);
#if BF16_ENABLED #if BF16_ENABLED
template void Scatter_<kDGLCUDA, __nv_bfloat16, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, __nv_bfloat16, int64_t>(
IdArray, NDArray, NDArray);
#endif // BF16_ENABLED #endif // BF16_ENABLED
template void Scatter_<kDGLCUDA, float, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, float, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, double, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, double, int64_t>(IdArray, NDArray, NDArray);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment