Unverified Commit 8ac27dad authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4824)



* [Misc] clang-format auto fix.

* blabla

* ablabla

* blabla
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent bcd37684
......@@ -10,10 +10,11 @@
#define DGL_ATEN_ARRAY_OPS_H_
#include <algorithm>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include <tuple>
#include <string>
#include "./types.h"
namespace dgl {
......@@ -24,7 +25,8 @@ namespace aten {
//////////////////////////////////////////////////////////////////////
/** @return A special array to represent null. */
inline NDArray NullArray(const DGLDataType& dtype = DGLDataType{kDGLInt, 64, 1},
inline NDArray NullArray(
const DGLDataType& dtype = DGLDataType{kDGLInt, 64, 1},
const DGLContext& ctx = DGLContext{kDGLCPU, 0}) {
return NDArray::Empty({0}, dtype, ctx);
}
......@@ -32,9 +34,7 @@ inline NDArray NullArray(const DGLDataType& dtype = DGLDataType{kDGLInt, 64, 1},
/**
* @return Whether the input array is a null array.
*/
inline bool IsNullArray(NDArray array) {
return array->shape[0] == 0;
}
inline bool IsNullArray(NDArray array) { return array->shape[0] == 0; }
/**
* @brief Create a new id array with given length
......@@ -43,8 +43,8 @@ inline bool IsNullArray(NDArray array) {
* @param nbits The number of integer bits
* @return id array
*/
IdArray NewIdArray(int64_t length,
DGLContext ctx = DGLContext{kDGLCPU, 0},
IdArray NewIdArray(
int64_t length, DGLContext ctx = DGLContext{kDGLCPU, 0},
uint8_t nbits = 64);
/**
......@@ -55,8 +55,8 @@ IdArray NewIdArray(int64_t length,
* @return the id array
*/
template <typename T>
IdArray VecToIdArray(const std::vector<T>& vec,
uint8_t nbits = 64,
IdArray VecToIdArray(
const std::vector<T>& vec, uint8_t nbits = 64,
DGLContext ctx = DGLContext{kDGLCPU, 0});
/**
......@@ -148,7 +148,7 @@ IdArray NonZero(BoolArray bool_arr);
* @brief Return the data under the index. In numpy notation, A[I]
* @tparam ValueType The type of return value.
*/
template<typename ValueType>
template <typename ValueType>
ValueType IndexSelect(NDArray array, int64_t index);
/**
......@@ -187,10 +187,11 @@ NDArray Scatter(NDArray array, IdArray indices);
void Scatter_(IdArray index, NDArray value, NDArray out);
/**
* @brief Repeat each element a number of times. Equivalent to np.repeat(array, repeats)
* @brief Repeat each element a number of times. Equivalent to np.repeat(array,
* repeats)
* @param array A 1D vector
* @param repeats A 1D integer vector for number of times to repeat for each element in
* \c array. Must have the same shape as \c array.
* @param repeats A 1D integer vector for number of times to repeat for each
* element in \c array. Must have the same shape as \c array.
*/
NDArray Repeat(NDArray array, IdArray repeats);
......@@ -253,12 +254,13 @@ inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) {
*
* The packed tensor would be [1, 2, 3, 4, 5].
*
* The length tensor would be [2, 3], i.e. the length of each sequence before padding.
* The length tensor would be [2, 3], i.e. the length of each sequence before
* padding.
*
* The offset tensor would be [0, 2], i.e. the offset to the packed tensor for each
* sequence (before padding)
* The offset tensor would be [0, 2], i.e. the offset to the packed tensor for
* each sequence (before padding)
*/
template<typename ValueType>
template <typename ValueType>
std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, ValueType pad_value);
/**
......@@ -295,8 +297,9 @@ std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths);
* @brief Return the cumulative summation (or inclusive sum) of the input array.
*
* The first element out[0] is equal to the first element of the input array
* array[0]. The rest elements are defined recursively, out[i] = out[i-1] + array[i].
* Hence, the result array length is the same as the input array length.
* array[0]. The rest elements are defined recursively, out[i] = out[i-1] +
* array[i]. Hence, the result array length is the same as the input array
* length.
*
* If prepend_zero is true, then the first element is zero and the result array
* length is the input array length plus one. This is useful for creating
......@@ -320,17 +323,18 @@ IdArray NonZero(NDArray array);
/**
* @brief Sort the ID vector in ascending order.
*
* It performs both sort and arg_sort (returning the sorted index). The sorted index
* is always in int64.
* It performs both sort and arg_sort (returning the sorted index). The sorted
* index is always in int64.
*
* @param array Input array.
* @param num_bits The number of bits used in key comparison. For example, if the data type
* of the input array is int32_t and `num_bits = 8`, it only uses bits in index
* range [0, 8) for sorting. Setting it to a small value could
* speed up the sorting if the underlying sorting algorithm is radix sort (e.g., on GPU).
* Setting it to zero (default value) means using all the bits for comparison.
* On CPU, it currently has no effect.
* @return A pair of arrays: sorted values and sorted index to the original position.
* @param num_bits The number of bits used in key comparison. For example, if
* the data type of the input array is int32_t and `num_bits = 8`, it only uses
* bits in index range [0, 8) for sorting. Setting it to a small value could
* speed up the sorting if the underlying sorting algorithm is
* radix sort (e.g., on GPU). Setting it to zero (default value) means using all
* the bits for comparison. On CPU, it currently has no effect.
* @return A pair of arrays: sorted values and sorted index to the original
* position.
*/
std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits = 0);
......@@ -341,9 +345,7 @@ std::string ToDebugString(NDArray array);
// inline implementations
template <typename T>
IdArray VecToIdArray(const std::vector<T>& vec,
uint8_t nbits,
DGLContext ctx) {
IdArray VecToIdArray(const std::vector<T>& vec, uint8_t nbits, DGLContext ctx) {
IdArray ret = NewIdArray(vec.size(), DGLContext{kDGLCPU, 0}, nbits);
if (nbits == 32) {
std::copy(vec.begin(), vec.end(), static_cast<int32_t*>(ret->data));
......@@ -367,7 +369,8 @@ inline DGLContext GetContextOf(const std::vector<IdArray>& arrays) {
first = false;
result = array->ctx;
} else {
CHECK_EQ(array->ctx, result) << "Context of the input arrays are different";
CHECK_EQ(array->ctx, result)
<< "Context of the input arrays are different";
}
}
return result;
......
......@@ -9,14 +9,16 @@
#include <dmlc/io.h>
#include <dmlc/serializer.h>
#include <vector>
#include <utility>
#include <tuple>
#include <string>
#include "./types.h"
#include <tuple>
#include <utility>
#include <vector>
#include "./array_ops.h"
#include "./spmat.h"
#include "./macro.h"
#include "./spmat.h"
#include "./types.h"
namespace dgl {
namespace aten {
......@@ -52,9 +54,9 @@ struct COOMatrix {
/** @brief default constructor */
COOMatrix() = default;
/** @brief constructor */
COOMatrix(int64_t nrows, int64_t ncols, IdArray rarr, IdArray carr,
IdArray darr = NullArray(), bool rsorted = false,
bool csorted = false)
COOMatrix(
int64_t nrows, int64_t ncols, IdArray rarr, IdArray carr,
IdArray darr = NullArray(), bool rsorted = false, bool csorted = false)
: num_rows(nrows),
num_cols(ncols),
row(rarr),
......@@ -79,8 +81,9 @@ struct COOMatrix {
// Convert to a SparseMatrix object that can return to python.
SparseMatrix ToSparseMatrix() const {
return SparseMatrix(static_cast<int32_t>(SparseFormat::kCOO), num_rows,
num_cols, {row, col, data}, {row_sorted, col_sorted});
return SparseMatrix(
static_cast<int32_t>(SparseFormat::kCOO), num_rows, num_cols,
{row, col, data}, {row_sorted, col_sorted});
}
bool Load(dmlc::Stream* fs) {
......@@ -122,12 +125,12 @@ struct COOMatrix {
}
/** @brief Return a copy of this matrix on the give device context. */
inline COOMatrix CopyTo(const DGLContext &ctx) const {
if (ctx == row->ctx)
return *this;
return COOMatrix(num_rows, num_cols, row.CopyTo(ctx), col.CopyTo(ctx),
aten::IsNullArray(data) ? data : data.CopyTo(ctx),
row_sorted, col_sorted);
inline COOMatrix CopyTo(const DGLContext& ctx) const {
if (ctx == row->ctx) return *this;
return COOMatrix(
num_rows, num_cols, row.CopyTo(ctx), col.CopyTo(ctx),
aten::IsNullArray(data) ? data : data.CopyTo(ctx), row_sorted,
col_sorted);
}
/**
......@@ -139,8 +142,7 @@ struct COOMatrix {
* The context check is deferred to pinning the NDArray.
*/
inline void PinMemory_() {
if (is_pinned)
return;
if (is_pinned) return;
row.PinMemory_();
col.PinMemory_();
if (!aten::IsNullArray(data)) {
......@@ -157,8 +159,7 @@ struct COOMatrix {
* The context check is deferred to unpinning the NDArray.
*/
inline void UnpinMemory_() {
if (!is_pinned)
return;
if (!is_pinned) return;
row.UnpinMemory_();
col.UnpinMemory_();
if (!aten::IsNullArray(data)) {
......@@ -183,25 +184,25 @@ struct COOMatrix {
///////////////////////// COO routines //////////////////////////
/** @brief Return true if the value (row, col) is non-zero */
bool COOIsNonZero(COOMatrix , int64_t row, int64_t col);
bool COOIsNonZero(COOMatrix, int64_t row, int64_t col);
/**
* @brief Batched implementation of COOIsNonZero.
* @note This operator allows broadcasting (i.e, either row or col can be of length 1).
* @note This operator allows broadcasting (i.e, either row or col can be of
* length 1).
*/
runtime::NDArray COOIsNonZero(COOMatrix , runtime::NDArray row, runtime::NDArray col);
runtime::NDArray COOIsNonZero(
COOMatrix, runtime::NDArray row, runtime::NDArray col);
/** @brief Return the nnz of the given row */
int64_t COOGetRowNNZ(COOMatrix , int64_t row);
runtime::NDArray COOGetRowNNZ(COOMatrix , runtime::NDArray row);
int64_t COOGetRowNNZ(COOMatrix, int64_t row);
runtime::NDArray COOGetRowNNZ(COOMatrix, runtime::NDArray row);
/** @brief Return the data array of the given row */
std::pair<runtime::NDArray, runtime::NDArray>
COOGetRowDataAndIndices(COOMatrix , int64_t row);
std::pair<runtime::NDArray, runtime::NDArray> COOGetRowDataAndIndices(
COOMatrix, int64_t row);
/** @brief Whether the COO matrix contains data */
inline bool COOHasData(COOMatrix csr) {
return !IsNullArray(csr.data);
}
inline bool COOHasData(COOMatrix csr) { return !IsNullArray(csr.data); }
/**
* @brief Check whether the COO is sorted.
......@@ -217,11 +218,12 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo);
/**
* @brief Get the data and the row,col indices for each returned entries.
*
* The operator supports matrix with duplicate entries and all the matched entries
* will be returned. The operator assumes there is NO duplicate (row, col) pair
* in the given input. Otherwise, the returned result is undefined.
* The operator supports matrix with duplicate entries and all the matched
* entries will be returned. The operator assumes there is NO duplicate (row,
* col) pair in the given input. Otherwise, the returned result is undefined.
*
* @note This operator allows broadcasting (i.e, either row or col can be of length 1).
* @note This operator allows broadcasting (i.e, either row or col can be of
* length 1).
* @param mat Sparse matrix
* @param rows Row index
* @param cols Column index
......@@ -230,10 +232,13 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo);
std::vector<runtime::NDArray> COOGetDataAndIndices(
COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols);
/** @brief Get data. The return type is an ndarray due to possible duplicate entries. */
/** @brief Get data. The return type is an ndarray due to possible duplicate
* entries. */
inline runtime::NDArray COOGetAllData(COOMatrix mat, int64_t row, int64_t col) {
IdArray rows = VecToIdArray<int64_t>({row}, mat.row->dtype.bits, mat.row->ctx);
IdArray cols = VecToIdArray<int64_t>({col}, mat.row->dtype.bits, mat.row->ctx);
IdArray rows =
VecToIdArray<int64_t>({row}, mat.row->dtype.bits, mat.row->ctx);
IdArray cols =
VecToIdArray<int64_t>({col}, mat.row->dtype.bits, mat.row->ctx);
const auto& rst = COOGetDataAndIndices(mat, rows, cols);
return rst[2];
}
......@@ -241,18 +246,20 @@ inline runtime::NDArray COOGetAllData(COOMatrix mat, int64_t row, int64_t col) {
/**
* @brief Get the data for each (row, col) pair.
*
* The operator supports matrix with duplicate entries but only one matched entry
* will be returned for each (row, col) pair. Support duplicate input (row, col)
* pairs.
* The operator supports matrix with duplicate entries but only one matched
* entry will be returned for each (row, col) pair. Support duplicate input
* (row, col) pairs.
*
* @note This operator allows broadcasting (i.e, either row or col can be of length 1).
* @note This operator allows broadcasting (i.e, either row or col can be of
* length 1).
*
* @param mat Sparse matrix.
* @param rows Row index.
* @param cols Column index.
* @return Data array. The i^th element is the data of (rows[i], cols[i])
*/
runtime::NDArray COOGetData(COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols);
runtime::NDArray COOGetData(
COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols);
/** @brief Return a transposed COO matrix */
COOMatrix COOTranspose(COOMatrix coo);
......@@ -301,14 +308,15 @@ COOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows);
* @param cols The col index to select
* @return submatrix
*/
COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
COOMatrix COOSliceMatrix(
COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
/** @return True if the matrix has duplicate entries */
bool COOHasDuplicate(COOMatrix coo);
/**
* @brief Deduplicate the entries of a sorted COO matrix, replacing the data with the
* number of occurrences of the row-col coordinates.
* @brief Deduplicate the entries of a sorted COO matrix, replacing the data
* with the number of occurrences of the row-col coordinates.
*/
std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo);
......@@ -316,11 +324,13 @@ std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo);
* @brief Sort the indices of a COO matrix in-place.
*
* The function sorts row indices in ascending order. If sort_column is true,
* col indices are sorted in ascending order too. The data array of the returned COOMatrix
* stores the shuffled index which could be used to fetch edge data.
* col indices are sorted in ascending order too. The data array of the returned
* COOMatrix stores the shuffled index which could be used to fetch edge data.
*
* Complexity: O(N*log(N)) time and O(1) space, where N is the number of nonzeros.
* TODO(minjie): The time complexity could be improved to O(N) by using a O(N) space.
* Complexity: O(N*log(N)) time and O(1) space, where N is the number of
* nonzeros.
* TODO(minjie): The time complexity could be improved to O(N) by using a O(N)
* space.
*
* @param mat The coo matrix to sort.
* @param sort_column True if column index should be sorted too.
......@@ -331,31 +341,32 @@ void COOSort_(COOMatrix* mat, bool sort_column = false);
* @brief Sort the indices of a COO matrix.
*
* The function sorts row indices in ascending order. If sort_column is true,
* col indices are sorted in ascending order too. The data array of the returned COOMatrix
* stores the shuffled index which could be used to fetch edge data.
* col indices are sorted in ascending order too. The data array of the returned
* COOMatrix stores the shuffled index which could be used to fetch edge data.
*
* Complexity: O(N*log(N)) time and O(1) space, where N is the number of nonzeros.
* TODO(minjie): The time complexity could be improved to O(N) by using a O(N) space.
* Complexity: O(N*log(N)) time and O(1) space, where N is the number of
* nonzeros.
* TODO(minjie): The time complexity could be improved to O(N) by using a O(N)
* space.
*
* @param mat The input coo matrix
* @param sort_column True if column index should be sorted too.
* @return COO matrix with index sorted.
*/
inline COOMatrix COOSort(COOMatrix mat, bool sort_column = false) {
if ((mat.row_sorted && !sort_column) || mat.col_sorted)
return mat;
COOMatrix ret(mat.num_rows, mat.num_cols,
mat.row.Clone(), mat.col.Clone(),
COOHasData(mat)? mat.data.Clone() : mat.data,
mat.row_sorted, mat.col_sorted);
if ((mat.row_sorted && !sort_column) || mat.col_sorted) return mat;
COOMatrix ret(
mat.num_rows, mat.num_cols, mat.row.Clone(), mat.col.Clone(),
COOHasData(mat) ? mat.data.Clone() : mat.data, mat.row_sorted,
mat.col_sorted);
COOSort_(&ret, sort_column);
return ret;
}
/**
* @brief Remove entries from COO matrix by entry indices (data indices)
* @return A new COO matrix as well as a mapping from the new COO entries to the old COO
* entries.
* @return A new COO matrix as well as a mapping from the new COO entries to the
* old COO entries.
*/
COOMatrix COORemove(COOMatrix coo, IdArray entries);
......@@ -365,10 +376,12 @@ COOMatrix COORemove(COOMatrix coo, IdArray entries);
* @param new_row_ids the new row Ids (the index is the old row Id)
* @param new_col_ids the new column Ids (the index is the old col Id).
*/
COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
COOMatrix COOReorder(
COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
/**
* @brief Randomly select a fixed number of non-zero entries along each given row independently.
* @brief Randomly select a fixed number of non-zero entries along each given
* row independently.
*
* The function performs random choices along each row independently.
* The picked indices are returned in the form of a COO matrix.
......@@ -400,15 +413,12 @@ COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArr
* Should be of the same length as the data array.
* If an empty array is provided, assume uniform.
* @param replace True if sample with replacement
* @return A COOMatrix storing the picked row and col indices. Its data field stores the
* the index of the picked elements in the value array.
* @return A COOMatrix storing the picked row and col indices. Its data field
* stores the the index of the picked elements in the value array.
*/
COOMatrix COORowWiseSampling(
COOMatrix mat,
IdArray rows,
int64_t num_samples,
NDArray prob_or_mask = NDArray(),
bool replace = true);
COOMatrix mat, IdArray rows, int64_t num_samples,
NDArray prob_or_mask = NDArray(), bool replace = true);
/**
* @brief Randomly select a fixed number of non-zero entries for each edge type
......@@ -433,8 +443,8 @@ COOMatrix COORowWiseSampling(
* COOMatrix coo = ...;
* IdArray rows = ... ; // [0, 3]
* std::vector<int64_t> num_samples = {2, 2, 2};
* COOMatrix sampled = COORowWisePerEtypeSampling(coo, rows, eid2etype_offset, num_samples,
* FloatArray(), false);
* COOMatrix sampled = COORowWisePerEtypeSampling(coo, rows, eid2etype_offset,
* num_samples, FloatArray(), false);
* // possible sampled coo matrix:
* // sampled.num_rows = 4
* // sampled.num_cols = 4
......@@ -450,20 +460,18 @@ COOMatrix COORowWiseSampling(
* Should be of the same length as the data array.
* If an empty array is provided, assume uniform.
* @param replace True if sample with replacement
* @return A COOMatrix storing the picked row and col indices. Its data field stores the
* the index of the picked elements in the value array.
* @return A COOMatrix storing the picked row and col indices. Its data field
* stores the the index of the picked elements in the value array.
* @note The edges of the entire graph must be ordered by their edge types.
*/
COOMatrix COORowWisePerEtypeSampling(
COOMatrix mat,
IdArray rows,
const std::vector<int64_t>& eid2etype_offset,
COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples,
const std::vector<NDArray>& prob_or_mask,
bool replace = true);
const std::vector<NDArray>& prob_or_mask, bool replace = true);
/**
* @brief Select K non-zero entries with the largest weights along each given row.
* @brief Select K non-zero entries with the largest weights along each given
* row.
*
* The function performs top-k selection along each row independently.
* The picked indices are returned in the form of a COO matrix.
......@@ -492,18 +500,15 @@ COOMatrix COORowWisePerEtypeSampling(
* @param mat Input COO matrix.
* @param rows Rows to sample from.
* @param k The K value.
* @param weight Weight associated with each entry. Should be of the same length as the
* data array. If an empty array is provided, assume uniform.
* @param ascending If true, elements are sorted by ascending order, equivalent to find
* the K smallest values. Otherwise, find K largest values.
* @return A COOMatrix storing the picked row and col indices. Its data field stores the
* the index of the picked elements in the value array.
* @param weight Weight associated with each entry. Should be of the same length
* as the data array. If an empty array is provided, assume uniform.
* @param ascending If true, elements are sorted by ascending order, equivalent
* to find the K smallest values. Otherwise, find K largest values.
* @return A COOMatrix storing the picked row and col indices. Its data field
* stores the the index of the picked elements in the value array.
*/
COOMatrix COORowWiseTopk(
COOMatrix mat,
IdArray rows,
int64_t k,
NDArray weight,
COOMatrix mat, IdArray rows, int64_t k, NDArray weight,
bool ascending = false);
/**
......@@ -535,8 +540,7 @@ COOMatrix COORowWiseTopk(
* COOMatrix_C.num_rows : 3
* COOMatrix_C.num_cols : 4
*/
COOMatrix UnionCoo(
const std::vector<COOMatrix>& coos);
COOMatrix UnionCoo(const std::vector<COOMatrix>& coos);
/**
* @brief DisjointUnion a list COOMatrix into one COOMatrix.
......@@ -566,12 +570,13 @@ COOMatrix UnionCoo(
* COOMatrix_C.num_cols : 5
*
* @param coos The input list of coo matrix.
* @param src_offset A list of integers recording src vertix id offset of each Matrix in coos
* @param src_offset A list of integers recording dst vertix id offset of each Matrix in coos
* @param src_offset A list of integers recording src vertix id offset of each
* Matrix in coos
* @param src_offset A list of integers recording dst vertix id offset of each
* Matrix in coos
* @return The combined COOMatrix.
*/
COOMatrix DisjointUnionCoo(
const std::vector<COOMatrix>& coos);
COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos);
/**
* @brief COOMatrix toSimple.
......@@ -591,8 +596,8 @@ COOMatrix DisjointUnionCoo(
* edge_map = [0, 0, 0, 1, 1, 2, 3, 4, 4, 4, 4]
*
* @return The simplified COOMatrix
* The count recording the number of duplicated edges from the original graph.
* The edge mapping from the edge IDs of original graph to those of the
* The count recording the number of duplicated edges from the original
* graph. The edge mapping from the edge IDs of original graph to those of the
* returned graph.
*/
std::tuple<COOMatrix, IdArray, IdArray> COOToSimple(const COOMatrix& coo);
......@@ -642,11 +647,10 @@ std::tuple<COOMatrix, IdArray, IdArray> COOToSimple(const COOMatrix& coo);
* @return A list of COOMatrixes representing each disjoint components.
*/
std::vector<COOMatrix> DisjointPartitionCooBySizes(
const COOMatrix &coo,
const uint64_t batch_size,
const std::vector<uint64_t> &edge_cumsum,
const std::vector<uint64_t> &src_vertex_cumsum,
const std::vector<uint64_t> &dst_vertex_cumsum);
const COOMatrix& coo, const uint64_t batch_size,
const std::vector<uint64_t>& edge_cumsum,
const std::vector<uint64_t>& src_vertex_cumsum,
const std::vector<uint64_t>& dst_vertex_cumsum);
/**
* @brief Slice a contiguous chunk from a COOMatrix
......@@ -684,10 +688,9 @@ std::vector<COOMatrix> DisjointPartitionCooBySizes(
* @return COOMatrix representing the chunk.
*/
COOMatrix COOSliceContiguousChunk(
const COOMatrix &coo,
const std::vector<uint64_t> &edge_range,
const std::vector<uint64_t> &src_vertex_range,
const std::vector<uint64_t> &dst_vertex_range);
const COOMatrix& coo, const std::vector<uint64_t>& edge_range,
const std::vector<uint64_t>& src_vertex_range,
const std::vector<uint64_t>& dst_vertex_range);
/**
* @brief Create a LineGraph of input coo
......@@ -716,10 +719,11 @@ COOMatrix COOSliceContiguousChunk(
* [0, 1, 1, 0, 0]]
*
* @param coo COOMatrix to create the LineGraph
* @param backtracking whether the pair of (v, u) (u, v) edges are treated as linked
* @param backtracking whether the pair of (v, u) (u, v) edges are treated as
* linked
* @return LineGraph in COO format
*/
COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking);
COOMatrix COOLineGraph(const COOMatrix& coo, bool backtracking);
} // namespace aten
} // namespace dgl
......
......@@ -223,8 +223,10 @@ bool CSRIsSorted(CSRMatrix csr);
std::vector<runtime::NDArray> CSRGetDataAndIndices(
CSRMatrix, runtime::NDArray rows, runtime::NDArray cols);
/* @brief Get data. The return type is an ndarray due to possible duplicate
* entries. */
/**
* @brief Get data. The return type is an ndarray due to possible duplicate
* entries.
*/
inline runtime::NDArray CSRGetAllData(CSRMatrix mat, int64_t row, int64_t col) {
const auto& nbits = mat.indptr->dtype.bits;
const auto& ctx = mat.indptr->ctx;
......
This diff is collapsed.
......@@ -10,7 +10,8 @@
#include <memory>
namespace dgl {
// Util class to call the private/public empty constructor, which is needed for serialization
// Util class to call the private/public empty constructor, which is needed for
// serialization
class Serializer {
public:
template <typename T>
......
This diff is collapsed.
......@@ -7,19 +7,18 @@
#define DGL_RUNTIME_PARALLEL_FOR_H_
#include <dmlc/omp.h>
#include <algorithm>
#include <string>
#include <atomic>
#include <cstdlib>
#include <exception>
#include <vector>
#include <atomic>
#include <string>
#include <utility>
#include <vector>
namespace {
int64_t divup(int64_t x, int64_t y) {
return (x + y - 1) / y;
}
}
int64_t divup(int64_t x, int64_t y) { return (x + y - 1) / y; }
} // namespace
namespace dgl {
namespace runtime {
......@@ -37,9 +36,7 @@ struct DefaultGrainSizeT {
}
}
size_t operator()() {
return grain_size;
}
size_t operator()() { return grain_size; }
};
} // namespace
......@@ -48,7 +45,9 @@ inline size_t compute_num_threads(size_t begin, size_t end, size_t grain_size) {
if (omp_in_parallel() || end - begin <= grain_size || end - begin == 1)
return 1;
return std::min(static_cast<int64_t>(omp_get_max_threads()), divup(end - begin, grain_size));
return std::min(
static_cast<int64_t>(omp_get_max_threads()),
divup(end - begin, grain_size));
#else
return 1;
#endif
......@@ -60,15 +59,13 @@ static DefaultGrainSizeT default_grain_size;
* @brief OpenMP-based parallel for loop.
*
* It requires each thread's workload to have at least \a grain_size elements.
* The loop body will be a function that takes in two arguments \a begin and \a end, which
* stands for the starting (inclusive) and ending index (exclusive) of the workload.
* The loop body will be a function that takes in two arguments \a begin and \a
* end, which stands for the starting (inclusive) and ending index (exclusive)
* of the workload.
*/
template <typename F>
void parallel_for(
const size_t begin,
const size_t end,
const size_t grain_size,
F&& f) {
const size_t begin, const size_t end, const size_t grain_size, F&& f) {
if (begin >= end) {
return;
}
......@@ -89,13 +86,11 @@ void parallel_for(
try {
f(begin_tid, end_tid);
} catch (...) {
if (!err_flag.test_and_set())
eptr = std::current_exception();
if (!err_flag.test_and_set()) eptr = std::current_exception();
}
}
}
if (eptr)
std::rethrow_exception(eptr);
if (eptr) std::rethrow_exception(eptr);
#else
f(begin, end);
#endif
......@@ -110,22 +105,21 @@ void parallel_for(
* parallel for pragma with static scheduling.
*/
template <typename F>
void parallel_for(
const size_t begin,
const size_t end,
F&& f) {
void parallel_for(const size_t begin, const size_t end, F&& f) {
parallel_for(begin, end, default_grain_size(), std::forward<F>(f));
}
/**
* @brief OpenMP-based two-stage parallel reduction.
*
* The first-stage reduction function \a f works in parallel. Each thread's workload has
* at least \a grain_size elements. The loop body will be a function that takes in
* the starting index (inclusive), the ending index (exclusive), and the reduction identity.
* The first-stage reduction function \a f works in parallel. Each thread's
* workload has at least \a grain_size elements. The loop body will be a
* function that takes in the starting index (inclusive), the ending index
* (exclusive), and the reduction identity.
*
* The second-stage reduction function \a sf is a binary function working in the main
* thread. It aggregates the partially reduced result computed from each thread.
* The second-stage reduction function \a sf is a binary function working in the
* main thread. It aggregates the partially reduced result computed from each
* thread.
*
* Example to compute a parallelized max reduction of an array \c a:
*
......@@ -134,11 +128,9 @@ void parallel_for(
* 100, // ending index
* 1, // grain size
* -std::numeric_limits<float>::infinity, // identity
* [&a] (int begin, int end, float ident) { // first-stage partial reducer
* float result = ident;
* for (int i = begin; i < end; ++i)
* result = std::max(result, a[i]);
* return result;
* [&a] (int begin, int end, float ident) { // first-stage partial
* reducer float result = ident; for (int i = begin; i < end; ++i) result =
* std::max(result, a[i]); return result;
* },
* [] (float result, float partial_result) {
* return std::max(result, partial_result);
......@@ -146,12 +138,8 @@ void parallel_for(
*/
template <typename DType, typename F, typename SF>
DType parallel_reduce(
const size_t begin,
const size_t end,
const size_t grain_size,
const DType ident,
const F& f,
const SF& sf) {
const size_t begin, const size_t end, const size_t grain_size,
const DType ident, const F& f, const SF& sf) {
if (begin >= end) {
return ident;
}
......@@ -174,17 +162,14 @@ DType parallel_reduce(
try {
results[tid] = f(begin_tid, end_tid, ident);
} catch (...) {
if (!err_flag.test_and_set())
eptr = std::current_exception();
if (!err_flag.test_and_set()) eptr = std::current_exception();
}
}
}
if (eptr)
std::rethrow_exception(eptr);
if (eptr) std::rethrow_exception(eptr);
DType out = ident;
for (int64_t i = 0; i < num_threads; ++i)
out = sf(out, results[i]);
for (int64_t i = 0; i < num_threads; ++i) out = sf(out, results[i]);
return out;
}
......
This diff is collapsed.
......@@ -8,9 +8,10 @@
#include <dgl/array.h>
#include <dgl/graph_traversal.h>
#include <vector>
#include <tuple>
#include <utility>
#include <vector>
namespace dgl {
namespace aten {
......@@ -82,7 +83,8 @@ template <DGLDeviceType XPU, typename IdType>
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col);
template <DGLDeviceType XPU, typename IdType>
runtime::NDArray CSRIsNonZero(CSRMatrix csr, runtime::NDArray row, runtime::NDArray col);
runtime::NDArray CSRIsNonZero(
CSRMatrix csr, runtime::NDArray row, runtime::NDArray col);
template <DGLDeviceType XPU, typename IdType>
bool CSRHasDuplicate(CSRMatrix csr);
......@@ -104,19 +106,21 @@ bool CSRIsSorted(CSRMatrix csr);
template <DGLDeviceType XPU, typename IdType, typename DType>
runtime::NDArray CSRGetData(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols, bool return_eids,
runtime::NDArray weights, DType filler);
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols,
bool return_eids, runtime::NDArray weights, DType filler);
template <DGLDeviceType XPU, typename IdType, typename DType>
runtime::NDArray CSRGetData(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols,
runtime::NDArray weights, DType filler) {
return CSRGetData<XPU, IdType, DType>(csr, rows, cols, false, weights, filler);
return CSRGetData<XPU, IdType, DType>(
csr, rows, cols, false, weights, filler);
}
template <DGLDeviceType XPU, typename IdType>
NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
return CSRGetData<XPU, IdType, IdType>(csr, rows, cols, true, NullArray(rows->dtype), -1);
return CSRGetData<XPU, IdType, IdType>(
csr, rows, cols, true, NullArray(rows->dtype), -1);
}
template <DGLDeviceType XPU, typename IdType>
......@@ -141,20 +145,23 @@ template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows);
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
CSRMatrix CSRSliceMatrix(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
template <DGLDeviceType XPU, typename IdType>
void CSRSort_(CSRMatrix* csr);
template <DGLDeviceType XPU, typename IdType, typename TagType>
std::pair<CSRMatrix, NDArray> CSRSortByTag(
const CSRMatrix &csr, IdArray tag_array, int64_t num_tags);
const CSRMatrix& csr, IdArray tag_array, int64_t num_tags);
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
CSRMatrix CSRReorder(
CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
template <DGLDeviceType XPU, typename IdType>
COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
COOMatrix COOReorder(
COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries);
......@@ -162,15 +169,16 @@ CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries);
// FloatType is the type of probability data.
template <DGLDeviceType XPU, typename IdType, typename DType>
COOMatrix CSRRowWiseSampling(
CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask, bool replace);
CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
bool replace);
// FloatType is the type of probability data.
template <DGLDeviceType XPU, typename IdType, typename DType>
COOMatrix CSRRowWisePerEtypeSampling(
CSRMatrix mat, IdArray rows,
const std::vector<int64_t>& eid2etype_offset,
CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples,
const std::vector<NDArray>& prob_or_mask, bool replace, bool rowwise_etype_sorted);
const std::vector<NDArray>& prob_or_mask, bool replace,
bool rowwise_etype_sorted);
template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRRowWiseSamplingUniform(
......@@ -178,9 +186,9 @@ COOMatrix CSRRowWiseSamplingUniform(
template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRRowWisePerEtypeSamplingUniform(
CSRMatrix mat, IdArray rows,
const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples, bool replace, bool rowwise_etype_sorted);
CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples, bool replace,
bool rowwise_etype_sorted);
// FloatType is the type of weight data.
template <DGLDeviceType XPU, typename IdType, typename DType>
......@@ -189,21 +197,13 @@ COOMatrix CSRRowWiseTopk(
template <DGLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix CSRRowWiseSamplingBiased(
CSRMatrix mat,
IdArray rows,
int64_t num_samples,
NDArray tag_offset,
FloatArray bias,
bool replace);
CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray tag_offset,
FloatArray bias, bool replace);
template <DGLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
const CSRMatrix& csr,
int64_t num_samples,
int num_trials,
bool exclude_self_loops,
bool replace,
double redundancy);
const CSRMatrix& csr, int64_t num_samples, int num_trials,
bool exclude_self_loops, bool replace, double redundancy);
// Union CSRMatrixes
template <DGLDeviceType XPU, typename IdType>
......@@ -218,7 +218,8 @@ template <DGLDeviceType XPU, typename IdType>
bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col);
template <DGLDeviceType XPU, typename IdType>
runtime::NDArray COOIsNonZero(COOMatrix coo, runtime::NDArray row, runtime::NDArray col);
runtime::NDArray COOIsNonZero(
COOMatrix coo, runtime::NDArray row, runtime::NDArray col);
template <DGLDeviceType XPU, typename IdType>
bool COOHasDuplicate(COOMatrix coo);
......@@ -230,15 +231,16 @@ template <DGLDeviceType XPU, typename IdType>
runtime::NDArray COOGetRowNNZ(COOMatrix coo, runtime::NDArray row);
template <DGLDeviceType XPU, typename IdType>
std::pair<runtime::NDArray, runtime::NDArray>
COOGetRowDataAndIndices(COOMatrix coo, int64_t row);
std::pair<runtime::NDArray, runtime::NDArray> COOGetRowDataAndIndices(
COOMatrix coo, int64_t row);
template <DGLDeviceType XPU, typename IdType>
std::vector<runtime::NDArray> COOGetDataAndIndices(
COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
template <DGLDeviceType XPU, typename IdType>
runtime::NDArray COOGetData(COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols);
runtime::NDArray COOGetData(
COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols);
template <DGLDeviceType XPU, typename IdType>
COOMatrix COOTranspose(COOMatrix coo);
......@@ -253,7 +255,8 @@ template <DGLDeviceType XPU, typename IdType>
COOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows);
template <DGLDeviceType XPU, typename IdType>
COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
COOMatrix COOSliceMatrix(
COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
template <DGLDeviceType XPU, typename IdType>
std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo);
......@@ -273,13 +276,13 @@ COOMatrix COORemove(COOMatrix coo, IdArray entries);
// FloatType is the type of probability data.
template <DGLDeviceType XPU, typename IdType, typename DType>
COOMatrix COORowWiseSampling(
COOMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask, bool replace);
COOMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
bool replace);
// FloatType is the type of probability data.
template <DGLDeviceType XPU, typename IdType, typename DType>
COOMatrix COORowWisePerEtypeSampling(
COOMatrix mat, IdArray rows,
const std::vector<int64_t>& eid2etype_offset,
COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples,
const std::vector<NDArray>& prob_or_mask, bool replace);
......@@ -289,8 +292,7 @@ COOMatrix COORowWiseSamplingUniform(
template <DGLDeviceType XPU, typename IdType>
COOMatrix COORowWisePerEtypeSamplingUniform(
COOMatrix mat, IdArray rows,
const std::vector<int64_t>& eid2etype_offset,
COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_samples, bool replace);
// FloatType is the type of weight data.
......@@ -313,14 +315,12 @@ template <DGLDeviceType XPU, typename IdType>
Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source);
template <DGLDeviceType XPU, typename IdType>
Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
IdArray source,
const bool has_reverse_edge,
const bool has_nontree_edge,
const bool return_labels);
Frontiers DGLDFSLabeledEdges(
const CSRMatrix& csr, IdArray source, const bool has_reverse_edge,
const bool has_nontree_edge, const bool return_labels);
template <DGLDeviceType XPU, typename IdType>
COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking);
COOMatrix COOLineGraph(const COOMatrix& coo, bool backtracking);
} // namespace impl
} // namespace aten
......
/**
/**
* Copyright (c) 2020 by Contributors
* @file kernel/cpu/gaher_mm.cc
* @brief GatherMM C APIs and definitions.
*/
#include "./gather_mm.h"
#include <dgl/array.h>
namespace dgl {
......@@ -11,81 +12,72 @@ namespace aten {
/** @brief Generalized SegmentMM. */
template <int XPU, typename IdType, typename DType>
void SegmentMM(const NDArray A,
const NDArray B,
NDArray C,
const NDArray seglen_A,
void SegmentMM(
const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
bool a_trans, bool b_trans) {
LOG(FATAL) << "Unsupported CPU kernel for SegmentMM.";
}
template <int XPU, typename IdType, typename DType>
void SegmentMMBackwardB(const NDArray A,
const NDArray dC,
NDArray dB,
const NDArray seglen) {
void SegmentMMBackwardB(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen) {
LOG(FATAL) << "Unsupported CPU kernel for SegmentMMBackwardB.";
}
/** @brief Generalized GatherMM. */
template <int XPU, typename IdType, typename DType>
void GatherMM(const NDArray A,
const NDArray B,
NDArray C,
const NDArray idx_a,
void GatherMM(
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b) {
LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
}
/** @brief Generalized GatherMM_scatter. */
template <int XPU, typename IdType, typename DType>
void GatherMMScatter(const NDArray A,
const NDArray B,
NDArray C,
const NDArray idx_a,
const NDArray idx_b,
const NDArray idx_c) {
void GatherMMScatter(
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_c) {
LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
}
template void GatherMM<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b);
template void GatherMM<kDGLCPU, int64_t, float>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b);
template void GatherMM<kDGLCPU, int32_t, double>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b);
template void GatherMM<kDGLCPU, int64_t, double>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b);
template void GatherMMScatter<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCPU, int64_t, float>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCPU, int32_t, double>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCPU, int64_t, double>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_c);
template void SegmentMM<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
bool a_trans, bool b_trans);
template void SegmentMM<kDGLCPU, int64_t, float>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
bool a_trans, bool b_trans);
template void SegmentMM<kDGLCPU, int32_t, double>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
bool a_trans, bool b_trans);
template void SegmentMM<kDGLCPU, int64_t, double>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
bool a_trans, bool b_trans);
template void SegmentMMBackwardB<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
......
......@@ -7,13 +7,14 @@
#define DGL_ARRAY_CPU_ROWWISE_PICK_H_
#include <dgl/array.h>
#include <dmlc/omp.h>
#include <dgl/runtime/parallel_for.h>
#include <functional>
#include <dmlc/omp.h>
#include <algorithm>
#include <functional>
#include <memory>
#include <string>
#include <vector>
#include <memory>
namespace dgl {
namespace aten {
......@@ -41,10 +42,10 @@ namespace impl {
template <typename IdxType>
using PickFn = std::function<void(
IdxType rowid, IdxType off, IdxType len, IdxType num_picks,
const IdxType* col, const IdxType* data,
IdxType* out_idx)>;
const IdxType* col, const IdxType* data, IdxType* out_idx)>;
// User-defined function for determining the number of elements to pick from one row.
// User-defined function for determining the number of elements to pick from one
// row.
//
// The column indices of the given row are stored in
// [col + off, col + off + len)
......@@ -63,8 +64,8 @@ using PickFn = std::function<void(
// \param data Pointer of the data indices.
template <typename IdxType>
using NumPicksFn = std::function<IdxType(
IdxType rowid, IdxType off, IdxType len,
const IdxType* col, const IdxType* data)>;
IdxType rowid, IdxType off, IdxType len, const IdxType* col,
const IdxType* data)>;
// User-defined function for picking elements from a range within a row.
//
......@@ -73,7 +74,8 @@ using NumPicksFn = std::function<IdxType(
//
// Similarly, the data indices are stored in
// data[off+et_idx[et_offset+i])]
// Data index pointer could be NULL, which means data[i] == off+et_idx[et_offset+i])
// Data index pointer could be NULL, which means data[i] ==
// off+et_idx[et_offset+i])
//
// *ATTENTION*: This function will be invoked concurrently. Please make sure
// it is thread-safe.
......@@ -93,15 +95,17 @@ using EtypeRangePickFn = std::function<void(
const IdxType* eid, IdxType* out_idx)>;
// Template for picking non-zero values row-wise. The implementation utilizes
// OpenMP parallelization on rows because each row performs computation independently.
// OpenMP parallelization on rows because each row performs computation
// independently.
template <typename IdxType>
COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
int64_t num_picks, bool replace, PickFn<IdxType> pick_fn,
NumPicksFn<IdxType> num_picks_fn) {
COOMatrix CSRRowWisePick(
CSRMatrix mat, IdArray rows, int64_t num_picks, bool replace,
PickFn<IdxType> pick_fn, NumPicksFn<IdxType> num_picks_fn) {
using namespace aten;
const IdxType* indptr = static_cast<IdxType*>(mat.indptr->data);
const IdxType* indices = static_cast<IdxType*>(mat.indices->data);
const IdxType* data = CSRHasData(mat)? static_cast<IdxType*>(mat.data->data) : nullptr;
const IdxType* data =
CSRHasData(mat) ? static_cast<IdxType*>(mat.data->data) : nullptr;
const IdxType* rows_data = static_cast<IdxType*>(rows->data);
const int64_t num_rows = rows->shape[0];
const auto& ctx = mat.indptr->ctx;
......@@ -115,30 +119,36 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
// have at least num_picks number of nnz when replace is false.
//
// If the check holds, remove -1 elements by remove_if operation, which simply
// moves valid elements to the head of arrays and create a view of the original
// array. The implementation consumes a little extra memory than the actual requirement.
// moves valid elements to the head of arrays and create a view of the
// original array. The implementation consumes a little extra memory than the
// actual requirement.
//
// Otherwise, directly use the row and col arrays to construct the result COO matrix.
// Otherwise, directly use the row and col arrays to construct the result COO
// matrix.
//
// [02/29/2020 update]: OMP is disabled for now since batch-wise parallelism is more
// [02/29/2020 update]: OMP is disabled for now since batch-wise parallelism
// is more
// significant. (minjie)
// Do not use omp_get_max_threads() since that doesn't work for compiling without OpenMP.
// Do not use omp_get_max_threads() since that doesn't work for compiling
// without OpenMP.
const int num_threads = runtime::compute_num_threads(0, num_rows, 1);
std::vector<int64_t> global_prefix(num_threads + 1, 0);
// TODO(BarclayII) Using OMP parallel directly instead of using runtime::parallel_for
// does not handle exceptions well (directly aborts when an exception pops up).
// It runs faster though because there is less scheduling. Need to handle
// exceptions better.
// TODO(BarclayII) Using OMP parallel directly instead of using
// runtime::parallel_for does not handle exceptions well (directly aborts when
// an exception pops up). It runs faster though because there is less
// scheduling. Need to handle exceptions better.
IdArray picked_row, picked_col, picked_idx;
#pragma omp parallel num_threads(num_threads)
{
const int thread_id = omp_get_thread_num();
const int64_t start_i = thread_id * (num_rows/num_threads) +
const int64_t start_i =
thread_id * (num_rows / num_threads) +
std::min(static_cast<int64_t>(thread_id), num_rows % num_threads);
const int64_t end_i = (thread_id + 1) * (num_rows/num_threads) +
const int64_t end_i =
(thread_id + 1) * (num_rows / num_threads) +
std::min(static_cast<int64_t>(thread_id + 1), num_rows % num_threads);
assert(thread_id + 1 < num_threads || end_i == num_rows);
......@@ -149,7 +159,7 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
local_prefix[0] = 0;
for (int64_t i = start_i; i < end_i; ++i) {
// build prefix-sum
const int64_t local_i = i-start_i;
const int64_t local_i = i - start_i;
const IdxType rid = rows_data[i];
IdxType len = num_picks_fn(
rid, indptr[rid], indptr[rid + 1] - indptr[rid], indices, data);
......@@ -157,8 +167,8 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
}
global_prefix[thread_id + 1] = local_prefix[num_local];
#pragma omp barrier
#pragma omp master
#pragma omp barrier
#pragma omp master
{
for (int t = 0; t < num_threads; ++t) {
global_prefix[t + 1] += global_prefix[t];
......@@ -168,7 +178,7 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
picked_idx = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);
}
#pragma omp barrier
#pragma omp barrier
IdxType* picked_rdata = picked_row.Ptr<IdxType>();
IdxType* picked_cdata = picked_col.Ptr<IdxType>();
IdxType* picked_idata = picked_idx.Ptr<IdxType>();
......@@ -180,14 +190,15 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
const IdxType off = indptr[rid];
const IdxType len = indptr[rid + 1] - off;
if (len == 0)
continue;
if (len == 0) continue;
const int64_t local_i = i - start_i;
const int64_t row_offset = thread_offset + local_prefix[local_i];
const int64_t num_picks = thread_offset + local_prefix[local_i + 1] - row_offset;
const int64_t num_picks =
thread_offset + local_prefix[local_i + 1] - row_offset;
pick_fn(rid, off, len, num_picks, indices, data, picked_idata + row_offset);
pick_fn(
rid, off, len, num_picks, indices, data, picked_idata + row_offset);
for (int64_t j = 0; j < num_picks; ++j) {
const IdxType picked = picked_idata[row_offset + j];
picked_rdata[row_offset + j] = rid;
......@@ -200,25 +211,25 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
const int64_t new_len = global_prefix.back();
return COOMatrix(
mat.num_rows,
mat.num_cols,
mat.num_rows, mat.num_cols,
picked_row.CreateView({new_len}, picked_row->dtype),
picked_col.CreateView({new_len}, picked_row->dtype),
picked_idx.CreateView({new_len}, picked_row->dtype));
}
// Template for picking non-zero values row-wise. The implementation utilizes
// OpenMP parallelization on rows because each row performs computation independently.
// OpenMP parallelization on rows because each row performs computation
// independently.
template <typename IdxType, typename DType>
COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows,
const std::vector<int64_t>& eid2etype_offset,
COOMatrix CSRRowWisePerEtypePick(
CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& num_picks, bool replace,
bool rowwise_etype_sorted, EtypeRangePickFn<IdxType> pick_fn,
const std::vector<NDArray>& prob_or_mask) {
using namespace aten;
const IdxType* indptr = mat.indptr.Ptr<IdxType>();
const IdxType* indices = mat.indices.Ptr<IdxType>();
const IdxType* eid = CSRHasData(mat)? mat.data.Ptr<IdxType>() : nullptr;
const IdxType* eid = CSRHasData(mat) ? mat.data.Ptr<IdxType>() : nullptr;
const IdxType* rows_data = rows.Ptr<IdxType>();
const int64_t num_rows = rows->shape[0];
const auto& ctx = mat.indptr->ctx;
......@@ -229,8 +240,8 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows,
std::vector<IdArray> picked_idxs(rows->shape[0]);
// Check if the number of picks have the same value.
// If so, we can potentially speed up if we have a node with total number of neighbors
// less than the given number of picks with replace=False.
// If so, we can potentially speed up if we have a node with total number of
// neighbors less than the given number of picks with replace=False.
bool same_num_pick = true;
int64_t num_pick_value = num_picks[0];
for (int64_t num_pick : num_picks) {
......@@ -267,9 +278,10 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows,
for (int64_t j = 0; j < len; ++j) {
const IdxType homogenized_eid = eid ? eid[off + j] : off + j;
auto it = std::upper_bound(
eid2etype_offset.begin(), eid2etype_offset.end(), homogenized_eid);
eid2etype_offset.begin(), eid2etype_offset.end(),
homogenized_eid);
const IdxType heterogenized_etype = it - eid2etype_offset.begin() - 1;
const IdxType heterogenized_eid = \
const IdxType heterogenized_eid =
homogenized_eid - eid2etype_offset[heterogenized_etype];
if (!has_probs || IsNullArray(prob_or_mask[heterogenized_etype])) {
......@@ -305,17 +317,21 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows,
for (int64_t j = 0; j < len; ++j) {
const IdxType homogenized_eid = eid ? eid[off + j] : off + j;
auto it = std::upper_bound(
eid2etype_offset.begin(), eid2etype_offset.end(), homogenized_eid);
eid2etype_offset.begin(), eid2etype_offset.end(),
homogenized_eid);
const IdxType heterogenized_etype = it - eid2etype_offset.begin() - 1;
const IdxType heterogenized_eid = \
const IdxType heterogenized_eid =
homogenized_eid - eid2etype_offset[heterogenized_etype];
et[j] = heterogenized_etype;
et_eid[j] = heterogenized_eid;
}
if (!rowwise_etype_sorted) // the edge type is sorted, not need to sort it
std::sort(et_idx.begin(), et_idx.end(),
[&et](IdxType i1, IdxType i2) {return et[i1] < et[i2];});
CHECK_LT(et[et_idx[len - 1]], num_etypes) << "etype values exceed the number of fanouts";
if (!rowwise_etype_sorted) // the edge type is sorted, not need to sort
// it
std::sort(
et_idx.begin(), et_idx.end(),
[&et](IdxType i1, IdxType i2) { return et[i1] < et[i2]; });
CHECK_LT(et[et_idx[len - 1]], num_etypes)
<< "etype values exceed the number of fanouts";
IdxType cur_et = et[et_idx[0]];
int64_t et_offset = 0;
......@@ -333,14 +349,18 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows,
// fast path, select all
for (int64_t k = 0; k < et_len; ++k) {
const IdxType eid_offset = off + et_idx[et_offset + k];
const IdxType homogenized_eid = eid ? eid[eid_offset] : eid_offset;
const IdxType homogenized_eid =
eid ? eid[eid_offset] : eid_offset;
auto it = std::upper_bound(
eid2etype_offset.begin(), eid2etype_offset.end(), homogenized_eid);
const IdxType heterogenized_etype = it - eid2etype_offset.begin() - 1;
const IdxType heterogenized_eid = \
eid2etype_offset.begin(), eid2etype_offset.end(),
homogenized_eid);
const IdxType heterogenized_etype =
it - eid2etype_offset.begin() - 1;
const IdxType heterogenized_eid =
homogenized_eid - eid2etype_offset[heterogenized_etype];
if (!has_probs || IsNullArray(prob_or_mask[heterogenized_etype])) {
if (!has_probs ||
IsNullArray(prob_or_mask[heterogenized_etype])) {
// No probability, select all
rows.push_back(rid);
cols.push_back(indices[eid_offset]);
......@@ -357,32 +377,31 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows,
}
}
} else {
IdArray picked_idx = Full(-1, num_picks[cur_et], sizeof(IdxType) * 8, ctx);
IdArray picked_idx =
Full(-1, num_picks[cur_et], sizeof(IdxType) * 8, ctx);
IdxType* picked_idata = picked_idx.Ptr<IdxType>();
// need call random pick
pick_fn(off, et_offset, cur_et,
et_len, et_idx, et_eid,
eid, picked_idata);
pick_fn(
off, et_offset, cur_et, et_len, et_idx, et_eid, eid,
picked_idata);
for (int64_t k = 0; k < num_picks[cur_et]; ++k) {
const IdxType picked = picked_idata[k];
if (picked == -1)
continue;
if (picked == -1) continue;
rows.push_back(rid);
cols.push_back(indices[off+et_idx[et_offset+picked]]);
cols.push_back(indices[off + et_idx[et_offset + picked]]);
if (eid) {
idx.push_back(eid[off+et_idx[et_offset+picked]]);
idx.push_back(eid[off + et_idx[et_offset + picked]]);
} else {
idx.push_back(off+et_idx[et_offset+picked]);
idx.push_back(off + et_idx[et_offset + picked]);
}
}
}
if (j+1 == len)
break;
if (j + 1 == len) break;
// next etype
cur_et = et[et_idx[j+1]];
et_offset = j+1;
cur_et = et[et_idx[j + 1]];
et_offset = j + 1;
et_len = 1;
} else {
et_len++;
......@@ -402,31 +421,32 @@ COOMatrix CSRRowWisePerEtypePick(CSRMatrix mat, IdArray rows,
IdArray picked_row = Concat(picked_rows);
IdArray picked_col = Concat(picked_cols);
IdArray picked_idx = Concat(picked_idxs);
return COOMatrix(mat.num_rows, mat.num_cols,
picked_row, picked_col, picked_idx);
return COOMatrix(
mat.num_rows, mat.num_cols, picked_row, picked_col, picked_idx);
}
// Template for picking non-zero values row-wise. The implementation first slices
// out the corresponding rows and then converts it to CSR format. It then performs
// row-wise pick on the CSR matrix and rectifies the returned results.
// Template for picking non-zero values row-wise. The implementation first
// slices out the corresponding rows and then converts it to CSR format. It then
// performs row-wise pick on the CSR matrix and rectifies the returned results.
template <typename IdxType>
COOMatrix COORowWisePick(COOMatrix mat, IdArray rows,
int64_t num_picks, bool replace, PickFn<IdxType> pick_fn,
NumPicksFn<IdxType> num_picks_fn) {
COOMatrix COORowWisePick(
COOMatrix mat, IdArray rows, int64_t num_picks, bool replace,
PickFn<IdxType> pick_fn, NumPicksFn<IdxType> num_picks_fn) {
using namespace aten;
const auto& csr = COOToCSR(COOSliceRows(mat, rows));
const IdArray new_rows = Range(0, rows->shape[0], rows->dtype.bits, rows->ctx);
const IdArray new_rows =
Range(0, rows->shape[0], rows->dtype.bits, rows->ctx);
const auto& picked = CSRRowWisePick<IdxType>(
csr, new_rows, num_picks, replace, pick_fn, num_picks_fn);
return COOMatrix(mat.num_rows, mat.num_cols,
return COOMatrix(
mat.num_rows, mat.num_cols,
IndexSelect(rows, picked.row), // map the row index to the correct one
picked.col,
picked.data);
picked.col, picked.data);
}
// Template for picking non-zero values row-wise. The implementation first slices
// out the corresponding rows and then converts it to CSR format. It then performs
// row-wise pick on the CSR matrix and rectifies the returned results.
// Template for picking non-zero values row-wise. The implementation first
// slices out the corresponding rows and then converts it to CSR format. It then
// performs row-wise pick on the CSR matrix and rectifies the returned results.
template <typename IdxType, typename DType>
COOMatrix COORowWisePerEtypePick(
COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
......@@ -435,13 +455,15 @@ COOMatrix COORowWisePerEtypePick(
const std::vector<NDArray>& prob_or_mask) {
using namespace aten;
const auto& csr = COOToCSR(COOSliceRows(mat, rows));
const IdArray new_rows = Range(0, rows->shape[0], rows->dtype.bits, rows->ctx);
const IdArray new_rows =
Range(0, rows->shape[0], rows->dtype.bits, rows->ctx);
const auto& picked = CSRRowWisePerEtypePick<IdxType, DType>(
csr, new_rows, eid2etype_offset, num_picks, replace, false, pick_fn, prob_or_mask);
return COOMatrix(mat.num_rows, mat.num_cols,
csr, new_rows, eid2etype_offset, num_picks, replace, false, pick_fn,
prob_or_mask);
return COOMatrix(
mat.num_rows, mat.num_cols,
IndexSelect(rows, picked.row), // map the row index to the correct one
picked.col,
picked.data);
picked.col, picked.data);
}
} // namespace impl
......
This diff is collapsed.
......@@ -3,8 +3,9 @@
* @file array/cpu/rowwise_topk.cc
* @brief rowwise topk
*/
#include <numeric>
#include <algorithm>
#include <numeric>
#include "./rowwise_pick.h"
namespace dgl {
......@@ -14,9 +15,9 @@ namespace {
template <typename IdxType>
inline NumPicksFn<IdxType> GetTopkNumPicksFn(int64_t k) {
NumPicksFn<IdxType> num_picks_fn = [k]
(IdxType rowid, IdxType off, IdxType len,
const IdxType* col, const IdxType* data) {
NumPicksFn<IdxType> num_picks_fn = [k](IdxType rowid, IdxType off,
IdxType len, const IdxType* col,
const IdxType* data) {
const int64_t max_num_picks = (k == -1) ? len : k;
return std::min(static_cast<IdxType>(max_num_picks), len);
};
......@@ -26,28 +27,28 @@ inline NumPicksFn<IdxType> GetTopkNumPicksFn(int64_t k) {
template <typename IdxType, typename DType>
inline PickFn<IdxType> GetTopkPickFn(NDArray weight, bool ascending) {
const DType* wdata = static_cast<DType*>(weight->data);
PickFn<IdxType> pick_fn = [ascending, wdata]
(IdxType rowid, IdxType off, IdxType len, IdxType num_picks,
const IdxType* col, const IdxType* data,
IdxType* out_idx) {
PickFn<IdxType> pick_fn = [ascending, wdata](
IdxType rowid, IdxType off, IdxType len,
IdxType num_picks, const IdxType* col,
const IdxType* data, IdxType* out_idx) {
std::function<bool(IdxType, IdxType)> compare_fn;
if (ascending) {
if (data) {
compare_fn = [wdata, data] (IdxType i, IdxType j) {
compare_fn = [wdata, data](IdxType i, IdxType j) {
return wdata[data[i]] < wdata[data[j]];
};
} else {
compare_fn = [wdata] (IdxType i, IdxType j) {
compare_fn = [wdata](IdxType i, IdxType j) {
return wdata[i] < wdata[j];
};
}
} else {
if (data) {
compare_fn = [wdata, data] (IdxType i, IdxType j) {
compare_fn = [wdata, data](IdxType i, IdxType j) {
return wdata[data[i]] > wdata[data[j]];
};
} else {
compare_fn = [wdata] (IdxType i, IdxType j) {
compare_fn = [wdata](IdxType i, IdxType j) {
return wdata[i] > wdata[j];
};
}
......
......@@ -4,6 +4,7 @@
* @brief SDDMM C APIs and definitions.
*/
#include "./sddmm.h"
#include <dgl/array.h>
namespace dgl {
......@@ -25,7 +26,7 @@ namespace aten {
} \
} while (0)
#define SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, ...)\
#define SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, ...) \
do { \
if ((lhs_target) == 0) { \
constexpr int LhsTarget = 0; \
......@@ -41,35 +42,26 @@ namespace aten {
} \
} while (0)
/** @brief Generalized SDDMM on Csr format. */
template <int XPU, typename IdType, typename DType>
void SDDMMCsr(const std::string& op,
const BcastOff& bcast,
const CSRMatrix& csr,
NDArray lhs,
NDArray rhs,
NDArray out,
int lhs_target,
int rhs_target) {
void SDDMMCsr(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target) {
SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out);
cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(
bcast, csr, lhs, rhs, out);
});
});
}
/** @brief Generalized SDDMM on Csr format with Heterograph support. */
template <int XPU, typename IdType, typename DType>
void SDDMMCsrHetero(const std::string& op,
const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& vec_lhs,
const std::vector<NDArray>& vec_rhs,
std::vector<NDArray> vec_out,
int lhs_target,
int rhs_target,
const std::vector<dgl_type_t>& lhs_nid,
void SDDMMCsrHetero(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& vec_lhs,
const std::vector<NDArray>& vec_rhs, std::vector<NDArray> vec_out,
int lhs_target, int rhs_target, const std::vector<dgl_type_t>& lhs_nid,
const std::vector<dgl_type_t>& rhs_nid) {
SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
......@@ -79,7 +71,8 @@ void SDDMMCsrHetero(const std::string& op,
NDArray lhs = vec_lhs[lhs_nid[etype]];
NDArray rhs = vec_rhs[rhs_nid[etype]];
NDArray out = vec_out[etype];
cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out);
cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(
bcast, csr, lhs, rhs, out);
}
});
});
......@@ -87,78 +80,62 @@ void SDDMMCsrHetero(const std::string& op,
template void SDDMMCsr<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCPU, int64_t, float>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCPU, int32_t, double>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCsrHetero<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCPU, int64_t, float>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCPU, int32_t, double>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
/** @brief Generalized SDDMM on Coo format. */
template <int XPU, typename IdType, typename DType>
void SDDMMCoo(const std::string& op,
const BcastOff& bcast,
const COOMatrix& coo,
NDArray lhs,
NDArray rhs,
NDArray out,
int lhs_target,
int rhs_target) {
void SDDMMCoo(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target) {
SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out);
cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(
bcast, coo, lhs, rhs, out);
});
});
}
/** @brief Generalized SDDMM on Coo format with Heterograph support. */
template <int XPU, typename IdType, typename DType>
void SDDMMCooHetero(const std::string& op,
const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& vec_lhs,
const std::vector<NDArray>& vec_rhs,
std::vector<NDArray> vec_out,
int lhs_target,
int rhs_target,
const std::vector<dgl_type_t>& lhs_nid,
void SDDMMCooHetero(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& vec_lhs,
const std::vector<NDArray>& vec_rhs, std::vector<NDArray> vec_out,
int lhs_target, int rhs_target, const std::vector<dgl_type_t>& lhs_nid,
const std::vector<dgl_type_t>& rhs_nid) {
SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
......@@ -168,7 +145,8 @@ void SDDMMCooHetero(const std::string& op,
NDArray lhs = vec_lhs[lhs_nid[etype]];
NDArray rhs = vec_rhs[rhs_nid[etype]];
NDArray out = vec_out[etype];
cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out);
cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(
bcast, coo, lhs, rhs, out);
}
});
});
......@@ -176,48 +154,40 @@ void SDDMMCooHetero(const std::string& op,
template void SDDMMCoo<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCPU, int64_t, float>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCPU, int32_t, double>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCooHetero<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCPU, int64_t, float>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCPU, int32_t, double>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
} // namespace aten
......
......@@ -4,8 +4,11 @@
* @brief Segment reduce C APIs and definitions.
*/
#include "./segment_reduce.h"
#include <dgl/array.h>
#include <string>
#include "./spmm_binary_ops.h"
namespace dgl {
......@@ -14,10 +17,7 @@ namespace aten {
/** @brief Segment Reduce operator. */
template <int XPU, typename IdType, typename DType>
void SegmentReduce(
const std::string& op,
NDArray feat,
NDArray offsets,
NDArray out,
const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray arg) {
if (op == "sum") {
cpu::SegmentSum<IdType, DType>(feat, offsets, out);
......@@ -36,73 +36,47 @@ void SegmentReduce(
/** @brief Scatter Add.*/
template <int XPU, typename IdType, typename DType>
void ScatterAdd(NDArray feat,
NDArray idx,
NDArray out) {
void ScatterAdd(NDArray feat, NDArray idx, NDArray out) {
cpu::ScatterAdd<IdType, DType>(feat, idx, out);
}
/** @brief Update gradients for reduce operator max/min on heterogeneous graph.*/
/** @brief Update gradients for reduce operator max/min on heterogeneous
* graph.*/
template <int XPU, typename IdType, typename DType>
void UpdateGradMinMax_hetero(const HeteroGraphPtr& g,
const std::string& op,
const std::vector<NDArray>& feat,
const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype,
std::vector<NDArray>* out) {
void UpdateGradMinMax_hetero(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out) {
cpu::UpdateGradMinMax_hetero<IdType, DType>(g, op, feat, idx, idx_etype, out);
}
/** @brief Backward function of segment cmp.*/
template <int XPU, typename IdType, typename DType>
void BackwardSegmentCmp(
NDArray feat,
NDArray arg,
NDArray out) {
void BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out) {
cpu::BackwardSegmentCmp<IdType, DType>(feat, arg, out);
}
template void SegmentReduce<kDGLCPU, int32_t, float>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray arg);
template void SegmentReduce<kDGLCPU, int64_t, float>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray arg);
template void SegmentReduce<kDGLCPU, int32_t, double>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray arg);
template void SegmentReduce<kDGLCPU, int64_t, double>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray arg);
template void ScatterAdd<kDGLCPU, int32_t, float>(
NDArray feat,
NDArray idx,
NDArray out);
NDArray feat, NDArray idx, NDArray out);
template void ScatterAdd<kDGLCPU, int64_t, float>(
NDArray feat,
NDArray idx,
NDArray out);
NDArray feat, NDArray idx, NDArray out);
template void ScatterAdd<kDGLCPU, int32_t, double>(
NDArray feat,
NDArray idx,
NDArray out);
NDArray feat, NDArray idx, NDArray out);
template void ScatterAdd<kDGLCPU, int64_t, double>(
NDArray feat,
NDArray arg,
NDArray out);
NDArray feat, NDArray arg, NDArray out);
template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, float>(
const HeteroGraphPtr& g, const std::string& op,
......@@ -122,21 +96,13 @@ template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, double>(
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void BackwardSegmentCmp<kDGLCPU, int32_t, float>(
NDArray feat,
NDArray arg,
NDArray out);
NDArray feat, NDArray arg, NDArray out);
template void BackwardSegmentCmp<kDGLCPU, int64_t, float>(
NDArray feat,
NDArray arg,
NDArray out);
NDArray feat, NDArray arg, NDArray out);
template void BackwardSegmentCmp<kDGLCPU, int32_t, double>(
NDArray feat,
NDArray arg,
NDArray out);
NDArray feat, NDArray arg, NDArray out);
template void BackwardSegmentCmp<kDGLCPU, int64_t, double>(
NDArray feat,
NDArray arg,
NDArray out);
NDArray feat, NDArray arg, NDArray out);
} // namespace aten
} // namespace dgl
......@@ -4,6 +4,7 @@
* @brief SPMM C APIs and definitions.
*/
#include "./spmm.h"
#include <dgl/array.h>
namespace dgl {
......@@ -11,12 +12,9 @@ namespace aten {
/** @brief Generalized SpMM on Csr format. */
template <int XPU, typename IdType, typename DType>
void SpMMCsr(const std::string& op, const std::string& reduce,
const BcastOff& bcast,
const CSRMatrix& csr,
NDArray ufeat,
NDArray efeat,
NDArray out,
void SpMMCsr(
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux) {
const int64_t dim = bcast.out_len;
if (reduce == "sum") {
......@@ -25,13 +23,15 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
});
} else if (reduce == "max" || reduce == "min") {
SWITCH_OP(op, Op, {
DType *out_off = out.Ptr<DType>();
DType* out_off = out.Ptr<DType>();
if (reduce == "max") {
std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Max<DType>::zero);
std::fill(
out_off, out_off + csr.num_rows * dim, cpu::op::Max<DType>::zero);
cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Max<DType>>(
bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
} else {
std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Min<DType>::zero);
std::fill(
out_off, out_off + csr.num_rows * dim, cpu::op::Min<DType>::zero);
cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Min<DType>>(
bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
}
......@@ -43,12 +43,11 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
/** @brief Generalized SpMM on Csr format. */
template <int XPU, typename IdType, typename DType>
void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const BcastOff& bcast,
void SpMMCsrHetero(
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& vec_ufeat,
const std::vector<NDArray>& vec_efeat,
std::vector<NDArray>* vec_out,
const std::vector<NDArray>& vec_efeat, std::vector<NDArray>* vec_out,
std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids) {
......@@ -60,8 +59,10 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const dgl_type_t src_id = ufeat_node_tids[etype];
const dgl_type_t dst_id = out_node_tids[etype];
CSRMatrix csr = vec_csr[etype];
NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
NDArray ufeat =
(vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
NDArray efeat =
(vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
NDArray out = (*vec_out)[dst_id];
cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
}
......@@ -71,21 +72,27 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
std::vector<bool> updated((*vec_out).size(), false);
// TODO(Israt): use vector updated to fill(out...) too
for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
DType *out_off = (*vec_out)[out_node_tids[etype]].Ptr<DType>();
DType* out_off = (*vec_out)[out_node_tids[etype]].Ptr<DType>();
if (reduce == "max")
std::fill(out_off, out_off + vec_csr[etype].num_rows * dim, cpu::op::Max<DType>::zero);
std::fill(
out_off, out_off + vec_csr[etype].num_rows * dim,
cpu::op::Max<DType>::zero);
else
std::fill(out_off, out_off + vec_csr[etype].num_rows * dim, cpu::op::Min<DType>::zero);
std::fill(
out_off, out_off + vec_csr[etype].num_rows * dim,
cpu::op::Min<DType>::zero);
const dgl_type_t dst_id = out_node_tids[etype];
if (!updated[dst_id]) {
updated[dst_id] = true;
if (Op::use_lhs) {
IdType *argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>();
std::fill(argu_ntype, argu_ntype + vec_csr[etype].num_rows * dim, -1);
IdType* argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>();
std::fill(
argu_ntype, argu_ntype + vec_csr[etype].num_rows * dim, -1);
}
if (Op::use_rhs) {
IdType *arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>();
std::fill(arge_etype, arge_etype + vec_csr[etype].num_rows * dim, -1);
IdType* arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>();
std::fill(
arge_etype, arge_etype + vec_csr[etype].num_rows * dim, -1);
}
}
}
......@@ -94,17 +101,21 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const dgl_type_t src_id = ufeat_node_tids[etype];
const dgl_type_t dst_id = out_node_tids[etype];
CSRMatrix csr = vec_csr[etype];
NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
NDArray ufeat =
(vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
NDArray efeat =
(vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
NDArray out = (*vec_out)[dst_id];
if (reduce == "max") {
cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Max<DType>>(
bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id], (*out_aux)[1][dst_id],
(*out_aux)[2][dst_id], (*out_aux)[3][dst_id], src_id, etype);
bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id],
(*out_aux)[1][dst_id], (*out_aux)[2][dst_id],
(*out_aux)[3][dst_id], src_id, etype);
} else {
cpu::SpMMCmpCsrHetero<IdType, DType, Op, cpu::op::Min<DType>>(
bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id], (*out_aux)[1][dst_id],
(*out_aux)[2][dst_id], (*out_aux)[3][dst_id], src_id, etype);
bcast, csr, ufeat, efeat, out, (*out_aux)[0][dst_id],
(*out_aux)[1][dst_id], (*out_aux)[2][dst_id],
(*out_aux)[3][dst_id], src_id, etype);
}
}
});
......@@ -114,119 +125,104 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
}
template void SpMMCsr<kDGLCPU, int32_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCPU, int64_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCPU, int32_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCPU, int64_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
template void SpMMCsrHetero<kDGLCPU, int32_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDGLCPU, int64_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDGLCPU, int32_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDGLCPU, int64_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
/** @brief Edge_softmax_csr forward op on Csr format. */
template <int XPU, typename IdType, typename DType>
void Edge_softmax_csr_forward(const std::string& op,
const BcastOff& bcast,
const CSRMatrix& csr,
NDArray ufeat,
NDArray efeat,
NDArray out) {
void Edge_softmax_csr_forward(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out) {
SWITCH_OP(op, Op, {
cpu::Edge_softmax_csr_forward<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
cpu::Edge_softmax_csr_forward<IdType, DType, Op>(
bcast, csr, ufeat, efeat, out);
});
}
/** @brief Edge_softmax_csr backward op on Csr format. */
template <int XPU, typename IdType, typename DType>
void Edge_softmax_csr_backward(const std::string& op,
const BcastOff& bcast,
const CSRMatrix& csr,
NDArray out,
NDArray sds,
NDArray back_out) {
void Edge_softmax_csr_backward(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray out, NDArray sds, NDArray back_out) {
SWITCH_OP(op, Op, {
cpu::Edge_softmax_csr_backward<IdType, DType, Op>(bcast, csr, out, sds, back_out);
cpu::Edge_softmax_csr_backward<IdType, DType, Op>(
bcast, csr, out, sds, back_out);
});
}
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, float>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, float>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, double>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, double>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, float>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, float>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, double>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, double>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
/** @brief Generalized SpMM on Coo format. */
template <int XPU, typename IdType, typename DType>
void SpMMCoo(const std::string& op, const std::string& reduce,
const BcastOff& bcast,
const COOMatrix& coo,
NDArray ufeat,
NDArray efeat,
NDArray out,
void SpMMCoo(
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux) {
if (reduce == "sum") {
SWITCH_OP(op, Op, {
......@@ -247,21 +243,21 @@ void SpMMCoo(const std::string& op, const std::string& reduce,
}
template void SpMMCoo<kDGLCPU, int32_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCPU, int64_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCPU, int32_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCPU, int64_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
} // namespace aten
} // namespace dgl
......@@ -7,6 +7,7 @@
#define DGL_ARRAY_CPU_SPMM_BINARY_OPS_H_
#include <dgl/array.h>
#include <dgl/bcast.h>
#include <limits>
namespace dgl {
namespace aten {
......
......@@ -63,8 +63,10 @@ template NDArray IndexSelect<kDGLCUDA, int64_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, __half, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, __half, int64_t>(NDArray, IdArray);
#if BF16_ENABLED
template NDArray IndexSelect<kDGLCUDA, __nv_bfloat16, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, __nv_bfloat16, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, __nv_bfloat16, int32_t>(
NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, __nv_bfloat16, int64_t>(
NDArray, IdArray);
#endif // BF16_ENABLED
template NDArray IndexSelect<kDGLCUDA, float, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, float, int64_t>(NDArray, IdArray);
......@@ -76,8 +78,8 @@ DType IndexSelect(NDArray array, int64_t index) {
auto device = runtime::DeviceAPI::Get(array->ctx);
DType ret = static_cast<DType>(0.0f);
device->CopyDataFromTo(
static_cast<DType*>(array->data) + index, 0, &ret, 0,
sizeof(DType), array->ctx, DGLContext{kDGLCPU, 0}, array->dtype);
static_cast<DType*>(array->data) + index, 0, &ret, 0, sizeof(DType),
array->ctx, DGLContext{kDGLCPU, 0}, array->dtype);
return ret;
}
......@@ -87,7 +89,8 @@ template uint32_t IndexSelect<kDGLCUDA, uint32_t>(NDArray array, int64_t index);
template uint64_t IndexSelect<kDGLCUDA, uint64_t>(NDArray array, int64_t index);
template __half IndexSelect<kDGLCUDA, __half>(NDArray array, int64_t index);
#if BF16_ENABLED
template __nv_bfloat16 IndexSelect<kDGLCUDA, __nv_bfloat16>(NDArray array, int64_t index);
template __nv_bfloat16 IndexSelect<kDGLCUDA, __nv_bfloat16>(
NDArray array, int64_t index);
#endif // BF16_ENABLED
template float IndexSelect<kDGLCUDA, float>(NDArray array, int64_t index);
template double IndexSelect<kDGLCUDA, double>(NDArray array, int64_t index);
......
This diff is collapsed.
......@@ -4,6 +4,7 @@
* @brief Array scatter GPU implementation
*/
#include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
......@@ -13,8 +14,8 @@ namespace aten {
namespace impl {
template <typename DType, typename IdType>
__global__ void _ScatterKernel(const IdType* index, const DType* value,
int64_t length, DType* out) {
__global__ void _ScatterKernel(
const IdType* index, const DType* value, int64_t length, DType* out) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = gridDim.x * blockDim.x;
while (tx < length) {
......@@ -33,15 +34,15 @@ void Scatter_(IdArray index, NDArray value, NDArray out) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
const int nt = cuda::FindNumThreads(len);
const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL(_ScatterKernel, nb, nt, 0, stream,
idx, val, len, outd);
CUDA_KERNEL_CALL(_ScatterKernel, nb, nt, 0, stream, idx, val, len, outd);
}
template void Scatter_<kDGLCUDA, int32_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, int64_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, __half, int32_t>(IdArray, NDArray, NDArray);
#if BF16_ENABLED
template void Scatter_<kDGLCUDA, __nv_bfloat16, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, __nv_bfloat16, int32_t>(
IdArray, NDArray, NDArray);
#endif // BF16_ENABLED
template void Scatter_<kDGLCUDA, float, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, double, int32_t>(IdArray, NDArray, NDArray);
......@@ -49,7 +50,8 @@ template void Scatter_<kDGLCUDA, int32_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, int64_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, __half, int64_t>(IdArray, NDArray, NDArray);
#if BF16_ENABLED
template void Scatter_<kDGLCUDA, __nv_bfloat16, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, __nv_bfloat16, int64_t>(
IdArray, NDArray, NDArray);
#endif // BF16_ENABLED
template void Scatter_<kDGLCUDA, float, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, double, int64_t>(IdArray, NDArray, NDArray);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment