Unverified Commit 8ac27dad authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4824)



* [Misc] clang-format auto fix.

* blabla

* ablabla

* blabla
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent bcd37684
......@@ -10,10 +10,11 @@
#define DGL_ATEN_ARRAY_OPS_H_
#include <algorithm>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include <tuple>
#include <string>
#include "./types.h"
namespace dgl {
......@@ -24,17 +25,16 @@ namespace aten {
//////////////////////////////////////////////////////////////////////
/** @return A special array to represent null. */
inline NDArray NullArray(const DGLDataType& dtype = DGLDataType{kDGLInt, 64, 1},
const DGLContext& ctx = DGLContext{kDGLCPU, 0}) {
inline NDArray NullArray(
const DGLDataType& dtype = DGLDataType{kDGLInt, 64, 1},
const DGLContext& ctx = DGLContext{kDGLCPU, 0}) {
return NDArray::Empty({0}, dtype, ctx);
}
/**
* @return Whether the input array is a null array.
*/
inline bool IsNullArray(NDArray array) {
return array->shape[0] == 0;
}
inline bool IsNullArray(NDArray array) { return array->shape[0] == 0; }
/**
* @brief Create a new id array with given length
......@@ -43,9 +43,9 @@ inline bool IsNullArray(NDArray array) {
* @param nbits The number of integer bits
* @return id array
*/
IdArray NewIdArray(int64_t length,
DGLContext ctx = DGLContext{kDGLCPU, 0},
uint8_t nbits = 64);
IdArray NewIdArray(
int64_t length, DGLContext ctx = DGLContext{kDGLCPU, 0},
uint8_t nbits = 64);
/**
* @brief Create a new id array using the given vector data
......@@ -55,9 +55,9 @@ IdArray NewIdArray(int64_t length,
* @return the id array
*/
template <typename T>
IdArray VecToIdArray(const std::vector<T>& vec,
uint8_t nbits = 64,
DGLContext ctx = DGLContext{kDGLCPU, 0});
IdArray VecToIdArray(
const std::vector<T>& vec, uint8_t nbits = 64,
DGLContext ctx = DGLContext{kDGLCPU, 0});
/**
* @brief Return an array representing a 1D range.
......@@ -148,7 +148,7 @@ IdArray NonZero(BoolArray bool_arr);
* @brief Return the data under the index. In numpy notation, A[I]
* @tparam ValueType The type of return value.
*/
template<typename ValueType>
template <typename ValueType>
ValueType IndexSelect(NDArray array, int64_t index);
/**
......@@ -187,10 +187,11 @@ NDArray Scatter(NDArray array, IdArray indices);
void Scatter_(IdArray index, NDArray value, NDArray out);
/**
* @brief Repeat each element a number of times. Equivalent to np.repeat(array, repeats)
* @brief Repeat each element a number of times. Equivalent to np.repeat(array,
* repeats)
* @param array A 1D vector
* @param repeats A 1D integer vector for number of times to repeat for each element in
* \c array. Must have the same shape as \c array.
* @param repeats A 1D integer vector for number of times to repeat for each
* element in \c array. Must have the same shape as \c array.
*/
NDArray Repeat(NDArray array, IdArray repeats);
......@@ -253,12 +254,13 @@ inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) {
*
* The packed tensor would be [1, 2, 3, 4, 5].
*
* The length tensor would be [2, 3], i.e. the length of each sequence before padding.
* The length tensor would be [2, 3], i.e. the length of each sequence before
* padding.
*
* The offset tensor would be [0, 2], i.e. the offset to the packed tensor for each
* sequence (before padding)
* The offset tensor would be [0, 2], i.e. the offset to the packed tensor for
* each sequence (before padding)
*/
template<typename ValueType>
template <typename ValueType>
std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, ValueType pad_value);
/**
......@@ -295,8 +297,9 @@ std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths);
* @brief Return the cumulative summation (or inclusive sum) of the input array.
*
* The first element out[0] is equal to the first element of the input array
* array[0]. The rest elements are defined recursively, out[i] = out[i-1] + array[i].
* Hence, the result array length is the same as the input array length.
* array[0]. The rest elements are defined recursively, out[i] = out[i-1] +
* array[i]. Hence, the result array length is the same as the input array
* length.
*
* If prepend_zero is true, then the first element is zero and the result array
* length is the input array length plus one. This is useful for creating
......@@ -320,17 +323,18 @@ IdArray NonZero(NDArray array);
/**
* @brief Sort the ID vector in ascending order.
*
* It performs both sort and arg_sort (returning the sorted index). The sorted index
* is always in int64.
* It performs both sort and arg_sort (returning the sorted index). The sorted
* index is always in int64.
*
* @param array Input array.
* @param num_bits The number of bits used in key comparison. For example, if the data type
* of the input array is int32_t and `num_bits = 8`, it only uses bits in index
* range [0, 8) for sorting. Setting it to a small value could
* speed up the sorting if the underlying sorting algorithm is radix sort (e.g., on GPU).
* Setting it to zero (default value) means using all the bits for comparison.
* On CPU, it currently has no effect.
* @return A pair of arrays: sorted values and sorted index to the original position.
* @param num_bits The number of bits used in key comparison. For example, if
* the data type of the input array is int32_t and `num_bits = 8`, it only uses
* bits in index range [0, 8) for sorting. Setting it to a small value could
* speed up the sorting if the underlying sorting algorithm is
* radix sort (e.g., on GPU). Setting it to zero (default value) means using all
* the bits for comparison. On CPU, it currently has no effect.
* @return A pair of arrays: sorted values and sorted index to the original
* position.
*/
std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits = 0);
......@@ -341,9 +345,7 @@ std::string ToDebugString(NDArray array);
// inline implementations
template <typename T>
IdArray VecToIdArray(const std::vector<T>& vec,
uint8_t nbits,
DGLContext ctx) {
IdArray VecToIdArray(const std::vector<T>& vec, uint8_t nbits, DGLContext ctx) {
IdArray ret = NewIdArray(vec.size(), DGLContext{kDGLCPU, 0}, nbits);
if (nbits == 32) {
std::copy(vec.begin(), vec.end(), static_cast<int32_t*>(ret->data));
......@@ -367,7 +369,8 @@ inline DGLContext GetContextOf(const std::vector<IdArray>& arrays) {
first = false;
result = array->ctx;
} else {
CHECK_EQ(array->ctx, result) << "Context of the input arrays are different";
CHECK_EQ(array->ctx, result)
<< "Context of the input arrays are different";
}
}
return result;
......
......@@ -9,14 +9,16 @@
#include <dmlc/io.h>
#include <dmlc/serializer.h>
#include <vector>
#include <utility>
#include <tuple>
#include <string>
#include "./types.h"
#include <tuple>
#include <utility>
#include <vector>
#include "./array_ops.h"
#include "./spmat.h"
#include "./macro.h"
#include "./spmat.h"
#include "./types.h"
namespace dgl {
namespace aten {
......@@ -52,9 +54,9 @@ struct COOMatrix {
/** @brief default constructor */
COOMatrix() = default;
/** @brief constructor */
COOMatrix(int64_t nrows, int64_t ncols, IdArray rarr, IdArray carr,
IdArray darr = NullArray(), bool rsorted = false,
bool csorted = false)
COOMatrix(
int64_t nrows, int64_t ncols, IdArray rarr, IdArray carr,
IdArray darr = NullArray(), bool rsorted = false, bool csorted = false)
: num_rows(nrows),
num_cols(ncols),
row(rarr),
......@@ -79,8 +81,9 @@ struct COOMatrix {
// Convert to a SparseMatrix object that can return to python.
SparseMatrix ToSparseMatrix() const {
return SparseMatrix(static_cast<int32_t>(SparseFormat::kCOO), num_rows,
num_cols, {row, col, data}, {row_sorted, col_sorted});
return SparseMatrix(
static_cast<int32_t>(SparseFormat::kCOO), num_rows, num_cols,
{row, col, data}, {row_sorted, col_sorted});
}
bool Load(dmlc::Stream* fs) {
......@@ -122,25 +125,24 @@ struct COOMatrix {
}
/** @brief Return a copy of this matrix on the give device context. */
inline COOMatrix CopyTo(const DGLContext &ctx) const {
if (ctx == row->ctx)
return *this;
return COOMatrix(num_rows, num_cols, row.CopyTo(ctx), col.CopyTo(ctx),
aten::IsNullArray(data) ? data : data.CopyTo(ctx),
row_sorted, col_sorted);
inline COOMatrix CopyTo(const DGLContext& ctx) const {
if (ctx == row->ctx) return *this;
return COOMatrix(
num_rows, num_cols, row.CopyTo(ctx), col.CopyTo(ctx),
aten::IsNullArray(data) ? data : data.CopyTo(ctx), row_sorted,
col_sorted);
}
/**
* @brief Pin the row, col and data (if not Null) of the matrix.
* @note This is an in-place method. Behavior depends on the current context,
* kDGLCPU: will be pinned;
* IsPinned: directly return;
* kDGLCUDA: invalid, will throw an error.
* The context check is deferred to pinning the NDArray.
*/
* @brief Pin the row, col and data (if not Null) of the matrix.
* @note This is an in-place method. Behavior depends on the current context,
* kDGLCPU: will be pinned;
* IsPinned: directly return;
* kDGLCUDA: invalid, will throw an error.
* The context check is deferred to pinning the NDArray.
*/
inline void PinMemory_() {
if (is_pinned)
return;
if (is_pinned) return;
row.PinMemory_();
col.PinMemory_();
if (!aten::IsNullArray(data)) {
......@@ -150,15 +152,14 @@ struct COOMatrix {
}
/**
* @brief Unpin the row, col and data (if not Null) of the matrix.
* @note This is an in-place method. Behavior depends on the current context,
* IsPinned: will be unpinned;
* others: directly return.
* The context check is deferred to unpinning the NDArray.
*/
* @brief Unpin the row, col and data (if not Null) of the matrix.
* @note This is an in-place method. Behavior depends on the current context,
* IsPinned: will be unpinned;
* others: directly return.
* The context check is deferred to unpinning the NDArray.
*/
inline void UnpinMemory_() {
if (!is_pinned)
return;
if (!is_pinned) return;
row.UnpinMemory_();
col.UnpinMemory_();
if (!aten::IsNullArray(data)) {
......@@ -183,25 +184,25 @@ struct COOMatrix {
///////////////////////// COO routines //////////////////////////
/** @brief Return true if the value (row, col) is non-zero */
bool COOIsNonZero(COOMatrix , int64_t row, int64_t col);
bool COOIsNonZero(COOMatrix, int64_t row, int64_t col);
/**
* @brief Batched implementation of COOIsNonZero.
* @note This operator allows broadcasting (i.e, either row or col can be of length 1).
* @note This operator allows broadcasting (i.e, either row or col can be of
* length 1).
*/
runtime::NDArray COOIsNonZero(COOMatrix , runtime::NDArray row, runtime::NDArray col);
runtime::NDArray COOIsNonZero(
COOMatrix, runtime::NDArray row, runtime::NDArray col);
/** @brief Return the nnz of the given row */
int64_t COOGetRowNNZ(COOMatrix , int64_t row);
runtime::NDArray COOGetRowNNZ(COOMatrix , runtime::NDArray row);
int64_t COOGetRowNNZ(COOMatrix, int64_t row);
runtime::NDArray COOGetRowNNZ(COOMatrix, runtime::NDArray row);
/** @brief Return the data array of the given row */
std::pair<runtime::NDArray, runtime::NDArray>
COOGetRowDataAndIndices(COOMatrix , int64_t row);
std::pair<runtime::NDArray, runtime::NDArray> COOGetRowDataAndIndices(
COOMatrix, int64_t row);
/** @brief Whether the COO matrix contains data */
inline bool COOHasData(COOMatrix csr) {
return !IsNullArray(csr.data);
}
inline bool COOHasData(COOMatrix csr) { return !IsNullArray(csr.data); }
/**
* @brief Check whether the COO is sorted.
......@@ -217,11 +218,12 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo);
/**
* @brief Get the data and the row,col indices for each returned entries.
*
* The operator supports matrix with duplicate entries and all the matched entries
* will be returned. The operator assumes there is NO duplicate (row, col) pair
* in the given input. Otherwise, the returned result is undefined.
* The operator supports matrix with duplicate entries and all the matched
* entries will be returned. The operator assumes there is NO duplicate (row,
* col) pair in the given input. Otherwise, the returned result is undefined.
*
* @note This operator allows broadcasting (i.e, either row or col can be of length 1).
* @note This operator allows broadcasting (i.e, either row or col can be of
* length 1).
* @param mat Sparse matrix
* @param rows Row index
* @param cols Column index
......@@ -230,10 +232,13 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo);
std::vector<runtime::NDArray> COOGetDataAndIndices(
COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols);
/** @brief Get data. The return type is an ndarray due to possible duplicate entries. */
/** @brief Get data. The return type is an ndarray due to possible duplicate
* entries. */
inline runtime::NDArray COOGetAllData(COOMatrix mat, int64_t row, int64_t col) {
IdArray rows = VecToIdArray<int64_t>({row}, mat.row->dtype.bits, mat.row->ctx);
IdArray cols = VecToIdArray<int64_t>({col}, mat.row->dtype.bits, mat.row->ctx);
IdArray rows =
VecToIdArray<int64_t>({row}, mat.row->dtype.bits, mat.row->ctx);
IdArray cols =
VecToIdArray<int64_t>({col}, mat.row->dtype.bits, mat.row->ctx);
const auto& rst = COOGetDataAndIndices(mat, rows, cols);
return rst[2];
}
......@@ -241,18 +246,20 @@ inline runtime::NDArray COOGetAllData(COOMatrix mat, int64_t row, int64_t col) {
/**
* @brief Get the data for each (row, col) pair.
*
* The operator supports matrix with duplicate entries but only one matched entry
* will be returned for each (row, col) pair. Support duplicate input (row, col)
* pairs.
* The operator supports matrix with duplicate entries but only one matched
* entry will be returned for each (row, col) pair. Support duplicate input
* (row, col) pairs.
*
* @note This operator allows broadcasting (i.e, either row or col can be of length 1).
* @note This operator allows broadcasting (i.e, either row or col can be of
* length 1).
*
* @param mat Sparse matrix.
* @param rows Row index.
* @param cols Column index.
* @return Data array. The i^th element is the data of (rows[i], cols[i])
*/
runtime::NDArray COOGetData(COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols);
runtime::NDArray COOGetData(
COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols);
/** @brief Return a transposed COO matrix */
COOMatrix COOTranspose(COOMatrix coo);
......@@ -301,14 +308,15 @@ COOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows);
* @param cols The col index to select
* @return submatrix
*/
COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
COOMatrix COOSliceMatrix(
COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
/** @return True if the matrix has duplicate entries */
bool COOHasDuplicate(COOMatrix coo);
/**
* @brief Deduplicate the entries of a sorted COO matrix, replacing the data with the
* number of occurrences of the row-col coordinates.
* @brief Deduplicate the entries of a sorted COO matrix, replacing the data
* with the number of occurrences of the row-col coordinates.
*/
std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo);
......@@ -316,11 +324,13 @@ std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo);
* @brief Sort the indices of a COO matrix in-place.
*
* The function sorts row indices in ascending order. If sort_column is true,
* col indices are sorted in ascending order too. The data array of the returned COOMatrix
* stores the shuffled index which could be used to fetch edge data.
* col indices are sorted in ascending order too. The data array of the returned
* COOMatrix stores the shuffled index which could be used to fetch edge data.
*
* Complexity: O(N*log(N)) time and O(1) space, where N is the number of nonzeros.
* TODO(minjie): The time complexity could be improved to O(N) by using a O(N) space.
* Complexity: O(N*log(N)) time and O(1) space, where N is the number of
* nonzeros.
* TODO(minjie): The time complexity could be improved to O(N) by using a O(N)
* space.
*
* @param mat The coo matrix to sort.
* @param sort_column True if column index should be sorted too.
......@@ -331,31 +341,32 @@ void COOSort_(COOMatrix* mat, bool sort_column = false);
* @brief Sort the indices of a COO matrix.
*
* The function sorts row indices in ascending order. If sort_column is true,
* col indices are sorted in ascending order too. The data array of the returned COOMatrix
* stores the shuffled index which could be used to fetch edge data.
* col indices are sorted in ascending order too. The data array of the returned
* COOMatrix stores the shuffled index which could be used to fetch edge data.
*
* Complexity: O(N*log(N)) time and O(1) space, where N is the number of nonzeros.
* TODO(minjie): The time complexity could be improved to O(N) by using a O(N) space.
* Complexity: O(N*log(N)) time and O(1) space, where N is the number of
* nonzeros.
* TODO(minjie): The time complexity could be improved to O(N) by using a O(N)
* space.
*
* @param mat The input coo matrix
* @param sort_column True if column index should be sorted too.
* @return COO matrix with index sorted.
*/
inline COOMatrix COOSort(COOMatrix mat, bool sort_column = false) {
if ((mat.row_sorted && !sort_column) || mat.col_sorted)
return mat;
COOMatrix ret(mat.num_rows, mat.num_cols,
mat.row.Clone(), mat.col.Clone(),
COOHasData(mat)? mat.data.Clone() : mat.data,
mat.row_sorted, mat.col_sorted);
if ((mat.row_sorted && !sort_column) || mat.col_sorted) return mat;
COOMatrix ret(
mat.num_rows, mat.num_cols, mat.row.Clone(), mat.col.Clone(),
COOHasData(mat) ? mat.data.Clone() : mat.data, mat.row_sorted,
mat.col_sorted);
COOSort_(&ret, sort_column);
return ret;
}
/**
* @brief Remove entries from COO matrix by entry indices (data indices)
* @return A new COO matrix as well as a mapping from the new COO entries to the old COO
* entries.
* @return A new COO matrix as well as a mapping from the new COO entries to the
* old COO entries.
*/
COOMatrix COORemove(COOMatrix coo, IdArray entries);
......@@ -365,10 +376,12 @@ COOMatrix COORemove(COOMatrix coo, IdArray entries);
* @param new_row_ids the new row Ids (the index is the old row Id)
* @param new_col_ids the new column Ids (the index is the old col Id).
*/
COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
COOMatrix COOReorder(
COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
/**
* @brief Randomly select a fixed number of non-zero entries along each given row independently.
* @brief Randomly select a fixed number of non-zero entries along each given
* row independently.
*
* The function performs random choices along each row independently.
* The picked indices are returned in the form of a COO matrix.
......@@ -400,15 +413,12 @@ COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArr
* Should be of the same length as the data array.
* If an empty array is provided, assume uniform.
* @param replace True if sample with replacement
* @return A COOMatrix storing the picked row and col indices. Its data field stores the
* the index of the picked elements in the value array.
* @return A COOMatrix storing the picked row and col indices. Its data field
* stores the the index of the picked elements in the value array.
*/
COOMatrix COORowWiseSampling(
COOMatrix mat,
IdArray rows,
int64_t num_samples,
NDArray prob_or_mask = NDArray(),
bool replace = true);
COOMatrix mat, IdArray rows, int64_t num_samples,
NDArray prob_or_mask = NDArray(), bool replace = true);
/**
* @brief Randomly select a fixed number of non-zero entries for each edge type
......@@ -433,8 +443,8 @@ COOMatrix COORowWiseSampling(
* COOMatrix coo = ...;
* IdArray rows = ... ; // [0, 3]
* std::vector<int64_t> num_samples = {2, 2, 2};
* COOMatrix sampled = COORowWisePerEtypeSampling(coo, rows, eid2etype_offset, num_samples,
* FloatArray(), false);
* COOMatrix sampled = COORowWisePerEtypeSampling(coo, rows, eid2etype_offset,
* num_samples, FloatArray(), false);
* // possible sampled coo matrix:
* // sampled.num_rows = 4
* // sampled.num_cols = 4
......@@ -450,20 +460,18 @@ COOMatrix COORowWiseSampling(
* Should be of the same length as the data array.
* If an empty array is provided, assume uniform.
* @param replace True if sample with replacement
* @return A COOMatrix storing the picked row and col indices. Its data field stores the
* the index of the picked elements in the value array.
* @return A COOMatrix storing the picked row and col indices. Its data field
* stores the the index of the picked elements in the value array.
* @note The edges of the entire graph must be ordered by their edge types.
*/
COOMatrix COORowWisePerEtypeSampling(
COOMatrix mat,
IdArray rows,
const std::vector<int64_t>& eid2etype_offset,
COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples,
const std::vector<NDArray>& prob_or_mask,
bool replace = true);
const std::vector<NDArray>& prob_or_mask, bool replace = true);
/**
* @brief Select K non-zero entries with the largest weights along each given row.
* @brief Select K non-zero entries with the largest weights along each given
* row.
*
* The function performs top-k selection along each row independently.
* The picked indices are returned in the form of a COO matrix.
......@@ -492,18 +500,15 @@ COOMatrix COORowWisePerEtypeSampling(
* @param mat Input COO matrix.
* @param rows Rows to sample from.
* @param k The K value.
* @param weight Weight associated with each entry. Should be of the same length as the
* data array. If an empty array is provided, assume uniform.
* @param ascending If true, elements are sorted by ascending order, equivalent to find
* the K smallest values. Otherwise, find K largest values.
* @return A COOMatrix storing the picked row and col indices. Its data field stores the
* the index of the picked elements in the value array.
* @param weight Weight associated with each entry. Should be of the same length
* as the data array. If an empty array is provided, assume uniform.
* @param ascending If true, elements are sorted by ascending order, equivalent
* to find the K smallest values. Otherwise, find K largest values.
* @return A COOMatrix storing the picked row and col indices. Its data field
* stores the the index of the picked elements in the value array.
*/
COOMatrix COORowWiseTopk(
COOMatrix mat,
IdArray rows,
int64_t k,
NDArray weight,
COOMatrix mat, IdArray rows, int64_t k, NDArray weight,
bool ascending = false);
/**
......@@ -535,8 +540,7 @@ COOMatrix COORowWiseTopk(
* COOMatrix_C.num_rows : 3
* COOMatrix_C.num_cols : 4
*/
COOMatrix UnionCoo(
const std::vector<COOMatrix>& coos);
COOMatrix UnionCoo(const std::vector<COOMatrix>& coos);
/**
* @brief DisjointUnion a list COOMatrix into one COOMatrix.
......@@ -566,12 +570,13 @@ COOMatrix UnionCoo(
* COOMatrix_C.num_cols : 5
*
* @param coos The input list of coo matrix.
* @param src_offset A list of integers recording src vertix id offset of each Matrix in coos
* @param src_offset A list of integers recording dst vertix id offset of each Matrix in coos
* @param src_offset A list of integers recording src vertix id offset of each
* Matrix in coos
* @param src_offset A list of integers recording dst vertix id offset of each
* Matrix in coos
* @return The combined COOMatrix.
*/
COOMatrix DisjointUnionCoo(
const std::vector<COOMatrix>& coos);
COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos);
/**
* @brief COOMatrix toSimple.
......@@ -591,8 +596,8 @@ COOMatrix DisjointUnionCoo(
* edge_map = [0, 0, 0, 1, 1, 2, 3, 4, 4, 4, 4]
*
* @return The simplified COOMatrix
* The count recording the number of duplicated edges from the original graph.
* The edge mapping from the edge IDs of original graph to those of the
* The count recording the number of duplicated edges from the original
* graph. The edge mapping from the edge IDs of original graph to those of the
* returned graph.
*/
std::tuple<COOMatrix, IdArray, IdArray> COOToSimple(const COOMatrix& coo);
......@@ -642,11 +647,10 @@ std::tuple<COOMatrix, IdArray, IdArray> COOToSimple(const COOMatrix& coo);
* @return A list of COOMatrixes representing each disjoint components.
*/
std::vector<COOMatrix> DisjointPartitionCooBySizes(
const COOMatrix &coo,
const uint64_t batch_size,
const std::vector<uint64_t> &edge_cumsum,
const std::vector<uint64_t> &src_vertex_cumsum,
const std::vector<uint64_t> &dst_vertex_cumsum);
const COOMatrix& coo, const uint64_t batch_size,
const std::vector<uint64_t>& edge_cumsum,
const std::vector<uint64_t>& src_vertex_cumsum,
const std::vector<uint64_t>& dst_vertex_cumsum);
/**
* @brief Slice a contiguous chunk from a COOMatrix
......@@ -684,10 +688,9 @@ std::vector<COOMatrix> DisjointPartitionCooBySizes(
* @return COOMatrix representing the chunk.
*/
COOMatrix COOSliceContiguousChunk(
const COOMatrix &coo,
const std::vector<uint64_t> &edge_range,
const std::vector<uint64_t> &src_vertex_range,
const std::vector<uint64_t> &dst_vertex_range);
const COOMatrix& coo, const std::vector<uint64_t>& edge_range,
const std::vector<uint64_t>& src_vertex_range,
const std::vector<uint64_t>& dst_vertex_range);
/**
* @brief Create a LineGraph of input coo
......@@ -716,10 +719,11 @@ COOMatrix COOSliceContiguousChunk(
* [0, 1, 1, 0, 0]]
*
* @param coo COOMatrix to create the LineGraph
* @param backtracking whether the pair of (v, u) (u, v) edges are treated as linked
* @param backtracking whether the pair of (v, u) (u, v) edges are treated as
* linked
* @return LineGraph in COO format
*/
COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking);
COOMatrix COOLineGraph(const COOMatrix& coo, bool backtracking);
} // namespace aten
} // namespace dgl
......
......@@ -223,8 +223,10 @@ bool CSRIsSorted(CSRMatrix csr);
std::vector<runtime::NDArray> CSRGetDataAndIndices(
CSRMatrix, runtime::NDArray rows, runtime::NDArray cols);
/* @brief Get data. The return type is an ndarray due to possible duplicate
* entries. */
/**
* @brief Get data. The return type is an ndarray due to possible duplicate
* entries.
*/
inline runtime::NDArray CSRGetAllData(CSRMatrix mat, int64_t row, int64_t col) {
const auto& nbits = mat.indptr->dtype.bits;
const auto& ctx = mat.indptr->ctx;
......
......@@ -17,16 +17,16 @@
* DeviceSpecificImplementation<XPU>(...);
* });
*/
#define ATEN_XPU_SWITCH(val, XPU, op, ...) do { \
if ((val) == kDGLCPU) { \
constexpr auto XPU = kDGLCPU; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "Operator " << (op) << " does not support " \
<< dgl::runtime::DeviceTypeCode2Str(val) \
<< " device."; \
} \
} while (0)
#define ATEN_XPU_SWITCH(val, XPU, op, ...) \
do { \
if ((val) == kDGLCPU) { \
constexpr auto XPU = kDGLCPU; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Operator " << (op) << " does not support " \
<< dgl::runtime::DeviceTypeCode2Str(val) << " device."; \
} \
} while (0)
/**
* Dispatch according to device:
......@@ -37,24 +37,24 @@
* // Now XPU is a placeholder for array->ctx.device_type
* DeviceSpecificImplementation<XPU>(...);
* });
*
*
* We treat pinned memory as normal host memory if we don't want
* to enable CUDA UVA access for this operator
*/
#ifdef DGL_USE_CUDA
#define ATEN_XPU_SWITCH_CUDA(val, XPU, op, ...) do { \
if ((val) == kDGLCPU) { \
constexpr auto XPU = kDGLCPU; \
{__VA_ARGS__} \
} else if ((val) == kDGLCUDA) { \
constexpr auto XPU = kDGLCUDA; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "Operator " << (op) << " does not support " \
<< dgl::runtime::DeviceTypeCode2Str(val) \
<< " device."; \
} \
} while (0)
#define ATEN_XPU_SWITCH_CUDA(val, XPU, op, ...) \
do { \
if ((val) == kDGLCPU) { \
constexpr auto XPU = kDGLCPU; \
{ __VA_ARGS__ } \
} else if ((val) == kDGLCUDA) { \
constexpr auto XPU = kDGLCUDA; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Operator " << (op) << " does not support " \
<< dgl::runtime::DeviceTypeCode2Str(val) << " device."; \
} \
} while (0)
#else // DGL_USE_CUDA
#define ATEN_XPU_SWITCH_CUDA ATEN_XPU_SWITCH
#endif // DGL_USE_CUDA
......@@ -68,18 +68,19 @@
* DType *data = static_cast<DType *>(array->data);
* });
*/
#define ATEN_ID_TYPE_SWITCH(val, IdType, ...) do { \
CHECK_EQ((val).code, kDGLInt) << "ID must be integer type"; \
if ((val).bits == 32) { \
typedef int32_t IdType; \
{__VA_ARGS__} \
} else if ((val).bits == 64) { \
typedef int64_t IdType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "ID can only be int32 or int64"; \
} \
} while (0)
#define ATEN_ID_TYPE_SWITCH(val, IdType, ...) \
do { \
CHECK_EQ((val).code, kDGLInt) << "ID must be integer type"; \
if ((val).bits == 32) { \
typedef int32_t IdType; \
{ __VA_ARGS__ } \
} else if ((val).bits == 64) { \
typedef int64_t IdType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "ID can only be int32 or int64"; \
} \
} while (0)
/**
* Dispatch according to bits (either int32 or int64):
......@@ -90,18 +91,18 @@
* DType *data = static_cast<DType *>(array->data);
* });
*/
#define ATEN_ID_BITS_SWITCH(bits, IdType, ...) \
do { \
#define ATEN_ID_BITS_SWITCH(bits, IdType, ...) \
do { \
CHECK((bits) == 32 || (bits) == 64) << "bits must be 32 or 64"; \
if ((bits) == 32) { \
typedef int32_t IdType; \
{ __VA_ARGS__ } \
} else if ((bits) == 64) { \
typedef int64_t IdType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "ID can only be int32 or int64"; \
} \
if ((bits) == 32) { \
typedef int32_t IdType; \
{ __VA_ARGS__ } \
} else if ((bits) == 64) { \
typedef int64_t IdType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "ID can only be int32 or int64"; \
} \
} while (0)
/**
......@@ -113,75 +114,79 @@
* FloatType *data = static_cast<FloatType *>(array->data);
* });
*/
#define ATEN_FLOAT_TYPE_SWITCH(val, FloatType, val_name, ...) do { \
CHECK_EQ((val).code, kDGLFloat) \
<< (val_name) << " must be float type"; \
if ((val).bits == 32) { \
typedef float FloatType; \
{__VA_ARGS__} \
} else if ((val).bits == 64) { \
typedef double FloatType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << (val_name) \
<< " can only be float32 or float64"; \
} \
} while (0)
#define ATEN_FLOAT_TYPE_SWITCH(val, FloatType, val_name, ...) \
do { \
CHECK_EQ((val).code, kDGLFloat) << (val_name) << " must be float type"; \
if ((val).bits == 32) { \
typedef float FloatType; \
{ __VA_ARGS__ } \
} else if ((val).bits == 64) { \
typedef double FloatType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << (val_name) << " can only be float32 or float64"; \
} \
} while (0)
/**
* Dispatch according to float type, including 16bits (float16/bfloat16/float32/float64).
* Dispatch according to float type, including 16bits
* (float16/bfloat16/float32/float64).
*/
#ifdef DGL_USE_CUDA
#if BF16_ENABLED
#define ATEN_FLOAT_TYPE_SWITCH_16BITS(val, FloatType, XPU, val_name, ...) do { \
CHECK((val).code == kDGLFloat || (val.code == kDGLBfloat)) \
<< (val_name) << " must be float type"; \
if ((val).bits == 32) { \
typedef float FloatType; \
{__VA_ARGS__} \
} else if ((val).bits == 64) { \
typedef double FloatType; \
{__VA_ARGS__} \
} else if (XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLFloat) { \
typedef __half FloatType; \
{__VA_ARGS__} \
} else if (XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLBfloat) { \
typedef __nv_bfloat16 FloatType; \
{__VA_ARGS__} \
} else if (XPU == kDGLCPU) { \
LOG(FATAL) << (val_name) \
<< " can only be float32 or float64 on CPU"; \
} else { \
LOG(FATAL) << (val_name) \
<< " can only be float16/bfloat16/float32/float64 on GPU"; \
} \
} while (0)
#define ATEN_FLOAT_TYPE_SWITCH_16BITS(val, FloatType, XPU, val_name, ...) \
do { \
CHECK((val).code == kDGLFloat || (val.code == kDGLBfloat)) \
<< (val_name) << " must be float type"; \
if ((val).bits == 32) { \
typedef float FloatType; \
{ __VA_ARGS__ } \
} else if ((val).bits == 64) { \
typedef double FloatType; \
{ __VA_ARGS__ } \
} else if ( \
XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLFloat) { \
typedef __half FloatType; \
{ __VA_ARGS__ } \
} else if ( \
XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLBfloat) { \
typedef __nv_bfloat16 FloatType; \
{ __VA_ARGS__ } \
} else if (XPU == kDGLCPU) { \
LOG(FATAL) << (val_name) << " can only be float32 or float64 on CPU"; \
} else { \
LOG(FATAL) << (val_name) \
<< " can only be float16/bfloat16/float32/float64 on GPU"; \
} \
} while (0)
#else // BF16_ENABLED
#define ATEN_FLOAT_TYPE_SWITCH_16BITS(val, FloatType, XPU, val_name, ...) do { \
CHECK((val).code == kDGLFloat || (val.code == kDGLBfloat)) \
<< (val_name) << " must be float type"; \
if ((val).bits == 32) { \
typedef float FloatType; \
{__VA_ARGS__} \
} else if ((val).bits == 64) { \
typedef double FloatType; \
{__VA_ARGS__} \
} else if (XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLFloat) { \
typedef __half FloatType; \
{__VA_ARGS__} \
} else if (XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLBfloat) { \
LOG(FATAL) << "bfloat16 requires CUDA >= 11.0"; \
} else if (XPU == kDGLCPU) { \
LOG(FATAL) << (val_name) \
<< " can only be float32 or float64 on CPU"; \
} else { \
LOG(FATAL) << (val_name) \
<< " can only be float16/float32/float64 on GPU"; \
} \
} while (0)
#define ATEN_FLOAT_TYPE_SWITCH_16BITS(val, FloatType, XPU, val_name, ...) \
do { \
CHECK((val).code == kDGLFloat || (val.code == kDGLBfloat)) \
<< (val_name) << " must be float type"; \
if ((val).bits == 32) { \
typedef float FloatType; \
{ __VA_ARGS__ } \
} else if ((val).bits == 64) { \
typedef double FloatType; \
{ __VA_ARGS__ } \
} else if ( \
XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLFloat) { \
typedef __half FloatType; \
{ __VA_ARGS__ } \
} else if ( \
XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLBfloat) { \
LOG(FATAL) << "bfloat16 requires CUDA >= 11.0"; \
} else if (XPU == kDGLCPU) { \
LOG(FATAL) << (val_name) << " can only be float32 or float64 on CPU"; \
} else { \
LOG(FATAL) << (val_name) \
<< " can only be float16/float32/float64 on GPU"; \
} \
} while (0)
#endif // BF16_ENABLED
#else // DGL_USE_CUDA
#define ATEN_FLOAT_TYPE_SWITCH_16BITS(val, FloatType, XPU, val_name, ...) \
#else // DGL_USE_CUDA
#define ATEN_FLOAT_TYPE_SWITCH_16BITS(val, FloatType, XPU, val_name, ...) \
ATEN_FLOAT_TYPE_SWITCH(val, FloatType, val_name, {__VA_ARGS__})
#endif // DGL_USE_CUDA
......@@ -194,23 +199,25 @@
* DType *data = static_cast<DType *>(array->data);
* });
*/
#define ATEN_DTYPE_SWITCH(val, DType, val_name, ...) do { \
if ((val).code == kDGLInt && (val).bits == 32) { \
typedef int32_t DType; \
{__VA_ARGS__} \
} else if ((val).code == kDGLInt && (val).bits == 64) { \
typedef int64_t DType; \
{__VA_ARGS__} \
} else if ((val).code == kDGLFloat && (val).bits == 32) { \
typedef float DType; \
{__VA_ARGS__} \
} else if ((val).code == kDGLFloat && (val).bits == 64) { \
typedef double DType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << (val_name) << " can only be int32, int64, float32 or float64"; \
} \
} while (0)
#define ATEN_DTYPE_SWITCH(val, DType, val_name, ...) \
do { \
if ((val).code == kDGLInt && (val).bits == 32) { \
typedef int32_t DType; \
{ __VA_ARGS__ } \
} else if ((val).code == kDGLInt && (val).bits == 64) { \
typedef int64_t DType; \
{ __VA_ARGS__ } \
} else if ((val).code == kDGLFloat && (val).bits == 32) { \
typedef float DType; \
{ __VA_ARGS__ } \
} else if ((val).code == kDGLFloat && (val).bits == 64) { \
typedef double DType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << (val_name) \
<< " can only be int32, int64, float32 or float64"; \
} \
} while (0)
/**
* Dispatch according to data type (int8, uint8, float32 or float64):
......@@ -221,23 +228,25 @@
* DType *data = static_cast<DType *>(array->data);
* });
*/
#define ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(val, DType, val_name, ...) do { \
if ((val).code == kDGLInt && (val).bits == 8) { \
typedef int8_t DType; \
{__VA_ARGS__} \
} else if ((val).code == kDGLUInt && (val).bits == 8) { \
typedef uint8_t DType; \
{__VA_ARGS__} \
} else if ((val).code == kDGLFloat && (val).bits == 32) { \
typedef float DType; \
{__VA_ARGS__} \
} else if ((val).code == kDGLFloat && (val).bits == 64) { \
typedef double DType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << (val_name) << " can only be int8, uint8, float32 or float64"; \
} \
} while (0)
#define ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(val, DType, val_name, ...) \
do { \
if ((val).code == kDGLInt && (val).bits == 8) { \
typedef int8_t DType; \
{ __VA_ARGS__ } \
} else if ((val).code == kDGLUInt && (val).bits == 8) { \
typedef uint8_t DType; \
{ __VA_ARGS__ } \
} else if ((val).code == kDGLFloat && (val).bits == 32) { \
typedef float DType; \
{ __VA_ARGS__ } \
} else if ((val).code == kDGLFloat && (val).bits == 64) { \
typedef double DType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << (val_name) \
<< " can only be int8, uint8, float32 or float64"; \
} \
} while (0)
/**
* Dispatch data type only based on bit-width (8-bit, 16-bit, 32-bit, 64-bit):
......@@ -250,61 +259,62 @@
* DType *data = static_cast<DType *>(array->data);
* });
*/
#define ATEN_DTYPE_BITS_ONLY_SWITCH(val, DType, val_name, ...) do { \
if ((val).bits == 8) { \
typedef int8_t DType; \
{__VA_ARGS__} \
} else if ((val).bits == 16) { \
typedef int16_t DType; \
{__VA_ARGS__} \
} else if ((val).bits == 32) { \
typedef int32_t DType; \
{__VA_ARGS__} \
} else if ((val).bits == 64) { \
typedef int64_t DType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << (val_name) << " can only be 8-bit, 16-bit, 32-bit, or 64-bit"; \
} \
} while (0)
#define ATEN_DTYPE_BITS_ONLY_SWITCH(val, DType, val_name, ...) \
do { \
if ((val).bits == 8) { \
typedef int8_t DType; \
{ __VA_ARGS__ } \
} else if ((val).bits == 16) { \
typedef int16_t DType; \
{ __VA_ARGS__ } \
} else if ((val).bits == 32) { \
typedef int32_t DType; \
{ __VA_ARGS__ } \
} else if ((val).bits == 64) { \
typedef int64_t DType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << (val_name) \
<< " can only be 8-bit, 16-bit, 32-bit, or 64-bit"; \
} \
} while (0)
/**
* Dispatch according to integral type of CSR graphs.
* Identical to ATEN_ID_TYPE_SWITCH except for a different error message.
*/
#define ATEN_CSR_DTYPE_SWITCH(val, DType, ...) do { \
if ((val).code == kDGLInt && (val).bits == 32) { \
typedef int32_t DType; \
{__VA_ARGS__} \
} else if ((val).code == kDGLInt && (val).bits == 64) { \
typedef int64_t DType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "CSR matrix data can only be int32 or int64"; \
} \
} while (0)
#define ATEN_CSR_DTYPE_SWITCH(val, DType, ...) \
do { \
if ((val).code == kDGLInt && (val).bits == 32) { \
typedef int32_t DType; \
{ __VA_ARGS__ } \
} else if ((val).code == kDGLInt && (val).bits == 64) { \
typedef int64_t DType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "CSR matrix data can only be int32 or int64"; \
} \
} while (0)
// Macro to dispatch according to device context and index type.
#define ATEN_CSR_SWITCH(csr, XPU, IdType, op, ...) \
ATEN_XPU_SWITCH((csr).indptr->ctx.device_type, XPU, op, { \
ATEN_ID_TYPE_SWITCH((csr).indptr->dtype, IdType, { \
{__VA_ARGS__} \
}); \
#define ATEN_CSR_SWITCH(csr, XPU, IdType, op, ...) \
ATEN_XPU_SWITCH((csr).indptr->ctx.device_type, XPU, op, { \
ATEN_ID_TYPE_SWITCH((csr).indptr->dtype, IdType, {{__VA_ARGS__}}); \
});
// Macro to dispatch according to device context and index type.
#define ATEN_COO_SWITCH(coo, XPU, IdType, op, ...) \
ATEN_XPU_SWITCH((coo).row->ctx.device_type, XPU, op, { \
ATEN_ID_TYPE_SWITCH((coo).row->dtype, IdType, { \
{__VA_ARGS__} \
}); \
#define ATEN_COO_SWITCH(coo, XPU, IdType, op, ...) \
ATEN_XPU_SWITCH((coo).row->ctx.device_type, XPU, op, { \
ATEN_ID_TYPE_SWITCH((coo).row->dtype, IdType, {{__VA_ARGS__}}); \
});
#define CHECK_VALID_CONTEXT(VAR1, VAR2) \
CHECK(((VAR1)->ctx == (VAR2)->ctx) || (VAR1).IsPinned()) \
<< "Expected " << (#VAR2) << "(" << (VAR2)->ctx << ")" << " to have the same device " \
<< "context as " << (#VAR1) << "(" << (VAR1)->ctx << "). " \
<< "Or " << (#VAR1) << "(" << (VAR1)->ctx << ")" << " is pinned";
#define CHECK_VALID_CONTEXT(VAR1, VAR2) \
CHECK(((VAR1)->ctx == (VAR2)->ctx) || (VAR1).IsPinned()) \
<< "Expected " << (#VAR2) << "(" << (VAR2)->ctx << ")" \
<< " to have the same device " \
<< "context as " << (#VAR1) << "(" << (VAR1)->ctx << "). " \
<< "Or " << (#VAR1) << "(" << (VAR1)->ctx << ")" \
<< " is pinned";
/**
* Macro to dispatch according to the context of array and dtype of csr
......@@ -313,30 +323,25 @@
* If csr has the same context with array, same behivor as ATEN_CSR_SWITCH_CUDA.
* If csr is pinned, array's context will conduct the actual operation.
*/
#define ATEN_CSR_SWITCH_CUDA_UVA(csr, array, XPU, IdType, op, ...) do { \
CHECK_VALID_CONTEXT(csr.indices, array); \
ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, op, { \
ATEN_ID_TYPE_SWITCH((csr).indptr->dtype, IdType, { \
{__VA_ARGS__} \
}); \
}); \
} while (0)
#define ATEN_CSR_SWITCH_CUDA_UVA(csr, array, XPU, IdType, op, ...) \
do { \
CHECK_VALID_CONTEXT(csr.indices, array); \
ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, op, { \
ATEN_ID_TYPE_SWITCH((csr).indptr->dtype, IdType, {{__VA_ARGS__}}); \
}); \
} while (0)
// Macro to dispatch according to device context (allowing cuda)
#ifdef DGL_USE_CUDA
#define ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, op, ...) \
ATEN_XPU_SWITCH_CUDA((csr).indptr->ctx.device_type, XPU, op, { \
ATEN_ID_TYPE_SWITCH((csr).indptr->dtype, IdType, { \
{__VA_ARGS__} \
}); \
#define ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, op, ...) \
ATEN_XPU_SWITCH_CUDA((csr).indptr->ctx.device_type, XPU, op, { \
ATEN_ID_TYPE_SWITCH((csr).indptr->dtype, IdType, {{__VA_ARGS__}}); \
});
// Macro to dispatch according to device context and index type.
#define ATEN_COO_SWITCH_CUDA(coo, XPU, IdType, op, ...) \
ATEN_XPU_SWITCH_CUDA((coo).row->ctx.device_type, XPU, op, { \
ATEN_ID_TYPE_SWITCH((coo).row->dtype, IdType, { \
{__VA_ARGS__} \
}); \
#define ATEN_COO_SWITCH_CUDA(coo, XPU, IdType, op, ...) \
ATEN_XPU_SWITCH_CUDA((coo).row->ctx.device_type, XPU, op, { \
ATEN_ID_TYPE_SWITCH((coo).row->dtype, IdType, {{__VA_ARGS__}}); \
});
#else // DGL_USE_CUDA
#define ATEN_CSR_SWITCH_CUDA ATEN_CSR_SWITCH
......@@ -345,54 +350,56 @@
///////////////////////// Array checks //////////////////////////
#define IS_INT32(a) \
((a)->dtype.code == kDGLInt && (a)->dtype.bits == 32)
#define IS_INT64(a) \
((a)->dtype.code == kDGLInt && (a)->dtype.bits == 64)
#define IS_FLOAT32(a) \
((a)->dtype.code == kDGLFloat && (a)->dtype.bits == 32)
#define IS_FLOAT64(a) \
((a)->dtype.code == kDGLFloat && (a)->dtype.bits == 64)
#define IS_INT32(a) ((a)->dtype.code == kDGLInt && (a)->dtype.bits == 32)
#define IS_INT64(a) ((a)->dtype.code == kDGLInt && (a)->dtype.bits == 64)
#define IS_FLOAT32(a) ((a)->dtype.code == kDGLFloat && (a)->dtype.bits == 32)
#define IS_FLOAT64(a) ((a)->dtype.code == kDGLFloat && (a)->dtype.bits == 64)
#define CHECK_IF(cond, prop, value_name, dtype_name) \
CHECK(cond) << "Expecting " << (prop) << " of " << (value_name) << " to be " << (dtype_name)
#define CHECK_IF(cond, prop, value_name, dtype_name) \
CHECK(cond) << "Expecting " << (prop) << " of " << (value_name) << " to be " \
<< (dtype_name)
#define CHECK_INT32(value, value_name) \
CHECK_IF(IS_INT32(value), "dtype", value_name, "int32")
#define CHECK_INT64(value, value_name) \
CHECK_IF(IS_INT64(value), "dtype", value_name, "int64")
#define CHECK_INT(value, value_name) \
CHECK_IF(IS_INT32(value) || IS_INT64(value), "dtype", value_name, "int32 or int64")
#define CHECK_INT(value, value_name) \
CHECK_IF( \
IS_INT32(value) || IS_INT64(value), "dtype", value_name, \
"int32 or int64")
#define CHECK_FLOAT32(value, value_name) \
CHECK_IF(IS_FLOAT32(value), "dtype", value_name, "float32")
#define CHECK_FLOAT64(value, value_name) \
CHECK_IF(IS_FLOAT64(value), "dtype", value_name, "float64")
#define CHECK_FLOAT(value, value_name) \
CHECK_IF(IS_FLOAT32(value) || IS_FLOAT64(value), "dtype", value_name, "float32 or float64")
#define CHECK_FLOAT(value, value_name) \
CHECK_IF( \
IS_FLOAT32(value) || IS_FLOAT64(value), "dtype", value_name, \
"float32 or float64")
#define CHECK_NDIM(value, _ndim, value_name) \
CHECK_IF((value)->ndim == (_ndim), "ndim", value_name, _ndim)
#define CHECK_SAME_DTYPE(VAR1, VAR2) \
CHECK((VAR1)->dtype == (VAR2)->dtype) \
<< "Expected " << (#VAR2) << " to be the same type as " << (#VAR1) << "(" \
<< (VAR1)->dtype << ")" \
<< ". But got " << (VAR2)->dtype << ".";
#define CHECK_SAME_DTYPE(VAR1, VAR2) \
CHECK((VAR1)->dtype == (VAR2)->dtype) \
<< "Expected " << (#VAR2) << " to be the same type as " << (#VAR1) \
<< "(" << (VAR1)->dtype << ")" \
<< ". But got " << (VAR2)->dtype << ".";
#define CHECK_SAME_CONTEXT(VAR1, VAR2) \
CHECK((VAR1)->ctx == (VAR2)->ctx) \
<< "Expected " << (#VAR2) << " to have the same device context as " << (#VAR1) << "(" \
<< (VAR1)->ctx << ")" \
<< ". But got " << (VAR2)->ctx << ".";
#define CHECK_SAME_CONTEXT(VAR1, VAR2) \
CHECK((VAR1)->ctx == (VAR2)->ctx) \
<< "Expected " << (#VAR2) << " to have the same device context as " \
<< (#VAR1) << "(" << (VAR1)->ctx << ")" \
<< ". But got " << (VAR2)->ctx << ".";
#define CHECK_NO_OVERFLOW(dtype, val) \
do { \
if (sizeof(val) == 8 && (dtype).bits == 32) \
CHECK_LE((val), 0x7FFFFFFFL) << "int32 overflow for argument " << (#val) << "."; \
#define CHECK_NO_OVERFLOW(dtype, val) \
do { \
if (sizeof(val) == 8 && (dtype).bits == 32) \
CHECK_LE((val), 0x7FFFFFFFL) \
<< "int32 overflow for argument " << (#val) << "."; \
} while (0);
#define CHECK_IS_ID_ARRAY(VAR) \
CHECK((VAR)->ndim == 1 && (IS_INT32(VAR) || IS_INT64(VAR))) \
<< "Expected argument " << (#VAR) << " to be an 1D integer array.";
#define CHECK_IS_ID_ARRAY(VAR) \
CHECK((VAR)->ndim == 1 && (IS_INT32(VAR) || IS_INT64(VAR))) \
<< "Expected argument " << (#VAR) << " to be an 1D integer array.";
#endif // DGL_ATEN_MACRO_H_
......@@ -10,7 +10,8 @@
#include <memory>
namespace dgl {
// Util class to call the private/public empty constructor, which is needed for serialization
// Util class to call the private/public empty constructor, which is needed for
// serialization
class Serializer {
public:
template <typename T>
......
......@@ -6,12 +6,11 @@
#ifndef DGL_RUNTIME_NDARRAY_H_
#define DGL_RUNTIME_NDARRAY_H_
#include <atomic>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <memory>
#include "c_runtime_api.h"
#include "serializer.h"
......@@ -29,7 +28,7 @@
#endif // DGL_USE_CUDA
// forward declaration
inline std::ostream& operator << (std::ostream& os, DGLDataType t);
inline std::ostream& operator<<(std::ostream& os, DGLDataType t);
namespace dgl {
......@@ -39,13 +38,13 @@ namespace dgl {
* Usage:
* DGLDataTypeTraits<int>::dtype == dtype
*/
template<typename T>
template <typename T>
struct DGLDataTypeTraits {
static constexpr DGLDataType dtype{0, 0, 0}; // dummy
static constexpr DGLDataType dtype{0, 0, 0}; // dummy
};
#define GEN_DGLDATATYPETRAITS_FOR(T, code, bits) \
template<> \
struct DGLDataTypeTraits<T> { \
#define GEN_DGLDATATYPETRAITS_FOR(T, code, bits) \
template <> \
struct DGLDataTypeTraits<T> { \
static constexpr DGLDataType dtype{code, bits, 1}; \
}
GEN_DGLDATATYPETRAITS_FOR(int8_t, kDGLInt, 8);
......@@ -53,8 +52,8 @@ GEN_DGLDATATYPETRAITS_FOR(uint8_t, kDGLUInt, 8);
GEN_DGLDATATYPETRAITS_FOR(int16_t, kDGLInt, 16);
GEN_DGLDATATYPETRAITS_FOR(int32_t, kDGLInt, 32);
GEN_DGLDATATYPETRAITS_FOR(int64_t, kDGLInt, 64);
// XXX(BarclayII) most DL frameworks do not support unsigned int and long arrays, so I'm just
// converting uints to signed DTypes.
// XXX(BarclayII) most DL frameworks do not support unsigned int and long
// arrays, so I'm just converting uints to signed DTypes.
GEN_DGLDATATYPETRAITS_FOR(uint32_t, kDGLInt, 32);
GEN_DGLDATATYPETRAITS_FOR(uint64_t, kDGLInt, 64);
#ifdef DGL_USE_CUDA
......@@ -98,14 +97,12 @@ class NDArray {
* @brief move constructor
* @param other The value to be moved
*/
NDArray(NDArray&& other) // NOLINT(*)
NDArray(NDArray&& other) // NOLINT(*)
: data_(other.data_) {
other.data_ = nullptr;
}
/** @brief destructor */
~NDArray() {
this->reset();
}
~NDArray() { this->reset(); }
/**
* @brief Swap this array with another NDArray
* @param other The other NDArray
......@@ -130,17 +127,13 @@ class NDArray {
*/
NDArray& operator=(NDArray&& other) { // NOLINT(*)
// copy-and-swap idiom
NDArray(std::move(other)).swap(*this); // NOLINT(*)
NDArray(std::move(other)).swap(*this); // NOLINT(*)
return *this;
}
/** @return If NDArray is defined */
bool defined() const {
return data_ != nullptr;
}
bool defined() const { return data_ != nullptr; }
/** @return If both NDArray reference the same container */
bool same_as(const NDArray& other) const {
return data_ == other.data_;
}
bool same_as(const NDArray& other) const { return data_ == other.data_; }
/** @brief reset the content of NDArray to be nullptr */
inline void reset();
/**
......@@ -160,22 +153,23 @@ class NDArray {
else
return static_cast<T*>(operator->()->data);
}
/**
/**
* @brief Copy data content from/into another array.
* @param other The source array to be copied from.
* @note The copy runs on the dgl internal stream if it involves a GPU context.
* @note The copy runs on the dgl internal stream if it involves a GPU
* context.
*/
inline void CopyFrom(DGLArray* other);
inline void CopyFrom(const NDArray& other);
inline void CopyTo(DGLArray *other) const;
inline void CopyTo(const NDArray &other) const;
inline void CopyTo(DGLArray* other) const;
inline void CopyTo(const NDArray& other) const;
/**
* @brief Copy the data to another context.
* @param ctx The target context.
* @return The array under another context.
*/
inline NDArray CopyTo(const DGLContext &ctx) const;
inline NDArray CopyTo(const DGLContext& ctx) const;
/**
* @brief Return a new array with a copy of the content.
*/
......@@ -224,8 +218,8 @@ class NDArray {
* @param offset The offset (in bytes) of the starting pointer.
* @note The memory size of new array must be smaller than the current one.
*/
DGL_DLL NDArray CreateView(
std::vector<int64_t> shape, DGLDataType dtype, int64_t offset = 0);
DGL_DLL NDArray
CreateView(std::vector<int64_t> shape, DGLDataType dtype, int64_t offset = 0);
/**
* @brief Create an empty NDArray.
* @param shape The shape of the new array.
......@@ -233,9 +227,8 @@ class NDArray {
* @param ctx The context of the Array.
* @return The created Array
*/
DGL_DLL static NDArray Empty(std::vector<int64_t> shape,
DGLDataType dtype,
DGLContext ctx);
DGL_DLL static NDArray Empty(
std::vector<int64_t> shape, DGLDataType dtype, DGLContext ctx);
/**
* @brief Create an empty NDArray with shared memory.
* @param name The name of shared memory.
......@@ -245,11 +238,9 @@ class NDArray {
* @param is_create whether to create shared memory.
* @return The created Array
*/
DGL_DLL static NDArray EmptyShared(const std::string &name,
std::vector<int64_t> shape,
DGLDataType dtype,
DGLContext ctx,
bool is_create);
DGL_DLL static NDArray EmptyShared(
const std::string& name, std::vector<int64_t> shape, DGLDataType dtype,
DGLContext ctx, bool is_create);
/**
* @brief Get the size of the array in the number of bytes.
*/
......@@ -264,23 +255,24 @@ class NDArray {
* @brief Create a NDArray by copying from std::vector.
* @tparam T Type of vector data. Determines the dtype of returned array.
*/
template<typename T>
template <typename T>
DGL_DLL static NDArray FromVector(
const std::vector<T>& vec, DGLContext ctx = DGLContext{kDGLCPU, 0});
/**
* @brief Create a NDArray from a raw pointer.
*/
DGL_DLL static NDArray CreateFromRaw(const std::vector<int64_t>& shape,
DGLDataType dtype, DGLContext ctx, void* raw, bool auto_free);
DGL_DLL static NDArray CreateFromRaw(
const std::vector<int64_t>& shape, DGLDataType dtype, DGLContext ctx,
void* raw, bool auto_free);
/**
* @brief Create a std::vector from a 1D NDArray.
* @tparam T Type of vector data.
* @note Type casting is NOT performed. The caller has to make sure that the vector
* type matches the dtype of NDArray.
* @note Type casting is NOT performed. The caller has to make sure that the
* vector type matches the dtype of NDArray.
*/
template<typename T>
template <typename T>
std::vector<T> ToVector() const;
std::shared_ptr<SharedMemory> GetSharedMem() const;
......@@ -291,8 +283,7 @@ class NDArray {
* @param to The target array.
* @param (optional) stream The stream used in copy.
*/
DGL_DLL static void CopyFromTo(
DGLArray* from, DGLArray* to);
DGL_DLL static void CopyFromTo(DGLArray* from, DGLArray* to);
DGL_DLL static void CopyFromTo(
DGLArray* from, DGLArray* to, DGLStreamHandle stream);
......@@ -337,8 +328,8 @@ class NDArray {
static void DefaultDeleter(NDArray::Container* ptr);
// Local create function which allocates tensor metadata
// but does not allocate space for the data.
static NDArray Create(std::vector<int64_t> shape,
DGLDataType dtype, DGLContext ctx);
static NDArray Create(
std::vector<int64_t> shape, DGLDataType dtype, DGLContext ctx);
// Implementation of API function
static DGLArray* MoveAsDGLArray(NDArray arr);
};
......@@ -360,7 +351,6 @@ class NDArray {
*/
inline bool SaveDGLArray(dmlc::Stream* strm, const DGLArray* tensor);
/**
* @brief Reference counted Container object used to back NDArray.
*
......@@ -409,9 +399,7 @@ struct NDArray::Container {
/** @brief pointer to shared memory */
std::shared_ptr<SharedMemory> mem;
/** @brief developer function, increases reference counter */
void IncRef() {
ref_counter_.fetch_add(1, std::memory_order_relaxed);
}
void IncRef() { ref_counter_.fetch_add(1, std::memory_order_relaxed); }
/** @brief developer function, decrease reference counter */
void DecRef() {
if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) {
......@@ -444,16 +432,12 @@ struct NDArray::Container {
// implementations of inline functions
// the usages of functions are documented in place.
inline NDArray::NDArray(Container* data)
: data_(data) {
if (data_)
data_->IncRef();
inline NDArray::NDArray(Container* data) : data_(data) {
if (data_) data_->IncRef();
}
inline NDArray::NDArray(const NDArray& other)
: data_(other.data_) {
if (data_)
data_->IncRef();
inline NDArray::NDArray(const NDArray& other) : data_(other.data_) {
if (data_) data_->IncRef();
}
inline void NDArray::reset() {
......@@ -473,21 +457,22 @@ inline void NDArray::CopyFrom(const NDArray& other) {
CopyFrom(&(other.data_->dl_tensor));
}
inline void NDArray::CopyTo(DGLArray *other) const {
inline void NDArray::CopyTo(DGLArray* other) const {
CHECK(data_ != nullptr);
CopyFromTo(&(data_->dl_tensor), other);
}
inline void NDArray::CopyTo(const NDArray &other) const {
inline void NDArray::CopyTo(const NDArray& other) const {
CHECK(other.data_ != nullptr);
CopyTo(&(other.data_->dl_tensor));
}
inline NDArray NDArray::CopyTo(const DGLContext &ctx) const {
inline NDArray NDArray::CopyTo(const DGLContext& ctx) const {
CHECK(data_ != nullptr);
const DGLArray* dptr = operator->();
NDArray ret = Empty(std::vector<int64_t>(dptr->shape, dptr->shape + dptr->ndim),
dptr->dtype, ctx);
NDArray ret = Empty(
std::vector<int64_t>(dptr->shape, dptr->shape + dptr->ndim), dptr->dtype,
ctx);
this->CopyTo(ret);
return ret;
}
......@@ -530,8 +515,7 @@ inline const DGLArray* NDArray::operator->() const {
/** @brief Magic number for NDArray file */
constexpr uint64_t kDGLNDArrayMagic = 0xDD5E40F096B4A13F;
inline bool SaveDGLArray(dmlc::Stream* strm,
DGLArray* tensor) {
inline bool SaveDGLArray(dmlc::Stream* strm, DGLArray* tensor) {
uint64_t header = kDGLNDArrayMagic, reserved = 0;
strm->Write(header);
strm->Write(reserved);
......@@ -560,16 +544,14 @@ inline bool SaveDGLArray(dmlc::Stream* strm,
int64_t data_byte_size = type_bytes * num_elems;
strm->Write(data_byte_size);
if (DMLC_IO_NO_ENDIAN_SWAP &&
tensor->ctx.device_type == kDGLCPU &&
tensor->strides == nullptr &&
tensor->byte_offset == 0) {
if (DMLC_IO_NO_ENDIAN_SWAP && tensor->ctx.device_type == kDGLCPU &&
tensor->strides == nullptr && tensor->byte_offset == 0) {
// quick path
strm->Write(tensor->data, data_byte_size);
} else {
std::vector<uint8_t> bytes(data_byte_size);
CHECK_EQ(DGLArrayCopyToBytes(
tensor, dmlc::BeginPtr(bytes), data_byte_size), 0)
CHECK_EQ(
DGLArrayCopyToBytes(tensor, dmlc::BeginPtr(bytes), data_byte_size), 0)
<< DGLGetLastError();
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(dmlc::BeginPtr(bytes), type_bytes, num_elems);
......@@ -586,22 +568,37 @@ inline bool SaveDGLArray(dmlc::Stream* strm,
*/
inline const char* TypeCode2Str(int type_code) {
switch (type_code) {
case kDGLInt: return "int";
case kDGLUInt: return "uint";
case kDGLFloat: return "float";
case kStr: return "str";
case kBytes: return "bytes";
case kHandle: return "handle";
case kNull: return "NULL";
case kObjectHandle: return "ObjectHandle";
case kArrayHandle: return "ArrayHandle";
case kDGLDataType: return "DGLDataType";
case kDGLContext: return "DGLContext";
case kFuncHandle: return "FunctionHandle";
case kModuleHandle: return "ModuleHandle";
case kNDArrayContainer: return "NDArrayContainer";
default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return "";
case kDGLInt:
return "int";
case kDGLUInt:
return "uint";
case kDGLFloat:
return "float";
case kStr:
return "str";
case kBytes:
return "bytes";
case kHandle:
return "handle";
case kNull:
return "NULL";
case kObjectHandle:
return "ObjectHandle";
case kArrayHandle:
return "ArrayHandle";
case kDGLDataType:
return "DGLDataType";
case kDGLContext:
return "DGLContext";
case kFuncHandle:
return "FunctionHandle";
case kModuleHandle:
return "ModuleHandle";
case kNDArrayContainer:
return "NDArrayContainer";
default:
LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
return "";
}
}
......@@ -612,10 +609,14 @@ inline const char* TypeCode2Str(int type_code) {
*/
inline const char* DeviceTypeCode2Str(DGLDeviceType device_type) {
switch (device_type) {
case kDGLCPU: return "cpu";
case kDGLCUDA: return "cuda";
default: LOG(FATAL) << "Unsupported device type code="
<< static_cast<int>(device_type); return "";
case kDGLCPU:
return "cpu";
case kDGLCUDA:
return "cuda";
default:
LOG(FATAL) << "Unsupported device type code="
<< static_cast<int>(device_type);
return "";
}
}
......@@ -626,14 +627,18 @@ inline const char* DeviceTypeCode2Str(DGLDeviceType device_type) {
*/
inline DGLDataType String2DGLDataType(std::string s) {
DGLDataType t;
t.bits = 32; t.lanes = 1;
t.bits = 32;
t.lanes = 1;
const char* scan;
if (s.substr(0, 3) == "int") {
t.code = kDGLInt; scan = s.c_str() + 3;
t.code = kDGLInt;
scan = s.c_str() + 3;
} else if (s.substr(0, 4) == "uint") {
t.code = kDGLUInt; scan = s.c_str() + 4;
t.code = kDGLUInt;
scan = s.c_str() + 4;
} else if (s.substr(0, 5) == "float") {
t.code = kDGLFloat; scan = s.c_str() + 5;
t.code = kDGLFloat;
scan = s.c_str() + 5;
} else if (s.substr(0, 6) == "handle") {
t.code = kHandle;
t.bits = 64; // handle uses 64 bit by default.
......@@ -674,9 +679,9 @@ inline std::string DGLDataType2String(DGLDataType t) {
}
// macro to check type code.
#define DGL_CHECK_TYPE_CODE(CODE, T) \
CHECK_EQ(CODE, T) << " expected " \
<< TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) \
#define DGL_CHECK_TYPE_CODE(CODE, T) \
CHECK_EQ(CODE, T) << " expected " << TypeCode2Str(T) << " but get " \
<< TypeCode2Str(CODE)
} // namespace runtime
} // namespace dgl
......@@ -686,69 +691,69 @@ DMLC_DECLARE_TRAITS(has_saveload, dgl::runtime::NDArray, true);
} // namespace dmlc
///////////////// Operator overloading for NDArray /////////////////
dgl::runtime::NDArray operator + (const dgl::runtime::NDArray& a1,
const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator - (const dgl::runtime::NDArray& a1,
const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator * (const dgl::runtime::NDArray& a1,
const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator / (const dgl::runtime::NDArray& a1,
const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator % (const dgl::runtime::NDArray& a1,
const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator + (const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator - (const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator * (const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator / (const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator % (const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator + (int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator - (int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator * (int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator / (int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator % (int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator - (const dgl::runtime::NDArray& array);
dgl::runtime::NDArray operator > (const dgl::runtime::NDArray& a1,
const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator < (const dgl::runtime::NDArray& a1,
const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator >= (const dgl::runtime::NDArray& a1,
const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator <= (const dgl::runtime::NDArray& a1,
const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator == (const dgl::runtime::NDArray& a1,
const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator != (const dgl::runtime::NDArray& a1,
const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator > (const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator < (const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator >= (const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator <= (const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator == (const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator != (const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator > (int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator < (int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator >= (int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator <= (int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator == (int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator != (int64_t lhs, const dgl::runtime::NDArray& a2);
std::ostream& operator << (std::ostream& os, dgl::runtime::NDArray array);
dgl::runtime::NDArray operator+(
const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator-(
const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator*(
const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator/(
const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator%(
const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator+(const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator-(const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator*(const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator/(const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator%(const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator+(int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator-(int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator*(int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator/(int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator%(int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator-(const dgl::runtime::NDArray& array);
dgl::runtime::NDArray operator>(
const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator<(
const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator>=(
const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator<=(
const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator==(
const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator!=(
const dgl::runtime::NDArray& a1, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator>(const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator<(const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator>=(const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator<=(const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator==(const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator!=(const dgl::runtime::NDArray& a1, int64_t rhs);
dgl::runtime::NDArray operator>(int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator<(int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator>=(int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator<=(int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator==(int64_t lhs, const dgl::runtime::NDArray& a2);
dgl::runtime::NDArray operator!=(int64_t lhs, const dgl::runtime::NDArray& a2);
std::ostream& operator<<(std::ostream& os, dgl::runtime::NDArray array);
///////////////// Operator overloading for DGLDataType /////////////////
/** @brief Check whether two data types are the same.*/
inline bool operator == (const DGLDataType& ty1, const DGLDataType& ty2) {
inline bool operator==(const DGLDataType& ty1, const DGLDataType& ty2) {
return ty1.code == ty2.code && ty1.bits == ty2.bits && ty1.lanes == ty2.lanes;
}
/** @brief Check whether two data types are different.*/
inline bool operator != (const DGLDataType& ty1, const DGLDataType& ty2) {
inline bool operator!=(const DGLDataType& ty1, const DGLDataType& ty2) {
return !(ty1 == ty2);
}
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
inline std::ostream& operator << (std::ostream& os, DGLDataType t) {
inline std::ostream& operator<<(std::ostream& os, DGLDataType t) {
os << dgl::runtime::TypeCode2Str(t.code);
if (t.code == kHandle) return os;
os << static_cast<int>(t.bits);
......@@ -762,18 +767,20 @@ inline std::ostream& operator << (std::ostream& os, DGLDataType t) {
///////////////// Operator overloading for DGLContext /////////////////
/** @brief Check whether two device contexts are the same.*/
inline bool operator == (const DGLContext& ctx1, const DGLContext& ctx2) {
return ctx1.device_type == ctx2.device_type && ctx1.device_id == ctx2.device_id;
inline bool operator==(const DGLContext& ctx1, const DGLContext& ctx2) {
return ctx1.device_type == ctx2.device_type &&
ctx1.device_id == ctx2.device_id;
}
/** @brief Check whether two device contexts are different.*/
inline bool operator != (const DGLContext& ctx1, const DGLContext& ctx2) {
inline bool operator!=(const DGLContext& ctx1, const DGLContext& ctx2) {
return !(ctx1 == ctx2);
}
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
inline std::ostream& operator << (std::ostream& os, const DGLContext& ctx) {
return os << dgl::runtime::DeviceTypeCode2Str(ctx.device_type) << ":" << ctx.device_id;
inline std::ostream& operator<<(std::ostream& os, const DGLContext& ctx) {
return os << dgl::runtime::DeviceTypeCode2Str(ctx.device_type) << ":"
<< ctx.device_id;
}
#endif
......
......@@ -7,19 +7,18 @@
#define DGL_RUNTIME_PARALLEL_FOR_H_
#include <dmlc/omp.h>
#include <algorithm>
#include <string>
#include <atomic>
#include <cstdlib>
#include <exception>
#include <vector>
#include <atomic>
#include <string>
#include <utility>
#include <vector>
namespace {
int64_t divup(int64_t x, int64_t y) {
return (x + y - 1) / y;
}
}
int64_t divup(int64_t x, int64_t y) { return (x + y - 1) / y; }
} // namespace
namespace dgl {
namespace runtime {
......@@ -37,9 +36,7 @@ struct DefaultGrainSizeT {
}
}
size_t operator()() {
return grain_size;
}
size_t operator()() { return grain_size; }
};
} // namespace
......@@ -48,7 +45,9 @@ inline size_t compute_num_threads(size_t begin, size_t end, size_t grain_size) {
if (omp_in_parallel() || end - begin <= grain_size || end - begin == 1)
return 1;
return std::min(static_cast<int64_t>(omp_get_max_threads()), divup(end - begin, grain_size));
return std::min(
static_cast<int64_t>(omp_get_max_threads()),
divup(end - begin, grain_size));
#else
return 1;
#endif
......@@ -60,15 +59,13 @@ static DefaultGrainSizeT default_grain_size;
* @brief OpenMP-based parallel for loop.
*
* It requires each thread's workload to have at least \a grain_size elements.
* The loop body will be a function that takes in two arguments \a begin and \a end, which
* stands for the starting (inclusive) and ending index (exclusive) of the workload.
* The loop body will be a function that takes in two arguments \a begin and \a
* end, which stands for the starting (inclusive) and ending index (exclusive)
* of the workload.
*/
template <typename F>
void parallel_for(
const size_t begin,
const size_t end,
const size_t grain_size,
F&& f) {
const size_t begin, const size_t end, const size_t grain_size, F&& f) {
if (begin >= end) {
return;
}
......@@ -89,13 +86,11 @@ void parallel_for(
try {
f(begin_tid, end_tid);
} catch (...) {
if (!err_flag.test_and_set())
eptr = std::current_exception();
if (!err_flag.test_and_set()) eptr = std::current_exception();
}
}
}
if (eptr)
std::rethrow_exception(eptr);
if (eptr) std::rethrow_exception(eptr);
#else
f(begin, end);
#endif
......@@ -110,22 +105,21 @@ void parallel_for(
* parallel for pragma with static scheduling.
*/
template <typename F>
void parallel_for(
const size_t begin,
const size_t end,
F&& f) {
void parallel_for(const size_t begin, const size_t end, F&& f) {
parallel_for(begin, end, default_grain_size(), std::forward<F>(f));
}
/**
* @brief OpenMP-based two-stage parallel reduction.
*
* The first-stage reduction function \a f works in parallel. Each thread's workload has
* at least \a grain_size elements. The loop body will be a function that takes in
* the starting index (inclusive), the ending index (exclusive), and the reduction identity.
* The first-stage reduction function \a f works in parallel. Each thread's
* workload has at least \a grain_size elements. The loop body will be a
* function that takes in the starting index (inclusive), the ending index
* (exclusive), and the reduction identity.
*
* The second-stage reduction function \a sf is a binary function working in the main
* thread. It aggregates the partially reduced result computed from each thread.
* The second-stage reduction function \a sf is a binary function working in the
* main thread. It aggregates the partially reduced result computed from each
* thread.
*
* Example to compute a parallelized max reduction of an array \c a:
*
......@@ -134,11 +128,9 @@ void parallel_for(
* 100, // ending index
* 1, // grain size
* -std::numeric_limits<float>::infinity, // identity
* [&a] (int begin, int end, float ident) { // first-stage partial reducer
* float result = ident;
* for (int i = begin; i < end; ++i)
* result = std::max(result, a[i]);
* return result;
* [&a] (int begin, int end, float ident) { // first-stage partial
* reducer float result = ident; for (int i = begin; i < end; ++i) result =
* std::max(result, a[i]); return result;
* },
* [] (float result, float partial_result) {
* return std::max(result, partial_result);
......@@ -146,12 +138,8 @@ void parallel_for(
*/
template <typename DType, typename F, typename SF>
DType parallel_reduce(
const size_t begin,
const size_t end,
const size_t grain_size,
const DType ident,
const F& f,
const SF& sf) {
const size_t begin, const size_t end, const size_t grain_size,
const DType ident, const F& f, const SF& sf) {
if (begin >= end) {
return ident;
}
......@@ -174,17 +162,14 @@ DType parallel_reduce(
try {
results[tid] = f(begin_tid, end_tid, ident);
} catch (...) {
if (!err_flag.test_and_set())
eptr = std::current_exception();
if (!err_flag.test_and_set()) eptr = std::current_exception();
}
}
}
if (eptr)
std::rethrow_exception(eptr);
if (eptr) std::rethrow_exception(eptr);
DType out = ident;
for (int64_t i = 0; i < num_threads; ++i)
out = sf(out, results[i]);
for (int64_t i = 0; i < num_threads; ++i) out = sf(out, results[i]);
return out;
}
......
......@@ -7,12 +7,14 @@
#include <dgl/graph_traversal.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
#include <dgl/runtime/shared_mem.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/shared_mem.h>
#include <sstream>
#include "../c_api_common.h"
#include "./array_op.h"
#include "./arith.h"
#include "./array_op.h"
using namespace dgl::runtime;
......@@ -73,17 +75,14 @@ template NDArray Full<double>(double val, int64_t length, DGLContext ctx);
IdArray AsNumBits(IdArray arr, uint8_t bits) {
CHECK(bits == 32 || bits == 64)
<< "Invalid ID type. Must be int32 or int64, but got int"
<< static_cast<int>(bits) << ".";
if (arr->dtype.bits == bits)
return arr;
if (arr.NumElements() == 0)
return NewIdArray(arr->shape[0], arr->ctx, bits);
<< "Invalid ID type. Must be int32 or int64, but got int"
<< static_cast<int>(bits) << ".";
if (arr->dtype.bits == bits) return arr;
if (arr.NumElements() == 0) return NewIdArray(arr->shape[0], arr->ctx, bits);
IdArray ret;
ATEN_XPU_SWITCH_CUDA(arr->ctx.device_type, XPU, "AsNumBits", {
ATEN_ID_TYPE_SWITCH(arr->dtype, IdType, {
ret = impl::AsNumBits<XPU, IdType>(arr, bits);
});
ATEN_ID_TYPE_SWITCH(
arr->dtype, IdType, { ret = impl::AsNumBits<XPU, IdType>(arr, bits); });
});
return ret;
}
......@@ -98,14 +97,12 @@ IdArray HStack(IdArray lhs, IdArray rhs) {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
const int64_t len = lhs->shape[0];
ret = NewIdArray(2 * len, lhs->ctx, lhs->dtype.bits);
device->CopyDataFromTo(lhs.Ptr<IdType>(), 0,
ret.Ptr<IdType>(), 0,
len * sizeof(IdType),
ctx, ctx, lhs->dtype);
device->CopyDataFromTo(rhs.Ptr<IdType>(), 0,
ret.Ptr<IdType>(), len * sizeof(IdType),
len * sizeof(IdType),
ctx, ctx, lhs->dtype);
device->CopyDataFromTo(
lhs.Ptr<IdType>(), 0, ret.Ptr<IdType>(), 0, len * sizeof(IdType), ctx,
ctx, lhs->dtype);
device->CopyDataFromTo(
rhs.Ptr<IdType>(), 0, ret.Ptr<IdType>(), len * sizeof(IdType),
len * sizeof(IdType), ctx, ctx, lhs->dtype);
});
return ret;
}
......@@ -127,11 +124,11 @@ NDArray IndexSelect(NDArray array, IdArray index) {
return ret;
}
template<typename ValueType>
template <typename ValueType>
ValueType IndexSelect(NDArray array, int64_t index) {
CHECK_EQ(array->ndim, 1) << "Only support select values from 1D array.";
CHECK(index >= 0 && index < array.NumElements())
<< "Index " << index << " is out of bound.";
<< "Index " << index << " is out of bound.";
ValueType ret = 0;
ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "IndexSelect", {
ATEN_DTYPE_SWITCH(array->dtype, DType, "values", {
......@@ -150,17 +147,17 @@ template double IndexSelect<double>(NDArray array, int64_t index);
NDArray IndexSelect(NDArray array, int64_t start, int64_t end) {
CHECK_EQ(array->ndim, 1) << "Only support select values from 1D array.";
CHECK(start >= 0 && start < array.NumElements())
<< "Index " << start << " is out of bound.";
<< "Index " << start << " is out of bound.";
CHECK(end >= 0 && end <= array.NumElements())
<< "Index " << end << " is out of bound.";
<< "Index " << end << " is out of bound.";
CHECK_LE(start, end);
auto device = runtime::DeviceAPI::Get(array->ctx);
const int64_t len = end - start;
NDArray ret = NDArray::Empty({len}, array->dtype, array->ctx);
ATEN_DTYPE_SWITCH(array->dtype, DType, "values", {
device->CopyDataFromTo(array->data, start * sizeof(DType),
ret->data, 0, len * sizeof(DType),
array->ctx, ret->ctx, array->dtype);
device->CopyDataFromTo(
array->data, start * sizeof(DType), ret->data, 0, len * sizeof(DType),
array->ctx, ret->ctx, array->dtype);
});
return ret;
}
......@@ -182,8 +179,7 @@ void Scatter_(IdArray index, NDArray value, NDArray out) {
CHECK_SAME_CONTEXT(index, value);
CHECK_SAME_CONTEXT(index, out);
CHECK_EQ(value->shape[0], index->shape[0]);
if (index->shape[0] == 0)
return;
if (index->shape[0] == 0) return;
ATEN_XPU_SWITCH_CUDA(value->ctx.device_type, XPU, "Scatter_", {
ATEN_DTYPE_SWITCH(value->dtype, DType, "values", {
ATEN_ID_TYPE_SWITCH(index->dtype, IdType, {
......@@ -225,31 +221,25 @@ NDArray Concat(const std::vector<IdArray>& arrays) {
CHECK_SAME_CONTEXT(arrays[0], arrays[i]);
}
NDArray ret_arr = NDArray::Empty({len},
arrays[0]->dtype,
arrays[0]->ctx);
NDArray ret_arr = NDArray::Empty({len}, arrays[0]->dtype, arrays[0]->ctx);
auto device = runtime::DeviceAPI::Get(arrays[0]->ctx);
for (size_t i = 0; i < arrays.size(); ++i) {
ATEN_DTYPE_SWITCH(arrays[i]->dtype, DType, "array", {
device->CopyDataFromTo(
static_cast<DType*>(arrays[i]->data),
0,
static_cast<DType*>(ret_arr->data),
offset,
arrays[i]->shape[0] * sizeof(DType),
arrays[i]->ctx,
ret_arr->ctx,
arrays[i]->dtype);
offset += arrays[i]->shape[0] * sizeof(DType);
static_cast<DType*>(arrays[i]->data), 0,
static_cast<DType*>(ret_arr->data), offset,
arrays[i]->shape[0] * sizeof(DType), arrays[i]->ctx, ret_arr->ctx,
arrays[i]->dtype);
offset += arrays[i]->shape[0] * sizeof(DType);
});
}
return ret_arr;
}
template<typename ValueType>
template <typename ValueType>
std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, ValueType pad_value) {
std::tuple<NDArray, IdArray, IdArray> ret;
ATEN_XPU_SWITCH(array->ctx.device_type, XPU, "Pack", {
......@@ -262,8 +252,10 @@ std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, ValueType pad_value) {
template std::tuple<NDArray, IdArray, IdArray> Pack<int32_t>(NDArray, int32_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<int64_t>(NDArray, int64_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<uint32_t>(NDArray, uint32_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<uint64_t>(NDArray, uint64_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<uint32_t>(
NDArray, uint32_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<uint64_t>(
NDArray, uint64_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<float>(NDArray, float);
template std::tuple<NDArray, IdArray, IdArray> Pack<double>(NDArray, double);
......@@ -292,9 +284,8 @@ IdArray CumSum(IdArray array, bool prepend_zero) {
IdArray NonZero(NDArray array) {
IdArray ret;
ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "NonZero", {
ATEN_ID_TYPE_SWITCH(array->dtype, DType, {
ret = impl::NonZero<XPU, DType>(array);
});
ATEN_ID_TYPE_SWITCH(
array->dtype, DType, { ret = impl::NonZero<XPU, DType>(array); });
});
return ret;
}
......@@ -322,8 +313,7 @@ std::string ToDebugString(NDArray array) {
oss << a.Ptr<DType>()[i] << ", ";
}
});
if (a.NumElements() > 10)
oss << "...";
if (a.NumElements() > 10) oss << "...";
oss << "], dtype=" << array->dtype << ", ctx=" << array->ctx << ")";
return oss.str();
}
......@@ -396,8 +386,7 @@ NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
}
bool CSRIsSorted(CSRMatrix csr) {
if (csr.indices->shape[0] <= 1)
return true;
if (csr.indices->shape[0] <= 1) return true;
bool ret = false;
ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRIsSorted", {
ret = impl::CSRIsSorted<XPU, IdType>(csr);
......@@ -417,14 +406,16 @@ NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
}
template <typename DType>
NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, DType filler) {
NDArray CSRGetData(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, DType filler) {
NDArray ret;
CHECK_SAME_DTYPE(csr.indices, rows);
CHECK_SAME_DTYPE(csr.indices, cols);
CHECK_SAME_CONTEXT(rows, cols);
CHECK_SAME_CONTEXT(rows, weights);
ATEN_CSR_SWITCH_CUDA_UVA(csr, rows, XPU, IdType, "CSRGetData", {
ret = impl::CSRGetData<XPU, IdType, DType>(csr, rows, cols, weights, filler);
ret =
impl::CSRGetData<XPU, IdType, DType>(csr, rows, cols, weights, filler);
});
return ret;
}
......@@ -459,11 +450,12 @@ CSRMatrix CSRTranspose(CSRMatrix csr) {
COOMatrix CSRToCOO(CSRMatrix csr, bool data_as_order) {
COOMatrix ret;
if (data_as_order) {
ATEN_XPU_SWITCH_CUDA(csr.indptr->ctx.device_type, XPU, "CSRToCOODataAsOrder", {
ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, {
ret = impl::CSRToCOODataAsOrder<XPU, IdType>(csr);
});
});
ATEN_XPU_SWITCH_CUDA(
csr.indptr->ctx.device_type, XPU, "CSRToCOODataAsOrder", {
ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, {
ret = impl::CSRToCOODataAsOrder<XPU, IdType>(csr);
});
});
} else {
ATEN_XPU_SWITCH_CUDA(csr.indptr->ctx.device_type, XPU, "CSRToCOO", {
ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, {
......@@ -506,17 +498,16 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, NDArray rows, NDArray cols) {
}
void CSRSort_(CSRMatrix* csr) {
if (csr->sorted)
return;
ATEN_CSR_SWITCH_CUDA(*csr, XPU, IdType, "CSRSort_", {
impl::CSRSort_<XPU, IdType>(csr);
});
if (csr->sorted) return;
ATEN_CSR_SWITCH_CUDA(
*csr, XPU, IdType, "CSRSort_", { impl::CSRSort_<XPU, IdType>(csr); });
}
std::pair<CSRMatrix, NDArray> CSRSortByTag(
const CSRMatrix &csr, IdArray tag, int64_t num_tags) {
const CSRMatrix& csr, IdArray tag, int64_t num_tags) {
CHECK_EQ(csr.indices->shape[0], tag->shape[0])
<< "The length of the tag array should be equal to the number of non-zero data.";
<< "The length of the tag array should be equal to the number of "
"non-zero data.";
CHECK_SAME_CONTEXT(csr.indices, tag);
CHECK_INT(tag, "tag");
std::pair<CSRMatrix, NDArray> ret;
......@@ -528,7 +519,8 @@ std::pair<CSRMatrix, NDArray> CSRSortByTag(
return ret;
}
CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids) {
CSRMatrix CSRReorder(
CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids) {
CSRMatrix ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRReorder", {
ret = impl::CSRReorder<XPU, IdType>(csr, new_row_ids, new_col_ids);
......@@ -545,23 +537,26 @@ CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) {
}
COOMatrix CSRRowWiseSampling(
CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask, bool replace) {
CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
bool replace) {
COOMatrix ret;
if (IsNullArray(prob_or_mask)) {
ATEN_CSR_SWITCH_CUDA_UVA(mat, rows, XPU, IdType, "CSRRowWiseSamplingUniform", {
ret = impl::CSRRowWiseSamplingUniform<XPU, IdType>(mat, rows, num_samples, replace);
});
ATEN_CSR_SWITCH_CUDA_UVA(
mat, rows, XPU, IdType, "CSRRowWiseSamplingUniform", {
ret = impl::CSRRowWiseSamplingUniform<XPU, IdType>(
mat, rows, num_samples, replace);
});
} else {
// prob_or_mask is pinned and rows on GPU is valid
CHECK_VALID_CONTEXT(prob_or_mask, rows);
ATEN_CSR_SWITCH_CUDA_UVA(mat, rows, XPU, IdType, "CSRRowWiseSampling", {
CHECK(!(prob_or_mask->dtype.bits == 8 && XPU == kDGLCUDA)) <<
"GPU sampling with masks is currently not supported yet.";
CHECK(!(prob_or_mask->dtype.bits == 8 && XPU == kDGLCUDA))
<< "GPU sampling with masks is currently not supported yet.";
ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(
prob_or_mask->dtype, FloatType, "probability or mask", {
ret = impl::CSRRowWiseSampling<XPU, IdType, FloatType>(
mat, rows, num_samples, prob_or_mask, replace);
});
ret = impl::CSRRowWiseSampling<XPU, IdType, FloatType>(
mat, rows, num_samples, prob_or_mask, replace);
});
});
}
return ret;
......@@ -569,26 +564,28 @@ COOMatrix CSRRowWiseSampling(
COOMatrix CSRRowWisePerEtypeSampling(
CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples, const std::vector<NDArray>& prob_or_mask,
bool replace, bool rowwise_etype_sorted) {
const std::vector<int64_t>& num_samples,
const std::vector<NDArray>& prob_or_mask, bool replace,
bool rowwise_etype_sorted) {
COOMatrix ret;
CHECK(prob_or_mask.size() > 0) << "probability or mask array is empty";
ATEN_CSR_SWITCH(mat, XPU, IdType, "CSRRowWisePerEtypeSampling", {
if (std::all_of(prob_or_mask.begin(), prob_or_mask.end(), IsNullArray)) {
ret = impl::CSRRowWisePerEtypeSamplingUniform<XPU, IdType>(
mat, rows, eid2etype_offset, num_samples, replace, rowwise_etype_sorted);
mat, rows, eid2etype_offset, num_samples, replace,
rowwise_etype_sorted);
} else {
ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(
prob_or_mask[0]->dtype, DType, "probability or mask", {
ret = impl::CSRRowWisePerEtypeSampling<XPU, IdType, DType>(
mat, rows, eid2etype_offset, num_samples, prob_or_mask, replace, rowwise_etype_sorted);
});
ret = impl::CSRRowWisePerEtypeSampling<XPU, IdType, DType>(
mat, rows, eid2etype_offset, num_samples, prob_or_mask, replace,
rowwise_etype_sorted);
});
}
});
return ret;
}
COOMatrix CSRRowWiseTopk(
CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending) {
COOMatrix ret;
......@@ -602,16 +599,12 @@ COOMatrix CSRRowWiseTopk(
}
COOMatrix CSRRowWiseSamplingBiased(
CSRMatrix mat,
IdArray rows,
int64_t num_samples,
NDArray tag_offset,
FloatArray bias,
bool replace) {
CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray tag_offset,
FloatArray bias, bool replace) {
COOMatrix ret;
ATEN_CSR_SWITCH(mat, XPU, IdType, "CSRRowWiseSamplingBiased", {
ATEN_FLOAT_TYPE_SWITCH(bias->dtype, FloatType, "bias", {
ret = impl::CSRRowWiseSamplingBiased<XPU, IdType, FloatType>(
ret = impl::CSRRowWiseSamplingBiased<XPU, IdType, FloatType>(
mat, rows, num_samples, tag_offset, bias, replace);
});
});
......@@ -619,12 +612,8 @@ COOMatrix CSRRowWiseSamplingBiased(
}
std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
const CSRMatrix& csr,
int64_t num_samples,
int num_trials,
bool exclude_self_loops,
bool replace,
double redundancy) {
const CSRMatrix& csr, int64_t num_samples, int num_trials,
bool exclude_self_loops, bool replace, double redundancy) {
CHECK_GT(num_samples, 0) << "Number of samples must be positive";
CHECK_GT(num_trials, 0) << "Number of sampling trials must be positive";
std::pair<IdArray, IdArray> result;
......@@ -635,16 +624,16 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
return result;
}
CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs) {
CSRMatrix ret;
CHECK_GT(csrs.size(), 1) << "UnionCsr creates a union of multiple CSRMatrixes";
CHECK_GT(csrs.size(), 1)
<< "UnionCsr creates a union of multiple CSRMatrixes";
// sanity check
for (size_t i = 1; i < csrs.size(); ++i) {
CHECK_EQ(csrs[0].num_rows, csrs[i].num_rows) <<
"UnionCsr requires both CSRMatrix have same number of rows";
CHECK_EQ(csrs[0].num_cols, csrs[i].num_cols) <<
"UnionCsr requires both CSRMatrix have same number of cols";
CHECK_EQ(csrs[0].num_rows, csrs[i].num_rows)
<< "UnionCsr requires both CSRMatrix have same number of rows";
CHECK_EQ(csrs[0].num_cols, csrs[i].num_cols)
<< "UnionCsr requires both CSRMatrix have same number of cols";
CHECK_SAME_CONTEXT(csrs[0].indptr, csrs[i].indptr);
CHECK_SAME_DTYPE(csrs[0].indptr, csrs[i].indptr);
}
......@@ -655,9 +644,7 @@ CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs) {
return ret;
}
std::tuple<CSRMatrix, IdArray, IdArray>
CSRToSimple(const CSRMatrix& csr) {
std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(const CSRMatrix& csr) {
std::tuple<CSRMatrix, IdArray, IdArray> ret;
CSRMatrix sorted_csr = (CSRIsSorted(csr)) ? csr : CSRSort(csr);
......@@ -709,7 +696,8 @@ NDArray COOGetRowNNZ(COOMatrix coo, NDArray row) {
return ret;
}
std::pair<NDArray, NDArray> COOGetRowDataAndIndices(COOMatrix coo, int64_t row) {
std::pair<NDArray, NDArray> COOGetRowDataAndIndices(
COOMatrix coo, int64_t row) {
std::pair<NDArray, NDArray> ret;
ATEN_COO_SWITCH(coo, XPU, IdType, "COOGetRowDataAndIndices", {
ret = impl::COOGetRowDataAndIndices<XPU, IdType>(coo, row);
......@@ -741,9 +729,8 @@ COOMatrix COOTranspose(COOMatrix coo) {
CSRMatrix COOToCSR(COOMatrix coo) {
CSRMatrix ret;
ATEN_XPU_SWITCH_CUDA(coo.row->ctx.device_type, XPU, "COOToCSR", {
ATEN_ID_TYPE_SWITCH(coo.row->dtype, IdType, {
ret = impl::COOToCSR<XPU, IdType>(coo);
});
ATEN_ID_TYPE_SWITCH(
coo.row->dtype, IdType, { ret = impl::COOToCSR<XPU, IdType>(coo); });
});
return ret;
}
......@@ -773,8 +760,7 @@ COOMatrix COOSliceMatrix(COOMatrix coo, NDArray rows, NDArray cols) {
}
void COOSort_(COOMatrix* mat, bool sort_column) {
if ((mat->row_sorted && !sort_column) || mat->col_sorted)
return;
if ((mat->row_sorted && !sort_column) || mat->col_sorted) return;
ATEN_XPU_SWITCH_CUDA(mat->row->ctx.device_type, XPU, "COOSort_", {
ATEN_ID_TYPE_SWITCH(mat->row->dtype, IdType, {
impl::COOSort_<XPU, IdType>(mat, sort_column);
......@@ -783,8 +769,7 @@ void COOSort_(COOMatrix* mat, bool sort_column) {
}
std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
if (coo.row->shape[0] <= 1)
return {true, true};
if (coo.row->shape[0] <= 1) return {true, true};
std::pair<bool, bool> ret;
ATEN_COO_SWITCH_CUDA(coo, XPU, IdType, "COOIsSorted", {
ret = impl::COOIsSorted<XPU, IdType>(coo);
......@@ -792,7 +777,8 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
return ret;
}
COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids) {
COOMatrix COOReorder(
COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids) {
COOMatrix ret;
ATEN_COO_SWITCH(coo, XPU, IdType, "COOReorder", {
ret = impl::COOReorder<XPU, IdType>(coo, new_row_ids, new_col_ids);
......@@ -809,17 +795,19 @@ COOMatrix COORemove(COOMatrix coo, IdArray entries) {
}
COOMatrix COORowWiseSampling(
COOMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask, bool replace) {
COOMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
bool replace) {
COOMatrix ret;
ATEN_COO_SWITCH(mat, XPU, IdType, "COORowWiseSampling", {
if (IsNullArray(prob_or_mask)) {
ret = impl::COORowWiseSamplingUniform<XPU, IdType>(mat, rows, num_samples, replace);
ret = impl::COORowWiseSamplingUniform<XPU, IdType>(
mat, rows, num_samples, replace);
} else {
ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(
prob_or_mask->dtype, DType, "probability or mask", {
ret = impl::COORowWiseSampling<XPU, IdType, DType>(
mat, rows, num_samples, prob_or_mask, replace);
});
ret = impl::COORowWiseSampling<XPU, IdType, DType>(
mat, rows, num_samples, prob_or_mask, replace);
});
}
});
return ret;
......@@ -827,20 +815,21 @@ COOMatrix COORowWiseSampling(
COOMatrix COORowWisePerEtypeSampling(
COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples, const std::vector<NDArray>& prob_or_mask,
bool replace) {
const std::vector<int64_t>& num_samples,
const std::vector<NDArray>& prob_or_mask, bool replace) {
COOMatrix ret;
CHECK(prob_or_mask.size() > 0) << "probability or mask array is empty";
ATEN_COO_SWITCH(mat, XPU, IdType, "COORowWisePerEtypeSampling", {
if (std::all_of(prob_or_mask.begin(), prob_or_mask.end(), IsNullArray)) {
ret = impl::COORowWisePerEtypeSamplingUniform<XPU, IdType>(
mat, rows, eid2etype_offset, num_samples, replace);
mat, rows, eid2etype_offset, num_samples, replace);
} else {
ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(
prob_or_mask[0]->dtype, DType, "probability or mask", {
ret = impl::COORowWisePerEtypeSampling<XPU, IdType, DType>(
mat, rows, eid2etype_offset, num_samples, prob_or_mask, replace);
});
ret = impl::COORowWisePerEtypeSampling<XPU, IdType, DType>(
mat, rows, eid2etype_offset, num_samples, prob_or_mask,
replace);
});
}
});
return ret;
......@@ -876,7 +865,7 @@ COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
return ret;
}
COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking) {
COOMatrix COOLineGraph(const COOMatrix& coo, bool backtracking) {
COOMatrix ret;
ATEN_COO_SWITCH(coo, XPU, IdType, "COOLineGraph", {
ret = impl::COOLineGraph<XPU, IdType>(coo, backtracking);
......@@ -886,13 +875,14 @@ COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking) {
COOMatrix UnionCoo(const std::vector<COOMatrix>& coos) {
COOMatrix ret;
CHECK_GT(coos.size(), 1) << "UnionCoo creates a union of multiple COOMatrixes";
CHECK_GT(coos.size(), 1)
<< "UnionCoo creates a union of multiple COOMatrixes";
// sanity check
for (size_t i = 1; i < coos.size(); ++i) {
CHECK_EQ(coos[0].num_rows, coos[i].num_rows) <<
"UnionCoo requires both COOMatrix have same number of rows";
CHECK_EQ(coos[0].num_cols, coos[i].num_cols) <<
"UnionCoo requires both COOMatrix have same number of cols";
CHECK_EQ(coos[0].num_rows, coos[i].num_rows)
<< "UnionCoo requires both COOMatrix have same number of rows";
CHECK_EQ(coos[0].num_cols, coos[i].num_cols)
<< "UnionCoo requires both COOMatrix have same number of cols";
CHECK_SAME_CONTEXT(coos[0].row, coos[i].row);
CHECK_SAME_DTYPE(coos[0].row, coos[i].row);
}
......@@ -914,20 +904,19 @@ COOMatrix UnionCoo(const std::vector<COOMatrix>& coos) {
if (has_data) {
std::vector<IdArray> eid_data;
eid_data.push_back(COOHasData(coos[0]) ?
coos[0].data :
Range(0,
coos[0].row->shape[0],
coos[0].row->dtype.bits,
coos[0].row->ctx));
eid_data.push_back(
COOHasData(coos[0]) ? coos[0].data
: Range(
0, coos[0].row->shape[0],
coos[0].row->dtype.bits, coos[0].row->ctx));
int64_t num_edges = coos[0].row->shape[0];
for (size_t i = 1; i < coos.size(); ++i) {
eid_data.push_back(COOHasData(coos[i]) ?
coos[i].data + num_edges :
Range(num_edges,
num_edges + coos[i].row->shape[0],
coos[i].row->dtype.bits,
coos[i].row->ctx));
eid_data.push_back(
COOHasData(coos[i])
? coos[i].data + num_edges
: Range(
num_edges, num_edges + coos[i].row->shape[0],
coos[i].row->dtype.bits, coos[i].row->ctx));
num_edges += coos[i].row->shape[0];
}
......@@ -935,71 +924,62 @@ COOMatrix UnionCoo(const std::vector<COOMatrix>& coos) {
}
return COOMatrix(
coos[0].num_rows,
coos[0].num_cols,
row,
col,
data,
false,
false);
coos[0].num_rows, coos[0].num_cols, row, col, data, false, false);
}
std::tuple<COOMatrix, IdArray, IdArray>
COOToSimple(const COOMatrix& coo) {
std::tuple<COOMatrix, IdArray, IdArray> COOToSimple(const COOMatrix& coo) {
// coo column sorted
const COOMatrix sorted_coo = COOSort(coo, true);
const IdArray eids_shuffled = COOHasData(sorted_coo) ?
sorted_coo.data :
Range(0, sorted_coo.row->shape[0], sorted_coo.row->dtype.bits, sorted_coo.row->ctx);
const auto &coalesced_result = COOCoalesce(sorted_coo);
const COOMatrix &coalesced_adj = coalesced_result.first;
const IdArray &count = coalesced_result.second;
const IdArray eids_shuffled =
COOHasData(sorted_coo)
? sorted_coo.data
: Range(
0, sorted_coo.row->shape[0], sorted_coo.row->dtype.bits,
sorted_coo.row->ctx);
const auto& coalesced_result = COOCoalesce(sorted_coo);
const COOMatrix& coalesced_adj = coalesced_result.first;
const IdArray& count = coalesced_result.second;
/**
* eids_shuffled actually already contains the mapping from old edge space to the
* new one:
* eids_shuffled actually already contains the mapping from old edge space to
* the new one:
*
* * eids_shuffled[0:count[0]] indicates the original edge IDs that coalesced into new
* edge #0.
* * eids_shuffled[count[0]:count[0] + count[1]] indicates those that coalesced into
* new edge #1.
* * eids_shuffled[count[0] + count[1]:count[0] + count[1] + count[2]] indicates those
* that coalesced into new edge #2.
* * eids_shuffled[0:count[0]] indicates the original edge IDs that coalesced
* into new edge #0.
* * eids_shuffled[count[0]:count[0] + count[1]] indicates those that
* coalesced into new edge #1.
* * eids_shuffled[count[0] + count[1]:count[0] + count[1] + count[2]]
* indicates those that coalesced into new edge #2.
* * etc.
*
* Here, we need to translate eids_shuffled to an array "eids_remapped" such that
* eids_remapped[i] indicates the new edge ID the old edge #i is mapped to. The
* translation can simply be achieved by (in numpy code):
* Here, we need to translate eids_shuffled to an array "eids_remapped" such
* that eids_remapped[i] indicates the new edge ID the old edge #i is mapped
* to. The translation can simply be achieved by (in numpy code):
*
* new_eid_for_eids_shuffled = np.range(len(count)).repeat(count)
* eids_remapped = np.zeros_like(new_eid_for_eids_shuffled)
* eids_remapped[eids_shuffled] = new_eid_for_eids_shuffled
*/
const IdArray new_eids = Range(
0, coalesced_adj.row->shape[0], coalesced_adj.row->dtype.bits, coalesced_adj.row->ctx);
0, coalesced_adj.row->shape[0], coalesced_adj.row->dtype.bits,
coalesced_adj.row->ctx);
const IdArray eids_remapped = Scatter(Repeat(new_eids, count), eids_shuffled);
COOMatrix ret = COOMatrix(
coalesced_adj.num_rows,
coalesced_adj.num_cols,
coalesced_adj.row,
coalesced_adj.col,
NullArray(),
true,
true);
coalesced_adj.num_rows, coalesced_adj.num_cols, coalesced_adj.row,
coalesced_adj.col, NullArray(), true, true);
return std::make_tuple(ret, count, eids_remapped);
}
///////////////////////// Graph Traverse routines //////////////////////////
Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) {
Frontiers ret;
CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type) <<
"Graph and source should in the same device context";
CHECK_EQ(csr.indices->dtype, source->dtype) <<
"Graph and source should in the same dtype";
CHECK_EQ(csr.num_rows, csr.num_cols) <<
"Graph traversal can only work on square-shaped CSR.";
CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type)
<< "Graph and source should in the same device context";
CHECK_EQ(csr.indices->dtype, source->dtype)
<< "Graph and source should in the same dtype";
CHECK_EQ(csr.num_rows, csr.num_cols)
<< "Graph traversal can only work on square-shaped CSR.";
ATEN_XPU_SWITCH(source->ctx.device_type, XPU, "BFSNodesFrontiers", {
ATEN_ID_TYPE_SWITCH(source->dtype, IdType, {
ret = impl::BFSNodesFrontiers<XPU, IdType>(csr, source);
......@@ -1010,12 +990,12 @@ Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) {
Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) {
Frontiers ret;
CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type) <<
"Graph and source should in the same device context";
CHECK_EQ(csr.indices->dtype, source->dtype) <<
"Graph and source should in the same dtype";
CHECK_EQ(csr.num_rows, csr.num_cols) <<
"Graph traversal can only work on square-shaped CSR.";
CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type)
<< "Graph and source should in the same device context";
CHECK_EQ(csr.indices->dtype, source->dtype)
<< "Graph and source should in the same dtype";
CHECK_EQ(csr.num_rows, csr.num_cols)
<< "Graph traversal can only work on square-shaped CSR.";
ATEN_XPU_SWITCH(source->ctx.device_type, XPU, "BFSEdgesFrontiers", {
ATEN_ID_TYPE_SWITCH(source->dtype, IdType, {
ret = impl::BFSEdgesFrontiers<XPU, IdType>(csr, source);
......@@ -1026,24 +1006,25 @@ Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) {
Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr) {
Frontiers ret;
CHECK_EQ(csr.num_rows, csr.num_cols) <<
"Graph traversal can only work on square-shaped CSR.";
ATEN_XPU_SWITCH(csr.indptr->ctx.device_type, XPU, "TopologicalNodesFrontiers", {
ATEN_ID_TYPE_SWITCH(csr.indices->dtype, IdType, {
ret = impl::TopologicalNodesFrontiers<XPU, IdType>(csr);
});
});
CHECK_EQ(csr.num_rows, csr.num_cols)
<< "Graph traversal can only work on square-shaped CSR.";
ATEN_XPU_SWITCH(
csr.indptr->ctx.device_type, XPU, "TopologicalNodesFrontiers", {
ATEN_ID_TYPE_SWITCH(csr.indices->dtype, IdType, {
ret = impl::TopologicalNodesFrontiers<XPU, IdType>(csr);
});
});
return ret;
}
Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) {
Frontiers ret;
CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type) <<
"Graph and source should in the same device context";
CHECK_EQ(csr.indices->dtype, source->dtype) <<
"Graph and source should in the same dtype";
CHECK_EQ(csr.num_rows, csr.num_cols) <<
"Graph traversal can only work on square-shaped CSR.";
CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type)
<< "Graph and source should in the same device context";
CHECK_EQ(csr.indices->dtype, source->dtype)
<< "Graph and source should in the same dtype";
CHECK_EQ(csr.num_rows, csr.num_cols)
<< "Graph traversal can only work on square-shaped CSR.";
ATEN_XPU_SWITCH(source->ctx.device_type, XPU, "DGLDFSEdges", {
ATEN_ID_TYPE_SWITCH(source->dtype, IdType, {
ret = impl::DGLDFSEdges<XPU, IdType>(csr, source);
......@@ -1052,104 +1033,97 @@ Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) {
return ret;
}
Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
IdArray source,
const bool has_reverse_edge,
const bool has_nontree_edge,
const bool return_labels) {
Frontiers DGLDFSLabeledEdges(
const CSRMatrix& csr, IdArray source, const bool has_reverse_edge,
const bool has_nontree_edge, const bool return_labels) {
Frontiers ret;
CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type) <<
"Graph and source should in the same device context";
CHECK_EQ(csr.indices->dtype, source->dtype) <<
"Graph and source should in the same dtype";
CHECK_EQ(csr.num_rows, csr.num_cols) <<
"Graph traversal can only work on square-shaped CSR.";
CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type)
<< "Graph and source should in the same device context";
CHECK_EQ(csr.indices->dtype, source->dtype)
<< "Graph and source should in the same dtype";
CHECK_EQ(csr.num_rows, csr.num_cols)
<< "Graph traversal can only work on square-shaped CSR.";
ATEN_XPU_SWITCH(source->ctx.device_type, XPU, "DGLDFSLabeledEdges", {
ATEN_ID_TYPE_SWITCH(source->dtype, IdType, {
ret = impl::DGLDFSLabeledEdges<XPU, IdType>(csr,
source,
has_reverse_edge,
has_nontree_edge,
return_labels);
ret = impl::DGLDFSLabeledEdges<XPU, IdType>(
csr, source, has_reverse_edge, has_nontree_edge, return_labels);
});
});
return ret;
}
///////////////////////// C APIs /////////////////////////
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetFormat")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
SparseMatrixRef spmat = args[0];
*rv = spmat->format;
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
SparseMatrixRef spmat = args[0];
*rv = spmat->format;
});
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetNumRows")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
SparseMatrixRef spmat = args[0];
*rv = spmat->num_rows;
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
SparseMatrixRef spmat = args[0];
*rv = spmat->num_rows;
});
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetNumCols")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
SparseMatrixRef spmat = args[0];
*rv = spmat->num_cols;
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
SparseMatrixRef spmat = args[0];
*rv = spmat->num_cols;
});
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetIndices")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
SparseMatrixRef spmat = args[0];
const int64_t i = args[1];
*rv = spmat->indices[i];
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
SparseMatrixRef spmat = args[0];
const int64_t i = args[1];
*rv = spmat->indices[i];
});
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetFlags")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
SparseMatrixRef spmat = args[0];
List<Value> flags;
for (bool flg : spmat->flags) {
flags.push_back(Value(MakeValue(flg)));
}
*rv = flags;
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
SparseMatrixRef spmat = args[0];
List<Value> flags;
for (bool flg : spmat->flags) {
flags.push_back(Value(MakeValue(flg)));
}
*rv = flags;
});
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLCreateSparseMatrix")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const int32_t format = args[0];
const int64_t nrows = args[1];
const int64_t ncols = args[2];
const List<Value> indices = args[3];
const List<Value> flags = args[4];
std::shared_ptr<SparseMatrix> spmat(new SparseMatrix(
format, nrows, ncols,
ListValueToVector<IdArray>(indices),
.set_body([](DGLArgs args, DGLRetValue* rv) {
const int32_t format = args[0];
const int64_t nrows = args[1];
const int64_t ncols = args[2];
const List<Value> indices = args[3];
const List<Value> flags = args[4];
std::shared_ptr<SparseMatrix> spmat(new SparseMatrix(
format, nrows, ncols, ListValueToVector<IdArray>(indices),
ListValueToVector<bool>(flags)));
*rv = SparseMatrixRef(spmat);
});
*rv = SparseMatrixRef(spmat);
});
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLExistSharedMemArray")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const std::string name = args[0];
.set_body([](DGLArgs args, DGLRetValue* rv) {
const std::string name = args[0];
#ifndef _WIN32
*rv = SharedMemory::Exist(name);
*rv = SharedMemory::Exist(name);
#else
*rv = false;
*rv = false;
#endif // _WIN32
});
});
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLArrayCastToSigned")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
NDArray array = args[0];
CHECK_EQ(array->dtype.code, kDGLUInt);
std::vector<int64_t> shape(array->shape, array->shape + array->ndim);
DGLDataType dtype = array->dtype;
dtype.code = kDGLInt;
*rv = array.CreateView(shape, dtype, 0);
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
NDArray array = args[0];
CHECK_EQ(array->dtype.code, kDGLUInt);
std::vector<int64_t> shape(array->shape, array->shape + array->ndim);
DGLDataType dtype = array->dtype;
dtype.code = kDGLInt;
*rv = array.CreateView(shape, dtype, 0);
});
} // namespace aten
} // namespace dgl
std::ostream& operator << (std::ostream& os, dgl::runtime::NDArray array) {
std::ostream& operator<<(std::ostream& os, dgl::runtime::NDArray array) {
return os << dgl::aten::ToDebugString(array);
}
......@@ -8,9 +8,10 @@
#include <dgl/array.h>
#include <dgl/graph_traversal.h>
#include <vector>
#include <tuple>
#include <utility>
#include <vector>
namespace dgl {
namespace aten {
......@@ -82,7 +83,8 @@ template <DGLDeviceType XPU, typename IdType>
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col);
template <DGLDeviceType XPU, typename IdType>
runtime::NDArray CSRIsNonZero(CSRMatrix csr, runtime::NDArray row, runtime::NDArray col);
runtime::NDArray CSRIsNonZero(
CSRMatrix csr, runtime::NDArray row, runtime::NDArray col);
template <DGLDeviceType XPU, typename IdType>
bool CSRHasDuplicate(CSRMatrix csr);
......@@ -104,19 +106,21 @@ bool CSRIsSorted(CSRMatrix csr);
template <DGLDeviceType XPU, typename IdType, typename DType>
runtime::NDArray CSRGetData(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols, bool return_eids,
runtime::NDArray weights, DType filler);
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols,
bool return_eids, runtime::NDArray weights, DType filler);
template <DGLDeviceType XPU, typename IdType, typename DType>
runtime::NDArray CSRGetData(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols,
runtime::NDArray weights, DType filler) {
return CSRGetData<XPU, IdType, DType>(csr, rows, cols, false, weights, filler);
return CSRGetData<XPU, IdType, DType>(
csr, rows, cols, false, weights, filler);
}
template <DGLDeviceType XPU, typename IdType>
NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
return CSRGetData<XPU, IdType, IdType>(csr, rows, cols, true, NullArray(rows->dtype), -1);
return CSRGetData<XPU, IdType, IdType>(
csr, rows, cols, true, NullArray(rows->dtype), -1);
}
template <DGLDeviceType XPU, typename IdType>
......@@ -141,20 +145,23 @@ template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows);
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
CSRMatrix CSRSliceMatrix(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
template <DGLDeviceType XPU, typename IdType>
void CSRSort_(CSRMatrix* csr);
template <DGLDeviceType XPU, typename IdType, typename TagType>
std::pair<CSRMatrix, NDArray> CSRSortByTag(
const CSRMatrix &csr, IdArray tag_array, int64_t num_tags);
const CSRMatrix& csr, IdArray tag_array, int64_t num_tags);
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
CSRMatrix CSRReorder(
CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
template <DGLDeviceType XPU, typename IdType>
COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
COOMatrix COOReorder(
COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries);
......@@ -162,15 +169,16 @@ CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries);
// FloatType is the type of probability data.
template <DGLDeviceType XPU, typename IdType, typename DType>
COOMatrix CSRRowWiseSampling(
CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask, bool replace);
CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
bool replace);
// FloatType is the type of probability data.
template <DGLDeviceType XPU, typename IdType, typename DType>
COOMatrix CSRRowWisePerEtypeSampling(
CSRMatrix mat, IdArray rows,
const std::vector<int64_t>& eid2etype_offset,
CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples,
const std::vector<NDArray>& prob_or_mask, bool replace, bool rowwise_etype_sorted);
const std::vector<NDArray>& prob_or_mask, bool replace,
bool rowwise_etype_sorted);
template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRRowWiseSamplingUniform(
......@@ -178,9 +186,9 @@ COOMatrix CSRRowWiseSamplingUniform(
template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRRowWisePerEtypeSamplingUniform(
CSRMatrix mat, IdArray rows,
const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples, bool replace, bool rowwise_etype_sorted);
CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples, bool replace,
bool rowwise_etype_sorted);
// FloatType is the type of weight data.
template <DGLDeviceType XPU, typename IdType, typename DType>
......@@ -189,21 +197,13 @@ COOMatrix CSRRowWiseTopk(
template <DGLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix CSRRowWiseSamplingBiased(
CSRMatrix mat,
IdArray rows,
int64_t num_samples,
NDArray tag_offset,
FloatArray bias,
bool replace);
CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray tag_offset,
FloatArray bias, bool replace);
template <DGLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
const CSRMatrix& csr,
int64_t num_samples,
int num_trials,
bool exclude_self_loops,
bool replace,
double redundancy);
const CSRMatrix& csr, int64_t num_samples, int num_trials,
bool exclude_self_loops, bool replace, double redundancy);
// Union CSRMatrixes
template <DGLDeviceType XPU, typename IdType>
......@@ -218,7 +218,8 @@ template <DGLDeviceType XPU, typename IdType>
bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col);
template <DGLDeviceType XPU, typename IdType>
runtime::NDArray COOIsNonZero(COOMatrix coo, runtime::NDArray row, runtime::NDArray col);
runtime::NDArray COOIsNonZero(
COOMatrix coo, runtime::NDArray row, runtime::NDArray col);
template <DGLDeviceType XPU, typename IdType>
bool COOHasDuplicate(COOMatrix coo);
......@@ -230,15 +231,16 @@ template <DGLDeviceType XPU, typename IdType>
runtime::NDArray COOGetRowNNZ(COOMatrix coo, runtime::NDArray row);
template <DGLDeviceType XPU, typename IdType>
std::pair<runtime::NDArray, runtime::NDArray>
COOGetRowDataAndIndices(COOMatrix coo, int64_t row);
std::pair<runtime::NDArray, runtime::NDArray> COOGetRowDataAndIndices(
COOMatrix coo, int64_t row);
template <DGLDeviceType XPU, typename IdType>
std::vector<runtime::NDArray> COOGetDataAndIndices(
COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
template <DGLDeviceType XPU, typename IdType>
runtime::NDArray COOGetData(COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols);
runtime::NDArray COOGetData(
COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols);
template <DGLDeviceType XPU, typename IdType>
COOMatrix COOTranspose(COOMatrix coo);
......@@ -253,7 +255,8 @@ template <DGLDeviceType XPU, typename IdType>
COOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows);
template <DGLDeviceType XPU, typename IdType>
COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
COOMatrix COOSliceMatrix(
COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
template <DGLDeviceType XPU, typename IdType>
std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo);
......@@ -273,13 +276,13 @@ COOMatrix COORemove(COOMatrix coo, IdArray entries);
// FloatType is the type of probability data.
template <DGLDeviceType XPU, typename IdType, typename DType>
COOMatrix COORowWiseSampling(
COOMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask, bool replace);
COOMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
bool replace);
// FloatType is the type of probability data.
template <DGLDeviceType XPU, typename IdType, typename DType>
COOMatrix COORowWisePerEtypeSampling(
COOMatrix mat, IdArray rows,
const std::vector<int64_t>& eid2etype_offset,
COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples,
const std::vector<NDArray>& prob_or_mask, bool replace);
......@@ -289,8 +292,7 @@ COOMatrix COORowWiseSamplingUniform(
template <DGLDeviceType XPU, typename IdType>
COOMatrix COORowWisePerEtypeSamplingUniform(
COOMatrix mat, IdArray rows,
const std::vector<int64_t>& eid2etype_offset,
COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples, bool replace);
// FloatType is the type of weight data.
......@@ -313,14 +315,12 @@ template <DGLDeviceType XPU, typename IdType>
Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source);
template <DGLDeviceType XPU, typename IdType>
Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
IdArray source,
const bool has_reverse_edge,
const bool has_nontree_edge,
const bool return_labels);
Frontiers DGLDFSLabeledEdges(
const CSRMatrix& csr, IdArray source, const bool has_reverse_edge,
const bool has_nontree_edge, const bool return_labels);
template <DGLDeviceType XPU, typename IdType>
COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking);
COOMatrix COOLineGraph(const COOMatrix& coo, bool backtracking);
} // namespace impl
} // namespace aten
......
/**
/**
* Copyright (c) 2020 by Contributors
* @file kernel/cpu/gaher_mm.cc
* @brief GatherMM C APIs and definitions.
*/
#include "./gather_mm.h"
#include <dgl/array.h>
namespace dgl {
......@@ -11,81 +12,72 @@ namespace aten {
/** @brief Generalized SegmentMM. */
template <int XPU, typename IdType, typename DType>
void SegmentMM(const NDArray A,
const NDArray B,
NDArray C,
const NDArray seglen_A,
bool a_trans, bool b_trans) {
LOG(FATAL) << "Unsupported CPU kernel for SegmentMM.";
void SegmentMM(
const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
bool a_trans, bool b_trans) {
LOG(FATAL) << "Unsupported CPU kernel for SegmentMM.";
}
template <int XPU, typename IdType, typename DType>
void SegmentMMBackwardB(const NDArray A,
const NDArray dC,
NDArray dB,
const NDArray seglen) {
LOG(FATAL) << "Unsupported CPU kernel for SegmentMMBackwardB.";
void SegmentMMBackwardB(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen) {
LOG(FATAL) << "Unsupported CPU kernel for SegmentMMBackwardB.";
}
/** @brief Generalized GatherMM. */
template <int XPU, typename IdType, typename DType>
void GatherMM(const NDArray A,
const NDArray B,
NDArray C,
const NDArray idx_a,
const NDArray idx_b) {
LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
void GatherMM(
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b) {
LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
}
/** @brief Generalized GatherMM_scatter. */
template <int XPU, typename IdType, typename DType>
void GatherMMScatter(const NDArray A,
const NDArray B,
NDArray C,
const NDArray idx_a,
const NDArray idx_b,
const NDArray idx_c) {
LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
void GatherMMScatter(
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_c) {
LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
}
template void GatherMM<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b);
template void GatherMM<kDGLCPU, int64_t, float>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b);
template void GatherMM<kDGLCPU, int32_t, double>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b);
template void GatherMM<kDGLCPU, int64_t, double>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b);
template void GatherMMScatter<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCPU, int64_t, float>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCPU, int32_t, double>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCPU, int64_t, double>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_c);
template void SegmentMM<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
bool a_trans, bool b_trans);
template void SegmentMM<kDGLCPU, int64_t, float>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
bool a_trans, bool b_trans);
template void SegmentMM<kDGLCPU, int32_t, double>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
bool a_trans, bool b_trans);
template void SegmentMM<kDGLCPU, int64_t, double>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
bool a_trans, bool b_trans);
template void SegmentMMBackwardB<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
......
......@@ -7,13 +7,14 @@
#define DGL_ARRAY_CPU_ROWWISE_PICK_H_
#include <dgl/array.h>
#include <dmlc/omp.h>
#include <dgl/runtime/parallel_for.h>
#include <functional>
#include <dmlc/omp.h>
#include <algorithm>
#include <functional>
#include <memory>
#include <string>
#include <vector>
#include <memory>
namespace dgl {
namespace aten {
......@@ -41,10 +42,10 @@ namespace impl {
template <typename IdxType>
using PickFn = std::function<void(
IdxType rowid, IdxType off, IdxType len, IdxType num_picks,
const IdxType* col, const IdxType* data,
IdxType* out_idx)>;
const IdxType* col, const IdxType* data, IdxType* out_idx)>;
// User-defined function for determining the number of elements to pick from one row.
// User-defined function for determining the number of elements to pick from one
// row.
//
// The column indices of the given row are stored in
// [col + off, col + off + len)
......@@ -63,8 +64,8 @@ using PickFn = std::function<void(
// \param data Pointer of the data indices.
template <typename IdxType>
using NumPicksFn = std::function<IdxType(
IdxType rowid, IdxType off, IdxType len,
const IdxType* col, const IdxType* data)>;
IdxType rowid, IdxType off, IdxType len, const IdxType* col,
const IdxType* data)>;
// User-defined function for picking elements from a range within a row.
//
......@@ -73,7 +74,8 @@ using NumPicksFn = std::function<IdxType(
//
// Similarly, the data indices are stored in
// data[off+et_idx[et_offset+i])]
// Data index pointer could be NULL, which means data[i] == off+et_idx[et_offset+i])
// Data index pointer could be NULL, which means data[i] ==
// off+et_idx[et_offset+i])
//
// *ATTENTION*: This function will be invoked concurrently. Please make sure
// it is thread-safe.
......@@ -93,15 +95,17 @@ using EtypeRangePickFn = std::function<void(
const IdxType* eid, IdxType* out_idx)>;
// Template for picking non-zero values row-wise. The implementation utilizes
// OpenMP parallelization on rows because each row performs computation independently.
// OpenMP parallelization on rows because each row performs computation
// independently.
template <typename IdxType>
COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
int64_t num_picks, bool replace, PickFn<IdxType> pick_fn,
NumPicksFn<IdxType> num_picks_fn) {
COOMatrix CSRRowWisePick(
CSRMatrix mat, IdArray rows, int64_t num_picks, bool replace,
PickFn<IdxType> pick_fn, NumPicksFn<IdxType> num_picks_fn) {
using namespace aten;
const IdxType* indptr = static_cast<IdxType*>(mat.indptr->data);
const IdxType* indices = static_cast<IdxType*>(mat.indices->data);
const IdxType* data = CSRHasData(mat)? static_cast<IdxType*>(mat.data->data) : nullptr;
const IdxType* data =
CSRHasData(mat) ? static_cast<IdxType*>(mat.data->data) : nullptr;
const IdxType* rows_data = static_cast<IdxType*>(rows->data);
const int64_t num_rows = rows->shape[0];
const auto& ctx = mat.indptr->ctx;
......@@ -115,30 +119,36 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
// have at least num_picks number of nnz when replace is false.
//
// If the check holds, remove -1 elements by remove_if operation, which simply
// moves valid elements to the head of arrays and create a view of the original
// array. The implementation consumes a little extra memory than the actual requirement.
// moves valid elements to the head of arrays and create a view of the
// original array. The implementation consumes a little extra memory than the
// actual requirement.
//
// Otherwise, directly use the row and col arrays to construct the result COO matrix.
// Otherwise, directly use the row and col arrays to construct the result COO
// matrix.
//
// [02/29/2020 update]: OMP is disabled for now since batch-wise parallelism is more
// [02/29/2020 update]: OMP is disabled for now since batch-wise parallelism
// is more
// significant. (minjie)
// Do not use omp_get_max_threads() since that doesn't work for compiling without OpenMP.
// Do not use omp_get_max_threads() since that doesn't work for compiling
// without OpenMP.
const int num_threads = runtime::compute_num_threads(0, num_rows, 1);
std::vector<int64_t> global_prefix(num_threads + 1, 0);
// TODO(BarclayII) Using OMP parallel directly instead of using runtime::parallel_for
// does not handle exceptions well (directly aborts when an exception pops up).
// It runs faster though because there is less scheduling. Need to handle
// exceptions better.
// TODO(BarclayII) Using OMP parallel directly instead of using
// runtime::parallel_for does not handle exceptions well (directly aborts when
// an exception pops up). It runs faster though because there is less
// scheduling. Need to handle exceptions better.
IdArray picked_row, picked_col, picked_idx;
#pragma omp parallel num_threads(num_threads)
{
const int thread_id = omp_get_thread_num();
const int64_t start_i = thread_id * (num_rows/num_threads) +
const int64_t start_i =
thread_id * (num_rows / num_threads) +
std::min(static_cast<int64_t>(thread_id), num_rows % num_threads);
const int64_t end_i = (thread_id + 1) * (num_rows/num_threads) +
const int64_t end_i =
(thread_id + 1) * (num_rows / num_threads) +
std::min(static_cast<int64_t>(thread_id + 1), num_rows % num_threads);
assert(thread_id + 1 < num_threads || end_i == num_rows);
......@@ -149,7 +159,7 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
local_prefix[0] = 0;
for (int64_t i = start_i; i < end_i; ++i) {
// build prefix-sum
const int64_t local_i = i-start_i;
const int64_t local_i = i - start_i;
const IdxType rid = rows_data[i];
IdxType len = num_picks_fn(
rid, indptr[rid], indptr[rid + 1] - indptr[rid], indices, data);
......@@ -157,8 +167,8 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
}
global_prefix[thread_id + 1] = local_prefix[num_local];
#pragma omp barrier
#pragma omp master
#pragma omp barrier
#pragma omp master
{
for (int t = 0; t < num_threads; ++t) {
global_prefix[t + 1] += global_prefix[t];
......@@ -168,7 +178,7 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
picked_idx = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);
}
#pragma omp barrier
#pragma omp barrier
IdxType* picked_rdata = picked_row.Ptr<IdxType>();
IdxType* picked_cdata = picked_col.Ptr<IdxType>();
IdxType* picked_idata = picked_idx.Ptr<IdxType>();
......@@ -180,14 +190,15 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
const IdxType off = indptr[rid];
const IdxType len = indptr[rid + 1] - off;
if (len == 0)
continue;
if (len == 0) continue;
const int64_t local_i = i - start_i;
const int64_t row_offset = thread_offset + local_prefix[local_i];
const int64_t num_picks = thread_offset + local_prefix[local_i + 1] - row_offset;
const int64_t num_picks =
thread_offset + local_prefix[local_i + 1] - row_offset;
pick_fn(rid, off, len, num_picks, indices, data, picked_idata + row_offset);
pick_fn(
rid, off, len, num_picks, indices, data, picked_idata + row_offset);
for (int64_t j = 0; j < num_picks; ++j) {
const IdxType picked = picked_idata[row_offset + j];
picked_rdata[row_offset + j] = rid;
......@@ -200,25 +211,25 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
const int64_t new_len = global_prefix.back();
return COOMatrix(
mat.num_rows,
mat.num_cols,
mat.num_rows, mat.num_cols,
picked_row.CreateView({new_len}, picked_row->dtype),
picked_col.CreateView({new_len}, picked_row->dtype),
picked_idx.CreateView({new_len}, picked_row->dtype));
}
// Template for picking non-zero values row-wise. The implementation utilizes
// OpenMP parallelization on rows because each row performs computation independently.
// OpenMP parallelization on rows because each row performs computation
// independently.
template <typename IdxType, typename DType>
COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows,
const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_picks, bool replace,
bool rowwise_etype_sorted, EtypeRangePickFn<IdxType> pick_fn,
const std::vector<NDArray>& prob_or_mask) {
COOMatrix CSRRowWisePerEtypePick(
CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_picks, bool replace,
bool rowwise_etype_sorted, EtypeRangePickFn<IdxType> pick_fn,
const std::vector<NDArray>& prob_or_mask) {
using namespace aten;
const IdxType* indptr = mat.indptr.Ptr<IdxType>();
const IdxType* indices = mat.indices.Ptr<IdxType>();
const IdxType* eid = CSRHasData(mat)? mat.data.Ptr<IdxType>() : nullptr;
const IdxType* eid = CSRHasData(mat) ? mat.data.Ptr<IdxType>() : nullptr;
const IdxType* rows_data = rows.Ptr<IdxType>();
const int64_t num_rows = rows->shape[0];
const auto& ctx = mat.indptr->ctx;
......@@ -229,8 +240,8 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows,
std::vector<IdArray> picked_idxs(rows->shape[0]);
// Check if the number of picks have the same value.
// If so, we can potentially speed up if we have a node with total number of neighbors
// less than the given number of picks with replace=False.
// If so, we can potentially speed up if we have a node with total number of
// neighbors less than the given number of picks with replace=False.
bool same_num_pick = true;
int64_t num_pick_value = num_picks[0];
for (int64_t num_pick : num_picks) {
......@@ -267,9 +278,10 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows,
for (int64_t j = 0; j < len; ++j) {
const IdxType homogenized_eid = eid ? eid[off + j] : off + j;
auto it = std::upper_bound(
eid2etype_offset.begin(), eid2etype_offset.end(), homogenized_eid);
eid2etype_offset.begin(), eid2etype_offset.end(),
homogenized_eid);
const IdxType heterogenized_etype = it - eid2etype_offset.begin() - 1;
const IdxType heterogenized_eid = \
const IdxType heterogenized_eid =
homogenized_eid - eid2etype_offset[heterogenized_etype];
if (!has_probs || IsNullArray(prob_or_mask[heterogenized_etype])) {
......@@ -305,17 +317,21 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows,
for (int64_t j = 0; j < len; ++j) {
const IdxType homogenized_eid = eid ? eid[off + j] : off + j;
auto it = std::upper_bound(
eid2etype_offset.begin(), eid2etype_offset.end(), homogenized_eid);
eid2etype_offset.begin(), eid2etype_offset.end(),
homogenized_eid);
const IdxType heterogenized_etype = it - eid2etype_offset.begin() - 1;
const IdxType heterogenized_eid = \
const IdxType heterogenized_eid =
homogenized_eid - eid2etype_offset[heterogenized_etype];
et[j] = heterogenized_etype;
et_eid[j] = heterogenized_eid;
}
if (!rowwise_etype_sorted) // the edge type is sorted, not need to sort it
std::sort(et_idx.begin(), et_idx.end(),
[&et](IdxType i1, IdxType i2) {return et[i1] < et[i2];});
CHECK_LT(et[et_idx[len - 1]], num_etypes) << "etype values exceed the number of fanouts";
if (!rowwise_etype_sorted) // the edge type is sorted, not need to sort
// it
std::sort(
et_idx.begin(), et_idx.end(),
[&et](IdxType i1, IdxType i2) { return et[i1] < et[i2]; });
CHECK_LT(et[et_idx[len - 1]], num_etypes)
<< "etype values exceed the number of fanouts";
IdxType cur_et = et[et_idx[0]];
int64_t et_offset = 0;
......@@ -333,14 +349,18 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows,
// fast path, select all
for (int64_t k = 0; k < et_len; ++k) {
const IdxType eid_offset = off + et_idx[et_offset + k];
const IdxType homogenized_eid = eid ? eid[eid_offset] : eid_offset;
const IdxType homogenized_eid =
eid ? eid[eid_offset] : eid_offset;
auto it = std::upper_bound(
eid2etype_offset.begin(), eid2etype_offset.end(), homogenized_eid);
const IdxType heterogenized_etype = it - eid2etype_offset.begin() - 1;
const IdxType heterogenized_eid = \
eid2etype_offset.begin(), eid2etype_offset.end(),
homogenized_eid);
const IdxType heterogenized_etype =
it - eid2etype_offset.begin() - 1;
const IdxType heterogenized_eid =
homogenized_eid - eid2etype_offset[heterogenized_etype];
if (!has_probs || IsNullArray(prob_or_mask[heterogenized_etype])) {
if (!has_probs ||
IsNullArray(prob_or_mask[heterogenized_etype])) {
// No probability, select all
rows.push_back(rid);
cols.push_back(indices[eid_offset]);
......@@ -357,32 +377,31 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows,
}
}
} else {
IdArray picked_idx = Full(-1, num_picks[cur_et], sizeof(IdxType) * 8, ctx);
IdArray picked_idx =
Full(-1, num_picks[cur_et], sizeof(IdxType) * 8, ctx);
IdxType* picked_idata = picked_idx.Ptr<IdxType>();
// need call random pick
pick_fn(off, et_offset, cur_et,
et_len, et_idx, et_eid,
eid, picked_idata);
pick_fn(
off, et_offset, cur_et, et_len, et_idx, et_eid, eid,
picked_idata);
for (int64_t k = 0; k < num_picks[cur_et]; ++k) {
const IdxType picked = picked_idata[k];
if (picked == -1)
continue;
if (picked == -1) continue;
rows.push_back(rid);
cols.push_back(indices[off+et_idx[et_offset+picked]]);
cols.push_back(indices[off + et_idx[et_offset + picked]]);
if (eid) {
idx.push_back(eid[off+et_idx[et_offset+picked]]);
idx.push_back(eid[off + et_idx[et_offset + picked]]);
} else {
idx.push_back(off+et_idx[et_offset+picked]);
idx.push_back(off + et_idx[et_offset + picked]);
}
}
}
if (j+1 == len)
break;
if (j + 1 == len) break;
// next etype
cur_et = et[et_idx[j+1]];
et_offset = j+1;
cur_et = et[et_idx[j + 1]];
et_offset = j + 1;
et_len = 1;
} else {
et_len++;
......@@ -402,31 +421,32 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows,
IdArray picked_row = Concat(picked_rows);
IdArray picked_col = Concat(picked_cols);
IdArray picked_idx = Concat(picked_idxs);
return COOMatrix(mat.num_rows, mat.num_cols,
picked_row, picked_col, picked_idx);
return COOMatrix(
mat.num_rows, mat.num_cols, picked_row, picked_col, picked_idx);
}
// Template for picking non-zero values row-wise. The implementation first slices
// out the corresponding rows and then converts it to CSR format. It then performs
// row-wise pick on the CSR matrix and rectifies the returned results.
// Template for picking non-zero values row-wise. The implementation first
// slices out the corresponding rows and then converts it to CSR format. It then
// performs row-wise pick on the CSR matrix and rectifies the returned results.
template <typename IdxType>
COOMatrix COORowWisePick(COOMatrix mat, IdArray rows,
int64_t num_picks, bool replace, PickFn<IdxType> pick_fn,
NumPicksFn<IdxType> num_picks_fn) {
COOMatrix COORowWisePick(
COOMatrix mat, IdArray rows, int64_t num_picks, bool replace,
PickFn<IdxType> pick_fn, NumPicksFn<IdxType> num_picks_fn) {
using namespace aten;
const auto& csr = COOToCSR(COOSliceRows(mat, rows));
const IdArray new_rows = Range(0, rows->shape[0], rows->dtype.bits, rows->ctx);
const IdArray new_rows =
Range(0, rows->shape[0], rows->dtype.bits, rows->ctx);
const auto& picked = CSRRowWisePick<IdxType>(
csr, new_rows, num_picks, replace, pick_fn, num_picks_fn);
return COOMatrix(mat.num_rows, mat.num_cols,
IndexSelect(rows, picked.row), // map the row index to the correct one
picked.col,
picked.data);
return COOMatrix(
mat.num_rows, mat.num_cols,
IndexSelect(rows, picked.row), // map the row index to the correct one
picked.col, picked.data);
}
// Template for picking non-zero values row-wise. The implementation first slices
// out the corresponding rows and then converts it to CSR format. It then performs
// row-wise pick on the CSR matrix and rectifies the returned results.
// Template for picking non-zero values row-wise. The implementation first
// slices out the corresponding rows and then converts it to CSR format. It then
// performs row-wise pick on the CSR matrix and rectifies the returned results.
template <typename IdxType, typename DType>
COOMatrix COORowWisePerEtypePick(
COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
......@@ -435,13 +455,15 @@ COOMatrix COORowWisePerEtypePick(
const std::vector<NDArray>& prob_or_mask) {
using namespace aten;
const auto& csr = COOToCSR(COOSliceRows(mat, rows));
const IdArray new_rows = Range(0, rows->shape[0], rows->dtype.bits, rows->ctx);
const IdArray new_rows =
Range(0, rows->shape[0], rows->dtype.bits, rows->ctx);
const auto& picked = CSRRowWisePerEtypePick<IdxType, DType>(
csr, new_rows, eid2etype_offset, num_picks, replace, false, pick_fn, prob_or_mask);
return COOMatrix(mat.num_rows, mat.num_cols,
IndexSelect(rows, picked.row), // map the row index to the correct one
picked.col,
picked.data);
csr, new_rows, eid2etype_offset, num_picks, replace, false, pick_fn,
prob_or_mask);
return COOMatrix(
mat.num_rows, mat.num_cols,
IndexSelect(rows, picked.row), // map the row index to the correct one
picked.col, picked.data);
}
} // namespace impl
......
......@@ -4,7 +4,9 @@
* @brief rowwise sampling
*/
#include <dgl/random.h>
#include <numeric>
#include "./rowwise_pick.h"
namespace dgl {
......@@ -13,8 +15,8 @@ namespace impl {
namespace {
// Equivalent to numpy expression: array[idx[off:off + len]]
template <typename IdxType, typename FloatType>
inline FloatArray DoubleSlice(FloatArray array, const IdxType* idx_data,
IdxType off, IdxType len) {
inline FloatArray DoubleSlice(
FloatArray array, const IdxType* idx_data, IdxType off, IdxType len) {
const FloatType* array_data = static_cast<FloatType*>(array->data);
FloatArray ret = FloatArray::Empty({len}, array->dtype, array->ctx);
FloatType* ret_data = static_cast<FloatType*>(ret->data);
......@@ -30,42 +32,44 @@ inline FloatArray DoubleSlice(FloatArray array, const IdxType* idx_data,
template <typename IdxType, typename DType>
inline NumPicksFn<IdxType> GetSamplingNumPicksFn(
int64_t num_samples, NDArray prob_or_mask, bool replace) {
NumPicksFn<IdxType> num_picks_fn = [prob_or_mask, num_samples, replace]
(IdxType rowid, IdxType off, IdxType len,
const IdxType* col, const IdxType* data) {
const int64_t max_num_picks = (num_samples == -1) ? len : num_samples;
const DType* prob_or_mask_data = prob_or_mask.Ptr<DType>();
IdxType nnz = 0;
for (IdxType i = off; i < off + len; ++i) {
const IdxType eid = data ? data[i] : i;
if (prob_or_mask_data[eid] > 0) {
++nnz;
}
NumPicksFn<IdxType> num_picks_fn = [prob_or_mask, num_samples, replace](
IdxType rowid, IdxType off,
IdxType len, const IdxType* col,
const IdxType* data) {
const int64_t max_num_picks = (num_samples == -1) ? len : num_samples;
const DType* prob_or_mask_data = prob_or_mask.Ptr<DType>();
IdxType nnz = 0;
for (IdxType i = off; i < off + len; ++i) {
const IdxType eid = data ? data[i] : i;
if (prob_or_mask_data[eid] > 0) {
++nnz;
}
}
if (replace) {
return static_cast<IdxType>(nnz == 0 ? 0 : max_num_picks);
} else {
return std::min(static_cast<IdxType>(max_num_picks), nnz);
}
};
if (replace) {
return static_cast<IdxType>(nnz == 0 ? 0 : max_num_picks);
} else {
return std::min(static_cast<IdxType>(max_num_picks), nnz);
}
};
return num_picks_fn;
}
template <typename IdxType, typename DType>
inline PickFn<IdxType> GetSamplingPickFn(
int64_t num_samples, NDArray prob_or_mask, bool replace) {
PickFn<IdxType> pick_fn = [prob_or_mask, num_samples, replace]
(IdxType rowid, IdxType off, IdxType len, IdxType num_picks,
const IdxType* col, const IdxType* data,
IdxType* out_idx) {
NDArray prob_or_mask_selected = DoubleSlice<IdxType, DType>(prob_or_mask, data, off, len);
RandomEngine::ThreadLocal()->Choice<IdxType, DType>(
num_picks, prob_or_mask_selected, out_idx, replace);
for (int64_t j = 0; j < num_picks; ++j) {
out_idx[j] += off;
}
};
PickFn<IdxType> pick_fn = [prob_or_mask, num_samples, replace](
IdxType rowid, IdxType off, IdxType len,
IdxType num_picks, const IdxType* col,
const IdxType* data, IdxType* out_idx) {
NDArray prob_or_mask_selected =
DoubleSlice<IdxType, DType>(prob_or_mask, data, off, len);
RandomEngine::ThreadLocal()->Choice<IdxType, DType>(
num_picks, prob_or_mask_selected, out_idx, replace);
for (int64_t j = 0; j < num_picks; ++j) {
out_idx[j] += off;
}
};
return pick_fn;
}
......@@ -73,108 +77,112 @@ template <typename IdxType, typename FloatType>
inline EtypeRangePickFn<IdxType> GetSamplingRangePickFn(
const std::vector<int64_t>& num_samples,
const std::vector<FloatArray>& prob, bool replace) {
EtypeRangePickFn<IdxType> pick_fn = [prob, num_samples, replace]
(IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,
const std::vector<IdxType> &et_idx,
const std::vector<IdxType> &et_eid,
const IdxType* eid, IdxType* out_idx) {
const FloatArray& p = prob[cur_et];
const FloatType* p_data = IsNullArray(p) ? nullptr : p.Ptr<FloatType>();
FloatArray probs = FloatArray::Empty({et_len}, p->dtype, p->ctx);
FloatType* probs_data = probs.Ptr<FloatType>();
for (int64_t j = 0; j < et_len; ++j) {
const IdxType cur_eid = et_eid[et_idx[et_offset + j]];
probs_data[j] = p_data ? p_data[cur_eid] : static_cast<FloatType>(1.);
}
EtypeRangePickFn<IdxType> pick_fn =
[prob, num_samples, replace](
IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,
const std::vector<IdxType>& et_idx,
const std::vector<IdxType>& et_eid, const IdxType* eid,
IdxType* out_idx) {
const FloatArray& p = prob[cur_et];
const FloatType* p_data = IsNullArray(p) ? nullptr : p.Ptr<FloatType>();
FloatArray probs = FloatArray::Empty({et_len}, p->dtype, p->ctx);
FloatType* probs_data = probs.Ptr<FloatType>();
for (int64_t j = 0; j < et_len; ++j) {
const IdxType cur_eid = et_eid[et_idx[et_offset + j]];
probs_data[j] = p_data ? p_data[cur_eid] : static_cast<FloatType>(1.);
}
RandomEngine::ThreadLocal()->Choice<IdxType, FloatType>(
num_samples[cur_et], probs, out_idx, replace);
};
RandomEngine::ThreadLocal()->Choice<IdxType, FloatType>(
num_samples[cur_et], probs, out_idx, replace);
};
return pick_fn;
}
template <typename IdxType>
inline NumPicksFn<IdxType> GetSamplingUniformNumPicksFn(
int64_t num_samples, bool replace) {
NumPicksFn<IdxType> num_picks_fn = [num_samples, replace]
(IdxType rowid, IdxType off, IdxType len,
const IdxType* col, const IdxType* data) {
const int64_t max_num_picks = (num_samples == -1) ? len : num_samples;
if (replace) {
return static_cast<IdxType>(len == 0 ? 0 : max_num_picks);
} else {
return std::min(static_cast<IdxType>(max_num_picks), len);
}
};
NumPicksFn<IdxType> num_picks_fn = [num_samples, replace](
IdxType rowid, IdxType off,
IdxType len, const IdxType* col,
const IdxType* data) {
const int64_t max_num_picks = (num_samples == -1) ? len : num_samples;
if (replace) {
return static_cast<IdxType>(len == 0 ? 0 : max_num_picks);
} else {
return std::min(static_cast<IdxType>(max_num_picks), len);
}
};
return num_picks_fn;
}
template <typename IdxType>
inline PickFn<IdxType> GetSamplingUniformPickFn(
int64_t num_samples, bool replace) {
PickFn<IdxType> pick_fn = [num_samples, replace]
(IdxType rowid, IdxType off, IdxType len, IdxType num_picks,
const IdxType* col, const IdxType* data,
IdxType* out_idx) {
RandomEngine::ThreadLocal()->UniformChoice<IdxType>(
num_picks, len, out_idx, replace);
for (int64_t j = 0; j < num_picks; ++j) {
out_idx[j] += off;
}
};
PickFn<IdxType> pick_fn = [num_samples, replace](
IdxType rowid, IdxType off, IdxType len,
IdxType num_picks, const IdxType* col,
const IdxType* data, IdxType* out_idx) {
RandomEngine::ThreadLocal()->UniformChoice<IdxType>(
num_picks, len, out_idx, replace);
for (int64_t j = 0; j < num_picks; ++j) {
out_idx[j] += off;
}
};
return pick_fn;
}
template <typename IdxType>
inline EtypeRangePickFn<IdxType> GetSamplingUniformRangePickFn(
const std::vector<int64_t>& num_samples, bool replace) {
EtypeRangePickFn<IdxType> pick_fn = [num_samples, replace]
(IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,
const std::vector<IdxType> &et_idx,
const std::vector<IdxType> &et_eid,
const IdxType* data, IdxType* out_idx) {
RandomEngine::ThreadLocal()->UniformChoice<IdxType>(
num_samples[cur_et], et_len, out_idx, replace);
};
EtypeRangePickFn<IdxType> pick_fn =
[num_samples, replace](
IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len,
const std::vector<IdxType>& et_idx,
const std::vector<IdxType>& et_eid, const IdxType* data,
IdxType* out_idx) {
RandomEngine::ThreadLocal()->UniformChoice<IdxType>(
num_samples[cur_et], et_len, out_idx, replace);
};
return pick_fn;
}
template <typename IdxType, typename FloatType>
inline NumPicksFn<IdxType> GetSamplingBiasedNumPicksFn(
int64_t num_samples, IdArray split, FloatArray bias, bool replace) {
NumPicksFn<IdxType> num_picks_fn = [num_samples, split, bias, replace]
(IdxType rowid, IdxType off, IdxType len,
const IdxType* col, const IdxType* data) {
const int64_t max_num_picks = (num_samples == -1) ? len : num_samples;
const int64_t num_tags = split->shape[1] - 1;
const IdxType* tag_offset = split.Ptr<IdxType>() + rowid * split->shape[1];
const FloatType* bias_data = bias.Ptr<FloatType>();
IdxType nnz = 0;
for (int64_t j = 0; j < num_tags; ++j) {
if (bias_data[j] > 0) {
nnz += tag_offset[j + 1] - tag_offset[j];
}
NumPicksFn<IdxType> num_picks_fn = [num_samples, split, bias, replace](
IdxType rowid, IdxType off,
IdxType len, const IdxType* col,
const IdxType* data) {
const int64_t max_num_picks = (num_samples == -1) ? len : num_samples;
const int64_t num_tags = split->shape[1] - 1;
const IdxType* tag_offset = split.Ptr<IdxType>() + rowid * split->shape[1];
const FloatType* bias_data = bias.Ptr<FloatType>();
IdxType nnz = 0;
for (int64_t j = 0; j < num_tags; ++j) {
if (bias_data[j] > 0) {
nnz += tag_offset[j + 1] - tag_offset[j];
}
}
if (replace) {
return static_cast<IdxType>(nnz == 0 ? 0 : max_num_picks);
} else {
return std::min(static_cast<IdxType>(max_num_picks), nnz);
}
};
if (replace) {
return static_cast<IdxType>(nnz == 0 ? 0 : max_num_picks);
} else {
return std::min(static_cast<IdxType>(max_num_picks), nnz);
}
};
return num_picks_fn;
}
template <typename IdxType, typename FloatType>
inline PickFn<IdxType> GetSamplingBiasedPickFn(
int64_t num_samples, IdArray split, FloatArray bias, bool replace) {
PickFn<IdxType> pick_fn = [num_samples, split, bias, replace]
(IdxType rowid, IdxType off, IdxType len, IdxType num_picks,
const IdxType* col, const IdxType* data,
IdxType* out_idx) {
const IdxType *tag_offset = split.Ptr<IdxType>() + rowid * split->shape[1];
PickFn<IdxType> pick_fn = [num_samples, split, bias, replace](
IdxType rowid, IdxType off, IdxType len,
IdxType num_picks, const IdxType* col,
const IdxType* data, IdxType* out_idx) {
const IdxType* tag_offset = split.Ptr<IdxType>() + rowid * split->shape[1];
RandomEngine::ThreadLocal()->BiasedChoice<IdxType, FloatType>(
num_picks, tag_offset, bias, out_idx, replace);
num_picks, tag_offset, bias, out_idx, replace);
for (int64_t j = 0; j < num_picks; ++j) {
out_idx[j] += off;
}
......@@ -187,15 +195,16 @@ inline PickFn<IdxType> GetSamplingBiasedPickFn(
/////////////////////////////// CSR ///////////////////////////////
template <DGLDeviceType XPU, typename IdxType, typename DType>
COOMatrix CSRRowWiseSampling(CSRMatrix mat, IdArray rows, int64_t num_samples,
NDArray prob_or_mask, bool replace) {
COOMatrix CSRRowWiseSampling(
CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
bool replace) {
// If num_samples is -1, select all neighbors without replacement.
replace = (replace && num_samples != -1);
CHECK(prob_or_mask.defined());
auto num_picks_fn = GetSamplingNumPicksFn<IdxType, DType>(
num_samples, prob_or_mask, replace);
auto pick_fn = GetSamplingPickFn<IdxType, DType>(
num_samples, prob_or_mask, replace);
auto num_picks_fn =
GetSamplingNumPicksFn<IdxType, DType>(num_samples, prob_or_mask, replace);
auto pick_fn =
GetSamplingPickFn<IdxType, DType>(num_samples, prob_or_mask, replace);
return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
}
......@@ -219,49 +228,52 @@ template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, uint8_t>(
template <DGLDeviceType XPU, typename IdxType, typename DType>
COOMatrix CSRRowWisePerEtypeSampling(
CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples, const std::vector<NDArray>& prob_or_mask,
bool replace, bool rowwise_etype_sorted) {
CHECK(prob_or_mask.size() == num_samples.size()) <<
"the number of probability tensors does not match the number of edge types.";
for (auto& p : prob_or_mask)
CHECK(p.defined());
auto pick_fn = GetSamplingRangePickFn<IdxType, DType>(num_samples, prob_or_mask, replace);
const std::vector<int64_t>& num_samples,
const std::vector<NDArray>& prob_or_mask, bool replace,
bool rowwise_etype_sorted) {
CHECK(prob_or_mask.size() == num_samples.size())
<< "the number of probability tensors does not match the number of edge "
"types.";
for (auto& p : prob_or_mask) CHECK(p.defined());
auto pick_fn = GetSamplingRangePickFn<IdxType, DType>(
num_samples, prob_or_mask, replace);
return CSRRowWisePerEtypePick<IdxType, DType>(
mat, rows, eid2etype_offset, num_samples, replace, rowwise_etype_sorted, pick_fn,
prob_or_mask);
mat, rows, eid2etype_offset, num_samples, replace, rowwise_etype_sorted,
pick_fn, prob_or_mask);
}
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, float>(
CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool, bool);
CSRMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, float>(
CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool, bool);
CSRMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, double>(
CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool, bool);
CSRMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, double>(
CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool, bool);
CSRMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, int8_t>(
CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool, bool);
CSRMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, int8_t>(
CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool, bool);
CSRMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, uint8_t>(
CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool, bool);
CSRMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, uint8_t>(
CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool, bool);
CSRMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<int64_t>&, const std::vector<NDArray>&, bool, bool);
template <DGLDeviceType XPU, typename IdxType>
COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat, IdArray rows,
int64_t num_samples, bool replace) {
COOMatrix CSRRowWiseSamplingUniform(
CSRMatrix mat, IdArray rows, int64_t num_samples, bool replace) {
// If num_samples is -1, select all neighbors without replacement.
replace = (replace && num_samples != -1);
auto num_picks_fn = GetSamplingUniformNumPicksFn<IdxType>(num_samples, replace);
auto num_picks_fn =
GetSamplingUniformNumPicksFn<IdxType>(num_samples, replace);
auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
}
......@@ -274,28 +286,25 @@ template COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int64_t>(
template <DGLDeviceType XPU, typename IdxType>
COOMatrix CSRRowWisePerEtypeSamplingUniform(
CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples, bool replace, bool rowwise_etype_sorted) {
const std::vector<int64_t>& num_samples, bool replace,
bool rowwise_etype_sorted) {
auto pick_fn = GetSamplingUniformRangePickFn<IdxType>(num_samples, replace);
return CSRRowWisePerEtypePick<IdxType, float>(
mat, rows, eid2etype_offset, num_samples, replace, rowwise_etype_sorted, pick_fn, {});
mat, rows, eid2etype_offset, num_samples, replace, rowwise_etype_sorted,
pick_fn, {});
}
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDGLCPU, int32_t>(
CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&, bool,
bool);
CSRMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<int64_t>&, bool, bool);
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDGLCPU, int64_t>(
CSRMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&, bool,
bool);
CSRMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<int64_t>&, bool, bool);
template <DGLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix CSRRowWiseSamplingBiased(
CSRMatrix mat,
IdArray rows,
int64_t num_samples,
NDArray tag_offset,
FloatArray bias,
bool replace
) {
CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray tag_offset,
FloatArray bias, bool replace) {
// If num_samples is -1, select all neighbors without replacement.
replace = (replace && num_samples != -1);
auto num_picks_fn = GetSamplingBiasedNumPicksFn<IdxType, FloatType>(
......@@ -306,30 +315,30 @@ COOMatrix CSRRowWiseSamplingBiased(
}
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int32_t, float>(
CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int64_t, float>(
CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int32_t, double>(
CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int64_t, double>(
CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
/////////////////////////////// COO ///////////////////////////////
template <DGLDeviceType XPU, typename IdxType, typename DType>
COOMatrix COORowWiseSampling(COOMatrix mat, IdArray rows, int64_t num_samples,
NDArray prob_or_mask, bool replace) {
COOMatrix COORowWiseSampling(
COOMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
bool replace) {
// If num_samples is -1, select all neighbors without replacement.
replace = (replace && num_samples != -1);
CHECK(prob_or_mask.defined());
auto num_picks_fn = GetSamplingNumPicksFn<IdxType, DType>(
num_samples, prob_or_mask, replace);
auto pick_fn = GetSamplingPickFn<IdxType, DType>(
num_samples, prob_or_mask, replace);
auto num_picks_fn =
GetSamplingNumPicksFn<IdxType, DType>(num_samples, prob_or_mask, replace);
auto pick_fn =
GetSamplingPickFn<IdxType, DType>(num_samples, prob_or_mask, replace);
return COORowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
}
......@@ -353,48 +362,50 @@ template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, uint8_t>(
template <DGLDeviceType XPU, typename IdxType, typename DType>
COOMatrix COORowWisePerEtypeSampling(
COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples, const std::vector<NDArray>& prob_or_mask,
bool replace) {
CHECK(prob_or_mask.size() == num_samples.size()) <<
"the number of probability tensors do not match the number of edge types.";
for (auto& p : prob_or_mask)
CHECK(p.defined());
auto pick_fn = GetSamplingRangePickFn<IdxType, DType>(num_samples, prob_or_mask, replace);
const std::vector<int64_t>& num_samples,
const std::vector<NDArray>& prob_or_mask, bool replace) {
CHECK(prob_or_mask.size() == num_samples.size())
<< "the number of probability tensors do not match the number of edge "
"types.";
for (auto& p : prob_or_mask) CHECK(p.defined());
auto pick_fn = GetSamplingRangePickFn<IdxType, DType>(
num_samples, prob_or_mask, replace);
return COORowWisePerEtypePick<IdxType, DType>(
mat, rows, eid2etype_offset, num_samples, replace, pick_fn, prob_or_mask);
}
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, float>(
COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool);
COOMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, float>(
COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool);
COOMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, double>(
COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool);
COOMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, double>(
COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool);
COOMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, int8_t>(
COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool);
COOMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, int8_t>(
COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool);
COOMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, uint8_t>(
COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool);
COOMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, uint8_t>(
COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&,
const std::vector<NDArray>&, bool);
COOMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<int64_t>&, const std::vector<NDArray>&, bool);
template <DGLDeviceType XPU, typename IdxType>
COOMatrix COORowWiseSamplingUniform(COOMatrix mat, IdArray rows,
int64_t num_samples, bool replace) {
COOMatrix COORowWiseSamplingUniform(
COOMatrix mat, IdArray rows, int64_t num_samples, bool replace) {
// If num_samples is -1, select all neighbors without replacement.
replace = (replace && num_samples != -1);
auto num_picks_fn = GetSamplingUniformNumPicksFn<IdxType>(num_samples, replace);
auto num_picks_fn =
GetSamplingUniformNumPicksFn<IdxType>(num_samples, replace);
auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
return COORowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn);
}
......@@ -414,9 +425,11 @@ COOMatrix COORowWisePerEtypeSamplingUniform(
}
template COOMatrix COORowWisePerEtypeSamplingUniform<kDGLCPU, int32_t>(
COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&, bool);
COOMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<int64_t>&, bool);
template COOMatrix COORowWisePerEtypeSamplingUniform<kDGLCPU, int64_t>(
COOMatrix, IdArray, const std::vector<int64_t>&, const std::vector<int64_t>&, bool);
COOMatrix, IdArray, const std::vector<int64_t>&,
const std::vector<int64_t>&, bool);
} // namespace impl
} // namespace aten
......
......@@ -3,8 +3,9 @@
* @file array/cpu/rowwise_topk.cc
* @brief rowwise topk
*/
#include <numeric>
#include <algorithm>
#include <numeric>
#include "./rowwise_pick.h"
namespace dgl {
......@@ -14,52 +15,52 @@ namespace {
template <typename IdxType>
inline NumPicksFn<IdxType> GetTopkNumPicksFn(int64_t k) {
NumPicksFn<IdxType> num_picks_fn = [k]
(IdxType rowid, IdxType off, IdxType len,
const IdxType* col, const IdxType* data) {
const int64_t max_num_picks = (k == -1) ? len : k;
return std::min(static_cast<IdxType>(max_num_picks), len);
};
NumPicksFn<IdxType> num_picks_fn = [k](IdxType rowid, IdxType off,
IdxType len, const IdxType* col,
const IdxType* data) {
const int64_t max_num_picks = (k == -1) ? len : k;
return std::min(static_cast<IdxType>(max_num_picks), len);
};
return num_picks_fn;
}
template <typename IdxType, typename DType>
inline PickFn<IdxType> GetTopkPickFn(NDArray weight, bool ascending) {
const DType* wdata = static_cast<DType*>(weight->data);
PickFn<IdxType> pick_fn = [ascending, wdata]
(IdxType rowid, IdxType off, IdxType len, IdxType num_picks,
const IdxType* col, const IdxType* data,
IdxType* out_idx) {
std::function<bool(IdxType, IdxType)> compare_fn;
if (ascending) {
if (data) {
compare_fn = [wdata, data] (IdxType i, IdxType j) {
return wdata[data[i]] < wdata[data[j]];
};
} else {
compare_fn = [wdata] (IdxType i, IdxType j) {
return wdata[i] < wdata[j];
};
}
PickFn<IdxType> pick_fn = [ascending, wdata](
IdxType rowid, IdxType off, IdxType len,
IdxType num_picks, const IdxType* col,
const IdxType* data, IdxType* out_idx) {
std::function<bool(IdxType, IdxType)> compare_fn;
if (ascending) {
if (data) {
compare_fn = [wdata, data](IdxType i, IdxType j) {
return wdata[data[i]] < wdata[data[j]];
};
} else {
if (data) {
compare_fn = [wdata, data] (IdxType i, IdxType j) {
return wdata[data[i]] > wdata[data[j]];
};
} else {
compare_fn = [wdata] (IdxType i, IdxType j) {
return wdata[i] > wdata[j];
};
}
compare_fn = [wdata](IdxType i, IdxType j) {
return wdata[i] < wdata[j];
};
}
std::vector<IdxType> idx(len);
std::iota(idx.begin(), idx.end(), off);
std::sort(idx.begin(), idx.end(), compare_fn);
for (int64_t j = 0; j < num_picks; ++j) {
out_idx[j] = idx[j];
} else {
if (data) {
compare_fn = [wdata, data](IdxType i, IdxType j) {
return wdata[data[i]] > wdata[data[j]];
};
} else {
compare_fn = [wdata](IdxType i, IdxType j) {
return wdata[i] > wdata[j];
};
}
};
}
std::vector<IdxType> idx(len);
std::iota(idx.begin(), idx.end(), off);
std::sort(idx.begin(), idx.end(), compare_fn);
for (int64_t j = 0; j < num_picks; ++j) {
out_idx[j] = idx[j];
}
};
return pick_fn;
}
......
......@@ -4,73 +4,65 @@
* @brief SDDMM C APIs and definitions.
*/
#include "./sddmm.h"
#include <dgl/array.h>
namespace dgl {
namespace aten {
#define SWITCH_RHS(rhs_target, RhsTarget, ...) \
do { \
if ((rhs_target) == 0) { \
constexpr int RhsTarget = 0; \
{ __VA_ARGS__ } \
} else if ((rhs_target) == 1) { \
constexpr int RhsTarget = 1; \
{ __VA_ARGS__ } \
} else if ((rhs_target) == 2) { \
constexpr int RhsTarget = 2; \
{ __VA_ARGS__ } \
} else { \
LOG(INFO) << "Invalid rhs target: " << (rhs_target); \
} \
#define SWITCH_RHS(rhs_target, RhsTarget, ...) \
do { \
if ((rhs_target) == 0) { \
constexpr int RhsTarget = 0; \
{ __VA_ARGS__ } \
} else if ((rhs_target) == 1) { \
constexpr int RhsTarget = 1; \
{ __VA_ARGS__ } \
} else if ((rhs_target) == 2) { \
constexpr int RhsTarget = 2; \
{ __VA_ARGS__ } \
} else { \
LOG(INFO) << "Invalid rhs target: " << (rhs_target); \
} \
} while (0)
#define SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, ...)\
do { \
if ((lhs_target) == 0) { \
constexpr int LhsTarget = 0; \
SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \
} else if ((lhs_target) == 1) { \
constexpr int LhsTarget = 1; \
SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \
} else if ((lhs_target) == 2) { \
constexpr int LhsTarget = 2; \
SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \
} else { \
LOG(INFO) << "Invalid lhs target: " << (lhs_target); \
} \
#define SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, ...) \
do { \
if ((lhs_target) == 0) { \
constexpr int LhsTarget = 0; \
SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \
} else if ((lhs_target) == 1) { \
constexpr int LhsTarget = 1; \
SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \
} else if ((lhs_target) == 2) { \
constexpr int LhsTarget = 2; \
SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \
} else { \
LOG(INFO) << "Invalid lhs target: " << (lhs_target); \
} \
} while (0)
/** @brief Generalized SDDMM on Csr format. */
template <int XPU, typename IdType, typename DType>
void SDDMMCsr(const std::string& op,
const BcastOff& bcast,
const CSRMatrix& csr,
NDArray lhs,
NDArray rhs,
NDArray out,
int lhs_target,
int rhs_target) {
void SDDMMCsr(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target) {
SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out);
cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(
bcast, csr, lhs, rhs, out);
});
});
}
/** @brief Generalized SDDMM on Csr format with Heterograph support. */
template <int XPU, typename IdType, typename DType>
void SDDMMCsrHetero(const std::string& op,
const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& vec_lhs,
const std::vector<NDArray>& vec_rhs,
std::vector<NDArray> vec_out,
int lhs_target,
int rhs_target,
const std::vector<dgl_type_t>& lhs_nid,
const std::vector<dgl_type_t>& rhs_nid) {
void SDDMMCsrHetero(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& vec_lhs,
const std::vector<NDArray>& vec_rhs, std::vector<NDArray> vec_out,
int lhs_target, int rhs_target, const std::vector<dgl_type_t>& lhs_nid,
const std::vector<dgl_type_t>& rhs_nid) {
SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
/* Call SDDMM for each relation type */
......@@ -79,7 +71,8 @@ void SDDMMCsrHetero(const std::string& op,
NDArray lhs = vec_lhs[lhs_nid[etype]];
NDArray rhs = vec_rhs[rhs_nid[etype]];
NDArray out = vec_out[etype];
cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out);
cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(
bcast, csr, lhs, rhs, out);
}
});
});
......@@ -87,79 +80,63 @@ void SDDMMCsrHetero(const std::string& op,
template void SDDMMCsr<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCPU, int64_t, float>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCPU, int32_t, double>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCsrHetero<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCPU, int64_t, float>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCPU, int32_t, double>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
/** @brief Generalized SDDMM on Coo format. */
template <int XPU, typename IdType, typename DType>
void SDDMMCoo(const std::string& op,
const BcastOff& bcast,
const COOMatrix& coo,
NDArray lhs,
NDArray rhs,
NDArray out,
int lhs_target,
int rhs_target) {
void SDDMMCoo(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target) {
SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out);
cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(
bcast, coo, lhs, rhs, out);
});
});
}
/** @brief Generalized SDDMM on Coo format with Heterograph support. */
template <int XPU, typename IdType, typename DType>
void SDDMMCooHetero(const std::string& op,
const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& vec_lhs,
const std::vector<NDArray>& vec_rhs,
std::vector<NDArray> vec_out,
int lhs_target,
int rhs_target,
const std::vector<dgl_type_t>& lhs_nid,
const std::vector<dgl_type_t>& rhs_nid) {
void SDDMMCooHetero(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& vec_lhs,
const std::vector<NDArray>& vec_rhs, std::vector<NDArray> vec_out,
int lhs_target, int rhs_target, const std::vector<dgl_type_t>& lhs_nid,
const std::vector<dgl_type_t>& rhs_nid) {
SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
/* Call SDDMM for each relation type */
......@@ -168,7 +145,8 @@ void SDDMMCooHetero(const std::string& op,
NDArray lhs = vec_lhs[lhs_nid[etype]];
NDArray rhs = vec_rhs[rhs_nid[etype]];
NDArray out = vec_out[etype];
cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out);
cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(
bcast, coo, lhs, rhs, out);
}
});
});
......@@ -176,48 +154,40 @@ void SDDMMCooHetero(const std::string& op,
template void SDDMMCoo<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCPU, int64_t, float>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCPU, int32_t, double>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCooHetero<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCPU, int64_t, float>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCPU, int32_t, double>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
} // namespace aten
......
......@@ -4,8 +4,11 @@
* @brief Segment reduce C APIs and definitions.
*/
#include "./segment_reduce.h"
#include <dgl/array.h>
#include <string>
#include "./spmm_binary_ops.h"
namespace dgl {
......@@ -14,10 +17,7 @@ namespace aten {
/** @brief Segment Reduce operator. */
template <int XPU, typename IdType, typename DType>
void SegmentReduce(
const std::string& op,
NDArray feat,
NDArray offsets,
NDArray out,
const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray arg) {
if (op == "sum") {
cpu::SegmentSum<IdType, DType>(feat, offsets, out);
......@@ -36,73 +36,47 @@ void SegmentReduce(
/** @brief Scatter Add.*/
template <int XPU, typename IdType, typename DType>
void ScatterAdd(NDArray feat,
NDArray idx,
NDArray out) {
void ScatterAdd(NDArray feat, NDArray idx, NDArray out) {
cpu::ScatterAdd<IdType, DType>(feat, idx, out);
}
/** @brief Update gradients for reduce operator max/min on heterogeneous graph.*/
/** @brief Update gradients for reduce operator max/min on heterogeneous
* graph.*/
template <int XPU, typename IdType, typename DType>
void UpdateGradMinMax_hetero(const HeteroGraphPtr& g,
const std::string& op,
const std::vector<NDArray>& feat,
const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype,
std::vector<NDArray>* out) {
void UpdateGradMinMax_hetero(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out) {
cpu::UpdateGradMinMax_hetero<IdType, DType>(g, op, feat, idx, idx_etype, out);
}
/** @brief Backward function of segment cmp.*/
template <int XPU, typename IdType, typename DType>
void BackwardSegmentCmp(
NDArray feat,
NDArray arg,
NDArray out) {
void BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out) {
cpu::BackwardSegmentCmp<IdType, DType>(feat, arg, out);
}
template void SegmentReduce<kDGLCPU, int32_t, float>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray arg);
template void SegmentReduce<kDGLCPU, int64_t, float>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray arg);
template void SegmentReduce<kDGLCPU, int32_t, double>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray arg);
template void SegmentReduce<kDGLCPU, int64_t, double>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray arg);
template void ScatterAdd<kDGLCPU, int32_t, float>(
NDArray feat,
NDArray idx,
NDArray out);
NDArray feat, NDArray idx, NDArray out);
template void ScatterAdd<kDGLCPU, int64_t, float>(
NDArray feat,
NDArray idx,
NDArray out);
NDArray feat, NDArray idx, NDArray out);
template void ScatterAdd<kDGLCPU, int32_t, double>(
NDArray feat,
NDArray idx,
NDArray out);
NDArray feat, NDArray idx, NDArray out);
template void ScatterAdd<kDGLCPU, int64_t, double>(
NDArray feat,
NDArray arg,
NDArray out);
NDArray feat, NDArray arg, NDArray out);
template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, float>(
const HeteroGraphPtr& g, const std::string& op,
......@@ -122,21 +96,13 @@ template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, double>(
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void BackwardSegmentCmp<kDGLCPU, int32_t, float>(
NDArray feat,
NDArray arg,
NDArray out);
NDArray feat, NDArray arg, NDArray out);
template void BackwardSegmentCmp<kDGLCPU, int64_t, float>(
NDArray feat,
NDArray arg,
NDArray out);
NDArray feat, NDArray arg, NDArray out);
template void BackwardSegmentCmp<kDGLCPU, int32_t, double>(
NDArray feat,
NDArray arg,
NDArray out);
NDArray feat, NDArray arg, NDArray out);
template void BackwardSegmentCmp<kDGLCPU, int64_t, double>(
NDArray feat,
NDArray arg,
NDArray out);
NDArray feat, NDArray arg, NDArray out);
} // namespace aten
} // namespace dgl
......@@ -4,6 +4,7 @@
* @brief SPMM C APIs and definitions.
*/
#include "./spmm.h"
#include <dgl/array.h>
namespace dgl {
......@@ -11,13 +12,10 @@ namespace aten {
/** @brief Generalized SpMM on Csr format. */
template <int XPU, typename IdType, typename DType>
void SpMMCsr(const std::string& op, const std::string& reduce,
const BcastOff& bcast,
const CSRMatrix& csr,
NDArray ufeat,
NDArray efeat,
NDArray out,
std::vector<NDArray> out_aux) {
void SpMMCsr(
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux) {
const int64_t dim = bcast.out_len;
if (reduce == "sum") {
SWITCH_OP(op, Op, {
......@@ -25,13 +23,15 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
});
} else if (reduce == "max" || reduce == "min") {
SWITCH_OP(op, Op, {
DType *out_off = out.Ptr<DType>();
DType* out_off = out.Ptr<DType>();
if (reduce == "max") {
std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Max<DType>::zero);
std::fill(
out_off, out_off + csr.num_rows * dim, cpu::op::Max<DType>::zero);
cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Max<DType>>(
bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
} else {
std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Min<DType>::zero);
std::fill(
out_off, out_off + csr.num_rows * dim, cpu::op::Min<DType>::zero);
cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Min<DType>>(
bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
}
......@@ -43,15 +43,14 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
/** @brief Generalized SpMM on Csr format. */
template <int XPU, typename IdType, typename DType>
void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& vec_ufeat,
const std::vector<NDArray>& vec_efeat,
std::vector<NDArray>* vec_out,
std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids) {
void SpMMCsrHetero(
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& vec_ufeat,
const std::vector<NDArray>& vec_efeat, std::vector<NDArray>* vec_out,
std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids) {
const int64_t dim = bcast.out_len;
if (reduce == "sum") {
SWITCH_OP(op, Op, {
......@@ -60,8 +59,10 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const dgl_type_t src_id = ufeat_node_tids[etype];
const dgl_type_t dst_id = out_node_tids[etype];
CSRMatrix csr = vec_csr[etype];
NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
NDArray ufeat =
(vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
NDArray efeat =
(vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
NDArray out = (*vec_out)[dst_id];
cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
}
......@@ -71,21 +72,27 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
std::vector<bool> updated((*vec_out).size(), false);
// TODO(Israt): use vector updated to fill(out...) too
for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
DType *out_off = (*vec_out)[out_node_tids[etype]].Ptr<DType>();
DType* out_off = (*vec_out)[out_node_tids[etype]].Ptr<DType>();
if (reduce == "max")
std::fill(out_off, out_off + vec_csr[etype].num_rows * dim, cpu::op::Max<DType>::zero);
std::fill(
out_off, out_off + vec_csr[etype].num_rows * dim,
cpu::op::Max<DType>::zero);
else
std::fill(out_off, out_off + vec_csr[etype].num_rows * dim, cpu::op::Min<DType>::zero);
std::fill(
out_off, out_off + vec_csr[etype].num_rows * dim,
cpu::op::Min<DType>::zero);
const dgl_type_t dst_id = out_node_tids[etype];
if (!updated[dst_id]) {
updated[dst_id] = true;
if (Op::use_lhs) {
IdType *argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>();
std::fill(argu_ntype, argu_ntype + vec_csr[etype].num_rows * dim, -1);
IdType* argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>();
std::fill(
argu_ntype, argu_ntype + vec_csr[etype].num_rows * dim, -1);
}
if (Op::use_rhs) {
IdType *arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>();
std::fill(arge_etype, arge_etype + vec_csr[etype].num_rows * dim, -1);
IdType* arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>();
std::fill(
arge_etype, arge_etype + vec_csr[etype].num_rows * dim, -1);
}
}
}
......@@ -94,17 +101,21 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const dgl_type_t src_id = ufeat_node_tids[etype];
const dgl_type_t dst_id = out_node_tids[etype];
CSRMatrix csr = vec_csr[etype];
NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
NDArray ufeat =
(vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
NDArray efeat =
(vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
NDArray out = (*vec_out)[dst_id];
if (reduce == "max") {
cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Max<DType>>(
bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id], (*out_aux)[1][dst_id],
(*out_aux)[2][dst_id], (*out_aux)[3][dst_id], src_id, etype);
bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id],
(*out_aux)[1][dst_id], (*out_aux)[2][dst_id],
(*out_aux)[3][dst_id], src_id, etype);
} else {
cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Min<DType>>(
bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id], (*out_aux)[1][dst_id],
(*out_aux)[2][dst_id], (*out_aux)[3][dst_id], src_id, etype);
bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id],
(*out_aux)[1][dst_id], (*out_aux)[2][dst_id],
(*out_aux)[3][dst_id], src_id, etype);
}
}
});
......@@ -114,120 +125,105 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
}
template void SpMMCsr<kDGLCPU, int32_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCPU, int64_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCPU, int32_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCPU, int64_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
template void SpMMCsrHetero<kDGLCPU, int32_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDGLCPU, int64_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDGLCPU, int32_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDGLCPU, int64_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
/** @brief Edge_softmax_csr forward op on Csr format. */
template <int XPU, typename IdType, typename DType>
void Edge_softmax_csr_forward(const std::string& op,
const BcastOff& bcast,
const CSRMatrix& csr,
NDArray ufeat,
NDArray efeat,
NDArray out) {
void Edge_softmax_csr_forward(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out) {
SWITCH_OP(op, Op, {
cpu::Edge_softmax_csr_forward<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
cpu::Edge_softmax_csr_forward<IdType, DType, Op>(
bcast, csr, ufeat, efeat, out);
});
}
/** @brief Edge_softmax_csr backward op on Csr format. */
template <int XPU, typename IdType, typename DType>
void Edge_softmax_csr_backward(const std::string& op,
const BcastOff& bcast,
const CSRMatrix& csr,
NDArray out,
NDArray sds,
NDArray back_out) {
void Edge_softmax_csr_backward(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray out, NDArray sds, NDArray back_out) {
SWITCH_OP(op, Op, {
cpu::Edge_softmax_csr_backward<IdType, DType, Op>(bcast, csr, out, sds, back_out);
cpu::Edge_softmax_csr_backward<IdType, DType, Op>(
bcast, csr, out, sds, back_out);
});
}
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, float>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, float>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, double>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, double>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, float>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, float>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, double>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, double>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
/** @brief Generalized SpMM on Coo format. */
template <int XPU, typename IdType, typename DType>
void SpMMCoo(const std::string& op, const std::string& reduce,
const BcastOff& bcast,
const COOMatrix& coo,
NDArray ufeat,
NDArray efeat,
NDArray out,
std::vector<NDArray> out_aux) {
void SpMMCoo(
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux) {
if (reduce == "sum") {
SWITCH_OP(op, Op, {
cpu::SpMMSumCoo<IdType, DType, Op>(bcast, coo, ufeat, efeat, out);
......@@ -247,21 +243,21 @@ void SpMMCoo(const std::string& op, const std::string& reduce,
}
template void SpMMCoo<kDGLCPU, int32_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCPU, int64_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCPU, int32_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCPU, int64_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
} // namespace aten
} // namespace dgl
......@@ -7,6 +7,7 @@
#define DGL_ARRAY_CPU_SPMM_BINARY_OPS_H_
#include <dgl/array.h>
#include <dgl/bcast.h>
#include <limits>
namespace dgl {
namespace aten {
......
......@@ -63,8 +63,10 @@ template NDArray IndexSelect<kDGLCUDA, int64_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, __half, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, __half, int64_t>(NDArray, IdArray);
#if BF16_ENABLED
template NDArray IndexSelect<kDGLCUDA, __nv_bfloat16, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, __nv_bfloat16, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, __nv_bfloat16, int32_t>(
NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, __nv_bfloat16, int64_t>(
NDArray, IdArray);
#endif // BF16_ENABLED
template NDArray IndexSelect<kDGLCUDA, float, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, float, int64_t>(NDArray, IdArray);
......@@ -76,8 +78,8 @@ DType IndexSelect(NDArray array, int64_t index) {
auto device = runtime::DeviceAPI::Get(array->ctx);
DType ret = static_cast<DType>(0.0f);
device->CopyDataFromTo(
static_cast<DType*>(array->data) + index, 0, &ret, 0,
sizeof(DType), array->ctx, DGLContext{kDGLCPU, 0}, array->dtype);
static_cast<DType*>(array->data) + index, 0, &ret, 0, sizeof(DType),
array->ctx, DGLContext{kDGLCPU, 0}, array->dtype);
return ret;
}
......@@ -87,7 +89,8 @@ template uint32_t IndexSelect<kDGLCUDA, uint32_t>(NDArray array, int64_t index);
template uint64_t IndexSelect<kDGLCUDA, uint64_t>(NDArray array, int64_t index);
template __half IndexSelect<kDGLCUDA, __half>(NDArray array, int64_t index);
#if BF16_ENABLED
template __nv_bfloat16 IndexSelect<kDGLCUDA, __nv_bfloat16>(NDArray array, int64_t index);
template __nv_bfloat16 IndexSelect<kDGLCUDA, __nv_bfloat16>(
NDArray array, int64_t index);
#endif // BF16_ENABLED
template float IndexSelect<kDGLCUDA, float>(NDArray array, int64_t index);
template double IndexSelect<kDGLCUDA, double>(NDArray array, int64_t index);
......
......@@ -4,10 +4,11 @@
* @brief Array operator GPU implementation
*/
#include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h"
#include "../../runtime/cuda/cuda_hashtable.cuh"
#include "./utils.h"
#include "../arith.h"
#include "./utils.h"
namespace dgl {
using runtime::NDArray;
......@@ -38,35 +39,56 @@ IdArray BinaryElewise(IdArray lhs, IdArray rhs) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(len);
int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL((_BinaryElewiseKernel<IdType, Op>),
nb, nt, 0, stream,
lhs_data, rhs_data, ret_data, len);
CUDA_KERNEL_CALL(
(_BinaryElewiseKernel<IdType, Op>), nb, nt, 0, stream, lhs_data, rhs_data,
ret_data, len);
return ret;
}
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::EQ>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::NE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::EQ>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::NE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Add>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Sub>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mul>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Div>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mod>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GT>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LT>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GE>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LE>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::EQ>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::NE>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Add>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Sub>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mul>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Div>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mod>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GT>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LT>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GE>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LE>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::EQ>(
IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::NE>(
IdArray lhs, IdArray rhs);
template <typename IdType, typename Op>
__global__ void _BinaryElewiseKernel(
......@@ -88,36 +110,56 @@ IdArray BinaryElewise(IdArray lhs, IdType rhs) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(len);
int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL((_BinaryElewiseKernel<IdType, Op>),
nb, nt, 0, stream,
lhs_data, rhs, ret_data, len);
CUDA_KERNEL_CALL(
(_BinaryElewiseKernel<IdType, Op>), nb, nt, 0, stream, lhs_data, rhs,
ret_data, len);
return ret;
}
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Add>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Sub>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mul>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Div>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mod>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::EQ>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::NE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Add>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Sub>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mul>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Div>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mod>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::EQ>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::NE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Add>(
IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Sub>(
IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mul>(
IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Div>(
IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mod>(
IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GT>(
IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LT>(
IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GE>(
IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LE>(
IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::EQ>(
IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::NE>(
IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Add>(
IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Sub>(
IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mul>(
IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Div>(
IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mod>(
IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GT>(
IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LT>(
IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GE>(
IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LE>(
IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::EQ>(
IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::NE>(
IdArray lhs, int64_t rhs);
template <typename IdType, typename Op>
__global__ void _BinaryElewiseKernel(
......@@ -139,34 +181,56 @@ IdArray BinaryElewise(IdType lhs, IdArray rhs) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(len);
int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL((_BinaryElewiseKernel<IdType, Op>),
nb, nt, 0, stream,
lhs, rhs_data, ret_data, len);
CUDA_KERNEL_CALL(
(_BinaryElewiseKernel<IdType, Op>), nb, nt, 0, stream, lhs, rhs_data,
ret_data, len);
return ret;
}
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Add>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Sub>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mul>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Div>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mod>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::EQ>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::NE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Add>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Sub>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mul>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Div>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mod>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::EQ>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::NE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Add>(
int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Sub>(
int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mul>(
int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Div>(
int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mod>(
int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GT>(
int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LT>(
int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GE>(
int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LE>(
int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::EQ>(
int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::NE>(
int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Add>(
int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Sub>(
int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mul>(
int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Div>(
int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mod>(
int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GT>(
int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LT>(
int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GE>(
int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LE>(
int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::EQ>(
int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::NE>(
int64_t lhs, IdArray rhs);
template <typename IdType, typename Op>
__global__ void _UnaryElewiseKernel(
......@@ -188,9 +252,9 @@ IdArray UnaryElewise(IdArray lhs) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(len);
int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL((_UnaryElewiseKernel<IdType, Op>),
nb, nt, 0, stream,
lhs_data, ret_data, len);
CUDA_KERNEL_CALL(
(_UnaryElewiseKernel<IdType, Op>), nb, nt, 0, stream, lhs_data, ret_data,
len);
return ret;
}
......@@ -200,8 +264,7 @@ template IdArray UnaryElewise<kDGLCUDA, int64_t, arith::Neg>(IdArray lhs);
///////////////////////////// Full /////////////////////////////
template <typename DType>
__global__ void _FullKernel(
DType* out, int64_t length, DType val) {
__global__ void _FullKernel(DType* out, int64_t length, DType val) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = gridDim.x * blockDim.x;
while (tx < length) {
......@@ -217,20 +280,25 @@ NDArray Full(DType val, int64_t length, DGLContext ctx) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(length);
int nb = (length + nt - 1) / nt;
CUDA_KERNEL_CALL((_FullKernel<DType>), nb, nt, 0, stream,
ret_data, length, val);
CUDA_KERNEL_CALL(
(_FullKernel<DType>), nb, nt, 0, stream, ret_data, length, val);
return ret;
}
template IdArray Full<kDGLCUDA, int32_t>(int32_t val, int64_t length, DGLContext ctx);
template IdArray Full<kDGLCUDA, int64_t>(int64_t val, int64_t length, DGLContext ctx);
template IdArray Full<kDGLCUDA, __half>(__half val, int64_t length, DGLContext ctx);
template IdArray Full<kDGLCUDA, int32_t>(
int32_t val, int64_t length, DGLContext ctx);
template IdArray Full<kDGLCUDA, int64_t>(
int64_t val, int64_t length, DGLContext ctx);
template IdArray Full<kDGLCUDA, __half>(
__half val, int64_t length, DGLContext ctx);
#if BF16_ENABLED
template IdArray Full<kDGLCUDA, __nv_bfloat16>(__nv_bfloat16 val, int64_t length, DGLContext ctx);
template IdArray Full<kDGLCUDA, __nv_bfloat16>(
__nv_bfloat16 val, int64_t length, DGLContext ctx);
#endif // BF16_ENABLED
template IdArray Full<kDGLCUDA, float>(float val, int64_t length, DGLContext ctx);
template IdArray Full<kDGLCUDA, double>(double val, int64_t length, DGLContext ctx);
template IdArray Full<kDGLCUDA, float>(
float val, int64_t length, DGLContext ctx);
template IdArray Full<kDGLCUDA, double>(
double val, int64_t length, DGLContext ctx);
///////////////////////////// Range /////////////////////////////
......@@ -249,15 +317,13 @@ IdArray Range(IdType low, IdType high, DGLContext ctx) {
CHECK(high >= low) << "high must be bigger than low";
const IdType length = high - low;
IdArray ret = NewIdArray(length, ctx, sizeof(IdType) * 8);
if (length == 0)
return ret;
if (length == 0) return ret;
IdType* ret_data = static_cast<IdType*>(ret->data);
cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(length);
int nb = (length + nt - 1) / nt;
CUDA_KERNEL_CALL((_RangeKernel<IdType>),
nb, nt, 0, stream,
ret_data, low, length);
CUDA_KERNEL_CALL(
(_RangeKernel<IdType>), nb, nt, 0, stream, ret_data, low, length);
return ret;
}
......@@ -269,7 +335,6 @@ template IdArray Range<kDGLCUDA, int64_t>(int64_t, int64_t, DGLContext);
template <typename IdType>
__global__ void _RelabelKernel(
IdType* out, int64_t length, DeviceOrderedHashTable<IdType> table) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = gridDim.x * blockDim.x;
......@@ -295,30 +360,20 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
// build node maps and get the induced nodes
OrderedHashTable<IdType> node_map(total_length, ctx, stream);
int64_t num_induced = 0;
int64_t * num_induced_device = static_cast<int64_t*>(
device->AllocWorkspace(ctx, sizeof(int64_t)));
IdArray induced_nodes = NewIdArray(total_length, ctx, sizeof(IdType)*8);
int64_t* num_induced_device =
static_cast<int64_t*>(device->AllocWorkspace(ctx, sizeof(int64_t)));
IdArray induced_nodes = NewIdArray(total_length, ctx, sizeof(IdType) * 8);
CUDA_CALL(cudaMemsetAsync(
num_induced_device,
0,
sizeof(*num_induced_device),
stream));
num_induced_device, 0, sizeof(*num_induced_device), stream));
node_map.FillWithDuplicates(
all_nodes.Ptr<IdType>(),
all_nodes->shape[0],
induced_nodes.Ptr<IdType>(),
num_induced_device,
stream);
all_nodes.Ptr<IdType>(), all_nodes->shape[0], induced_nodes.Ptr<IdType>(),
num_induced_device, stream);
// copy using the internal current stream
device->CopyDataFromTo(
num_induced_device, 0,
&num_induced, 0,
sizeof(num_induced),
ctx,
DGLContext{kDGLCPU, 0},
DGLDataType{kDGLInt, 64, 1});
num_induced_device, 0, &num_induced, 0, sizeof(num_induced), ctx,
DGLContext{kDGLCPU, 0}, DGLDataType{kDGLInt, 64, 1});
device->StreamSync(ctx, stream);
device->FreeWorkspace(ctx, num_induced_device);
......@@ -331,16 +386,18 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
for (IdArray arr : arrays) {
const int64_t length = arr->shape[0];
int nb = (length + nt - 1) / nt;
CUDA_KERNEL_CALL((_RelabelKernel<IdType>),
nb, nt, 0, stream,
arr.Ptr<IdType>(), length, node_map.DeviceHandle());
CUDA_KERNEL_CALL(
(_RelabelKernel<IdType>), nb, nt, 0, stream, arr.Ptr<IdType>(), length,
node_map.DeviceHandle());
}
return induced_nodes;
}
template IdArray Relabel_<kDGLCUDA, int32_t>(const std::vector<IdArray>& arrays);
template IdArray Relabel_<kDGLCUDA, int64_t>(const std::vector<IdArray>& arrays);
template IdArray Relabel_<kDGLCUDA, int32_t>(
const std::vector<IdArray>& arrays);
template IdArray Relabel_<kDGLCUDA, int64_t>(
const std::vector<IdArray>& arrays);
///////////////////////////// AsNumBits /////////////////////////////
......@@ -363,18 +420,19 @@ IdArray AsNumBits(IdArray arr, uint8_t bits) {
int nt = cuda::FindNumThreads(length);
int nb = (length + nt - 1) / nt;
if (bits == 32) {
CUDA_KERNEL_CALL((_CastKernel<IdType, int32_t>),
nb, nt, 0, stream,
static_cast<IdType*>(arr->data), static_cast<int32_t*>(ret->data), length);
CUDA_KERNEL_CALL(
(_CastKernel<IdType, int32_t>), nb, nt, 0, stream,
static_cast<IdType*>(arr->data), static_cast<int32_t*>(ret->data),
length);
} else {
CUDA_KERNEL_CALL((_CastKernel<IdType, int64_t>),
nb, nt, 0, stream,
static_cast<IdType*>(arr->data), static_cast<int64_t*>(ret->data), length);
CUDA_KERNEL_CALL(
(_CastKernel<IdType, int64_t>), nb, nt, 0, stream,
static_cast<IdType*>(arr->data), static_cast<int64_t*>(ret->data),
length);
}
return ret;
}
template IdArray AsNumBits<kDGLCUDA, int32_t>(IdArray arr, uint8_t bits);
template IdArray AsNumBits<kDGLCUDA, int64_t>(IdArray arr, uint8_t bits);
......
......@@ -4,6 +4,7 @@
* @brief Array scatter GPU implementation
*/
#include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
......@@ -13,8 +14,8 @@ namespace aten {
namespace impl {
template <typename DType, typename IdType>
__global__ void _ScatterKernel(const IdType* index, const DType* value,
int64_t length, DType* out) {
__global__ void _ScatterKernel(
const IdType* index, const DType* value, int64_t length, DType* out) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = gridDim.x * blockDim.x;
while (tx < length) {
......@@ -33,15 +34,15 @@ void Scatter_(IdArray index, NDArray value, NDArray out) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
const int nt = cuda::FindNumThreads(len);
const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL(_ScatterKernel, nb, nt, 0, stream,
idx, val, len, outd);
CUDA_KERNEL_CALL(_ScatterKernel, nb, nt, 0, stream, idx, val, len, outd);
}
template void Scatter_<kDGLCUDA, int32_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, int64_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, __half, int32_t>(IdArray, NDArray, NDArray);
#if BF16_ENABLED
template void Scatter_<kDGLCUDA, __nv_bfloat16, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, __nv_bfloat16, int32_t>(
IdArray, NDArray, NDArray);
#endif // BF16_ENABLED
template void Scatter_<kDGLCUDA, float, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, double, int32_t>(IdArray, NDArray, NDArray);
......@@ -49,7 +50,8 @@ template void Scatter_<kDGLCUDA, int32_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, int64_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, __half, int64_t>(IdArray, NDArray, NDArray);
#if BF16_ENABLED
template void Scatter_<kDGLCUDA, __nv_bfloat16, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, __nv_bfloat16, int64_t>(
IdArray, NDArray, NDArray);
#endif // BF16_ENABLED
template void Scatter_<kDGLCUDA, float, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, double, int64_t>(IdArray, NDArray, NDArray);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment