Unverified Commit 8ac27dad authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4824)



* [Misc] clang-format auto fix.

* blabla

* ablabla

* blabla
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent bcd37684
...@@ -10,10 +10,11 @@ ...@@ -10,10 +10,11 @@
#define DGL_ATEN_ARRAY_OPS_H_ #define DGL_ATEN_ARRAY_OPS_H_
#include <algorithm> #include <algorithm>
#include <string>
#include <tuple>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <tuple>
#include <string>
#include "./types.h" #include "./types.h"
namespace dgl { namespace dgl {
...@@ -24,7 +25,8 @@ namespace aten { ...@@ -24,7 +25,8 @@ namespace aten {
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
/** @return A special array to represent null. */ /** @return A special array to represent null. */
inline NDArray NullArray(const DGLDataType& dtype = DGLDataType{kDGLInt, 64, 1}, inline NDArray NullArray(
const DGLDataType& dtype = DGLDataType{kDGLInt, 64, 1},
const DGLContext& ctx = DGLContext{kDGLCPU, 0}) { const DGLContext& ctx = DGLContext{kDGLCPU, 0}) {
return NDArray::Empty({0}, dtype, ctx); return NDArray::Empty({0}, dtype, ctx);
} }
...@@ -32,9 +34,7 @@ inline NDArray NullArray(const DGLDataType& dtype = DGLDataType{kDGLInt, 64, 1}, ...@@ -32,9 +34,7 @@ inline NDArray NullArray(const DGLDataType& dtype = DGLDataType{kDGLInt, 64, 1},
/** /**
* @return Whether the input array is a null array. * @return Whether the input array is a null array.
*/ */
inline bool IsNullArray(NDArray array) { inline bool IsNullArray(NDArray array) { return array->shape[0] == 0; }
return array->shape[0] == 0;
}
/** /**
* @brief Create a new id array with given length * @brief Create a new id array with given length
...@@ -43,8 +43,8 @@ inline bool IsNullArray(NDArray array) { ...@@ -43,8 +43,8 @@ inline bool IsNullArray(NDArray array) {
* @param nbits The number of integer bits * @param nbits The number of integer bits
* @return id array * @return id array
*/ */
IdArray NewIdArray(int64_t length, IdArray NewIdArray(
DGLContext ctx = DGLContext{kDGLCPU, 0}, int64_t length, DGLContext ctx = DGLContext{kDGLCPU, 0},
uint8_t nbits = 64); uint8_t nbits = 64);
/** /**
...@@ -55,8 +55,8 @@ IdArray NewIdArray(int64_t length, ...@@ -55,8 +55,8 @@ IdArray NewIdArray(int64_t length,
* @return the id array * @return the id array
*/ */
template <typename T> template <typename T>
IdArray VecToIdArray(const std::vector<T>& vec, IdArray VecToIdArray(
uint8_t nbits = 64, const std::vector<T>& vec, uint8_t nbits = 64,
DGLContext ctx = DGLContext{kDGLCPU, 0}); DGLContext ctx = DGLContext{kDGLCPU, 0});
/** /**
...@@ -148,7 +148,7 @@ IdArray NonZero(BoolArray bool_arr); ...@@ -148,7 +148,7 @@ IdArray NonZero(BoolArray bool_arr);
* @brief Return the data under the index. In numpy notation, A[I] * @brief Return the data under the index. In numpy notation, A[I]
* @tparam ValueType The type of return value. * @tparam ValueType The type of return value.
*/ */
template<typename ValueType> template <typename ValueType>
ValueType IndexSelect(NDArray array, int64_t index); ValueType IndexSelect(NDArray array, int64_t index);
/** /**
...@@ -187,10 +187,11 @@ NDArray Scatter(NDArray array, IdArray indices); ...@@ -187,10 +187,11 @@ NDArray Scatter(NDArray array, IdArray indices);
void Scatter_(IdArray index, NDArray value, NDArray out); void Scatter_(IdArray index, NDArray value, NDArray out);
/** /**
* @brief Repeat each element a number of times. Equivalent to np.repeat(array, repeats) * @brief Repeat each element a number of times. Equivalent to np.repeat(array,
* repeats)
* @param array A 1D vector * @param array A 1D vector
* @param repeats A 1D integer vector for number of times to repeat for each element in * @param repeats A 1D integer vector for number of times to repeat for each
* \c array. Must have the same shape as \c array. * element in \c array. Must have the same shape as \c array.
*/ */
NDArray Repeat(NDArray array, IdArray repeats); NDArray Repeat(NDArray array, IdArray repeats);
...@@ -253,12 +254,13 @@ inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) { ...@@ -253,12 +254,13 @@ inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) {
* *
* The packed tensor would be [1, 2, 3, 4, 5]. * The packed tensor would be [1, 2, 3, 4, 5].
* *
* The length tensor would be [2, 3], i.e. the length of each sequence before padding. * The length tensor would be [2, 3], i.e. the length of each sequence before
* padding.
* *
* The offset tensor would be [0, 2], i.e. the offset to the packed tensor for each * The offset tensor would be [0, 2], i.e. the offset to the packed tensor for
* sequence (before padding) * each sequence (before padding)
*/ */
template<typename ValueType> template <typename ValueType>
std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, ValueType pad_value); std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, ValueType pad_value);
/** /**
...@@ -295,8 +297,9 @@ std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths); ...@@ -295,8 +297,9 @@ std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths);
* @brief Return the cumulative summation (or inclusive sum) of the input array. * @brief Return the cumulative summation (or inclusive sum) of the input array.
* *
* The first element out[0] is equal to the first element of the input array * The first element out[0] is equal to the first element of the input array
* array[0]. The rest elements are defined recursively, out[i] = out[i-1] + array[i]. * array[0]. The rest elements are defined recursively, out[i] = out[i-1] +
* Hence, the result array length is the same as the input array length. * array[i]. Hence, the result array length is the same as the input array
* length.
* *
* If prepend_zero is true, then the first element is zero and the result array * If prepend_zero is true, then the first element is zero and the result array
* length is the input array length plus one. This is useful for creating * length is the input array length plus one. This is useful for creating
...@@ -320,17 +323,18 @@ IdArray NonZero(NDArray array); ...@@ -320,17 +323,18 @@ IdArray NonZero(NDArray array);
/** /**
* @brief Sort the ID vector in ascending order. * @brief Sort the ID vector in ascending order.
* *
* It performs both sort and arg_sort (returning the sorted index). The sorted index * It performs both sort and arg_sort (returning the sorted index). The sorted
* is always in int64. * index is always in int64.
* *
* @param array Input array. * @param array Input array.
* @param num_bits The number of bits used in key comparison. For example, if the data type * @param num_bits The number of bits used in key comparison. For example, if
* of the input array is int32_t and `num_bits = 8`, it only uses bits in index * the data type of the input array is int32_t and `num_bits = 8`, it only uses
* range [0, 8) for sorting. Setting it to a small value could * bits in index range [0, 8) for sorting. Setting it to a small value could
* speed up the sorting if the underlying sorting algorithm is radix sort (e.g., on GPU). * speed up the sorting if the underlying sorting algorithm is
* Setting it to zero (default value) means using all the bits for comparison. * radix sort (e.g., on GPU). Setting it to zero (default value) means using all
* On CPU, it currently has no effect. * the bits for comparison. On CPU, it currently has no effect.
* @return A pair of arrays: sorted values and sorted index to the original position. * @return A pair of arrays: sorted values and sorted index to the original
* position.
*/ */
std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits = 0); std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits = 0);
...@@ -341,9 +345,7 @@ std::string ToDebugString(NDArray array); ...@@ -341,9 +345,7 @@ std::string ToDebugString(NDArray array);
// inline implementations // inline implementations
template <typename T> template <typename T>
IdArray VecToIdArray(const std::vector<T>& vec, IdArray VecToIdArray(const std::vector<T>& vec, uint8_t nbits, DGLContext ctx) {
uint8_t nbits,
DGLContext ctx) {
IdArray ret = NewIdArray(vec.size(), DGLContext{kDGLCPU, 0}, nbits); IdArray ret = NewIdArray(vec.size(), DGLContext{kDGLCPU, 0}, nbits);
if (nbits == 32) { if (nbits == 32) {
std::copy(vec.begin(), vec.end(), static_cast<int32_t*>(ret->data)); std::copy(vec.begin(), vec.end(), static_cast<int32_t*>(ret->data));
...@@ -367,7 +369,8 @@ inline DGLContext GetContextOf(const std::vector<IdArray>& arrays) { ...@@ -367,7 +369,8 @@ inline DGLContext GetContextOf(const std::vector<IdArray>& arrays) {
first = false; first = false;
result = array->ctx; result = array->ctx;
} else { } else {
CHECK_EQ(array->ctx, result) << "Context of the input arrays are different"; CHECK_EQ(array->ctx, result)
<< "Context of the input arrays are different";
} }
} }
return result; return result;
......
...@@ -9,14 +9,16 @@ ...@@ -9,14 +9,16 @@
#include <dmlc/io.h> #include <dmlc/io.h>
#include <dmlc/serializer.h> #include <dmlc/serializer.h>
#include <vector>
#include <utility>
#include <tuple>
#include <string> #include <string>
#include "./types.h" #include <tuple>
#include <utility>
#include <vector>
#include "./array_ops.h" #include "./array_ops.h"
#include "./spmat.h"
#include "./macro.h" #include "./macro.h"
#include "./spmat.h"
#include "./types.h"
namespace dgl { namespace dgl {
namespace aten { namespace aten {
...@@ -52,9 +54,9 @@ struct COOMatrix { ...@@ -52,9 +54,9 @@ struct COOMatrix {
/** @brief default constructor */ /** @brief default constructor */
COOMatrix() = default; COOMatrix() = default;
/** @brief constructor */ /** @brief constructor */
COOMatrix(int64_t nrows, int64_t ncols, IdArray rarr, IdArray carr, COOMatrix(
IdArray darr = NullArray(), bool rsorted = false, int64_t nrows, int64_t ncols, IdArray rarr, IdArray carr,
bool csorted = false) IdArray darr = NullArray(), bool rsorted = false, bool csorted = false)
: num_rows(nrows), : num_rows(nrows),
num_cols(ncols), num_cols(ncols),
row(rarr), row(rarr),
...@@ -79,8 +81,9 @@ struct COOMatrix { ...@@ -79,8 +81,9 @@ struct COOMatrix {
// Convert to a SparseMatrix object that can return to python. // Convert to a SparseMatrix object that can return to python.
SparseMatrix ToSparseMatrix() const { SparseMatrix ToSparseMatrix() const {
return SparseMatrix(static_cast<int32_t>(SparseFormat::kCOO), num_rows, return SparseMatrix(
num_cols, {row, col, data}, {row_sorted, col_sorted}); static_cast<int32_t>(SparseFormat::kCOO), num_rows, num_cols,
{row, col, data}, {row_sorted, col_sorted});
} }
bool Load(dmlc::Stream* fs) { bool Load(dmlc::Stream* fs) {
...@@ -122,12 +125,12 @@ struct COOMatrix { ...@@ -122,12 +125,12 @@ struct COOMatrix {
} }
/** @brief Return a copy of this matrix on the give device context. */ /** @brief Return a copy of this matrix on the give device context. */
inline COOMatrix CopyTo(const DGLContext &ctx) const { inline COOMatrix CopyTo(const DGLContext& ctx) const {
if (ctx == row->ctx) if (ctx == row->ctx) return *this;
return *this; return COOMatrix(
return COOMatrix(num_rows, num_cols, row.CopyTo(ctx), col.CopyTo(ctx), num_rows, num_cols, row.CopyTo(ctx), col.CopyTo(ctx),
aten::IsNullArray(data) ? data : data.CopyTo(ctx), aten::IsNullArray(data) ? data : data.CopyTo(ctx), row_sorted,
row_sorted, col_sorted); col_sorted);
} }
/** /**
...@@ -139,8 +142,7 @@ struct COOMatrix { ...@@ -139,8 +142,7 @@ struct COOMatrix {
* The context check is deferred to pinning the NDArray. * The context check is deferred to pinning the NDArray.
*/ */
inline void PinMemory_() { inline void PinMemory_() {
if (is_pinned) if (is_pinned) return;
return;
row.PinMemory_(); row.PinMemory_();
col.PinMemory_(); col.PinMemory_();
if (!aten::IsNullArray(data)) { if (!aten::IsNullArray(data)) {
...@@ -157,8 +159,7 @@ struct COOMatrix { ...@@ -157,8 +159,7 @@ struct COOMatrix {
* The context check is deferred to unpinning the NDArray. * The context check is deferred to unpinning the NDArray.
*/ */
inline void UnpinMemory_() { inline void UnpinMemory_() {
if (!is_pinned) if (!is_pinned) return;
return;
row.UnpinMemory_(); row.UnpinMemory_();
col.UnpinMemory_(); col.UnpinMemory_();
if (!aten::IsNullArray(data)) { if (!aten::IsNullArray(data)) {
...@@ -183,25 +184,25 @@ struct COOMatrix { ...@@ -183,25 +184,25 @@ struct COOMatrix {
///////////////////////// COO routines ////////////////////////// ///////////////////////// COO routines //////////////////////////
/** @brief Return true if the value (row, col) is non-zero */ /** @brief Return true if the value (row, col) is non-zero */
bool COOIsNonZero(COOMatrix , int64_t row, int64_t col); bool COOIsNonZero(COOMatrix, int64_t row, int64_t col);
/** /**
* @brief Batched implementation of COOIsNonZero. * @brief Batched implementation of COOIsNonZero.
* @note This operator allows broadcasting (i.e, either row or col can be of length 1). * @note This operator allows broadcasting (i.e, either row or col can be of
* length 1).
*/ */
runtime::NDArray COOIsNonZero(COOMatrix , runtime::NDArray row, runtime::NDArray col); runtime::NDArray COOIsNonZero(
COOMatrix, runtime::NDArray row, runtime::NDArray col);
/** @brief Return the nnz of the given row */ /** @brief Return the nnz of the given row */
int64_t COOGetRowNNZ(COOMatrix , int64_t row); int64_t COOGetRowNNZ(COOMatrix, int64_t row);
runtime::NDArray COOGetRowNNZ(COOMatrix , runtime::NDArray row); runtime::NDArray COOGetRowNNZ(COOMatrix, runtime::NDArray row);
/** @brief Return the data array of the given row */ /** @brief Return the data array of the given row */
std::pair<runtime::NDArray, runtime::NDArray> std::pair<runtime::NDArray, runtime::NDArray> COOGetRowDataAndIndices(
COOGetRowDataAndIndices(COOMatrix , int64_t row); COOMatrix, int64_t row);
/** @brief Whether the COO matrix contains data */ /** @brief Whether the COO matrix contains data */
inline bool COOHasData(COOMatrix csr) { inline bool COOHasData(COOMatrix csr) { return !IsNullArray(csr.data); }
return !IsNullArray(csr.data);
}
/** /**
* @brief Check whether the COO is sorted. * @brief Check whether the COO is sorted.
...@@ -217,11 +218,12 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo); ...@@ -217,11 +218,12 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo);
/** /**
* @brief Get the data and the row,col indices for each returned entries. * @brief Get the data and the row,col indices for each returned entries.
* *
* The operator supports matrix with duplicate entries and all the matched entries * The operator supports matrix with duplicate entries and all the matched
* will be returned. The operator assumes there is NO duplicate (row, col) pair * entries will be returned. The operator assumes there is NO duplicate (row,
* in the given input. Otherwise, the returned result is undefined. * col) pair in the given input. Otherwise, the returned result is undefined.
* *
* @note This operator allows broadcasting (i.e, either row or col can be of length 1). * @note This operator allows broadcasting (i.e, either row or col can be of
* length 1).
* @param mat Sparse matrix * @param mat Sparse matrix
* @param rows Row index * @param rows Row index
* @param cols Column index * @param cols Column index
...@@ -230,10 +232,13 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo); ...@@ -230,10 +232,13 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo);
std::vector<runtime::NDArray> COOGetDataAndIndices( std::vector<runtime::NDArray> COOGetDataAndIndices(
COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols); COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols);
/** @brief Get data. The return type is an ndarray due to possible duplicate entries. */ /** @brief Get data. The return type is an ndarray due to possible duplicate
* entries. */
inline runtime::NDArray COOGetAllData(COOMatrix mat, int64_t row, int64_t col) { inline runtime::NDArray COOGetAllData(COOMatrix mat, int64_t row, int64_t col) {
IdArray rows = VecToIdArray<int64_t>({row}, mat.row->dtype.bits, mat.row->ctx); IdArray rows =
IdArray cols = VecToIdArray<int64_t>({col}, mat.row->dtype.bits, mat.row->ctx); VecToIdArray<int64_t>({row}, mat.row->dtype.bits, mat.row->ctx);
IdArray cols =
VecToIdArray<int64_t>({col}, mat.row->dtype.bits, mat.row->ctx);
const auto& rst = COOGetDataAndIndices(mat, rows, cols); const auto& rst = COOGetDataAndIndices(mat, rows, cols);
return rst[2]; return rst[2];
} }
...@@ -241,18 +246,20 @@ inline runtime::NDArray COOGetAllData(COOMatrix mat, int64_t row, int64_t col) { ...@@ -241,18 +246,20 @@ inline runtime::NDArray COOGetAllData(COOMatrix mat, int64_t row, int64_t col) {
/** /**
* @brief Get the data for each (row, col) pair. * @brief Get the data for each (row, col) pair.
* *
* The operator supports matrix with duplicate entries but only one matched entry * The operator supports matrix with duplicate entries but only one matched
* will be returned for each (row, col) pair. Support duplicate input (row, col) * entry will be returned for each (row, col) pair. Support duplicate input
* pairs. * (row, col) pairs.
* *
* @note This operator allows broadcasting (i.e, either row or col can be of length 1). * @note This operator allows broadcasting (i.e, either row or col can be of
* length 1).
* *
* @param mat Sparse matrix. * @param mat Sparse matrix.
* @param rows Row index. * @param rows Row index.
* @param cols Column index. * @param cols Column index.
* @return Data array. The i^th element is the data of (rows[i], cols[i]) * @return Data array. The i^th element is the data of (rows[i], cols[i])
*/ */
runtime::NDArray COOGetData(COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols); runtime::NDArray COOGetData(
COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols);
/** @brief Return a transposed COO matrix */ /** @brief Return a transposed COO matrix */
COOMatrix COOTranspose(COOMatrix coo); COOMatrix COOTranspose(COOMatrix coo);
...@@ -301,14 +308,15 @@ COOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows); ...@@ -301,14 +308,15 @@ COOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows);
* @param cols The col index to select * @param cols The col index to select
* @return submatrix * @return submatrix
*/ */
COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols); COOMatrix COOSliceMatrix(
COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
/** @return True if the matrix has duplicate entries */ /** @return True if the matrix has duplicate entries */
bool COOHasDuplicate(COOMatrix coo); bool COOHasDuplicate(COOMatrix coo);
/** /**
* @brief Deduplicate the entries of a sorted COO matrix, replacing the data with the * @brief Deduplicate the entries of a sorted COO matrix, replacing the data
* number of occurrences of the row-col coordinates. * with the number of occurrences of the row-col coordinates.
*/ */
std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo); std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo);
...@@ -316,11 +324,13 @@ std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo); ...@@ -316,11 +324,13 @@ std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo);
* @brief Sort the indices of a COO matrix in-place. * @brief Sort the indices of a COO matrix in-place.
* *
* The function sorts row indices in ascending order. If sort_column is true, * The function sorts row indices in ascending order. If sort_column is true,
* col indices are sorted in ascending order too. The data array of the returned COOMatrix * col indices are sorted in ascending order too. The data array of the returned
* stores the shuffled index which could be used to fetch edge data. * COOMatrix stores the shuffled index which could be used to fetch edge data.
* *
* Complexity: O(N*log(N)) time and O(1) space, where N is the number of nonzeros. * Complexity: O(N*log(N)) time and O(1) space, where N is the number of
* TODO(minjie): The time complexity could be improved to O(N) by using a O(N) space. * nonzeros.
* TODO(minjie): The time complexity could be improved to O(N) by using a O(N)
* space.
* *
* @param mat The coo matrix to sort. * @param mat The coo matrix to sort.
* @param sort_column True if column index should be sorted too. * @param sort_column True if column index should be sorted too.
...@@ -331,31 +341,32 @@ void COOSort_(COOMatrix* mat, bool sort_column = false); ...@@ -331,31 +341,32 @@ void COOSort_(COOMatrix* mat, bool sort_column = false);
* @brief Sort the indices of a COO matrix. * @brief Sort the indices of a COO matrix.
* *
* The function sorts row indices in ascending order. If sort_column is true, * The function sorts row indices in ascending order. If sort_column is true,
* col indices are sorted in ascending order too. The data array of the returned COOMatrix * col indices are sorted in ascending order too. The data array of the returned
* stores the shuffled index which could be used to fetch edge data. * COOMatrix stores the shuffled index which could be used to fetch edge data.
* *
* Complexity: O(N*log(N)) time and O(1) space, where N is the number of nonzeros. * Complexity: O(N*log(N)) time and O(1) space, where N is the number of
* TODO(minjie): The time complexity could be improved to O(N) by using a O(N) space. * nonzeros.
* TODO(minjie): The time complexity could be improved to O(N) by using a O(N)
* space.
* *
* @param mat The input coo matrix * @param mat The input coo matrix
* @param sort_column True if column index should be sorted too. * @param sort_column True if column index should be sorted too.
* @return COO matrix with index sorted. * @return COO matrix with index sorted.
*/ */
inline COOMatrix COOSort(COOMatrix mat, bool sort_column = false) { inline COOMatrix COOSort(COOMatrix mat, bool sort_column = false) {
if ((mat.row_sorted && !sort_column) || mat.col_sorted) if ((mat.row_sorted && !sort_column) || mat.col_sorted) return mat;
return mat; COOMatrix ret(
COOMatrix ret(mat.num_rows, mat.num_cols, mat.num_rows, mat.num_cols, mat.row.Clone(), mat.col.Clone(),
mat.row.Clone(), mat.col.Clone(), COOHasData(mat) ? mat.data.Clone() : mat.data, mat.row_sorted,
COOHasData(mat)? mat.data.Clone() : mat.data, mat.col_sorted);
mat.row_sorted, mat.col_sorted);
COOSort_(&ret, sort_column); COOSort_(&ret, sort_column);
return ret; return ret;
} }
/** /**
* @brief Remove entries from COO matrix by entry indices (data indices) * @brief Remove entries from COO matrix by entry indices (data indices)
* @return A new COO matrix as well as a mapping from the new COO entries to the old COO * @return A new COO matrix as well as a mapping from the new COO entries to the
* entries. * old COO entries.
*/ */
COOMatrix COORemove(COOMatrix coo, IdArray entries); COOMatrix COORemove(COOMatrix coo, IdArray entries);
...@@ -365,10 +376,12 @@ COOMatrix COORemove(COOMatrix coo, IdArray entries); ...@@ -365,10 +376,12 @@ COOMatrix COORemove(COOMatrix coo, IdArray entries);
* @param new_row_ids the new row Ids (the index is the old row Id) * @param new_row_ids the new row Ids (the index is the old row Id)
* @param new_col_ids the new column Ids (the index is the old col Id). * @param new_col_ids the new column Ids (the index is the old col Id).
*/ */
COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids); COOMatrix COOReorder(
COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
/** /**
* @brief Randomly select a fixed number of non-zero entries along each given row independently. * @brief Randomly select a fixed number of non-zero entries along each given
* row independently.
* *
* The function performs random choices along each row independently. * The function performs random choices along each row independently.
* The picked indices are returned in the form of a COO matrix. * The picked indices are returned in the form of a COO matrix.
...@@ -400,15 +413,12 @@ COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArr ...@@ -400,15 +413,12 @@ COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArr
* Should be of the same length as the data array. * Should be of the same length as the data array.
* If an empty array is provided, assume uniform. * If an empty array is provided, assume uniform.
* @param replace True if sample with replacement * @param replace True if sample with replacement
* @return A COOMatrix storing the picked row and col indices. Its data field stores the * @return A COOMatrix storing the picked row and col indices. Its data field
* the index of the picked elements in the value array. * stores the the index of the picked elements in the value array.
*/ */
COOMatrix COORowWiseSampling( COOMatrix COORowWiseSampling(
COOMatrix mat, COOMatrix mat, IdArray rows, int64_t num_samples,
IdArray rows, NDArray prob_or_mask = NDArray(), bool replace = true);
int64_t num_samples,
NDArray prob_or_mask = NDArray(),
bool replace = true);
/** /**
* @brief Randomly select a fixed number of non-zero entries for each edge type * @brief Randomly select a fixed number of non-zero entries for each edge type
...@@ -433,8 +443,8 @@ COOMatrix COORowWiseSampling( ...@@ -433,8 +443,8 @@ COOMatrix COORowWiseSampling(
* COOMatrix coo = ...; * COOMatrix coo = ...;
* IdArray rows = ... ; // [0, 3] * IdArray rows = ... ; // [0, 3]
* std::vector<int64_t> num_samples = {2, 2, 2}; * std::vector<int64_t> num_samples = {2, 2, 2};
* COOMatrix sampled = COORowWisePerEtypeSampling(coo, rows, eid2etype_offset, num_samples, * COOMatrix sampled = COORowWisePerEtypeSampling(coo, rows, eid2etype_offset,
* FloatArray(), false); * num_samples, FloatArray(), false);
* // possible sampled coo matrix: * // possible sampled coo matrix:
* // sampled.num_rows = 4 * // sampled.num_rows = 4
* // sampled.num_cols = 4 * // sampled.num_cols = 4
...@@ -450,20 +460,18 @@ COOMatrix COORowWiseSampling( ...@@ -450,20 +460,18 @@ COOMatrix COORowWiseSampling(
* Should be of the same length as the data array. * Should be of the same length as the data array.
* If an empty array is provided, assume uniform. * If an empty array is provided, assume uniform.
* @param replace True if sample with replacement * @param replace True if sample with replacement
* @return A COOMatrix storing the picked row and col indices. Its data field stores the * @return A COOMatrix storing the picked row and col indices. Its data field
* the index of the picked elements in the value array. * stores the the index of the picked elements in the value array.
* @note The edges of the entire graph must be ordered by their edge types. * @note The edges of the entire graph must be ordered by their edge types.
*/ */
COOMatrix COORowWisePerEtypeSampling( COOMatrix COORowWisePerEtypeSampling(
COOMatrix mat, COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
IdArray rows,
const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples, const std::vector<int64_t>& num_samples,
const std::vector<NDArray>& prob_or_mask, const std::vector<NDArray>& prob_or_mask, bool replace = true);
bool replace = true);
/** /**
* @brief Select K non-zero entries with the largest weights along each given row. * @brief Select K non-zero entries with the largest weights along each given
* row.
* *
* The function performs top-k selection along each row independently. * The function performs top-k selection along each row independently.
* The picked indices are returned in the form of a COO matrix. * The picked indices are returned in the form of a COO matrix.
...@@ -492,18 +500,15 @@ COOMatrix COORowWisePerEtypeSampling( ...@@ -492,18 +500,15 @@ COOMatrix COORowWisePerEtypeSampling(
* @param mat Input COO matrix. * @param mat Input COO matrix.
* @param rows Rows to sample from. * @param rows Rows to sample from.
* @param k The K value. * @param k The K value.
* @param weight Weight associated with each entry. Should be of the same length as the * @param weight Weight associated with each entry. Should be of the same length
* data array. If an empty array is provided, assume uniform. * as the data array. If an empty array is provided, assume uniform.
* @param ascending If true, elements are sorted by ascending order, equivalent to find * @param ascending If true, elements are sorted by ascending order, equivalent
* the K smallest values. Otherwise, find K largest values. * to find the K smallest values. Otherwise, find K largest values.
* @return A COOMatrix storing the picked row and col indices. Its data field stores the * @return A COOMatrix storing the picked row and col indices. Its data field
* the index of the picked elements in the value array. * stores the the index of the picked elements in the value array.
*/ */
COOMatrix COORowWiseTopk( COOMatrix COORowWiseTopk(
COOMatrix mat, COOMatrix mat, IdArray rows, int64_t k, NDArray weight,
IdArray rows,
int64_t k,
NDArray weight,
bool ascending = false); bool ascending = false);
/** /**
...@@ -535,8 +540,7 @@ COOMatrix COORowWiseTopk( ...@@ -535,8 +540,7 @@ COOMatrix COORowWiseTopk(
* COOMatrix_C.num_rows : 3 * COOMatrix_C.num_rows : 3
* COOMatrix_C.num_cols : 4 * COOMatrix_C.num_cols : 4
*/ */
COOMatrix UnionCoo( COOMatrix UnionCoo(const std::vector<COOMatrix>& coos);
const std::vector<COOMatrix>& coos);
/** /**
* @brief DisjointUnion a list COOMatrix into one COOMatrix. * @brief DisjointUnion a list COOMatrix into one COOMatrix.
...@@ -566,12 +570,13 @@ COOMatrix UnionCoo( ...@@ -566,12 +570,13 @@ COOMatrix UnionCoo(
* COOMatrix_C.num_cols : 5 * COOMatrix_C.num_cols : 5
* *
* @param coos The input list of coo matrix. * @param coos The input list of coo matrix.
* @param src_offset A list of integers recording src vertix id offset of each Matrix in coos * @param src_offset A list of integers recording src vertix id offset of each
* @param src_offset A list of integers recording dst vertix id offset of each Matrix in coos * Matrix in coos
* @param src_offset A list of integers recording dst vertix id offset of each
* Matrix in coos
* @return The combined COOMatrix. * @return The combined COOMatrix.
*/ */
COOMatrix DisjointUnionCoo( COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos);
const std::vector<COOMatrix>& coos);
/** /**
* @brief COOMatrix toSimple. * @brief COOMatrix toSimple.
...@@ -591,8 +596,8 @@ COOMatrix DisjointUnionCoo( ...@@ -591,8 +596,8 @@ COOMatrix DisjointUnionCoo(
* edge_map = [0, 0, 0, 1, 1, 2, 3, 4, 4, 4, 4] * edge_map = [0, 0, 0, 1, 1, 2, 3, 4, 4, 4, 4]
* *
* @return The simplified COOMatrix * @return The simplified COOMatrix
* The count recording the number of duplicated edges from the original graph. * The count recording the number of duplicated edges from the original
* The edge mapping from the edge IDs of original graph to those of the * graph. The edge mapping from the edge IDs of original graph to those of the
* returned graph. * returned graph.
*/ */
std::tuple<COOMatrix, IdArray, IdArray> COOToSimple(const COOMatrix& coo); std::tuple<COOMatrix, IdArray, IdArray> COOToSimple(const COOMatrix& coo);
...@@ -642,11 +647,10 @@ std::tuple<COOMatrix, IdArray, IdArray> COOToSimple(const COOMatrix& coo); ...@@ -642,11 +647,10 @@ std::tuple<COOMatrix, IdArray, IdArray> COOToSimple(const COOMatrix& coo);
* @return A list of COOMatrixes representing each disjoint components. * @return A list of COOMatrixes representing each disjoint components.
*/ */
std::vector<COOMatrix> DisjointPartitionCooBySizes( std::vector<COOMatrix> DisjointPartitionCooBySizes(
const COOMatrix &coo, const COOMatrix& coo, const uint64_t batch_size,
const uint64_t batch_size, const std::vector<uint64_t>& edge_cumsum,
const std::vector<uint64_t> &edge_cumsum, const std::vector<uint64_t>& src_vertex_cumsum,
const std::vector<uint64_t> &src_vertex_cumsum, const std::vector<uint64_t>& dst_vertex_cumsum);
const std::vector<uint64_t> &dst_vertex_cumsum);
/** /**
* @brief Slice a contiguous chunk from a COOMatrix * @brief Slice a contiguous chunk from a COOMatrix
...@@ -684,10 +688,9 @@ std::vector<COOMatrix> DisjointPartitionCooBySizes( ...@@ -684,10 +688,9 @@ std::vector<COOMatrix> DisjointPartitionCooBySizes(
* @return COOMatrix representing the chunk. * @return COOMatrix representing the chunk.
*/ */
COOMatrix COOSliceContiguousChunk( COOMatrix COOSliceContiguousChunk(
const COOMatrix &coo, const COOMatrix& coo, const std::vector<uint64_t>& edge_range,
const std::vector<uint64_t> &edge_range, const std::vector<uint64_t>& src_vertex_range,
const std::vector<uint64_t> &src_vertex_range, const std::vector<uint64_t>& dst_vertex_range);
const std::vector<uint64_t> &dst_vertex_range);
/** /**
* @brief Create a LineGraph of input coo * @brief Create a LineGraph of input coo
...@@ -716,10 +719,11 @@ COOMatrix COOSliceContiguousChunk( ...@@ -716,10 +719,11 @@ COOMatrix COOSliceContiguousChunk(
* [0, 1, 1, 0, 0]] * [0, 1, 1, 0, 0]]
* *
* @param coo COOMatrix to create the LineGraph * @param coo COOMatrix to create the LineGraph
* @param backtracking whether the pair of (v, u) (u, v) edges are treated as linked * @param backtracking whether the pair of (v, u) (u, v) edges are treated as
* linked
* @return LineGraph in COO format * @return LineGraph in COO format
*/ */
COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking); COOMatrix COOLineGraph(const COOMatrix& coo, bool backtracking);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
......
...@@ -223,8 +223,10 @@ bool CSRIsSorted(CSRMatrix csr); ...@@ -223,8 +223,10 @@ bool CSRIsSorted(CSRMatrix csr);
std::vector<runtime::NDArray> CSRGetDataAndIndices( std::vector<runtime::NDArray> CSRGetDataAndIndices(
CSRMatrix, runtime::NDArray rows, runtime::NDArray cols); CSRMatrix, runtime::NDArray rows, runtime::NDArray cols);
/* @brief Get data. The return type is an ndarray due to possible duplicate /**
* entries. */ * @brief Get data. The return type is an ndarray due to possible duplicate
* entries.
*/
inline runtime::NDArray CSRGetAllData(CSRMatrix mat, int64_t row, int64_t col) { inline runtime::NDArray CSRGetAllData(CSRMatrix mat, int64_t row, int64_t col) {
const auto& nbits = mat.indptr->dtype.bits; const auto& nbits = mat.indptr->dtype.bits;
const auto& ctx = mat.indptr->ctx; const auto& ctx = mat.indptr->ctx;
......
This diff is collapsed.
...@@ -10,7 +10,8 @@ ...@@ -10,7 +10,8 @@
#include <memory> #include <memory>
namespace dgl { namespace dgl {
// Util class to call the private/public empty constructor, which is needed for serialization // Util class to call the private/public empty constructor, which is needed for
// serialization
class Serializer { class Serializer {
public: public:
template <typename T> template <typename T>
......
This diff is collapsed.
...@@ -7,19 +7,18 @@ ...@@ -7,19 +7,18 @@
#define DGL_RUNTIME_PARALLEL_FOR_H_ #define DGL_RUNTIME_PARALLEL_FOR_H_
#include <dmlc/omp.h> #include <dmlc/omp.h>
#include <algorithm> #include <algorithm>
#include <string> #include <atomic>
#include <cstdlib> #include <cstdlib>
#include <exception> #include <exception>
#include <vector> #include <string>
#include <atomic>
#include <utility> #include <utility>
#include <vector>
namespace { namespace {
int64_t divup(int64_t x, int64_t y) { int64_t divup(int64_t x, int64_t y) { return (x + y - 1) / y; }
return (x + y - 1) / y; } // namespace
}
}
namespace dgl { namespace dgl {
namespace runtime { namespace runtime {
...@@ -37,9 +36,7 @@ struct DefaultGrainSizeT { ...@@ -37,9 +36,7 @@ struct DefaultGrainSizeT {
} }
} }
size_t operator()() { size_t operator()() { return grain_size; }
return grain_size;
}
}; };
} // namespace } // namespace
...@@ -48,7 +45,9 @@ inline size_t compute_num_threads(size_t begin, size_t end, size_t grain_size) { ...@@ -48,7 +45,9 @@ inline size_t compute_num_threads(size_t begin, size_t end, size_t grain_size) {
if (omp_in_parallel() || end - begin <= grain_size || end - begin == 1) if (omp_in_parallel() || end - begin <= grain_size || end - begin == 1)
return 1; return 1;
return std::min(static_cast<int64_t>(omp_get_max_threads()), divup(end - begin, grain_size)); return std::min(
static_cast<int64_t>(omp_get_max_threads()),
divup(end - begin, grain_size));
#else #else
return 1; return 1;
#endif #endif
...@@ -60,15 +59,13 @@ static DefaultGrainSizeT default_grain_size; ...@@ -60,15 +59,13 @@ static DefaultGrainSizeT default_grain_size;
* @brief OpenMP-based parallel for loop. * @brief OpenMP-based parallel for loop.
* *
* It requires each thread's workload to have at least \a grain_size elements. * It requires each thread's workload to have at least \a grain_size elements.
* The loop body will be a function that takes in two arguments \a begin and \a end, which * The loop body will be a function that takes in two arguments \a begin and \a
* stands for the starting (inclusive) and ending index (exclusive) of the workload. * end, which stands for the starting (inclusive) and ending index (exclusive)
* of the workload.
*/ */
template <typename F> template <typename F>
void parallel_for( void parallel_for(
const size_t begin, const size_t begin, const size_t end, const size_t grain_size, F&& f) {
const size_t end,
const size_t grain_size,
F&& f) {
if (begin >= end) { if (begin >= end) {
return; return;
} }
...@@ -89,13 +86,11 @@ void parallel_for( ...@@ -89,13 +86,11 @@ void parallel_for(
try { try {
f(begin_tid, end_tid); f(begin_tid, end_tid);
} catch (...) { } catch (...) {
if (!err_flag.test_and_set()) if (!err_flag.test_and_set()) eptr = std::current_exception();
eptr = std::current_exception();
} }
} }
} }
if (eptr) if (eptr) std::rethrow_exception(eptr);
std::rethrow_exception(eptr);
#else #else
f(begin, end); f(begin, end);
#endif #endif
...@@ -110,22 +105,21 @@ void parallel_for( ...@@ -110,22 +105,21 @@ void parallel_for(
* parallel for pragma with static scheduling. * parallel for pragma with static scheduling.
*/ */
template <typename F> template <typename F>
void parallel_for( void parallel_for(const size_t begin, const size_t end, F&& f) {
const size_t begin,
const size_t end,
F&& f) {
parallel_for(begin, end, default_grain_size(), std::forward<F>(f)); parallel_for(begin, end, default_grain_size(), std::forward<F>(f));
} }
/** /**
* @brief OpenMP-based two-stage parallel reduction. * @brief OpenMP-based two-stage parallel reduction.
* *
* The first-stage reduction function \a f works in parallel. Each thread's workload has * The first-stage reduction function \a f works in parallel. Each thread's
* at least \a grain_size elements. The loop body will be a function that takes in * workload has at least \a grain_size elements. The loop body will be a
* the starting index (inclusive), the ending index (exclusive), and the reduction identity. * function that takes in the starting index (inclusive), the ending index
* (exclusive), and the reduction identity.
* *
* The second-stage reduction function \a sf is a binary function working in the main * The second-stage reduction function \a sf is a binary function working in the
* thread. It aggregates the partially reduced result computed from each thread. * main thread. It aggregates the partially reduced result computed from each
* thread.
* *
* Example to compute a parallelized max reduction of an array \c a: * Example to compute a parallelized max reduction of an array \c a:
* *
...@@ -134,11 +128,9 @@ void parallel_for( ...@@ -134,11 +128,9 @@ void parallel_for(
* 100, // ending index * 100, // ending index
* 1, // grain size * 1, // grain size
* -std::numeric_limits<float>::infinity, // identity * -std::numeric_limits<float>::infinity, // identity
* [&a] (int begin, int end, float ident) { // first-stage partial reducer * [&a] (int begin, int end, float ident) { // first-stage partial
* float result = ident; * reducer float result = ident; for (int i = begin; i < end; ++i) result =
* for (int i = begin; i < end; ++i) * std::max(result, a[i]); return result;
* result = std::max(result, a[i]);
* return result;
* }, * },
* [] (float result, float partial_result) { * [] (float result, float partial_result) {
* return std::max(result, partial_result); * return std::max(result, partial_result);
...@@ -146,12 +138,8 @@ void parallel_for( ...@@ -146,12 +138,8 @@ void parallel_for(
*/ */
template <typename DType, typename F, typename SF> template <typename DType, typename F, typename SF>
DType parallel_reduce( DType parallel_reduce(
const size_t begin, const size_t begin, const size_t end, const size_t grain_size,
const size_t end, const DType ident, const F& f, const SF& sf) {
const size_t grain_size,
const DType ident,
const F& f,
const SF& sf) {
if (begin >= end) { if (begin >= end) {
return ident; return ident;
} }
...@@ -174,17 +162,14 @@ DType parallel_reduce( ...@@ -174,17 +162,14 @@ DType parallel_reduce(
try { try {
results[tid] = f(begin_tid, end_tid, ident); results[tid] = f(begin_tid, end_tid, ident);
} catch (...) { } catch (...) {
if (!err_flag.test_and_set()) if (!err_flag.test_and_set()) eptr = std::current_exception();
eptr = std::current_exception();
} }
} }
} }
if (eptr) if (eptr) std::rethrow_exception(eptr);
std::rethrow_exception(eptr);
DType out = ident; DType out = ident;
for (int64_t i = 0; i < num_threads; ++i) for (int64_t i = 0; i < num_threads; ++i) out = sf(out, results[i]);
out = sf(out, results[i]);
return out; return out;
} }
......
This diff is collapsed.
...@@ -8,9 +8,10 @@ ...@@ -8,9 +8,10 @@
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/graph_traversal.h> #include <dgl/graph_traversal.h>
#include <vector>
#include <tuple> #include <tuple>
#include <utility> #include <utility>
#include <vector>
namespace dgl { namespace dgl {
namespace aten { namespace aten {
...@@ -82,7 +83,8 @@ template <DGLDeviceType XPU, typename IdType> ...@@ -82,7 +83,8 @@ template <DGLDeviceType XPU, typename IdType>
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col); bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
runtime::NDArray CSRIsNonZero(CSRMatrix csr, runtime::NDArray row, runtime::NDArray col); runtime::NDArray CSRIsNonZero(
CSRMatrix csr, runtime::NDArray row, runtime::NDArray col);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
bool CSRHasDuplicate(CSRMatrix csr); bool CSRHasDuplicate(CSRMatrix csr);
...@@ -104,19 +106,21 @@ bool CSRIsSorted(CSRMatrix csr); ...@@ -104,19 +106,21 @@ bool CSRIsSorted(CSRMatrix csr);
template <DGLDeviceType XPU, typename IdType, typename DType> template <DGLDeviceType XPU, typename IdType, typename DType>
runtime::NDArray CSRGetData( runtime::NDArray CSRGetData(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols, bool return_eids, CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols,
runtime::NDArray weights, DType filler); bool return_eids, runtime::NDArray weights, DType filler);
template <DGLDeviceType XPU, typename IdType, typename DType> template <DGLDeviceType XPU, typename IdType, typename DType>
runtime::NDArray CSRGetData( runtime::NDArray CSRGetData(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols, CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols,
runtime::NDArray weights, DType filler) { runtime::NDArray weights, DType filler) {
return CSRGetData<XPU, IdType, DType>(csr, rows, cols, false, weights, filler); return CSRGetData<XPU, IdType, DType>(
csr, rows, cols, false, weights, filler);
} }
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) { NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
return CSRGetData<XPU, IdType, IdType>(csr, rows, cols, true, NullArray(rows->dtype), -1); return CSRGetData<XPU, IdType, IdType>(
csr, rows, cols, true, NullArray(rows->dtype), -1);
} }
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
...@@ -141,20 +145,23 @@ template <DGLDeviceType XPU, typename IdType> ...@@ -141,20 +145,23 @@ template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows); CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols); CSRMatrix CSRSliceMatrix(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
void CSRSort_(CSRMatrix* csr); void CSRSort_(CSRMatrix* csr);
template <DGLDeviceType XPU, typename IdType, typename TagType> template <DGLDeviceType XPU, typename IdType, typename TagType>
std::pair<CSRMatrix, NDArray> CSRSortByTag( std::pair<CSRMatrix, NDArray> CSRSortByTag(
const CSRMatrix &csr, IdArray tag_array, int64_t num_tags); const CSRMatrix& csr, IdArray tag_array, int64_t num_tags);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids); CSRMatrix CSRReorder(
CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids); COOMatrix COOReorder(
COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries); CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries);
...@@ -162,15 +169,16 @@ CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries); ...@@ -162,15 +169,16 @@ CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries);
// FloatType is the type of probability data. // FloatType is the type of probability data.
template <DGLDeviceType XPU, typename IdType, typename DType> template <DGLDeviceType XPU, typename IdType, typename DType>
COOMatrix CSRRowWiseSampling( COOMatrix CSRRowWiseSampling(
CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask, bool replace); CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
bool replace);
// FloatType is the type of probability data. // FloatType is the type of probability data.
template <DGLDeviceType XPU, typename IdType, typename DType> template <DGLDeviceType XPU, typename IdType, typename DType>
COOMatrix CSRRowWisePerEtypeSampling( COOMatrix CSRRowWisePerEtypeSampling(
CSRMatrix mat, IdArray rows, CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples, const std::vector<int64_t>& num_samples,
const std::vector<NDArray>& prob_or_mask, bool replace, bool rowwise_etype_sorted); const std::vector<NDArray>& prob_or_mask, bool replace,
bool rowwise_etype_sorted);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRRowWiseSamplingUniform( COOMatrix CSRRowWiseSamplingUniform(
...@@ -178,9 +186,9 @@ COOMatrix CSRRowWiseSamplingUniform( ...@@ -178,9 +186,9 @@ COOMatrix CSRRowWiseSamplingUniform(
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRRowWisePerEtypeSamplingUniform( COOMatrix CSRRowWisePerEtypeSamplingUniform(
CSRMatrix mat, IdArray rows, CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& eid2etype_offset, const std::vector<int64_t>& num_samples, bool replace,
const std::vector<int64_t>& num_samples, bool replace, bool rowwise_etype_sorted); bool rowwise_etype_sorted);
// FloatType is the type of weight data. // FloatType is the type of weight data.
template <DGLDeviceType XPU, typename IdType, typename DType> template <DGLDeviceType XPU, typename IdType, typename DType>
...@@ -189,21 +197,13 @@ COOMatrix CSRRowWiseTopk( ...@@ -189,21 +197,13 @@ COOMatrix CSRRowWiseTopk(
template <DGLDeviceType XPU, typename IdType, typename FloatType> template <DGLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix CSRRowWiseSamplingBiased( COOMatrix CSRRowWiseSamplingBiased(
CSRMatrix mat, CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray tag_offset,
IdArray rows, FloatArray bias, bool replace);
int64_t num_samples,
NDArray tag_offset,
FloatArray bias,
bool replace);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling( std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
const CSRMatrix& csr, const CSRMatrix& csr, int64_t num_samples, int num_trials,
int64_t num_samples, bool exclude_self_loops, bool replace, double redundancy);
int num_trials,
bool exclude_self_loops,
bool replace,
double redundancy);
// Union CSRMatrixes // Union CSRMatrixes
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
...@@ -218,7 +218,8 @@ template <DGLDeviceType XPU, typename IdType> ...@@ -218,7 +218,8 @@ template <DGLDeviceType XPU, typename IdType>
bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col); bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
runtime::NDArray COOIsNonZero(COOMatrix coo, runtime::NDArray row, runtime::NDArray col); runtime::NDArray COOIsNonZero(
COOMatrix coo, runtime::NDArray row, runtime::NDArray col);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
bool COOHasDuplicate(COOMatrix coo); bool COOHasDuplicate(COOMatrix coo);
...@@ -230,15 +231,16 @@ template <DGLDeviceType XPU, typename IdType> ...@@ -230,15 +231,16 @@ template <DGLDeviceType XPU, typename IdType>
runtime::NDArray COOGetRowNNZ(COOMatrix coo, runtime::NDArray row); runtime::NDArray COOGetRowNNZ(COOMatrix coo, runtime::NDArray row);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<runtime::NDArray, runtime::NDArray> std::pair<runtime::NDArray, runtime::NDArray> COOGetRowDataAndIndices(
COOGetRowDataAndIndices(COOMatrix coo, int64_t row); COOMatrix coo, int64_t row);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::vector<runtime::NDArray> COOGetDataAndIndices( std::vector<runtime::NDArray> COOGetDataAndIndices(
COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols); COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
runtime::NDArray COOGetData(COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols); runtime::NDArray COOGetData(
COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix COOTranspose(COOMatrix coo); COOMatrix COOTranspose(COOMatrix coo);
...@@ -253,7 +255,8 @@ template <DGLDeviceType XPU, typename IdType> ...@@ -253,7 +255,8 @@ template <DGLDeviceType XPU, typename IdType>
COOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows); COOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols); COOMatrix COOSliceMatrix(
COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo); std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo);
...@@ -273,13 +276,13 @@ COOMatrix COORemove(COOMatrix coo, IdArray entries); ...@@ -273,13 +276,13 @@ COOMatrix COORemove(COOMatrix coo, IdArray entries);
// FloatType is the type of probability data. // FloatType is the type of probability data.
template <DGLDeviceType XPU, typename IdType, typename DType> template <DGLDeviceType XPU, typename IdType, typename DType>
COOMatrix COORowWiseSampling( COOMatrix COORowWiseSampling(
COOMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask, bool replace); COOMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
bool replace);
// FloatType is the type of probability data. // FloatType is the type of probability data.
template <DGLDeviceType XPU, typename IdType, typename DType> template <DGLDeviceType XPU, typename IdType, typename DType>
COOMatrix COORowWisePerEtypeSampling( COOMatrix COORowWisePerEtypeSampling(
COOMatrix mat, IdArray rows, COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples, const std::vector<int64_t>& num_samples,
const std::vector<NDArray>& prob_or_mask, bool replace); const std::vector<NDArray>& prob_or_mask, bool replace);
...@@ -289,8 +292,7 @@ COOMatrix COORowWiseSamplingUniform( ...@@ -289,8 +292,7 @@ COOMatrix COORowWiseSamplingUniform(
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix COORowWisePerEtypeSamplingUniform( COOMatrix COORowWisePerEtypeSamplingUniform(
COOMatrix mat, IdArray rows, COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples, bool replace); const std::vector<int64_t>& num_samples, bool replace);
// FloatType is the type of weight data. // FloatType is the type of weight data.
...@@ -313,14 +315,12 @@ template <DGLDeviceType XPU, typename IdType> ...@@ -313,14 +315,12 @@ template <DGLDeviceType XPU, typename IdType>
Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source); Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr, Frontiers DGLDFSLabeledEdges(
IdArray source, const CSRMatrix& csr, IdArray source, const bool has_reverse_edge,
const bool has_reverse_edge, const bool has_nontree_edge, const bool return_labels);
const bool has_nontree_edge,
const bool return_labels);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking); COOMatrix COOLineGraph(const COOMatrix& coo, bool backtracking);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
/** /**
* Copyright (c) 2020 by Contributors * Copyright (c) 2020 by Contributors
* @file kernel/cpu/gaher_mm.cc * @file kernel/cpu/gaher_mm.cc
* @brief GatherMM C APIs and definitions. * @brief GatherMM C APIs and definitions.
*/ */
#include "./gather_mm.h" #include "./gather_mm.h"
#include <dgl/array.h> #include <dgl/array.h>
namespace dgl { namespace dgl {
...@@ -11,81 +12,72 @@ namespace aten { ...@@ -11,81 +12,72 @@ namespace aten {
/** @brief Generalized SegmentMM. */ /** @brief Generalized SegmentMM. */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SegmentMM(const NDArray A, void SegmentMM(
const NDArray B, const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
NDArray C,
const NDArray seglen_A,
bool a_trans, bool b_trans) { bool a_trans, bool b_trans) {
LOG(FATAL) << "Unsupported CPU kernel for SegmentMM."; LOG(FATAL) << "Unsupported CPU kernel for SegmentMM.";
} }
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SegmentMMBackwardB(const NDArray A, void SegmentMMBackwardB(
const NDArray dC, const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen) {
NDArray dB,
const NDArray seglen) {
LOG(FATAL) << "Unsupported CPU kernel for SegmentMMBackwardB."; LOG(FATAL) << "Unsupported CPU kernel for SegmentMMBackwardB.";
} }
/** @brief Generalized GatherMM. */ /** @brief Generalized GatherMM. */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void GatherMM(const NDArray A, void GatherMM(
const NDArray B, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
NDArray C,
const NDArray idx_a,
const NDArray idx_b) { const NDArray idx_b) {
LOG(FATAL) << "Unsupported CPU kernel for GatherMM."; LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
} }
/** @brief Generalized GatherMM_scatter. */ /** @brief Generalized GatherMM_scatter. */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void GatherMMScatter(const NDArray A, void GatherMMScatter(
const NDArray B, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
NDArray C, const NDArray idx_b, const NDArray idx_c) {
const NDArray idx_a,
const NDArray idx_b,
const NDArray idx_c) {
LOG(FATAL) << "Unsupported CPU kernel for GatherMM."; LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
} }
template void GatherMM<kDGLCPU, int32_t, float>( template void GatherMM<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_b);
template void GatherMM<kDGLCPU, int64_t, float>( template void GatherMM<kDGLCPU, int64_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_b);
template void GatherMM<kDGLCPU, int32_t, double>( template void GatherMM<kDGLCPU, int32_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_b);
template void GatherMM<kDGLCPU, int64_t, double>( template void GatherMM<kDGLCPU, int64_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_b);
template void GatherMMScatter<kDGLCPU, int32_t, float>( template void GatherMMScatter<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCPU, int64_t, float>( template void GatherMMScatter<kDGLCPU, int64_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCPU, int32_t, double>( template void GatherMMScatter<kDGLCPU, int32_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCPU, int64_t, double>( template void GatherMMScatter<kDGLCPU, int64_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_b, const NDArray idx_c);
template void SegmentMM<kDGLCPU, int32_t, float>( template void SegmentMM<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
const NDArray seglen_A, bool a_trans, bool b_trans); bool a_trans, bool b_trans);
template void SegmentMM<kDGLCPU, int64_t, float>( template void SegmentMM<kDGLCPU, int64_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
const NDArray seglen_A, bool a_trans, bool b_trans); bool a_trans, bool b_trans);
template void SegmentMM<kDGLCPU, int32_t, double>( template void SegmentMM<kDGLCPU, int32_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
const NDArray seglen_A, bool a_trans, bool b_trans); bool a_trans, bool b_trans);
template void SegmentMM<kDGLCPU, int64_t, double>( template void SegmentMM<kDGLCPU, int64_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
const NDArray seglen_A, bool a_trans, bool b_trans); bool a_trans, bool b_trans);
template void SegmentMMBackwardB<kDGLCPU, int32_t, float>( template void SegmentMMBackwardB<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
......
...@@ -7,13 +7,14 @@ ...@@ -7,13 +7,14 @@
#define DGL_ARRAY_CPU_ROWWISE_PICK_H_ #define DGL_ARRAY_CPU_ROWWISE_PICK_H_
#include <dgl/array.h> #include <dgl/array.h>
#include <dmlc/omp.h>
#include <dgl/runtime/parallel_for.h> #include <dgl/runtime/parallel_for.h>
#include <functional> #include <dmlc/omp.h>
#include <algorithm> #include <algorithm>
#include <functional>
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory>
namespace dgl { namespace dgl {
namespace aten { namespace aten {
...@@ -41,10 +42,10 @@ namespace impl { ...@@ -41,10 +42,10 @@ namespace impl {
template <typename IdxType> template <typename IdxType>
using PickFn = std::function<void( using PickFn = std::function<void(
IdxType rowid, IdxType off, IdxType len, IdxType num_picks, IdxType rowid, IdxType off, IdxType len, IdxType num_picks,
const IdxType* col, const IdxType* data, const IdxType* col, const IdxType* data, IdxType* out_idx)>;
IdxType* out_idx)>;
// User-defined function for determining the number of elements to pick from one row. // User-defined function for determining the number of elements to pick from one
// row.
// //
// The column indices of the given row are stored in // The column indices of the given row are stored in
// [col + off, col + off + len) // [col + off, col + off + len)
...@@ -63,8 +64,8 @@ using PickFn = std::function<void( ...@@ -63,8 +64,8 @@ using PickFn = std::function<void(
// \param data Pointer of the data indices. // \param data Pointer of the data indices.
template <typename IdxType> template <typename IdxType>
using NumPicksFn = std::function<IdxType( using NumPicksFn = std::function<IdxType(
IdxType rowid, IdxType off, IdxType len, IdxType rowid, IdxType off, IdxType len, const IdxType* col,
const IdxType* col, const IdxType* data)>; const IdxType* data)>;
// User-defined function for picking elements from a range within a row. // User-defined function for picking elements from a range within a row.
// //
...@@ -73,7 +74,8 @@ using NumPicksFn = std::function<IdxType( ...@@ -73,7 +74,8 @@ using NumPicksFn = std::function<IdxType(
// //
// Similarly, the data indices are stored in // Similarly, the data indices are stored in
// data[off+et_idx[et_offset+i])] // data[off+et_idx[et_offset+i])]
// Data index pointer could be NULL, which means data[i] == off+et_idx[et_offset+i]) // Data index pointer could be NULL, which means data[i] ==
// off+et_idx[et_offset+i])
// //
// *ATTENTION*: This function will be invoked concurrently. Please make sure // *ATTENTION*: This function will be invoked concurrently. Please make sure
// it is thread-safe. // it is thread-safe.
...@@ -93,15 +95,17 @@ using EtypeRangePickFn = std::function<void( ...@@ -93,15 +95,17 @@ using EtypeRangePickFn = std::function<void(
const IdxType* eid, IdxType* out_idx)>; const IdxType* eid, IdxType* out_idx)>;
// Template for picking non-zero values row-wise. The implementation utilizes // Template for picking non-zero values row-wise. The implementation utilizes
// OpenMP parallelization on rows because each row performs computation independently. // OpenMP parallelization on rows because each row performs computation
// independently.
template <typename IdxType> template <typename IdxType>
COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows, COOMatrix CSRRowWisePick(
int64_t num_picks, bool replace, PickFn<IdxType> pick_fn, CSRMatrix mat, IdArray rows, int64_t num_picks, bool replace,
NumPicksFn<IdxType> num_picks_fn) { PickFn<IdxType> pick_fn, NumPicksFn<IdxType> num_picks_fn) {
using namespace aten; using namespace aten;
const IdxType* indptr = static_cast<IdxType*>(mat.indptr->data); const IdxType* indptr = static_cast<IdxType*>(mat.indptr->data);
const IdxType* indices = static_cast<IdxType*>(mat.indices->data); const IdxType* indices = static_cast<IdxType*>(mat.indices->data);
const IdxType* data = CSRHasData(mat)? static_cast<IdxType*>(mat.data->data) : nullptr; const IdxType* data =
CSRHasData(mat) ? static_cast<IdxType*>(mat.data->data) : nullptr;
const IdxType* rows_data = static_cast<IdxType*>(rows->data); const IdxType* rows_data = static_cast<IdxType*>(rows->data);
const int64_t num_rows = rows->shape[0]; const int64_t num_rows = rows->shape[0];
const auto& ctx = mat.indptr->ctx; const auto& ctx = mat.indptr->ctx;
...@@ -115,30 +119,36 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows, ...@@ -115,30 +119,36 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
// have at least num_picks number of nnz when replace is false. // have at least num_picks number of nnz when replace is false.
// //
// If the check holds, remove -1 elements by remove_if operation, which simply // If the check holds, remove -1 elements by remove_if operation, which simply
// moves valid elements to the head of arrays and create a view of the original // moves valid elements to the head of arrays and create a view of the
// array. The implementation consumes a little extra memory than the actual requirement. // original array. The implementation consumes a little extra memory than the
// actual requirement.
// //
// Otherwise, directly use the row and col arrays to construct the result COO matrix. // Otherwise, directly use the row and col arrays to construct the result COO
// matrix.
// //
// [02/29/2020 update]: OMP is disabled for now since batch-wise parallelism is more // [02/29/2020 update]: OMP is disabled for now since batch-wise parallelism
// is more
// significant. (minjie) // significant. (minjie)
// Do not use omp_get_max_threads() since that doesn't work for compiling without OpenMP. // Do not use omp_get_max_threads() since that doesn't work for compiling
// without OpenMP.
const int num_threads = runtime::compute_num_threads(0, num_rows, 1); const int num_threads = runtime::compute_num_threads(0, num_rows, 1);
std::vector<int64_t> global_prefix(num_threads + 1, 0); std::vector<int64_t> global_prefix(num_threads + 1, 0);
// TODO(BarclayII) Using OMP parallel directly instead of using runtime::parallel_for // TODO(BarclayII) Using OMP parallel directly instead of using
// does not handle exceptions well (directly aborts when an exception pops up). // runtime::parallel_for does not handle exceptions well (directly aborts when
// It runs faster though because there is less scheduling. Need to handle // an exception pops up). It runs faster though because there is less
// exceptions better. // scheduling. Need to handle exceptions better.
IdArray picked_row, picked_col, picked_idx; IdArray picked_row, picked_col, picked_idx;
#pragma omp parallel num_threads(num_threads) #pragma omp parallel num_threads(num_threads)
{ {
const int thread_id = omp_get_thread_num(); const int thread_id = omp_get_thread_num();
const int64_t start_i = thread_id * (num_rows/num_threads) + const int64_t start_i =
thread_id * (num_rows / num_threads) +
std::min(static_cast<int64_t>(thread_id), num_rows % num_threads); std::min(static_cast<int64_t>(thread_id), num_rows % num_threads);
const int64_t end_i = (thread_id + 1) * (num_rows/num_threads) + const int64_t end_i =
(thread_id + 1) * (num_rows / num_threads) +
std::min(static_cast<int64_t>(thread_id + 1), num_rows % num_threads); std::min(static_cast<int64_t>(thread_id + 1), num_rows % num_threads);
assert(thread_id + 1 < num_threads || end_i == num_rows); assert(thread_id + 1 < num_threads || end_i == num_rows);
...@@ -149,7 +159,7 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows, ...@@ -149,7 +159,7 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
local_prefix[0] = 0; local_prefix[0] = 0;
for (int64_t i = start_i; i < end_i; ++i) { for (int64_t i = start_i; i < end_i; ++i) {
// build prefix-sum // build prefix-sum
const int64_t local_i = i-start_i; const int64_t local_i = i - start_i;
const IdxType rid = rows_data[i]; const IdxType rid = rows_data[i];
IdxType len = num_picks_fn( IdxType len = num_picks_fn(
rid, indptr[rid], indptr[rid + 1] - indptr[rid], indices, data); rid, indptr[rid], indptr[rid + 1] - indptr[rid], indices, data);
...@@ -157,8 +167,8 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows, ...@@ -157,8 +167,8 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
} }
global_prefix[thread_id + 1] = local_prefix[num_local]; global_prefix[thread_id + 1] = local_prefix[num_local];
#pragma omp barrier #pragma omp barrier
#pragma omp master #pragma omp master
{ {
for (int t = 0; t < num_threads; ++t) { for (int t = 0; t < num_threads; ++t) {
global_prefix[t + 1] += global_prefix[t]; global_prefix[t + 1] += global_prefix[t];
...@@ -168,7 +178,7 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows, ...@@ -168,7 +178,7 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
picked_idx = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx); picked_idx = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);
} }
#pragma omp barrier #pragma omp barrier
IdxType* picked_rdata = picked_row.Ptr<IdxType>(); IdxType* picked_rdata = picked_row.Ptr<IdxType>();
IdxType* picked_cdata = picked_col.Ptr<IdxType>(); IdxType* picked_cdata = picked_col.Ptr<IdxType>();
IdxType* picked_idata = picked_idx.Ptr<IdxType>(); IdxType* picked_idata = picked_idx.Ptr<IdxType>();
...@@ -180,14 +190,15 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows, ...@@ -180,14 +190,15 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
const IdxType off = indptr[rid]; const IdxType off = indptr[rid];
const IdxType len = indptr[rid + 1] - off; const IdxType len = indptr[rid + 1] - off;
if (len == 0) if (len == 0) continue;
continue;
const int64_t local_i = i - start_i; const int64_t local_i = i - start_i;
const int64_t row_offset = thread_offset + local_prefix[local_i]; const int64_t row_offset = thread_offset + local_prefix[local_i];
const int64_t num_picks = thread_offset + local_prefix[local_i + 1] - row_offset; const int64_t num_picks =
thread_offset + local_prefix[local_i + 1] - row_offset;
pick_fn(rid, off, len, num_picks, indices, data, picked_idata + row_offset); pick_fn(
rid, off, len, num_picks, indices, data, picked_idata + row_offset);
for (int64_t j = 0; j < num_picks; ++j) { for (int64_t j = 0; j < num_picks; ++j) {
const IdxType picked = picked_idata[row_offset + j]; const IdxType picked = picked_idata[row_offset + j];
picked_rdata[row_offset + j] = rid; picked_rdata[row_offset + j] = rid;
...@@ -200,25 +211,25 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows, ...@@ -200,25 +211,25 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
const int64_t new_len = global_prefix.back(); const int64_t new_len = global_prefix.back();
return COOMatrix( return COOMatrix(
mat.num_rows, mat.num_rows, mat.num_cols,
mat.num_cols,
picked_row.CreateView({new_len}, picked_row->dtype), picked_row.CreateView({new_len}, picked_row->dtype),
picked_col.CreateView({new_len}, picked_row->dtype), picked_col.CreateView({new_len}, picked_row->dtype),
picked_idx.CreateView({new_len}, picked_row->dtype)); picked_idx.CreateView({new_len}, picked_row->dtype));
} }
// Template for picking non-zero values row-wise. The implementation utilizes // Template for picking non-zero values row-wise. The implementation utilizes
// OpenMP parallelization on rows because each row performs computation independently. // OpenMP parallelization on rows because each row performs computation
// independently.
template <typename IdxType, typename DType> template <typename IdxType, typename DType>
COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, COOMatrix CSRRowWisePerEtypePick(
const std::vector<int64_t>& eid2etype_offset, CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_picks, bool replace, const std::vector<int64_t>& num_picks, bool replace,
bool rowwise_etype_sorted, EtypeRangePickFn<IdxType> pick_fn, bool rowwise_etype_sorted, EtypeRangePickFn<IdxType> pick_fn,
const std::vector<NDArray>& prob_or_mask) { const std::vector<NDArray>& prob_or_mask) {
using namespace aten; using namespace aten;
const IdxType* indptr = mat.indptr.Ptr<IdxType>(); const IdxType* indptr = mat.indptr.Ptr<IdxType>();
const IdxType* indices = mat.indices.Ptr<IdxType>(); const IdxType* indices = mat.indices.Ptr<IdxType>();
const IdxType* eid = CSRHasData(mat)? mat.data.Ptr<IdxType>() : nullptr; const IdxType* eid = CSRHasData(mat) ? mat.data.Ptr<IdxType>() : nullptr;
const IdxType* rows_data = rows.Ptr<IdxType>(); const IdxType* rows_data = rows.Ptr<IdxType>();
const int64_t num_rows = rows->shape[0]; const int64_t num_rows = rows->shape[0];
const auto& ctx = mat.indptr->ctx; const auto& ctx = mat.indptr->ctx;
...@@ -229,8 +240,8 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, ...@@ -229,8 +240,8 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows,
std::vector<IdArray> picked_idxs(rows->shape[0]); std::vector<IdArray> picked_idxs(rows->shape[0]);
// Check if the number of picks have the same value. // Check if the number of picks have the same value.
// If so, we can potentially speed up if we have a node with total number of neighbors // If so, we can potentially speed up if we have a node with total number of
// less than the given number of picks with replace=False. // neighbors less than the given number of picks with replace=False.
bool same_num_pick = true; bool same_num_pick = true;
int64_t num_pick_value = num_picks[0]; int64_t num_pick_value = num_picks[0];
for (int64_t num_pick : num_picks) { for (int64_t num_pick : num_picks) {
...@@ -267,9 +278,10 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, ...@@ -267,9 +278,10 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows,
for (int64_t j = 0; j < len; ++j) { for (int64_t j = 0; j < len; ++j) {
const IdxType homogenized_eid = eid ? eid[off + j] : off + j; const IdxType homogenized_eid = eid ? eid[off + j] : off + j;
auto it = std::upper_bound( auto it = std::upper_bound(
eid2etype_offset.begin(), eid2etype_offset.end(), homogenized_eid); eid2etype_offset.begin(), eid2etype_offset.end(),
homogenized_eid);
const IdxType heterogenized_etype = it - eid2etype_offset.begin() - 1; const IdxType heterogenized_etype = it - eid2etype_offset.begin() - 1;
const IdxType heterogenized_eid = \ const IdxType heterogenized_eid =
homogenized_eid - eid2etype_offset[heterogenized_etype]; homogenized_eid - eid2etype_offset[heterogenized_etype];
if (!has_probs || IsNullArray(prob_or_mask[heterogenized_etype])) { if (!has_probs || IsNullArray(prob_or_mask[heterogenized_etype])) {
...@@ -305,17 +317,21 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, ...@@ -305,17 +317,21 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows,
for (int64_t j = 0; j < len; ++j) { for (int64_t j = 0; j < len; ++j) {
const IdxType homogenized_eid = eid ? eid[off + j] : off + j; const IdxType homogenized_eid = eid ? eid[off + j] : off + j;
auto it = std::upper_bound( auto it = std::upper_bound(
eid2etype_offset.begin(), eid2etype_offset.end(), homogenized_eid); eid2etype_offset.begin(), eid2etype_offset.end(),
homogenized_eid);
const IdxType heterogenized_etype = it - eid2etype_offset.begin() - 1; const IdxType heterogenized_etype = it - eid2etype_offset.begin() - 1;
const IdxType heterogenized_eid = \ const IdxType heterogenized_eid =
homogenized_eid - eid2etype_offset[heterogenized_etype]; homogenized_eid - eid2etype_offset[heterogenized_etype];
et[j] = heterogenized_etype; et[j] = heterogenized_etype;
et_eid[j] = heterogenized_eid; et_eid[j] = heterogenized_eid;
} }
if (!rowwise_etype_sorted) // the edge type is sorted, not need to sort it if (!rowwise_etype_sorted) // the edge type is sorted, not need to sort
std::sort(et_idx.begin(), et_idx.end(), // it
[&et](IdxType i1, IdxType i2) {return et[i1] < et[i2];}); std::sort(
CHECK_LT(et[et_idx[len - 1]], num_etypes) << "etype values exceed the number of fanouts"; et_idx.begin(), et_idx.end(),
[&et](IdxType i1, IdxType i2) { return et[i1] < et[i2]; });
CHECK_LT(et[et_idx[len - 1]], num_etypes)
<< "etype values exceed the number of fanouts";
IdxType cur_et = et[et_idx[0]]; IdxType cur_et = et[et_idx[0]];
int64_t et_offset = 0; int64_t et_offset = 0;
...@@ -333,14 +349,18 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, ...@@ -333,14 +349,18 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows,
// fast path, select all // fast path, select all
for (int64_t k = 0; k < et_len; ++k) { for (int64_t k = 0; k < et_len; ++k) {
const IdxType eid_offset = off + et_idx[et_offset + k]; const IdxType eid_offset = off + et_idx[et_offset + k];
const IdxType homogenized_eid = eid ? eid[eid_offset] : eid_offset; const IdxType homogenized_eid =
eid ? eid[eid_offset] : eid_offset;
auto it = std::upper_bound( auto it = std::upper_bound(
eid2etype_offset.begin(), eid2etype_offset.end(), homogenized_eid); eid2etype_offset.begin(), eid2etype_offset.end(),
const IdxType heterogenized_etype = it - eid2etype_offset.begin() - 1; homogenized_eid);
const IdxType heterogenized_eid = \ const IdxType heterogenized_etype =
it - eid2etype_offset.begin() - 1;
const IdxType heterogenized_eid =
homogenized_eid - eid2etype_offset[heterogenized_etype]; homogenized_eid - eid2etype_offset[heterogenized_etype];
if (!has_probs || IsNullArray(prob_or_mask[heterogenized_etype])) { if (!has_probs ||
IsNullArray(prob_or_mask[heterogenized_etype])) {
// No probability, select all // No probability, select all
rows.push_back(rid); rows.push_back(rid);
cols.push_back(indices[eid_offset]); cols.push_back(indices[eid_offset]);
...@@ -357,32 +377,31 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, ...@@ -357,32 +377,31 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows,
} }
} }
} else { } else {
IdArray picked_idx = Full(-1, num_picks[cur_et], sizeof(IdxType) * 8, ctx); IdArray picked_idx =
Full(-1, num_picks[cur_et], sizeof(IdxType) * 8, ctx);
IdxType* picked_idata = picked_idx.Ptr<IdxType>(); IdxType* picked_idata = picked_idx.Ptr<IdxType>();
// need call random pick // need call random pick
pick_fn(off, et_offset, cur_et, pick_fn(
et_len, et_idx, et_eid, off, et_offset, cur_et, et_len, et_idx, et_eid, eid,
eid, picked_idata); picked_idata);
for (int64_t k = 0; k < num_picks[cur_et]; ++k) { for (int64_t k = 0; k < num_picks[cur_et]; ++k) {
const IdxType picked = picked_idata[k]; const IdxType picked = picked_idata[k];
if (picked == -1) if (picked == -1) continue;
continue;
rows.push_back(rid); rows.push_back(rid);
cols.push_back(indices[off+et_idx[et_offset+picked]]); cols.push_back(indices[off + et_idx[et_offset + picked]]);
if (eid) { if (eid) {
idx.push_back(eid[off+et_idx[et_offset+picked]]); idx.push_back(eid[off + et_idx[et_offset + picked]]);
} else { } else {
idx.push_back(off+et_idx[et_offset+picked]); idx.push_back(off + et_idx[et_offset + picked]);
} }
} }
} }
if (j+1 == len) if (j + 1 == len) break;
break;
// next etype // next etype
cur_et = et[et_idx[j+1]]; cur_et = et[et_idx[j + 1]];
et_offset = j+1; et_offset = j + 1;
et_len = 1; et_len = 1;
} else { } else {
et_len++; et_len++;
...@@ -402,31 +421,32 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows, ...@@ -402,31 +421,32 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows,
IdArray picked_row = Concat(picked_rows); IdArray picked_row = Concat(picked_rows);
IdArray picked_col = Concat(picked_cols); IdArray picked_col = Concat(picked_cols);
IdArray picked_idx = Concat(picked_idxs); IdArray picked_idx = Concat(picked_idxs);
return COOMatrix(mat.num_rows, mat.num_cols, return COOMatrix(
picked_row, picked_col, picked_idx); mat.num_rows, mat.num_cols, picked_row, picked_col, picked_idx);
} }
// Template for picking non-zero values row-wise. The implementation first slices // Template for picking non-zero values row-wise. The implementation first
// out the corresponding rows and then converts it to CSR format. It then performs // slices out the corresponding rows and then converts it to CSR format. It then
// row-wise pick on the CSR matrix and rectifies the returned results. // performs row-wise pick on the CSR matrix and rectifies the returned results.
template <typename IdxType> template <typename IdxType>
COOMatrix COORowWisePick(COOMatrix mat, IdArray rows, COOMatrix COORowWisePick(
int64_t num_picks, bool replace, PickFn<IdxType> pick_fn, COOMatrix mat, IdArray rows, int64_t num_picks, bool replace,
NumPicksFn<IdxType> num_picks_fn) { PickFn<IdxType> pick_fn, NumPicksFn<IdxType> num_picks_fn) {
using namespace aten; using namespace aten;
const auto& csr = COOToCSR(COOSliceRows(mat, rows)); const auto& csr = COOToCSR(COOSliceRows(mat, rows));
const IdArray new_rows = Range(0, rows->shape[0], rows->dtype.bits, rows->ctx); const IdArray new_rows =
Range(0, rows->shape[0], rows->dtype.bits, rows->ctx);
const auto& picked = CSRRowWisePick<IdxType>( const auto& picked = CSRRowWisePick<IdxType>(
csr, new_rows, num_picks, replace, pick_fn, num_picks_fn); csr, new_rows, num_picks, replace, pick_fn, num_picks_fn);
return COOMatrix(mat.num_rows, mat.num_cols, return COOMatrix(
mat.num_rows, mat.num_cols,
IndexSelect(rows, picked.row), // map the row index to the correct one IndexSelect(rows, picked.row), // map the row index to the correct one
picked.col, picked.col, picked.data);
picked.data);
} }
// Template for picking non-zero values row-wise. The implementation first slices // Template for picking non-zero values row-wise. The implementation first
// out the corresponding rows and then converts it to CSR format. It then performs // slices out the corresponding rows and then converts it to CSR format. It then
// row-wise pick on the CSR matrix and rectifies the returned results. // performs row-wise pick on the CSR matrix and rectifies the returned results.
template <typename IdxType, typename DType> template <typename IdxType, typename DType>
COOMatrix COORowWisePerEtypePick( COOMatrix COORowWisePerEtypePick(
COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset, COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
...@@ -435,13 +455,15 @@ COOMatrix COORowWisePerEtypePick( ...@@ -435,13 +455,15 @@ COOMatrix COORowWisePerEtypePick(
const std::vector<NDArray>& prob_or_mask) { const std::vector<NDArray>& prob_or_mask) {
using namespace aten; using namespace aten;
const auto& csr = COOToCSR(COOSliceRows(mat, rows)); const auto& csr = COOToCSR(COOSliceRows(mat, rows));
const IdArray new_rows = Range(0, rows->shape[0], rows->dtype.bits, rows->ctx); const IdArray new_rows =
Range(0, rows->shape[0], rows->dtype.bits, rows->ctx);
const auto& picked = CSRRowWisePerEtypePick<IdxType, DType>( const auto& picked = CSRRowWisePerEtypePick<IdxType, DType>(
csr, new_rows, eid2etype_offset, num_picks, replace, false, pick_fn, prob_or_mask); csr, new_rows, eid2etype_offset, num_picks, replace, false, pick_fn,
return COOMatrix(mat.num_rows, mat.num_cols, prob_or_mask);
return COOMatrix(
mat.num_rows, mat.num_cols,
IndexSelect(rows, picked.row), // map the row index to the correct one IndexSelect(rows, picked.row), // map the row index to the correct one
picked.col, picked.col, picked.data);
picked.data);
} }
} // namespace impl } // namespace impl
......
This diff is collapsed.
...@@ -3,8 +3,9 @@ ...@@ -3,8 +3,9 @@
* @file array/cpu/rowwise_topk.cc * @file array/cpu/rowwise_topk.cc
* @brief rowwise topk * @brief rowwise topk
*/ */
#include <numeric>
#include <algorithm> #include <algorithm>
#include <numeric>
#include "./rowwise_pick.h" #include "./rowwise_pick.h"
namespace dgl { namespace dgl {
...@@ -14,9 +15,9 @@ namespace { ...@@ -14,9 +15,9 @@ namespace {
template <typename IdxType> template <typename IdxType>
inline NumPicksFn<IdxType> GetTopkNumPicksFn(int64_t k) { inline NumPicksFn<IdxType> GetTopkNumPicksFn(int64_t k) {
NumPicksFn<IdxType> num_picks_fn = [k] NumPicksFn<IdxType> num_picks_fn = [k](IdxType rowid, IdxType off,
(IdxType rowid, IdxType off, IdxType len, IdxType len, const IdxType* col,
const IdxType* col, const IdxType* data) { const IdxType* data) {
const int64_t max_num_picks = (k == -1) ? len : k; const int64_t max_num_picks = (k == -1) ? len : k;
return std::min(static_cast<IdxType>(max_num_picks), len); return std::min(static_cast<IdxType>(max_num_picks), len);
}; };
...@@ -26,28 +27,28 @@ inline NumPicksFn<IdxType> GetTopkNumPicksFn(int64_t k) { ...@@ -26,28 +27,28 @@ inline NumPicksFn<IdxType> GetTopkNumPicksFn(int64_t k) {
template <typename IdxType, typename DType> template <typename IdxType, typename DType>
inline PickFn<IdxType> GetTopkPickFn(NDArray weight, bool ascending) { inline PickFn<IdxType> GetTopkPickFn(NDArray weight, bool ascending) {
const DType* wdata = static_cast<DType*>(weight->data); const DType* wdata = static_cast<DType*>(weight->data);
PickFn<IdxType> pick_fn = [ascending, wdata] PickFn<IdxType> pick_fn = [ascending, wdata](
(IdxType rowid, IdxType off, IdxType len, IdxType num_picks, IdxType rowid, IdxType off, IdxType len,
const IdxType* col, const IdxType* data, IdxType num_picks, const IdxType* col,
IdxType* out_idx) { const IdxType* data, IdxType* out_idx) {
std::function<bool(IdxType, IdxType)> compare_fn; std::function<bool(IdxType, IdxType)> compare_fn;
if (ascending) { if (ascending) {
if (data) { if (data) {
compare_fn = [wdata, data] (IdxType i, IdxType j) { compare_fn = [wdata, data](IdxType i, IdxType j) {
return wdata[data[i]] < wdata[data[j]]; return wdata[data[i]] < wdata[data[j]];
}; };
} else { } else {
compare_fn = [wdata] (IdxType i, IdxType j) { compare_fn = [wdata](IdxType i, IdxType j) {
return wdata[i] < wdata[j]; return wdata[i] < wdata[j];
}; };
} }
} else { } else {
if (data) { if (data) {
compare_fn = [wdata, data] (IdxType i, IdxType j) { compare_fn = [wdata, data](IdxType i, IdxType j) {
return wdata[data[i]] > wdata[data[j]]; return wdata[data[i]] > wdata[data[j]];
}; };
} else { } else {
compare_fn = [wdata] (IdxType i, IdxType j) { compare_fn = [wdata](IdxType i, IdxType j) {
return wdata[i] > wdata[j]; return wdata[i] > wdata[j];
}; };
} }
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* @brief SDDMM C APIs and definitions. * @brief SDDMM C APIs and definitions.
*/ */
#include "./sddmm.h" #include "./sddmm.h"
#include <dgl/array.h> #include <dgl/array.h>
namespace dgl { namespace dgl {
...@@ -25,7 +26,7 @@ namespace aten { ...@@ -25,7 +26,7 @@ namespace aten {
} \ } \
} while (0) } while (0)
#define SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, ...)\ #define SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, ...) \
do { \ do { \
if ((lhs_target) == 0) { \ if ((lhs_target) == 0) { \
constexpr int LhsTarget = 0; \ constexpr int LhsTarget = 0; \
...@@ -41,35 +42,26 @@ namespace aten { ...@@ -41,35 +42,26 @@ namespace aten {
} \ } \
} while (0) } while (0)
/** @brief Generalized SDDMM on Csr format. */ /** @brief Generalized SDDMM on Csr format. */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SDDMMCsr(const std::string& op, void SDDMMCsr(
const BcastOff& bcast, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
const CSRMatrix& csr, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target) {
NDArray lhs,
NDArray rhs,
NDArray out,
int lhs_target,
int rhs_target) {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, { SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out); cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(
bcast, csr, lhs, rhs, out);
}); });
}); });
} }
/** @brief Generalized SDDMM on Csr format with Heterograph support. */ /** @brief Generalized SDDMM on Csr format with Heterograph support. */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SDDMMCsrHetero(const std::string& op, void SDDMMCsrHetero(
const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& vec_lhs,
const std::vector<NDArray>& vec_lhs, const std::vector<NDArray>& vec_rhs, std::vector<NDArray> vec_out,
const std::vector<NDArray>& vec_rhs, int lhs_target, int rhs_target, const std::vector<dgl_type_t>& lhs_nid,
std::vector<NDArray> vec_out,
int lhs_target,
int rhs_target,
const std::vector<dgl_type_t>& lhs_nid,
const std::vector<dgl_type_t>& rhs_nid) { const std::vector<dgl_type_t>& rhs_nid) {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, { SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
...@@ -79,7 +71,8 @@ void SDDMMCsrHetero(const std::string& op, ...@@ -79,7 +71,8 @@ void SDDMMCsrHetero(const std::string& op,
NDArray lhs = vec_lhs[lhs_nid[etype]]; NDArray lhs = vec_lhs[lhs_nid[etype]];
NDArray rhs = vec_rhs[rhs_nid[etype]]; NDArray rhs = vec_rhs[rhs_nid[etype]];
NDArray out = vec_out[etype]; NDArray out = vec_out[etype];
cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out); cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(
bcast, csr, lhs, rhs, out);
} }
}); });
}); });
...@@ -87,78 +80,62 @@ void SDDMMCsrHetero(const std::string& op, ...@@ -87,78 +80,62 @@ void SDDMMCsrHetero(const std::string& op,
template void SDDMMCsr<kDGLCPU, int32_t, float>( template void SDDMMCsr<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCPU, int64_t, float>( template void SDDMMCsr<kDGLCPU, int64_t, float>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCPU, int32_t, double>( template void SDDMMCsr<kDGLCPU, int32_t, double>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCPU, int64_t, double>( template void SDDMMCsr<kDGLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
int lhs_target, int rhs_target);
template void SDDMMCsrHetero<kDGLCPU, int32_t, float>( template void SDDMMCsrHetero<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
std::vector<NDArray> out, int lhs_target, int rhs_target, int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCPU, int64_t, float>( template void SDDMMCsrHetero<kDGLCPU, int64_t, float>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
std::vector<NDArray> out, int lhs_target, int rhs_target, int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCPU, int32_t, double>( template void SDDMMCsrHetero<kDGLCPU, int32_t, double>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
std::vector<NDArray> out, int lhs_target, int rhs_target, int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCPU, int64_t, double>( template void SDDMMCsrHetero<kDGLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
std::vector<NDArray> out, int lhs_target, int rhs_target, int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
/** @brief Generalized SDDMM on Coo format. */ /** @brief Generalized SDDMM on Coo format. */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SDDMMCoo(const std::string& op, void SDDMMCoo(
const BcastOff& bcast, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
const COOMatrix& coo, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target) {
NDArray lhs,
NDArray rhs,
NDArray out,
int lhs_target,
int rhs_target) {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, { SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out); cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(
bcast, coo, lhs, rhs, out);
}); });
}); });
} }
/** @brief Generalized SDDMM on Coo format with Heterograph support. */ /** @brief Generalized SDDMM on Coo format with Heterograph support. */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SDDMMCooHetero(const std::string& op, void SDDMMCooHetero(
const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& vec_lhs,
const std::vector<NDArray>& vec_lhs, const std::vector<NDArray>& vec_rhs, std::vector<NDArray> vec_out,
const std::vector<NDArray>& vec_rhs, int lhs_target, int rhs_target, const std::vector<dgl_type_t>& lhs_nid,
std::vector<NDArray> vec_out,
int lhs_target,
int rhs_target,
const std::vector<dgl_type_t>& lhs_nid,
const std::vector<dgl_type_t>& rhs_nid) { const std::vector<dgl_type_t>& rhs_nid) {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, { SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
...@@ -168,7 +145,8 @@ void SDDMMCooHetero(const std::string& op, ...@@ -168,7 +145,8 @@ void SDDMMCooHetero(const std::string& op,
NDArray lhs = vec_lhs[lhs_nid[etype]]; NDArray lhs = vec_lhs[lhs_nid[etype]];
NDArray rhs = vec_rhs[rhs_nid[etype]]; NDArray rhs = vec_rhs[rhs_nid[etype]];
NDArray out = vec_out[etype]; NDArray out = vec_out[etype];
cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out); cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(
bcast, coo, lhs, rhs, out);
} }
}); });
}); });
...@@ -176,48 +154,40 @@ void SDDMMCooHetero(const std::string& op, ...@@ -176,48 +154,40 @@ void SDDMMCooHetero(const std::string& op,
template void SDDMMCoo<kDGLCPU, int32_t, float>( template void SDDMMCoo<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCPU, int64_t, float>( template void SDDMMCoo<kDGLCPU, int64_t, float>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCPU, int32_t, double>( template void SDDMMCoo<kDGLCPU, int32_t, double>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCPU, int64_t, double>( template void SDDMMCoo<kDGLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
int lhs_target, int rhs_target);
template void SDDMMCooHetero<kDGLCPU, int32_t, float>( template void SDDMMCooHetero<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
std::vector<NDArray> out, int lhs_target, int rhs_target, int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCPU, int64_t, float>( template void SDDMMCooHetero<kDGLCPU, int64_t, float>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
std::vector<NDArray> out, int lhs_target, int rhs_target, int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCPU, int32_t, double>( template void SDDMMCooHetero<kDGLCPU, int32_t, double>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
std::vector<NDArray> out, int lhs_target, int rhs_target, int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCPU, int64_t, double>( template void SDDMMCooHetero<kDGLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
std::vector<NDArray> out, int lhs_target, int rhs_target, int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
} // namespace aten } // namespace aten
......
...@@ -4,8 +4,11 @@ ...@@ -4,8 +4,11 @@
* @brief Segment reduce C APIs and definitions. * @brief Segment reduce C APIs and definitions.
*/ */
#include "./segment_reduce.h" #include "./segment_reduce.h"
#include <dgl/array.h> #include <dgl/array.h>
#include <string> #include <string>
#include "./spmm_binary_ops.h" #include "./spmm_binary_ops.h"
namespace dgl { namespace dgl {
...@@ -14,10 +17,7 @@ namespace aten { ...@@ -14,10 +17,7 @@ namespace aten {
/** @brief Segment Reduce operator. */ /** @brief Segment Reduce operator. */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SegmentReduce( void SegmentReduce(
const std::string& op, const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg) { NDArray arg) {
if (op == "sum") { if (op == "sum") {
cpu::SegmentSum<IdType, DType>(feat, offsets, out); cpu::SegmentSum<IdType, DType>(feat, offsets, out);
...@@ -36,73 +36,47 @@ void SegmentReduce( ...@@ -36,73 +36,47 @@ void SegmentReduce(
/** @brief Scatter Add.*/ /** @brief Scatter Add.*/
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void ScatterAdd(NDArray feat, void ScatterAdd(NDArray feat, NDArray idx, NDArray out) {
NDArray idx,
NDArray out) {
cpu::ScatterAdd<IdType, DType>(feat, idx, out); cpu::ScatterAdd<IdType, DType>(feat, idx, out);
} }
/** @brief Update gradients for reduce operator max/min on heterogeneous graph.*/ /** @brief Update gradients for reduce operator max/min on heterogeneous
* graph.*/
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void UpdateGradMinMax_hetero(const HeteroGraphPtr& g, void UpdateGradMinMax_hetero(
const std::string& op, const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx, const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out) {
const std::vector<NDArray>& idx_etype,
std::vector<NDArray>* out) {
cpu::UpdateGradMinMax_hetero<IdType, DType>(g, op, feat, idx, idx_etype, out); cpu::UpdateGradMinMax_hetero<IdType, DType>(g, op, feat, idx, idx_etype, out);
} }
/** @brief Backward function of segment cmp.*/ /** @brief Backward function of segment cmp.*/
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void BackwardSegmentCmp( void BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out) {
NDArray feat,
NDArray arg,
NDArray out) {
cpu::BackwardSegmentCmp<IdType, DType>(feat, arg, out); cpu::BackwardSegmentCmp<IdType, DType>(feat, arg, out);
} }
template void SegmentReduce<kDGLCPU, int32_t, float>( template void SegmentReduce<kDGLCPU, int32_t, float>(
const std::string &op, const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg); NDArray arg);
template void SegmentReduce<kDGLCPU, int64_t, float>( template void SegmentReduce<kDGLCPU, int64_t, float>(
const std::string &op, const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg); NDArray arg);
template void SegmentReduce<kDGLCPU, int32_t, double>( template void SegmentReduce<kDGLCPU, int32_t, double>(
const std::string &op, const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg); NDArray arg);
template void SegmentReduce<kDGLCPU, int64_t, double>( template void SegmentReduce<kDGLCPU, int64_t, double>(
const std::string &op, const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg); NDArray arg);
template void ScatterAdd<kDGLCPU, int32_t, float>( template void ScatterAdd<kDGLCPU, int32_t, float>(
NDArray feat, NDArray feat, NDArray idx, NDArray out);
NDArray idx,
NDArray out);
template void ScatterAdd<kDGLCPU, int64_t, float>( template void ScatterAdd<kDGLCPU, int64_t, float>(
NDArray feat, NDArray feat, NDArray idx, NDArray out);
NDArray idx,
NDArray out);
template void ScatterAdd<kDGLCPU, int32_t, double>( template void ScatterAdd<kDGLCPU, int32_t, double>(
NDArray feat, NDArray feat, NDArray idx, NDArray out);
NDArray idx,
NDArray out);
template void ScatterAdd<kDGLCPU, int64_t, double>( template void ScatterAdd<kDGLCPU, int64_t, double>(
NDArray feat, NDArray feat, NDArray arg, NDArray out);
NDArray arg,
NDArray out);
template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, float>( template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, float>(
const HeteroGraphPtr& g, const std::string& op, const HeteroGraphPtr& g, const std::string& op,
...@@ -122,21 +96,13 @@ template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, double>( ...@@ -122,21 +96,13 @@ template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, double>(
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out); const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void BackwardSegmentCmp<kDGLCPU, int32_t, float>( template void BackwardSegmentCmp<kDGLCPU, int32_t, float>(
NDArray feat, NDArray feat, NDArray arg, NDArray out);
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDGLCPU, int64_t, float>( template void BackwardSegmentCmp<kDGLCPU, int64_t, float>(
NDArray feat, NDArray feat, NDArray arg, NDArray out);
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDGLCPU, int32_t, double>( template void BackwardSegmentCmp<kDGLCPU, int32_t, double>(
NDArray feat, NDArray feat, NDArray arg, NDArray out);
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDGLCPU, int64_t, double>( template void BackwardSegmentCmp<kDGLCPU, int64_t, double>(
NDArray feat, NDArray feat, NDArray arg, NDArray out);
NDArray arg,
NDArray out);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* @brief SPMM C APIs and definitions. * @brief SPMM C APIs and definitions.
*/ */
#include "./spmm.h" #include "./spmm.h"
#include <dgl/array.h> #include <dgl/array.h>
namespace dgl { namespace dgl {
...@@ -11,12 +12,9 @@ namespace aten { ...@@ -11,12 +12,9 @@ namespace aten {
/** @brief Generalized SpMM on Csr format. */ /** @brief Generalized SpMM on Csr format. */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SpMMCsr(const std::string& op, const std::string& reduce, void SpMMCsr(
const BcastOff& bcast, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const CSRMatrix& csr, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat,
NDArray efeat,
NDArray out,
std::vector<NDArray> out_aux) { std::vector<NDArray> out_aux) {
const int64_t dim = bcast.out_len; const int64_t dim = bcast.out_len;
if (reduce == "sum") { if (reduce == "sum") {
...@@ -25,13 +23,15 @@ void SpMMCsr(const std::string& op, const std::string& reduce, ...@@ -25,13 +23,15 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
}); });
} else if (reduce == "max" || reduce == "min") { } else if (reduce == "max" || reduce == "min") {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
DType *out_off = out.Ptr<DType>(); DType* out_off = out.Ptr<DType>();
if (reduce == "max") { if (reduce == "max") {
std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Max<DType>::zero); std::fill(
out_off, out_off + csr.num_rows * dim, cpu::op::Max<DType>::zero);
cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Max<DType>>( cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Max<DType>>(
bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]); bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
} else { } else {
std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Min<DType>::zero); std::fill(
out_off, out_off + csr.num_rows * dim, cpu::op::Min<DType>::zero);
cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Min<DType>>( cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Min<DType>>(
bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]); bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
} }
...@@ -43,12 +43,11 @@ void SpMMCsr(const std::string& op, const std::string& reduce, ...@@ -43,12 +43,11 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
/** @brief Generalized SpMM on Csr format. */ /** @brief Generalized SpMM on Csr format. */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SpMMCsrHetero(const std::string& op, const std::string& reduce, void SpMMCsrHetero(
const BcastOff& bcast, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& vec_ufeat, const std::vector<NDArray>& vec_ufeat,
const std::vector<NDArray>& vec_efeat, const std::vector<NDArray>& vec_efeat, std::vector<NDArray>* vec_out,
std::vector<NDArray>* vec_out,
std::vector<std::vector<NDArray>>* out_aux, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids, const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids) { const std::vector<dgl_type_t>& out_node_tids) {
...@@ -60,8 +59,10 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -60,8 +59,10 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const dgl_type_t src_id = ufeat_node_tids[etype]; const dgl_type_t src_id = ufeat_node_tids[etype];
const dgl_type_t dst_id = out_node_tids[etype]; const dgl_type_t dst_id = out_node_tids[etype];
CSRMatrix csr = vec_csr[etype]; CSRMatrix csr = vec_csr[etype];
NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id]; NDArray ufeat =
NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype]; (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
NDArray efeat =
(vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
NDArray out = (*vec_out)[dst_id]; NDArray out = (*vec_out)[dst_id];
cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out); cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
} }
...@@ -71,21 +72,27 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -71,21 +72,27 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
std::vector<bool> updated((*vec_out).size(), false); std::vector<bool> updated((*vec_out).size(), false);
// TODO(Israt): use vector updated to fill(out...) too // TODO(Israt): use vector updated to fill(out...) too
for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) { for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
DType *out_off = (*vec_out)[out_node_tids[etype]].Ptr<DType>(); DType* out_off = (*vec_out)[out_node_tids[etype]].Ptr<DType>();
if (reduce == "max") if (reduce == "max")
std::fill(out_off, out_off + vec_csr[etype].num_rows * dim, cpu::op::Max<DType>::zero); std::fill(
out_off, out_off + vec_csr[etype].num_rows * dim,
cpu::op::Max<DType>::zero);
else else
std::fill(out_off, out_off + vec_csr[etype].num_rows * dim, cpu::op::Min<DType>::zero); std::fill(
out_off, out_off + vec_csr[etype].num_rows * dim,
cpu::op::Min<DType>::zero);
const dgl_type_t dst_id = out_node_tids[etype]; const dgl_type_t dst_id = out_node_tids[etype];
if (!updated[dst_id]) { if (!updated[dst_id]) {
updated[dst_id] = true; updated[dst_id] = true;
if (Op::use_lhs) { if (Op::use_lhs) {
IdType *argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>(); IdType* argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>();
std::fill(argu_ntype, argu_ntype + vec_csr[etype].num_rows * dim, -1); std::fill(
argu_ntype, argu_ntype + vec_csr[etype].num_rows * dim, -1);
} }
if (Op::use_rhs) { if (Op::use_rhs) {
IdType *arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>(); IdType* arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>();
std::fill(arge_etype, arge_etype + vec_csr[etype].num_rows * dim, -1); std::fill(
arge_etype, arge_etype + vec_csr[etype].num_rows * dim, -1);
} }
} }
} }
...@@ -94,17 +101,21 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -94,17 +101,21 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const dgl_type_t src_id = ufeat_node_tids[etype]; const dgl_type_t src_id = ufeat_node_tids[etype];
const dgl_type_t dst_id = out_node_tids[etype]; const dgl_type_t dst_id = out_node_tids[etype];
CSRMatrix csr = vec_csr[etype]; CSRMatrix csr = vec_csr[etype];
NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id]; NDArray ufeat =
NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype]; (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
NDArray efeat =
(vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
NDArray out = (*vec_out)[dst_id]; NDArray out = (*vec_out)[dst_id];
if (reduce == "max") { if (reduce == "max") {
cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Max<DType>>( cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Max<DType>>(
bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id], (*out_aux)[1][dst_id], bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id],
(*out_aux)[2][dst_id], (*out_aux)[3][dst_id], src_id, etype); (*out_aux)[1][dst_id], (*out_aux)[2][dst_id],
(*out_aux)[3][dst_id], src_id, etype);
} else { } else {
cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Min<DType>>( cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Min<DType>>(
bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id], (*out_aux)[1][dst_id], bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id],
(*out_aux)[2][dst_id], (*out_aux)[3][dst_id], src_id, etype); (*out_aux)[1][dst_id], (*out_aux)[2][dst_id],
(*out_aux)[3][dst_id], src_id, etype);
} }
} }
}); });
...@@ -114,119 +125,104 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -114,119 +125,104 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
} }
template void SpMMCsr<kDGLCPU, int32_t, float>( template void SpMMCsr<kDGLCPU, int32_t, float>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const CSRMatrix& csr, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCPU, int64_t, float>( template void SpMMCsr<kDGLCPU, int64_t, float>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const CSRMatrix& csr, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCPU, int32_t, double>( template void SpMMCsr<kDGLCPU, int32_t, double>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const CSRMatrix& csr, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCPU, int64_t, double>( template void SpMMCsr<kDGLCPU, int64_t, double>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const CSRMatrix& csr, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); std::vector<NDArray> out_aux);
template void SpMMCsrHetero<kDGLCPU, int32_t, float>( template void SpMMCsrHetero<kDGLCPU, int32_t, float>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids, const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids); const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDGLCPU, int64_t, float>( template void SpMMCsrHetero<kDGLCPU, int64_t, float>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids, const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids); const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDGLCPU, int32_t, double>( template void SpMMCsrHetero<kDGLCPU, int32_t, double>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids, const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids); const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDGLCPU, int64_t, double>( template void SpMMCsrHetero<kDGLCPU, int64_t, double>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids, const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids); const std::vector<dgl_type_t>& out_node_tids);
/** @brief Edge_softmax_csr forward op on Csr format. */ /** @brief Edge_softmax_csr forward op on Csr format. */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void Edge_softmax_csr_forward(const std::string& op, void Edge_softmax_csr_forward(
const BcastOff& bcast, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out) {
NDArray ufeat,
NDArray efeat,
NDArray out) {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
cpu::Edge_softmax_csr_forward<IdType, DType, Op>(bcast, csr, ufeat, efeat, out); cpu::Edge_softmax_csr_forward<IdType, DType, Op>(
bcast, csr, ufeat, efeat, out);
}); });
} }
/** @brief Edge_softmax_csr backward op on Csr format. */ /** @brief Edge_softmax_csr backward op on Csr format. */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void Edge_softmax_csr_backward(const std::string& op, void Edge_softmax_csr_backward(
const BcastOff& bcast, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
const CSRMatrix& csr, NDArray out, NDArray sds, NDArray back_out) {
NDArray out,
NDArray sds,
NDArray back_out) {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
cpu::Edge_softmax_csr_backward<IdType, DType, Op>(bcast, csr, out, sds, back_out); cpu::Edge_softmax_csr_backward<IdType, DType, Op>(
bcast, csr, out, sds, back_out);
}); });
} }
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, float>( template void Edge_softmax_csr_forward<kDGLCPU, int32_t, float>(
const std::string& op, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, float>( template void Edge_softmax_csr_forward<kDGLCPU, int64_t, float>(
const std::string& op, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, double>( template void Edge_softmax_csr_forward<kDGLCPU, int32_t, double>(
const std::string& op, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, double>( template void Edge_softmax_csr_forward<kDGLCPU, int64_t, double>(
const std::string& op, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, float>( template void Edge_softmax_csr_backward<kDGLCPU, int32_t, float>(
const std::string& op, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, float>( template void Edge_softmax_csr_backward<kDGLCPU, int64_t, float>(
const std::string& op, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, double>( template void Edge_softmax_csr_backward<kDGLCPU, int32_t, double>(
const std::string& op, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, double>( template void Edge_softmax_csr_backward<kDGLCPU, int64_t, double>(
const std::string& op, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
/** @brief Generalized SpMM on Coo format. */ /** @brief Generalized SpMM on Coo format. */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SpMMCoo(const std::string& op, const std::string& reduce, void SpMMCoo(
const BcastOff& bcast, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const COOMatrix& coo, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat,
NDArray efeat,
NDArray out,
std::vector<NDArray> out_aux) { std::vector<NDArray> out_aux) {
if (reduce == "sum") { if (reduce == "sum") {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
...@@ -247,21 +243,21 @@ void SpMMCoo(const std::string& op, const std::string& reduce, ...@@ -247,21 +243,21 @@ void SpMMCoo(const std::string& op, const std::string& reduce,
} }
template void SpMMCoo<kDGLCPU, int32_t, float>( template void SpMMCoo<kDGLCPU, int32_t, float>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const COOMatrix& coo, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCPU, int64_t, float>( template void SpMMCoo<kDGLCPU, int64_t, float>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const COOMatrix& coo, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCPU, int32_t, double>( template void SpMMCoo<kDGLCPU, int32_t, double>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const COOMatrix& coo, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCPU, int64_t, double>( template void SpMMCoo<kDGLCPU, int64_t, double>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const COOMatrix& coo, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); std::vector<NDArray> out_aux);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#define DGL_ARRAY_CPU_SPMM_BINARY_OPS_H_ #define DGL_ARRAY_CPU_SPMM_BINARY_OPS_H_
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/bcast.h> #include <dgl/bcast.h>
#include <limits> #include <limits>
namespace dgl { namespace dgl {
namespace aten { namespace aten {
......
...@@ -63,8 +63,10 @@ template NDArray IndexSelect<kDGLCUDA, int64_t, int64_t>(NDArray, IdArray); ...@@ -63,8 +63,10 @@ template NDArray IndexSelect<kDGLCUDA, int64_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, __half, int32_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, __half, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, __half, int64_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, __half, int64_t>(NDArray, IdArray);
#if BF16_ENABLED #if BF16_ENABLED
template NDArray IndexSelect<kDGLCUDA, __nv_bfloat16, int32_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, __nv_bfloat16, int32_t>(
template NDArray IndexSelect<kDGLCUDA, __nv_bfloat16, int64_t>(NDArray, IdArray); NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, __nv_bfloat16, int64_t>(
NDArray, IdArray);
#endif // BF16_ENABLED #endif // BF16_ENABLED
template NDArray IndexSelect<kDGLCUDA, float, int32_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, float, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, float, int64_t>(NDArray, IdArray); template NDArray IndexSelect<kDGLCUDA, float, int64_t>(NDArray, IdArray);
...@@ -76,8 +78,8 @@ DType IndexSelect(NDArray array, int64_t index) { ...@@ -76,8 +78,8 @@ DType IndexSelect(NDArray array, int64_t index) {
auto device = runtime::DeviceAPI::Get(array->ctx); auto device = runtime::DeviceAPI::Get(array->ctx);
DType ret = static_cast<DType>(0.0f); DType ret = static_cast<DType>(0.0f);
device->CopyDataFromTo( device->CopyDataFromTo(
static_cast<DType*>(array->data) + index, 0, &ret, 0, static_cast<DType*>(array->data) + index, 0, &ret, 0, sizeof(DType),
sizeof(DType), array->ctx, DGLContext{kDGLCPU, 0}, array->dtype); array->ctx, DGLContext{kDGLCPU, 0}, array->dtype);
return ret; return ret;
} }
...@@ -87,7 +89,8 @@ template uint32_t IndexSelect<kDGLCUDA, uint32_t>(NDArray array, int64_t index); ...@@ -87,7 +89,8 @@ template uint32_t IndexSelect<kDGLCUDA, uint32_t>(NDArray array, int64_t index);
template uint64_t IndexSelect<kDGLCUDA, uint64_t>(NDArray array, int64_t index); template uint64_t IndexSelect<kDGLCUDA, uint64_t>(NDArray array, int64_t index);
template __half IndexSelect<kDGLCUDA, __half>(NDArray array, int64_t index); template __half IndexSelect<kDGLCUDA, __half>(NDArray array, int64_t index);
#if BF16_ENABLED #if BF16_ENABLED
template __nv_bfloat16 IndexSelect<kDGLCUDA, __nv_bfloat16>(NDArray array, int64_t index); template __nv_bfloat16 IndexSelect<kDGLCUDA, __nv_bfloat16>(
NDArray array, int64_t index);
#endif // BF16_ENABLED #endif // BF16_ENABLED
template float IndexSelect<kDGLCUDA, float>(NDArray array, int64_t index); template float IndexSelect<kDGLCUDA, float>(NDArray array, int64_t index);
template double IndexSelect<kDGLCUDA, double>(NDArray array, int64_t index); template double IndexSelect<kDGLCUDA, double>(NDArray array, int64_t index);
......
This diff is collapsed.
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* @brief Array scatter GPU implementation * @brief Array scatter GPU implementation
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./utils.h" #include "./utils.h"
...@@ -13,8 +14,8 @@ namespace aten { ...@@ -13,8 +14,8 @@ namespace aten {
namespace impl { namespace impl {
template <typename DType, typename IdType> template <typename DType, typename IdType>
__global__ void _ScatterKernel(const IdType* index, const DType* value, __global__ void _ScatterKernel(
int64_t length, DType* out) { const IdType* index, const DType* value, int64_t length, DType* out) {
int tx = blockIdx.x * blockDim.x + threadIdx.x; int tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = gridDim.x * blockDim.x; int stride_x = gridDim.x * blockDim.x;
while (tx < length) { while (tx < length) {
...@@ -33,15 +34,15 @@ void Scatter_(IdArray index, NDArray value, NDArray out) { ...@@ -33,15 +34,15 @@ void Scatter_(IdArray index, NDArray value, NDArray out) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const int nt = cuda::FindNumThreads(len); const int nt = cuda::FindNumThreads(len);
const int nb = (len + nt - 1) / nt; const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL(_ScatterKernel, nb, nt, 0, stream, CUDA_KERNEL_CALL(_ScatterKernel, nb, nt, 0, stream, idx, val, len, outd);
idx, val, len, outd);
} }
template void Scatter_<kDGLCUDA, int32_t, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, int32_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, int64_t, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, int64_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, __half, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, __half, int32_t>(IdArray, NDArray, NDArray);
#if BF16_ENABLED #if BF16_ENABLED
template void Scatter_<kDGLCUDA, __nv_bfloat16, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, __nv_bfloat16, int32_t>(
IdArray, NDArray, NDArray);
#endif // BF16_ENABLED #endif // BF16_ENABLED
template void Scatter_<kDGLCUDA, float, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, float, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, double, int32_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, double, int32_t>(IdArray, NDArray, NDArray);
...@@ -49,7 +50,8 @@ template void Scatter_<kDGLCUDA, int32_t, int64_t>(IdArray, NDArray, NDArray); ...@@ -49,7 +50,8 @@ template void Scatter_<kDGLCUDA, int32_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, int64_t, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, int64_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, __half, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, __half, int64_t>(IdArray, NDArray, NDArray);
#if BF16_ENABLED #if BF16_ENABLED
template void Scatter_<kDGLCUDA, __nv_bfloat16, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, __nv_bfloat16, int64_t>(
IdArray, NDArray, NDArray);
#endif // BF16_ENABLED #endif // BF16_ENABLED
template void Scatter_<kDGLCUDA, float, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, float, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, double, int64_t>(IdArray, NDArray, NDArray); template void Scatter_<kDGLCUDA, double, int64_t>(IdArray, NDArray, NDArray);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment