You need to sign in or sign up before continuing.
Unverified Commit ab2bd1f1 authored by Israt Nisa's avatar Israt Nisa Committed by GitHub
Browse files

[Feature] Add cuda support for Sparse Matrix multiplication, summation and masking (#2782)



* init cuda support

* cuSPARSE err

* passed unittest for csr_mm/SpGEMM. int64 not supported

* Debugging cuSPARSE error 3

* csrgeam only supports int32?

* disabling int64 for cuda

* refactor and add CSRMask

* lint

* oops

* remove todo

* rewrite CSRMask with CSRGetData

* lint

* fix test

* address comments

* lint

* fix

* addresses comments and rename BUG_ON
Co-authored-by: default avatarIsrat Nisa <nisisrat@amazon.com>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-30-71.ec2.internal>
Co-authored-by: default avatarQuan Gan <coin2028@hotmail.com>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent e18c2ab4
...@@ -79,6 +79,16 @@ IdArray Range(int64_t low, int64_t high, uint8_t nbits, DLContext ctx); ...@@ -79,6 +79,16 @@ IdArray Range(int64_t low, int64_t high, uint8_t nbits, DLContext ctx);
*/ */
IdArray Full(int64_t val, int64_t length, uint8_t nbits, DLContext ctx); IdArray Full(int64_t val, int64_t length, uint8_t nbits, DLContext ctx);
/*!
* \brief Return an array full of the given value with the given type.
* \param val The value to fill.
* \param length Number of elements.
* \param ctx Device context
* \return the result array
*/
template <typename DType>
NDArray Full(DType val, int64_t length, DLContext ctx);
/*! \brief Create a deep copy of the given array */ /*! \brief Create a deep copy of the given array */
IdArray Clone(IdArray arr); IdArray Clone(IdArray arr);
......
...@@ -198,6 +198,31 @@ inline runtime::NDArray CSRGetAllData(CSRMatrix mat, int64_t row, int64_t col) { ...@@ -198,6 +198,31 @@ inline runtime::NDArray CSRGetAllData(CSRMatrix mat, int64_t row, int64_t col) {
*/ */
runtime::NDArray CSRGetData(CSRMatrix, runtime::NDArray rows, runtime::NDArray cols); runtime::NDArray CSRGetData(CSRMatrix, runtime::NDArray rows, runtime::NDArray cols);
/*!
* \brief Get the data for each (row, col) pair, then index into the weights array.
*
* The operator supports matrix with duplicate entries but only one matched entry
* will be returned for each (row, col) pair. Support duplicate input (row, col)
* pairs.
*
* If some (row, col) pairs do not contain a valid non-zero elements to index into the
* weights array, DGL returns the value \a filler for that pair instead.
*
* \note This operator allows broadcasting (i.e, either row or col can be of length 1).
*
* \tparam DType the data type of the weights array.
* \param mat Sparse matrix.
* \param rows Row index.
* \param cols Column index.
* \param weights The weights array.
* \param filler The value to return for row-column pairs not existent in the matrix.
* \return Data array. The i^th element is the data of (rows[i], cols[i])
*/
template <typename DType>
runtime::NDArray CSRGetData(
CSRMatrix, runtime::NDArray rows, runtime::NDArray cols, runtime::NDArray weights,
DType filler);
/*! \brief Return a transposed CSR matrix */ /*! \brief Return a transposed CSR matrix */
CSRMatrix CSRTranspose(CSRMatrix csr); CSRMatrix CSRTranspose(CSRMatrix csr);
......
...@@ -55,7 +55,8 @@ void SDDMM(const std::string& op, ...@@ -55,7 +55,8 @@ void SDDMM(const std::string& op,
/*! /*!
* \brief Sparse-sparse matrix multiplication. * \brief Sparse-sparse matrix multiplication.
* *
* \note B is transposed (i.e. in CSC format). * The sparse matrices must have scalar weights (i.e. \a A_weights and \a B_weights
* are 1D vectors.)
*/ */
std::pair<CSRMatrix, NDArray> CSRMM( std::pair<CSRMatrix, NDArray> CSRMM(
CSRMatrix A, CSRMatrix A,
...@@ -64,17 +65,15 @@ std::pair<CSRMatrix, NDArray> CSRMM( ...@@ -64,17 +65,15 @@ std::pair<CSRMatrix, NDArray> CSRMM(
NDArray B_weights); NDArray B_weights);
/*! /*!
* \brief Sparse-sparse matrix summation. * \brief Summing up a list of sparse matrices.
*
* The sparse matrices must have scalar weights (i.e. the arrays in \a A_weights
* are 1D vectors.)
*/ */
std::pair<CSRMatrix, NDArray> CSRSum( std::pair<CSRMatrix, NDArray> CSRSum(
const std::vector<CSRMatrix>& A, const std::vector<CSRMatrix>& A,
const std::vector<NDArray>& A_weights); const std::vector<NDArray>& A_weights);
/*!
* \brief Return a sparse matrix with the values of A but nonzero entry locations of B.
*/
NDArray CSRMask(const CSRMatrix& A, NDArray A_weights, const CSRMatrix& B);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
......
...@@ -555,7 +555,7 @@ DGL_DLL void DGLLoadTensorAdapter(const char *path); ...@@ -555,7 +555,7 @@ DGL_DLL void DGLLoadTensorAdapter(const char *path);
* *
* Hints the user to file a bug report if the condition fails. * Hints the user to file a bug report if the condition fails.
*/ */
#define BUG_ON(cond) \ #define BUG_IF_FAIL(cond) \
CHECK(cond) << "A bug has been occurred. " \ CHECK(cond) << "A bug has been occurred. " \
"Please file a bug report at https://github.com/dmlc/dgl/issues. " \ "Please file a bug report at https://github.com/dmlc/dgl/issues. " \
"Message: " "Message: "
......
...@@ -696,6 +696,9 @@ class HeteroGraphIndex(ObjectBase): ...@@ -696,6 +696,9 @@ class HeteroGraphIndex(ObjectBase):
indptr = utils.toindex(rst(0), self.dtype).tonumpy() indptr = utils.toindex(rst(0), self.dtype).tonumpy()
indices = utils.toindex(rst(1), self.dtype).tonumpy() indices = utils.toindex(rst(1), self.dtype).tonumpy()
data = utils.toindex(rst(2)).tonumpy() if return_edge_ids else np.ones_like(indices) data = utils.toindex(rst(2)).tonumpy() if return_edge_ids else np.ones_like(indices)
# Check if edge ID is omitted
if return_edge_ids and data.shape[0] == 0:
data = np.arange(nnz)
return scipy.sparse.csr_matrix((data, indices, indptr), shape=(nrows, ncols)) return scipy.sparse.csr_matrix((data, indices, indptr), shape=(nrows, ncols))
elif fmt == 'coo': elif fmt == 'coo':
idx = utils.toindex(rst(0), self.dtype).tonumpy() idx = utils.toindex(rst(0), self.dtype).tonumpy()
......
...@@ -366,124 +366,87 @@ def _bwd_segment_cmp(feat, arg, m): ...@@ -366,124 +366,87 @@ def _bwd_segment_cmp(feat, arg, m):
to_dgl_nd_for_write(out)) to_dgl_nd_for_write(out))
return out return out
class CSRMatrix(object): def csrmm(A, A_weights, B, B_weights, num_vtypes):
"""Device- and backend-agnostic sparse matrix in CSR format. """Return a graph whose adjacency matrix is the sparse matrix multiplication
of those of two given graphs.
Parameters Note that the edge weights of both graphs must be scalar, i.e. :attr:`A_weights`
---------- and :attr:`B_weights` must be 1D vectors.
data : Tensor
The data array.
indices : Tensor
The column indices array.
indptr : Tensor
The row index pointer array.
num_rows : int
The number of rows.
num_cols : int
The number of columns.
"""
def __init__(self, data, indices, indptr, num_rows, num_cols):
self.indptr = indptr
self.indices = indices
self.data = data
self.shape = (num_rows, num_cols)
def csrmm(A, B):
"""Sparse-sparse matrix multiplication.
This is an internal function whose interface is subject to changes.
Parameters Parameters
---------- ----------
A : dgl.sparse.CSRMatrix A : HeteroGraphIndex
The left operand The input graph index as left operand.
B : dgl.sparse.CSRMatrix A_weights : Tensor
The right operand The edge weights of graph A as 1D tensor.
B : HeteroGraphIndex
The input graph index as right operand.
B_weights : Tensor
The edge weights of graph B as 1D tensor.
num_vtypes : int
The number of node types for the returned graph (must be either 1 or 2).
Returns Returns
------- -------
dgl.sparse.CSRMatrix C : HeteroGraphIndex
The result The output graph index.
C_weights : Tensor
The edge weights of the output graph.
""" """
A_indptr = F.zerocopy_from_numpy(A.indptr) C, C_weights = _CAPI_DGLCSRMM(
A_indices = F.zerocopy_from_numpy(A.indices) A, F.to_dgl_nd(A_weights), B, F.to_dgl_nd(B_weights), num_vtypes)
A_data = F.zerocopy_from_numpy(A.data) return C, F.from_dgl_nd(C_weights)
B_indptr = F.zerocopy_from_numpy(B.indptr)
B_indices = F.zerocopy_from_numpy(B.indices) def csrsum(As, A_weights):
B_data = F.zerocopy_from_numpy(B.data) """Return a graph whose adjacency matrix is the sparse matrix summation
C_indptr, C_indices, C_data = _CAPI_DGLCSRMM( of the given list of graphs.
A.shape[0], A.shape[1], B.shape[1],
F.to_dgl_nd(A_indptr), Note that the edge weights of all graphs must be scalar, i.e. the arrays in
F.to_dgl_nd(A_indices), :attr:`A_weights` must be 1D vectors.
F.to_dgl_nd(A_data),
F.to_dgl_nd(B_indptr),
F.to_dgl_nd(B_indices),
F.to_dgl_nd(B_data))
return CSRMatrix(
F.from_dgl_nd(C_data),
F.from_dgl_nd(C_indices),
F.from_dgl_nd(C_indptr),
A.shape[0],
B.shape[1])
def csrsum(As):
"""Sparse-sparse matrix summation.
This is an internal function whose interface is subject to changes.
Parameters Parameters
---------- ----------
As : List[dgl.sparse.CSRMatrix] As : list[HeteroGraphIndex]
List of scipy sparse matrices in CSR format. The input graph indices.
A_weights : list[Tensor]
The edge weights of graph A as 1D tensor.
Returns Returns
------- -------
dgl.sparse.CSRMatrix C : HeteroGraphIndex
The result The output graph index.
C_weights : Tensor
The edge weights of the output graph.
""" """
A_indptr = [F.zerocopy_from_numpy(x.indptr) for x in As] C, C_weights = _CAPI_DGLCSRSum(As, [F.to_dgl_nd(w) for w in A_weights])
A_indices = [F.zerocopy_from_numpy(x.indices) for x in As] return C, F.from_dgl_nd(C_weights)
A_data = [F.zerocopy_from_numpy(x.data) for x in As]
C_indptr, C_indices, C_data = _CAPI_DGLCSRSum( def csrmask(A, A_weights, B):
As[0].shape[0], As[0].shape[1], """Return the weights of A at the locations identical to the sparsity pattern
[F.to_dgl_nd(x) for x in A_indptr], of B.
[F.to_dgl_nd(x) for x in A_indices],
[F.to_dgl_nd(x) for x in A_data]) If a non-zero entry in B does not exist in A, DGL returns 0 for that location
return CSRMatrix( instead.
F.from_dgl_nd(C_data),
F.from_dgl_nd(C_indices), Note that the edge weights of the graph must be scalar, i.e. :attr:`A_weights`
F.from_dgl_nd(C_indptr), must be a 1D vector.
As[0].shape[0], As[0].shape[1])
In scipy notation this is identical to ``A[B != 0]``.
def csrmask(A, B):
"""Sparse-sparse matrix masking operation that computes ``A[B != 0]``.
This is an internal function whose interface is subject to changes.
Parameters Parameters
---------- ----------
A : dgl.sparse.CSRMatrix A : HeteroGraphIndex
The left operand The input graph index as left operand.
B : dgl.sparse.CSRMatrix A_weights : Tensor
The right operand The edge weights of graph A as 1D tensor.
B : HeteroGraphIndex
The input graph index as right operand.
Returns Returns
------- -------
Tensor B_weights : Tensor
The result The output weights.
""" """
A_indptr = F.zerocopy_from_numpy(A.indptr) return F.from_dgl_nd(_CAPI_DGLCSRMask(A, F.to_dgl_nd(A_weights), B))
A_indices = F.zerocopy_from_numpy(A.indices)
A_data = F.zerocopy_from_numpy(A.data)
B_indptr = F.zerocopy_from_numpy(B.indptr)
B_indices = F.zerocopy_from_numpy(B.indices)
B_data = _CAPI_DGLCSRMask(
A.shape[0], A.shape[1],
F.to_dgl_nd(A_indptr),
F.to_dgl_nd(A_indices),
F.to_dgl_nd(A_data),
F.to_dgl_nd(B_indptr),
F.to_dgl_nd(B_indices))
return F.from_dgl_nd(B_data)
_init_api("dgl.sparse") _init_api("dgl.sparse")
...@@ -57,6 +57,20 @@ IdArray Full(int64_t val, int64_t length, uint8_t nbits, DLContext ctx) { ...@@ -57,6 +57,20 @@ IdArray Full(int64_t val, int64_t length, uint8_t nbits, DLContext ctx) {
return ret; return ret;
} }
template <typename DType>
NDArray Full(DType val, int64_t length, DLContext ctx) {
NDArray ret;
ATEN_XPU_SWITCH_CUDA(ctx.device_type, XPU, "Full", {
ret = impl::Full<XPU, DType>(val, length, ctx);
});
return ret;
}
template NDArray Full<int32_t>(int32_t val, int64_t length, DLContext ctx);
template NDArray Full<int64_t>(int64_t val, int64_t length, DLContext ctx);
template NDArray Full<float>(float val, int64_t length, DLContext ctx);
template NDArray Full<double>(double val, int64_t length, DLContext ctx);
IdArray AsNumBits(IdArray arr, uint8_t bits) { IdArray AsNumBits(IdArray arr, uint8_t bits) {
CHECK(bits == 32 || bits == 64) CHECK(bits == 32 || bits == 64)
<< "Invalid ID type. Must be int32 or int64, but got int" << "Invalid ID type. Must be int32 or int64, but got int"
...@@ -406,6 +420,25 @@ NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) { ...@@ -406,6 +420,25 @@ NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
return ret; return ret;
} }
template <typename DType>
NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, DType filler) {
NDArray ret;
CHECK_SAME_DTYPE(csr.indices, rows);
CHECK_SAME_DTYPE(csr.indices, cols);
CHECK_SAME_CONTEXT(csr.indices, rows);
CHECK_SAME_CONTEXT(csr.indices, cols);
CHECK_SAME_CONTEXT(csr.indices, weights);
ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRGetData", {
ret = impl::CSRGetData<XPU, IdType, DType>(csr, rows, cols, weights, filler);
});
return ret;
}
template NDArray CSRGetData<float>(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, float filler);
template NDArray CSRGetData<double>(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, double filler);
std::vector<NDArray> CSRGetDataAndIndices( std::vector<NDArray> CSRGetDataAndIndices(
CSRMatrix csr, NDArray rows, NDArray cols) { CSRMatrix csr, NDArray rows, NDArray cols) {
CHECK_SAME_DTYPE(csr.indices, rows); CHECK_SAME_DTYPE(csr.indices, rows);
......
...@@ -102,8 +102,15 @@ runtime::NDArray CSRGetRowData(CSRMatrix csr, int64_t row); ...@@ -102,8 +102,15 @@ runtime::NDArray CSRGetRowData(CSRMatrix csr, int64_t row);
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
bool CSRIsSorted(CSRMatrix csr); bool CSRIsSorted(CSRMatrix csr);
template <DLDeviceType XPU, typename IdType, typename DType>
runtime::NDArray CSRGetData(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols,
runtime::NDArray weights, DType filler);
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
runtime::NDArray CSRGetData(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols); NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
return CSRGetData<XPU, IdType, IdType>(csr, rows, cols, NullArray(rows->dtype), -1);
}
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
std::vector<runtime::NDArray> CSRGetDataAndIndices( std::vector<runtime::NDArray> CSRGetDataAndIndices(
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* \brief Array operator CPU implementation * \brief Array operator CPU implementation
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/runtime/ndarray.h>
#include <numeric> #include <numeric>
#include "../arith.h" #include "../arith.h"
...@@ -173,16 +174,18 @@ template IdArray UnaryElewise<kDLCPU, int64_t, arith::Neg>(IdArray lhs); ...@@ -173,16 +174,18 @@ template IdArray UnaryElewise<kDLCPU, int64_t, arith::Neg>(IdArray lhs);
///////////////////////////// Full ///////////////////////////// ///////////////////////////// Full /////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename DType>
IdArray Full(IdType val, int64_t length, DLContext ctx) { NDArray Full(DType val, int64_t length, DLContext ctx) {
IdArray ret = NewIdArray(length, ctx, sizeof(IdType) * 8); NDArray ret = NDArray::Empty({length}, DLDataTypeTraits<DType>::dtype, ctx);
IdType* ret_data = static_cast<IdType*>(ret->data); DType* ret_data = static_cast<DType*>(ret->data);
std::fill(ret_data, ret_data + length, val); std::fill(ret_data, ret_data + length, val);
return ret; return ret;
} }
template IdArray Full<kDLCPU, int32_t>(int32_t val, int64_t length, DLContext ctx); template NDArray Full<kDLCPU, int32_t>(int32_t val, int64_t length, DLContext ctx);
template IdArray Full<kDLCPU, int64_t>(int64_t val, int64_t length, DLContext ctx); template NDArray Full<kDLCPU, int64_t>(int64_t val, int64_t length, DLContext ctx);
template NDArray Full<kDLCPU, float>(float val, int64_t length, DLContext ctx);
template NDArray Full<kDLCPU, double>(double val, int64_t length, DLContext ctx);
///////////////////////////// Range ///////////////////////////// ///////////////////////////// Range /////////////////////////////
......
/*!
* Copyright (c) 2021 by Contributors
* \file array/cpu/csr_get_data.cc
* \brief Retrieve entries of a CSR matrix
*/
#include <dgl/array.h>
#include <vector>
#include <unordered_set>
#include <numeric>
#include "array_utils.h"
namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
void CollectDataFromSorted(const IdType *indices_data, const IdType *data,
const IdType start, const IdType end, const IdType col,
std::vector<IdType> *ret_vec) {
const IdType *start_ptr = indices_data + start;
const IdType *end_ptr = indices_data + end;
auto it = std::lower_bound(start_ptr, end_ptr, col);
// This might be a multi-graph. We need to collect all of the matched
// columns.
for (; it != end_ptr; it++) {
// If the col exist
if (*it == col) {
IdType idx = it - indices_data;
ret_vec->push_back(data? data[idx] : idx);
} else {
// If we find a column that is different, we can stop searching now.
break;
}
}
}
template <DLDeviceType XPU, typename IdType, typename DType>
NDArray CSRGetData(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, DType filler) {
const int64_t rowlen = rows->shape[0];
const int64_t collen = cols->shape[0];
CHECK((rowlen == collen) || (rowlen == 1) || (collen == 1))
<< "Invalid row and col id array.";
const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
const IdType* row_data = static_cast<IdType*>(rows->data);
const IdType* col_data = static_cast<IdType*>(cols->data);
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
const IdType* data = CSRHasData(csr)? static_cast<IdType*>(csr.data->data) : nullptr;
const int64_t retlen = std::max(rowlen, collen);
bool return_eids = IsNullArray(weights);
const DType* weight_data = return_eids ? nullptr : weights.Ptr<DType>();
if (return_eids)
BUG_IF_FAIL(DLDataTypeTraits<DType>::dtype == rows->dtype) <<
"DType does not match row's dtype.";
NDArray ret = Full(filler, retlen, rows->ctx);
DType* ret_data = ret.Ptr<DType>();
// NOTE: In most cases, the input csr is already sorted. If not, we might need to
// consider sorting it especially when the number of (row, col) pairs is large.
// Need more benchmarks to justify the choice.
if (csr.sorted) {
// use binary search on each row
#pragma omp parallel for
for (int64_t p = 0; p < retlen; ++p) {
const IdType row_id = row_data[p * row_stride], col_id = col_data[p * col_stride];
CHECK(row_id >= 0 && row_id < csr.num_rows) << "Invalid row index: " << row_id;
CHECK(col_id >= 0 && col_id < csr.num_cols) << "Invalid col index: " << col_id;
const IdType *start_ptr = indices_data + indptr_data[row_id];
const IdType *end_ptr = indices_data + indptr_data[row_id + 1];
auto it = std::lower_bound(start_ptr, end_ptr, col_id);
if (it != end_ptr && *it == col_id) {
const IdType idx = it - indices_data;
IdType eid = data ? data[idx] : idx;
ret_data[p] = return_eids ? eid : weight_data[eid];
}
}
} else {
// linear search on each row
#pragma omp parallel for
for (int64_t p = 0; p < retlen; ++p) {
const IdType row_id = row_data[p * row_stride], col_id = col_data[p * col_stride];
CHECK(row_id >= 0 && row_id < csr.num_rows) << "Invalid row index: " << row_id;
CHECK(col_id >= 0 && col_id < csr.num_cols) << "Invalid col index: " << col_id;
for (IdType idx = indptr_data[row_id]; idx < indptr_data[row_id + 1]; ++idx) {
if (indices_data[idx] == col_id) {
IdType eid = data ? data[idx] : idx;
ret_data[p] = return_eids ? eid : weight_data[eid];
break;
}
}
}
}
return ret;
}
template NDArray CSRGetData<kDLCPU, int32_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, float filler);
template NDArray CSRGetData<kDLCPU, int64_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, float filler);
template NDArray CSRGetData<kDLCPU, int32_t, double>(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, double filler);
template NDArray CSRGetData<kDLCPU, int64_t, double>(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, double filler);
// For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray)
template NDArray CSRGetData<kDLCPU, int32_t, int32_t>(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, int32_t filler);
template NDArray CSRGetData<kDLCPU, int64_t, int64_t>(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, int64_t filler);
} // namespace impl
} // namespace aten
} // namespace dgl
/*!
* Copyright (c) 2019 by Contributors
* \file array/cpu/csr_mask.cc
* \brief CSR Masking Operation
*/
#include <dgl/array.h>
#include <parallel_hashmap/phmap.h>
#include <vector>
#include "array_utils.h"
namespace dgl {
using dgl::runtime::NDArray;
namespace aten {
namespace {
// TODO(BarclayII): avoid using map for sorted CSRs
template <typename IdType, typename DType>
void ComputeValues(
const IdType* A_indptr,
const IdType* A_indices,
const IdType* A_eids,
const DType* A_data,
const IdType* B_indptr,
const IdType* B_indices,
const IdType* B_eids,
DType* C_data,
int64_t M) {
phmap::flat_hash_map<IdType, DType> map;
#pragma omp parallel for firstprivate(map)
for (IdType i = 0; i < M; ++i) {
map.clear();
for (IdType u = A_indptr[i]; u < A_indptr[i + 1]; ++u) {
IdType kA = A_indices[u];
map[kA] = A_data[A_eids ? A_eids[u] : u];
}
for (IdType v = B_indptr[i]; v < B_indptr[i + 1]; ++v) {
IdType kB = B_indices[v];
auto it = map.find(kB);
C_data[B_eids ? B_eids[v] : v] = (it != map.end()) ? it->second : 0;
}
}
}
}; // namespace
template <int XPU, typename IdType, typename DType>
NDArray CSRMask(
const CSRMatrix& A,
NDArray A_weights,
const CSRMatrix& B) {
CHECK_EQ(A.num_rows, B.num_rows) << "Number of rows must match.";
CHECK_EQ(A.num_cols, B.num_cols) << "Number of columns must match.";
const bool A_has_eid = !IsNullArray(A.data);
const bool B_has_eid = !IsNullArray(B.data);
const IdType* A_indptr = A.indptr.Ptr<IdType>();
const IdType* A_indices = A.indices.Ptr<IdType>();
const IdType* A_eids = A_has_eid ? A.data.Ptr<IdType>() : nullptr;
const IdType* B_indptr = B.indptr.Ptr<IdType>();
const IdType* B_indices = B.indices.Ptr<IdType>();
const IdType* B_eids = B_has_eid ? B.data.Ptr<IdType>() : nullptr;
const DType* A_data = A_weights.Ptr<DType>();
const int64_t M = A.num_rows;
const int64_t N = A.num_cols;
NDArray C_weights = NDArray::Empty({B.indices->shape[0]}, A_weights->dtype, A_weights->ctx);
DType* C_data = C_weights.Ptr<DType>();
ComputeValues(A_indptr, A_indices, A_eids, A_data, B_indptr, B_indices, B_eids, C_data, M);
return C_weights;
}
template NDArray CSRMask<kDLCPU, int32_t, float>(const CSRMatrix&, NDArray, const CSRMatrix&);
template NDArray CSRMask<kDLCPU, int64_t, float>(const CSRMatrix&, NDArray, const CSRMatrix&);
template NDArray CSRMask<kDLCPU, int32_t, double>(const CSRMatrix&, NDArray, const CSRMatrix&);
template NDArray CSRMask<kDLCPU, int64_t, double>(const CSRMatrix&, NDArray, const CSRMatrix&);
}; // namespace aten
}; // namespace dgl
...@@ -141,89 +141,6 @@ template NDArray CSRGetRowData<kDLCPU, int32_t>(CSRMatrix, int64_t); ...@@ -141,89 +141,6 @@ template NDArray CSRGetRowData<kDLCPU, int32_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowData<kDLCPU, int64_t>(CSRMatrix, int64_t); template NDArray CSRGetRowData<kDLCPU, int64_t>(CSRMatrix, int64_t);
///////////////////////////// CSRGetData ///////////////////////////// ///////////////////////////// CSRGetData /////////////////////////////
template <DLDeviceType XPU, typename IdType>
void CollectDataFromSorted(const IdType *indices_data, const IdType *data,
const IdType start, const IdType end, const IdType col,
std::vector<IdType> *ret_vec) {
const IdType *start_ptr = indices_data + start;
const IdType *end_ptr = indices_data + end;
auto it = std::lower_bound(start_ptr, end_ptr, col);
// This might be a multi-graph. We need to collect all of the matched
// columns.
for (; it != end_ptr; it++) {
// If the col exist
if (*it == col) {
IdType idx = it - indices_data;
ret_vec->push_back(data? data[idx] : idx);
} else {
// If we find a column that is different, we can stop searching now.
break;
}
}
}
template <DLDeviceType XPU, typename IdType>
IdArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
const int64_t rowlen = rows->shape[0];
const int64_t collen = cols->shape[0];
CHECK((rowlen == collen) || (rowlen == 1) || (collen == 1))
<< "Invalid row and col id array.";
const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
const IdType* row_data = static_cast<IdType*>(rows->data);
const IdType* col_data = static_cast<IdType*>(cols->data);
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
const IdType* data = CSRHasData(csr)? static_cast<IdType*>(csr.data->data) : nullptr;
const int64_t retlen = std::max(rowlen, collen);
IdArray ret = Full(-1, retlen, rows->dtype.bits, rows->ctx);
IdType* ret_data = ret.Ptr<IdType>();
// NOTE: In most cases, the input csr is already sorted. If not, we might need to
// consider sorting it especially when the number of (row, col) pairs is large.
// Need more benchmarks to justify the choice.
if (csr.sorted) {
// use binary search on each row
#pragma omp parallel for
for (int64_t p = 0; p < retlen; ++p) {
const IdType row_id = row_data[p * row_stride], col_id = col_data[p * col_stride];
CHECK(row_id >= 0 && row_id < csr.num_rows) << "Invalid row index: " << row_id;
CHECK(col_id >= 0 && col_id < csr.num_cols) << "Invalid col index: " << col_id;
const IdType *start_ptr = indices_data + indptr_data[row_id];
const IdType *end_ptr = indices_data + indptr_data[row_id + 1];
auto it = std::lower_bound(start_ptr, end_ptr, col_id);
if (it != end_ptr && *it == col_id) {
const IdType idx = it - indices_data;
ret_data[p] = data? data[idx] : idx;
}
}
} else {
// linear search on each row
#pragma omp parallel for
for (int64_t p = 0; p < retlen; ++p) {
const IdType row_id = row_data[p * row_stride], col_id = col_data[p * col_stride];
CHECK(row_id >= 0 && row_id < csr.num_rows) << "Invalid row index: " << row_id;
CHECK(col_id >= 0 && col_id < csr.num_cols) << "Invalid col index: " << col_id;
for (IdType idx = indptr_data[row_id]; idx < indptr_data[row_id + 1]; ++idx) {
if (indices_data[idx] == col_id) {
ret_data[p] = data? data[idx] : idx;
break;
}
}
}
}
return ret;
}
template NDArray CSRGetData<kDLCPU, int32_t>(CSRMatrix csr, NDArray rows, NDArray cols);
template NDArray CSRGetData<kDLCPU, int64_t>(CSRMatrix csr, NDArray rows, NDArray cols);
///////////////////////////// CSRGetDataAndIndices ///////////////////////////// ///////////////////////////// CSRGetDataAndIndices /////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
......
...@@ -197,9 +197,9 @@ template IdArray UnaryElewise<kDLGPU, int64_t, arith::Neg>(IdArray lhs); ...@@ -197,9 +197,9 @@ template IdArray UnaryElewise<kDLGPU, int64_t, arith::Neg>(IdArray lhs);
///////////////////////////// Full ///////////////////////////// ///////////////////////////// Full /////////////////////////////
template <typename IdType> template <typename DType>
__global__ void _FullKernel( __global__ void _FullKernel(
IdType* out, int64_t length, IdType val) { DType* out, int64_t length, DType val) {
int tx = blockIdx.x * blockDim.x + threadIdx.x; int tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = gridDim.x * blockDim.x; int stride_x = gridDim.x * blockDim.x;
while (tx < length) { while (tx < length) {
...@@ -208,20 +208,22 @@ __global__ void _FullKernel( ...@@ -208,20 +208,22 @@ __global__ void _FullKernel(
} }
} }
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename DType>
IdArray Full(IdType val, int64_t length, DLContext ctx) { NDArray Full(DType val, int64_t length, DLContext ctx) {
IdArray ret = NewIdArray(length, ctx, sizeof(IdType) * 8); NDArray ret = NDArray::Empty({length}, DLDataTypeTraits<DType>::dtype, ctx);
IdType* ret_data = static_cast<IdType*>(ret->data); DType* ret_data = static_cast<DType*>(ret->data);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
int nt = cuda::FindNumThreads(length); int nt = cuda::FindNumThreads(length);
int nb = (length + nt - 1) / nt; int nb = (length + nt - 1) / nt;
CUDA_KERNEL_CALL((_FullKernel<IdType>), nb, nt, 0, thr_entry->stream, CUDA_KERNEL_CALL((_FullKernel<DType>), nb, nt, 0, thr_entry->stream,
ret_data, length, val); ret_data, length, val);
return ret; return ret;
} }
template IdArray Full<kDLGPU, int32_t>(int32_t val, int64_t length, DLContext ctx); template IdArray Full<kDLGPU, int32_t>(int32_t val, int64_t length, DLContext ctx);
template IdArray Full<kDLGPU, int64_t>(int64_t val, int64_t length, DLContext ctx); template IdArray Full<kDLGPU, int64_t>(int64_t val, int64_t length, DLContext ctx);
template IdArray Full<kDLGPU, float>(float val, int64_t length, DLContext ctx);
template IdArray Full<kDLGPU, double>(double val, int64_t length, DLContext ctx);
///////////////////////////// Range ///////////////////////////// ///////////////////////////// Range /////////////////////////////
......
/*!
* Copyright (c) 2021 by Contributors
* \file array/cuda/csr_get_data.cu
* \brief Retrieve entries of a CSR matrix
*/
#include <dgl/array.h>
#include <vector>
#include <unordered_set>
#include <numeric>
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType, typename DType>
NDArray CSRGetData(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, DType filler) {
const int64_t rowlen = rows->shape[0];
const int64_t collen = cols->shape[0];
CHECK((rowlen == collen) || (rowlen == 1) || (collen == 1))
<< "Invalid row and col id array.";
const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
const int64_t rstlen = std::max(rowlen, collen);
IdArray rst = NDArray::Empty({rstlen}, weights->dtype, rows->ctx);
if (rstlen == 0)
return rst;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
const int nt = cuda::FindNumThreads(rstlen);
const int nb = (rstlen + nt - 1) / nt;
bool return_eids = IsNullArray(weights);
if (return_eids)
BUG_IF_FAIL(DLDataTypeTraits<DType>::dtype == rows->dtype) <<
"DType does not match row's dtype.";
// TODO(minjie): use binary search for sorted csr
CUDA_KERNEL_CALL(cuda::_LinearSearchKernel,
nb, nt, 0, thr_entry->stream,
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(),
CSRHasData(csr)? csr.data.Ptr<IdType>() : nullptr,
rows.Ptr<IdType>(), cols.Ptr<IdType>(),
row_stride, col_stride, rstlen,
return_eids ? nullptr : weights.Ptr<DType>(), filler, rst.Ptr<DType>());
return rst;
}
template NDArray CSRGetData<kDLGPU, int32_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, float filler);
template NDArray CSRGetData<kDLGPU, int64_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, float filler);
template NDArray CSRGetData<kDLGPU, int32_t, double>(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, double filler);
template NDArray CSRGetData<kDLGPU, int64_t, double>(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, double filler);
// For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray)
template NDArray CSRGetData<kDLGPU, int32_t, int32_t>(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, int32_t filler);
template NDArray CSRGetData<kDLGPU, int64_t, int64_t>(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, int64_t filler);
} // namespace impl
} // namespace aten
} // namespace dgl
/*!
* Copyright (c) 2020 by Contributors
* \file array/cuda/csr_mm.cu
* \brief SpSpMM/SpGEMM C APIs and definitions.
*/
#include <dgl/array.h>
#include <dgl/runtime/device_api.h>
#include "./functor.cuh"
#include "./cusparse_dispatcher.cuh"
#include "../../runtime/cuda/cuda_common.h"
namespace dgl {
using namespace dgl::runtime;
namespace aten {
namespace cusparse {
#if __CUDACC_VER_MAJOR__ == 11
/*! \brief Cusparse implementation of SpGEMM on Csr format for CUDA 11.0+ */
template <typename DType, typename IdType>
std::pair<CSRMatrix, NDArray> CusparseSpgemm(
const CSRMatrix& A,
const NDArray A_weights_array,
const CSRMatrix& B,
const NDArray B_weights_array) {
// We use Spgemm (SpSpMM) to perform following operation:
// C = A x B, where A, B and C are sparse matrices in csr format.
const int nnzA = A.indices->shape[0];
const int nnzB = B.indices->shape[0];
const DType alpha = 1.0;
const DType beta = 0.0;
auto transA = CUSPARSE_OPERATION_NON_TRANSPOSE;
auto transB = CUSPARSE_OPERATION_NON_TRANSPOSE;
// device
auto ctx = A.indptr->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
const DType* A_weights = A_weights_array.Ptr<DType>();
const DType* B_weights = B_weights_array.Ptr<DType>();
// allocate cusparse handle if needed
if (!thr_entry->cusparse_handle) {
CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
}
CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, thr_entry->stream));
// all one data array
cusparseSpMatDescr_t matA, matB, matC;
IdArray dC_csrOffsets = IdArray::Empty({A.num_rows+1}, A.indptr->dtype, A.indptr->ctx);
IdType* dC_csrOffsets_data = dC_csrOffsets.Ptr<IdType>();
constexpr auto idtype = cusparse_idtype<IdType>::value;
constexpr auto dtype = cuda_dtype<DType>::value;
// Create sparse matrix A, B and C in CSR format
CUSPARSE_CALL(cusparseCreateCsr(&matA,
A.num_rows, A.num_cols, nnzA,
A.indptr.Ptr<DType>(),
A.indices.Ptr<DType>(),
const_cast<DType*>(A_weights), // cusparseCreateCsr only accepts non-const pointers
idtype, idtype, CUSPARSE_INDEX_BASE_ZERO, dtype));
CUSPARSE_CALL(cusparseCreateCsr(&matB,
B.num_rows, B.num_cols, nnzB,
B.indptr.Ptr<DType>(),
B.indices.Ptr<DType>(),
const_cast<DType*>(B_weights), // cusparseCreateCsr only accepts non-const pointers
idtype, idtype, CUSPARSE_INDEX_BASE_ZERO, dtype));
CUSPARSE_CALL(cusparseCreateCsr(&matC,
A.num_rows, B.num_cols, 0,
nullptr, nullptr, nullptr, idtype, idtype,
CUSPARSE_INDEX_BASE_ZERO, dtype));
// SpGEMM Computation
cusparseSpGEMMDescr_t spgemmDesc;
CUSPARSE_CALL(cusparseSpGEMM_createDescr(&spgemmDesc));
size_t workspace_size1 = 0, workspace_size2 = 0;
// ask bufferSize1 bytes for external memory
CUSPARSE_CALL(cusparseSpGEMM_workEstimation(
thr_entry->cusparse_handle, transA, transB,
&alpha, matA, matB, &beta, matC, dtype,
CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &workspace_size1,
NULL));
void* workspace1 = (device->AllocWorkspace(ctx, workspace_size1));
// inspect the matrices A and B to understand the memory requiremnent
// for the next step
CUSPARSE_CALL(cusparseSpGEMM_workEstimation(
thr_entry->cusparse_handle, transA, transB,
&alpha, matA, matB, &beta, matC, dtype,
CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &workspace_size1,
workspace1));
// ask bufferSize2 bytes for external memory
CUSPARSE_CALL(cusparseSpGEMM_compute(thr_entry->cusparse_handle,
transA, transB, &alpha, matA, matB, &beta, matC,
dtype, CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &workspace_size2,
NULL));
void* workspace2 = device->AllocWorkspace(ctx, workspace_size2);
// compute the intermediate product of A * B
CUSPARSE_CALL(cusparseSpGEMM_compute(thr_entry->cusparse_handle,
transA, transB, &alpha, matA, matB, &beta, matC,
dtype, CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &workspace_size2,
workspace2));
// get matrix C non-zero entries C_nnz1
int64_t C_num_rows1, C_num_cols1, C_nnz1;
CUSPARSE_CALL(cusparseSpMatGetSize(matC, &C_num_rows1, &C_num_cols1, &C_nnz1));
IdArray dC_columns = IdArray::Empty({C_nnz1}, A.indptr->dtype, A.indptr->ctx);
NDArray dC_weights = NDArray::Empty({C_nnz1}, A_weights_array->dtype, A.indptr->ctx);
IdType* dC_columns_data = dC_columns.Ptr<IdType>();
DType* dC_weights_data = dC_weights.Ptr<DType>();
// update matC with the new pointers
CUSPARSE_CALL(cusparseCsrSetPointers(matC, dC_csrOffsets_data,
dC_columns_data, dC_weights_data));
// copy the final products to the matrix C
CUSPARSE_CALL(cusparseSpGEMM_copy(thr_entry->cusparse_handle,
transA, transB, &alpha, matA, matB, &beta, matC,
dtype, CUSPARSE_SPGEMM_DEFAULT, spgemmDesc));
device->FreeWorkspace(ctx, workspace1);
device->FreeWorkspace(ctx, workspace2);
// destroy matrix/vector descriptors
CUSPARSE_CALL(cusparseSpGEMM_destroyDescr(spgemmDesc));
CUSPARSE_CALL(cusparseDestroySpMat(matA));
CUSPARSE_CALL(cusparseDestroySpMat(matB));
CUSPARSE_CALL(cusparseDestroySpMat(matC));
return {CSRMatrix(A.num_rows, B.num_cols, dC_csrOffsets, dC_columns), dC_weights};
}
#else // __CUDACC_VER_MAJOR__ != 11
/*! \brief Cusparse implementation of SpGEMM on Csr format for older CUDA versions */
template <typename DType, typename IdType>
std::pair<CSRMatrix, NDArray> CusparseSpgemm(
const CSRMatrix& A,
const NDArray A_weights_array,
const CSRMatrix& B,
const NDArray B_weights_array) {
int nnzC;
csrgemm2Info_t info = nullptr;
size_t workspace_size;
const DType alpha = 1.;
const int nnzA = A.indices->shape[0];
const int nnzB = B.indices->shape[0];
const int m = A.num_rows;
const int n = A.num_cols;
const int k = B.num_cols;
auto ctx = A.indptr->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
auto idtype = A.indptr->dtype;
auto dtype = A_weights_array->dtype;
const DType* A_weights = A_weights_array.Ptr<DType>();
const DType* B_weights = B_weights_array.Ptr<DType>();
if (!thr_entry->cusparse_handle) {
CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
}
CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, thr_entry->stream));
CUSPARSE_CALL(cusparseSetPointerMode(
thr_entry->cusparse_handle, CUSPARSE_POINTER_MODE_HOST));
CUSPARSE_CALL(cusparseCreateCsrgemm2Info(&info));
cusparseMatDescr_t matA, matB, matC, matD;
CUSPARSE_CALL(cusparseCreateMatDescr(&matA));
CUSPARSE_CALL(cusparseCreateMatDescr(&matB));
CUSPARSE_CALL(cusparseCreateMatDescr(&matC));
CUSPARSE_CALL(cusparseCreateMatDescr(&matD)); // needed even if D is null
CUSPARSE_CALL(CSRGEMM<DType>::bufferSizeExt(thr_entry->cusparse_handle,
m, n, k, &alpha,
matA, nnzA, A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(),
matB, nnzB, B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(),
nullptr,
matD, 0, nullptr, nullptr,
info,
&workspace_size));
void *workspace = device->AllocWorkspace(ctx, workspace_size);
IdArray C_indptr = IdArray::Empty({m + 1}, idtype, ctx);
CUSPARSE_CALL(CSRGEMM<DType>::nnz(thr_entry->cusparse_handle,
m, n, k,
matA, nnzA, A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(),
matB, nnzB, B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(),
matD, 0, nullptr, nullptr,
matC, C_indptr.Ptr<IdType>(), &nnzC, info, workspace));
IdArray C_indices = IdArray::Empty({nnzC}, idtype, ctx);
NDArray C_weights = NDArray::Empty({nnzC}, dtype, ctx);
CUSPARSE_CALL(CSRGEMM<DType>::compute(thr_entry->cusparse_handle,
m, n, k, &alpha,
matA, nnzA, A_weights, A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(),
matB, nnzB, B_weights, B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(),
nullptr,
matD, 0, nullptr, nullptr, nullptr,
matC, C_weights.Ptr<DType>(), C_indptr.Ptr<IdType>(), C_indices.Ptr<IdType>(),
info, workspace));
device->FreeWorkspace(ctx, workspace);
CUSPARSE_CALL(cusparseDestroyCsrgemm2Info(info));
CUSPARSE_CALL(cusparseDestroyMatDescr(matA));
CUSPARSE_CALL(cusparseDestroyMatDescr(matB));
CUSPARSE_CALL(cusparseDestroyMatDescr(matC));
CUSPARSE_CALL(cusparseDestroyMatDescr(matD));
return {CSRMatrix(m, k, C_indptr, C_indices), C_weights};
}
#endif // __CUDACC_VER_MAJOR__ == 11
} // namespace cusparse
template <int XPU, typename IdType, typename DType>
std::pair<CSRMatrix, NDArray> CSRMM(
const CSRMatrix& A,
NDArray A_weights,
const CSRMatrix& B,
NDArray B_weights) {
auto ctx = A.indptr->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
CSRMatrix newA, newB;
bool cast = false;
// Cast 64 bit indices to 32 bit.
if (A.indptr->dtype.bits == 64) {
newA = CSRMatrix(
A.num_rows, A.num_cols,
AsNumBits(A.indptr, 32), AsNumBits(A.indices, 32), AsNumBits(A.data, 32));
newB = CSRMatrix(
B.num_rows, B.num_cols,
AsNumBits(B.indptr, 32), AsNumBits(B.indices, 32), AsNumBits(B.data, 32));
cast = true;
}
// Reorder weights if A or B has edge IDs
NDArray newA_weights, newB_weights;
if (CSRHasData(A))
newA_weights = IndexSelect(A_weights, A.data);
if (CSRHasData(B))
newB_weights = IndexSelect(B_weights, B.data);
auto result = cusparse::CusparseSpgemm<DType, int32_t>(
cast ? newA : A, CSRHasData(A) ? newA_weights : A_weights,
cast ? newB : B, CSRHasData(B) ? newB_weights : B_weights);
// Cast 32 bit indices back to 64 bit if necessary
if (cast) {
CSRMatrix C = result.first;
return {
CSRMatrix(C.num_rows, C.num_cols, AsNumBits(C.indptr, 64), AsNumBits(C.indices, 64)),
result.second};
} else {
return result;
}
}
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int32_t, float>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int64_t, float>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int32_t, double>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int64_t, double>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
} // namespace aten
} // namespace dgl
/*!
* Copyright (c) 2020 by Contributors
* \file array/cuda/spmm.cu
* \brief SpGEAM C APIs and definitions.
*/
#include <dgl/array.h>
#include <dgl/runtime/device_api.h>
#include "./functor.cuh"
#include "./cusparse_dispatcher.cuh"
#include "../../runtime/cuda/cuda_common.h"
namespace dgl {
using namespace dgl::runtime;
namespace aten {
namespace cusparse {
/*! Cusparse implementation of SpSum on Csr format. */
template <typename DType, typename IdType>
std::pair<CSRMatrix, NDArray> CusparseCsrgeam2(
const CSRMatrix& A,
const NDArray A_weights_array,
const CSRMatrix& B,
const NDArray B_weights_array) {
const int m = A.num_rows;
const int n = A.num_cols;
const int nnzA = A.indices->shape[0];
const int nnzB = B.indices->shape[0];
int nnzC;
const DType alpha = 1.0;
const DType beta = 1.0;
auto ctx = A.indptr->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
const DType* A_weights = A_weights_array.Ptr<DType>();
const DType* B_weights = B_weights_array.Ptr<DType>();
// allocate cusparse handle if needed
if (!thr_entry->cusparse_handle)
CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, thr_entry->stream));
cusparseMatDescr_t matA, matB, matC;
CUSPARSE_CALL(cusparseCreateMatDescr(&matA));
CUSPARSE_CALL(cusparseCreateMatDescr(&matB));
CUSPARSE_CALL(cusparseCreateMatDescr(&matC));
cusparseSetPointerMode(thr_entry->cusparse_handle, CUSPARSE_POINTER_MODE_HOST);
size_t workspace_size = 0;
/* prepare output C */
IdArray dC_csrOffsets = IdArray::Empty({A.num_rows+1}, A.indptr->dtype, ctx);
IdType* dC_csrOffsets_data = dC_csrOffsets.Ptr<IdType>();
IdArray dC_columns;
NDArray dC_weights;
IdType* dC_columns_data = dC_columns.Ptr<IdType>();
DType* dC_weights_data = dC_weights.Ptr<DType>();
/* prepare buffer */
CUSPARSE_CALL(CSRGEAM<DType>::bufferSizeExt(
thr_entry->cusparse_handle, m, n, &alpha,
matA, nnzA, A_weights,
A.indptr.Ptr<IdType>(),
A.indices.Ptr<IdType>(),
&beta, matB, nnzB, B_weights,
B.indptr.Ptr<IdType>(),
B.indices.Ptr<IdType>(),
matC, dC_weights_data, dC_csrOffsets_data, dC_columns_data,
&workspace_size));
void *workspace = device->AllocWorkspace(ctx, workspace_size);
CUSPARSE_CALL(CSRGEAM<DType>::nnz(thr_entry->cusparse_handle,
m, n, matA, nnzA,
A.indptr.Ptr<IdType>(),
A.indices.Ptr<IdType>(),
matB, nnzB,
B.indptr.Ptr<IdType>(),
B.indices.Ptr<IdType>(),
matC, dC_csrOffsets_data, &nnzC, workspace));
dC_columns = IdArray::Empty({nnzC}, A.indptr->dtype, ctx);
dC_weights = NDArray::Empty({nnzC}, A_weights_array->dtype, ctx);
dC_columns_data = dC_columns.Ptr<IdType>();
dC_weights_data = dC_weights.Ptr<DType>();
CUSPARSE_CALL(CSRGEAM<DType>::compute(
thr_entry->cusparse_handle, m, n, &alpha,
matA, nnzA, A_weights,
A.indptr.Ptr<IdType>(),
A.indices.Ptr<IdType>(),
&beta, matB, nnzB, B_weights,
B.indptr.Ptr<IdType>(),
B.indices.Ptr<IdType>(),
matC, dC_weights_data, dC_csrOffsets_data, dC_columns_data,
workspace));
device->FreeWorkspace(ctx, workspace);
// destroy matrix/vector descriptors
CUSPARSE_CALL(cusparseDestroyMatDescr(matA));
CUSPARSE_CALL(cusparseDestroyMatDescr(matB));
CUSPARSE_CALL(cusparseDestroyMatDescr(matC));
return {CSRMatrix(A.num_rows, A.num_cols, dC_csrOffsets, dC_columns),
dC_weights};
}
} // namespace cusparse
template <int XPU, typename IdType, typename DType>
std::pair<CSRMatrix, NDArray> CSRSum(
const std::vector<CSRMatrix>& As,
const std::vector<NDArray>& A_weights) {
const int64_t M = As[0].num_rows;
const int64_t N = As[0].num_cols;
const int64_t n = As.size();
// Cast 64 bit indices to 32 bit
std::vector<CSRMatrix> newAs;
bool cast = false;
if (As[0].indptr->dtype.bits == 64) {
newAs.reserve(n);
for (int i = 0; i < n; ++i)
newAs.emplace_back(
As[i].num_rows, As[i].num_cols, AsNumBits(As[i].indptr, 32),
AsNumBits(As[i].indices, 32), AsNumBits(As[i].data, 32));
cast = true;
}
const std::vector<CSRMatrix> &As_ref = cast ? newAs : As;
// Reorder weights if A[i] has edge IDs
std::vector<NDArray> A_weights_reordered(n);
for (int i = 0; i < n; ++i) {
if (CSRHasData(As[i]))
A_weights_reordered[i] = IndexSelect(A_weights[i], As[i].data);
else
A_weights_reordered[i] = A_weights[i];
}
// Loop and sum
auto result = std::make_pair(
CSRMatrix(
As_ref[0].num_rows, As_ref[0].num_cols,
As_ref[0].indptr, As_ref[0].indices),
A_weights_reordered[0]); // Weights already reordered so we don't need As[0].data
for (int64_t i = 1; i < n; ++i)
result = cusparse::CusparseCsrgeam2<DType, int32_t>(
result.first, result.second, As_ref[i], A_weights_reordered[i]);
// Cast 32 bit indices back to 64 bit if necessary
if (cast) {
CSRMatrix C = result.first;
return {
CSRMatrix(C.num_rows, C.num_cols, AsNumBits(C.indptr, 64), AsNumBits(C.indices, 64)),
result.second};
} else {
return result;
}
}
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int32_t, float>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int64_t, float>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int32_t, double>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int64_t, double>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
} // namespace aten
} // namespace dgl
/*!
* Copyright (c) 2020 by Contributors
* \file array/cuda/dispatcher.cuh
* \brief Templates to dispatch into different cuSPARSE routines based on the type
* argument.
*/
#ifndef DGL_ARRAY_CUDA_CUSPARSE_DISPATCHER_CUH_
#define DGL_ARRAY_CUDA_CUSPARSE_DISPATCHER_CUH_
#include <cusparse.h>
#include <dgl/runtime/c_runtime_api.h>
namespace dgl {
namespace aten {
/*! \brief cusparseXcsrgemm dispatcher */
template <typename DType>
struct CSRGEMM {
template <typename... Args>
static inline cusparseStatus_t bufferSizeExt(Args... args) {
BUG_IF_FAIL(false) << "This piece of code should not be reached.";
return 0;
}
template <typename... Args>
static inline cusparseStatus_t nnz(Args... args) {
return cusparseXcsrgemm2Nnz(args...);
}
template <typename... Args>
static inline cusparseStatus_t compute(Args... args) {
BUG_IF_FAIL(false) << "This piece of code should not be reached.";
return 0;
}
};
template <>
struct CSRGEMM<float> {
template <typename... Args>
static inline cusparseStatus_t bufferSizeExt(Args... args) {
return cusparseScsrgemm2_bufferSizeExt(args...);
}
template <typename... Args>
static inline cusparseStatus_t nnz(Args... args) {
return cusparseXcsrgemm2Nnz(args...);
}
template <typename... Args>
static inline cusparseStatus_t compute(Args... args) {
return cusparseScsrgemm2(args...);
}
};
template <>
struct CSRGEMM<double> {
template <typename... Args>
static inline cusparseStatus_t bufferSizeExt(Args... args) {
return cusparseDcsrgemm2_bufferSizeExt(args...);
}
template <typename... Args>
static inline cusparseStatus_t nnz(Args... args) {
return cusparseXcsrgemm2Nnz(args...);
}
template <typename... Args>
static inline cusparseStatus_t compute(Args... args) {
return cusparseDcsrgemm2(args...);
}
};
/*! \brief cusparseXcsrgeam dispatcher */
template <typename DType>
struct CSRGEAM {
template <typename... Args>
static inline cusparseStatus_t bufferSizeExt(Args... args) {
BUG_IF_FAIL(false) << "This piece of code should not be reached.";
return 0;
}
template <typename... Args>
static inline cusparseStatus_t nnz(Args... args) {
return cusparseXcsrgeam2Nnz(args...);
}
template <typename... Args>
static inline cusparseStatus_t compute(Args... args) {
BUG_IF_FAIL(false) << "This piece of code should not be reached.";
return 0;
}
};
template <>
struct CSRGEAM<float> {
template <typename... Args>
static inline cusparseStatus_t bufferSizeExt(Args... args) {
return cusparseScsrgeam2_bufferSizeExt(args...);
}
template <typename... Args>
static inline cusparseStatus_t nnz(Args... args) {
return cusparseXcsrgeam2Nnz(args...);
}
template <typename... Args>
static inline cusparseStatus_t compute(Args... args) {
return cusparseScsrgeam2(args...);
}
};
template <>
struct CSRGEAM<double> {
template <typename... Args>
static inline cusparseStatus_t bufferSizeExt(Args... args) {
return cusparseDcsrgeam2_bufferSizeExt(args...);
}
template <typename... Args>
static inline cusparseStatus_t nnz(Args... args) {
return cusparseXcsrgeam2Nnz(args...);
}
template <typename... Args>
static inline cusparseStatus_t compute(Args... args) {
return cusparseDcsrgeam2(args...);
}
};
}; // namespace aten
}; // namespace dgl
#endif // DGL_ARRAY_CUDA_CUSPARSE_DISPATCHER_CUH_
...@@ -17,37 +17,6 @@ using runtime::NDArray; ...@@ -17,37 +17,6 @@ using runtime::NDArray;
namespace aten { namespace aten {
namespace impl { namespace impl {
/*!
* \brief Search adjacency list linearly for each (row, col) pair and
* write the data under the matched position in the indices array to the output.
*
* If there is no match, -1 is written.
* If there are multiple matches, only the first match is written.
* If the given data array is null, write the matched position to the output.
*/
template <typename IdType>
__global__ void _LinearSearchKernel(
const IdType* indptr, const IdType* indices, const IdType* data,
const IdType* row, const IdType* col,
int64_t row_stride, int64_t col_stride,
int64_t length, IdType* out) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x;
while (tx < length) {
int rpos = tx * row_stride, cpos = tx * col_stride;
IdType v = -1;
const IdType r = row[rpos], c = col[cpos];
for (IdType i = indptr[r]; i < indptr[r + 1]; ++i) {
if (indices[i] == c) {
v = (data)? data[i] : i;
break;
}
}
out[tx] = v;
tx += stride_x;
}
}
///////////////////////////// CSRIsNonZero ///////////////////////////// ///////////////////////////// CSRIsNonZero /////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
...@@ -61,12 +30,12 @@ bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) { ...@@ -61,12 +30,12 @@ bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
IdArray out = aten::NewIdArray(1, ctx, sizeof(IdType) * 8); IdArray out = aten::NewIdArray(1, ctx, sizeof(IdType) * 8);
const IdType* data = nullptr; const IdType* data = nullptr;
// TODO(minjie): use binary search for sorted csr // TODO(minjie): use binary search for sorted csr
CUDA_KERNEL_CALL(_LinearSearchKernel, CUDA_KERNEL_CALL(cuda::_LinearSearchKernel,
1, 1, 0, thr_entry->stream, 1, 1, 0, thr_entry->stream,
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), data, csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), data,
rows.Ptr<IdType>(), cols.Ptr<IdType>(), rows.Ptr<IdType>(), cols.Ptr<IdType>(),
1, 1, 1, 1, 1, 1,
out.Ptr<IdType>()); static_cast<IdType*>(nullptr), static_cast<IdType>(-1), out.Ptr<IdType>());
out = out.CopyTo(DLContext{kDLCPU, 0}); out = out.CopyTo(DLContext{kDLCPU, 0});
return *out.Ptr<IdType>() != -1; return *out.Ptr<IdType>() != -1;
} }
...@@ -89,12 +58,12 @@ NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) { ...@@ -89,12 +58,12 @@ NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
const int nb = (rstlen + nt - 1) / nt; const int nb = (rstlen + nt - 1) / nt;
const IdType* data = nullptr; const IdType* data = nullptr;
// TODO(minjie): use binary search for sorted csr // TODO(minjie): use binary search for sorted csr
CUDA_KERNEL_CALL(_LinearSearchKernel, CUDA_KERNEL_CALL(cuda::_LinearSearchKernel,
nb, nt, 0, thr_entry->stream, nb, nt, 0, thr_entry->stream,
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), data, csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), data,
row.Ptr<IdType>(), col.Ptr<IdType>(), row.Ptr<IdType>(), col.Ptr<IdType>(),
row_stride, col_stride, rstlen, row_stride, col_stride, rstlen,
rst.Ptr<IdType>()); static_cast<IdType*>(nullptr), static_cast<IdType>(-1), rst.Ptr<IdType>());
return rst != -1; return rst != -1;
} }
...@@ -305,41 +274,6 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) { ...@@ -305,41 +274,6 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
template CSRMatrix CSRSliceRows<kDLGPU, int32_t>(CSRMatrix , NDArray); template CSRMatrix CSRSliceRows<kDLGPU, int32_t>(CSRMatrix , NDArray);
template CSRMatrix CSRSliceRows<kDLGPU, int64_t>(CSRMatrix , NDArray); template CSRMatrix CSRSliceRows<kDLGPU, int64_t>(CSRMatrix , NDArray);
///////////////////////////// CSRGetData /////////////////////////////
template <DLDeviceType XPU, typename IdType>
IdArray CSRGetData(CSRMatrix csr, NDArray row, NDArray col) {
const int64_t rowlen = row->shape[0];
const int64_t collen = col->shape[0];
CHECK((rowlen == collen) || (rowlen == 1) || (collen == 1))
<< "Invalid row and col id array.";
const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
const int64_t rstlen = std::max(rowlen, collen);
IdArray rst = NDArray::Empty({rstlen}, row->dtype, row->ctx);
if (rstlen == 0)
return rst;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
const int nt = cuda::FindNumThreads(rstlen);
const int nb = (rstlen + nt - 1) / nt;
// TODO(minjie): use binary search for sorted csr
CUDA_KERNEL_CALL(_LinearSearchKernel,
nb, nt, 0, thr_entry->stream,
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(),
CSRHasData(csr)? csr.data.Ptr<IdType>() : nullptr,
row.Ptr<IdType>(), col.Ptr<IdType>(),
row_stride, col_stride, rstlen,
rst.Ptr<IdType>());
return rst;
}
template NDArray CSRGetData<kDLGPU, int32_t>(CSRMatrix csr, NDArray rows, NDArray cols);
template NDArray CSRGetData<kDLGPU, int64_t>(CSRMatrix csr, NDArray rows, NDArray cols);
///////////////////////////// CSRGetDataAndIndices ///////////////////////////// ///////////////////////////// CSRGetDataAndIndices /////////////////////////////
/*! /*!
......
...@@ -138,6 +138,40 @@ void _Fill(DType* ptr, size_t length, DType val) { ...@@ -138,6 +138,40 @@ void _Fill(DType* ptr, size_t length, DType val) {
CUDA_KERNEL_CALL(cuda::_FillKernel, nb, nt, 0, thr_entry->stream, ptr, length, val); CUDA_KERNEL_CALL(cuda::_FillKernel, nb, nt, 0, thr_entry->stream, ptr, length, val);
} }
/*!
* \brief Search adjacency list linearly for each (row, col) pair and
* write the data under the matched position in the indices array to the output.
*
* If there is no match, the value in \c filler is written.
* If there are multiple matches, only the first match is written.
* If the given data array is null, write the matched position to the output.
*/
template <typename IdType, typename DType>
__global__ void _LinearSearchKernel(
const IdType* indptr, const IdType* indices, const IdType* data,
const IdType* row, const IdType* col,
int64_t row_stride, int64_t col_stride,
int64_t length, const DType* weights, DType filler, DType* out) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x;
while (tx < length) {
int rpos = tx * row_stride, cpos = tx * col_stride;
IdType v = -1;
const IdType r = row[rpos], c = col[cpos];
for (IdType i = indptr[r]; i < indptr[r + 1]; ++i) {
if (indices[i] == c) {
v = data ? data[i] : i;
break;
}
}
if (v == -1)
out[tx] = filler;
else
out[tx] = weights ? weights[v] : v;
tx += stride_x;
}
}
} // namespace cuda } // namespace cuda
} // namespace dgl } // namespace dgl
......
...@@ -134,6 +134,9 @@ std::pair<CSRMatrix, NDArray> CSRMM( ...@@ -134,6 +134,9 @@ std::pair<CSRMatrix, NDArray> CSRMM(
NDArray A_weights, NDArray A_weights,
CSRMatrix B, CSRMatrix B,
NDArray B_weights) { NDArray B_weights) {
CHECK_EQ(A.num_cols, B.num_rows) <<
"The number of nodes of destination node type of the first graph must be the "
"same as the number of nodes of source node type of the second graph.";
CheckCtx( CheckCtx(
A.indptr->ctx, A.indptr->ctx,
{A_weights, B_weights}, {A_weights, B_weights},
...@@ -143,8 +146,7 @@ std::pair<CSRMatrix, NDArray> CSRMM( ...@@ -143,8 +146,7 @@ std::pair<CSRMatrix, NDArray> CSRMM(
CHECK_EQ(A_weights->dtype, B_weights->dtype) << "Data types of two edge weights must match."; CHECK_EQ(A_weights->dtype, B_weights->dtype) << "Data types of two edge weights must match.";
std::pair<CSRMatrix, NDArray> ret; std::pair<CSRMatrix, NDArray> ret;
// TODO(BarclayII): change to ATEN_XPU_SWITCH_CUDA once the GPU kernels are implemented ATEN_XPU_SWITCH_CUDA(A.indptr->ctx.device_type, XPU, "CSRMM", {
ATEN_XPU_SWITCH(A.indptr->ctx.device_type, XPU, "CSRMM", {
ATEN_ID_TYPE_SWITCH(A.indptr->dtype, IdType, { ATEN_ID_TYPE_SWITCH(A.indptr->dtype, IdType, {
ATEN_FLOAT_TYPE_SWITCH(A_weights->dtype, DType, "Edge weights", { ATEN_FLOAT_TYPE_SWITCH(A_weights->dtype, DType, "Edge weights", {
ret = CSRMM<XPU, IdType, DType>(A, A_weights, B, B_weights); ret = CSRMM<XPU, IdType, DType>(A, A_weights, B, B_weights);
...@@ -160,9 +162,11 @@ std::pair<CSRMatrix, NDArray> CSRSum( ...@@ -160,9 +162,11 @@ std::pair<CSRMatrix, NDArray> CSRSum(
CHECK(A.size() > 0) << "The list of graphs must not be empty."; CHECK(A.size() > 0) << "The list of graphs must not be empty.";
CHECK_EQ(A.size(), A_weights.size()) << CHECK_EQ(A.size(), A_weights.size()) <<
"The list of edge weights must have the same length as the list of graphs."; "The list of edge weights must have the same length as the list of graphs.";
auto ctx = A[0].indptr->ctx; const auto ctx = A[0].indptr->ctx;
auto idtype = A[0].indptr->dtype; const auto idtype = A[0].indptr->dtype;
auto dtype = A_weights[0]->dtype; const auto dtype = A_weights[0]->dtype;
const auto num_rows = A[0].num_rows;
const auto num_cols = A[0].num_cols;
for (size_t i = 0; i < A.size(); ++i) { for (size_t i = 0; i < A.size(); ++i) {
CHECK_EQ(A[i].indptr->ctx, ctx) << "The devices of all graphs must be equal."; CHECK_EQ(A[i].indptr->ctx, ctx) << "The devices of all graphs must be equal.";
CHECK_EQ(A[i].indptr->dtype, idtype) << "The ID types of all graphs must be equal."; CHECK_EQ(A[i].indptr->dtype, idtype) << "The ID types of all graphs must be equal.";
...@@ -172,11 +176,12 @@ std::pair<CSRMatrix, NDArray> CSRSum( ...@@ -172,11 +176,12 @@ std::pair<CSRMatrix, NDArray> CSRSum(
"The devices of edge weights must be the same as that of the graphs."; "The devices of edge weights must be the same as that of the graphs.";
CHECK_EQ(A_weights[i]->dtype, dtype) << CHECK_EQ(A_weights[i]->dtype, dtype) <<
"The data types of all edge weights must be equal."; "The data types of all edge weights must be equal.";
CHECK_EQ(A[i].num_rows, num_rows) << "Graphs must have the same number of nodes.";
CHECK_EQ(A[i].num_cols, num_cols) << "Graphs must have the same number of nodes.";
} }
std::pair<CSRMatrix, NDArray> ret; std::pair<CSRMatrix, NDArray> ret;
// TODO(BarclayII): change to ATEN_XPU_SWITCH_CUDA once the GPU kernels are implemented ATEN_XPU_SWITCH_CUDA(ctx.device_type, XPU, "CSRSum", {
ATEN_XPU_SWITCH(ctx.device_type, XPU, "CSRSum", {
ATEN_ID_TYPE_SWITCH(idtype, IdType, { ATEN_ID_TYPE_SWITCH(idtype, IdType, {
ATEN_FLOAT_TYPE_SWITCH(dtype, DType, "Edge weights", { ATEN_FLOAT_TYPE_SWITCH(dtype, DType, "Edge weights", {
ret = CSRSum<XPU, IdType, DType>(A, A_weights); ret = CSRSum<XPU, IdType, DType>(A, A_weights);
...@@ -186,29 +191,6 @@ std::pair<CSRMatrix, NDArray> CSRSum( ...@@ -186,29 +191,6 @@ std::pair<CSRMatrix, NDArray> CSRSum(
return ret; return ret;
} }
NDArray CSRMask(const CSRMatrix& A, NDArray A_weights, const CSRMatrix& B) {
CHECK_EQ(A.indptr->ctx, A_weights->ctx) <<
"Device of the graph and the edge weights must match.";
CHECK_EQ(A.indptr->ctx, B.indptr->ctx) << "Device of two graphs must match.";
CHECK_EQ(A.indptr->dtype, B.indptr->dtype) << "ID types of two graphs must match.";
CHECK_EQ(A_weights->shape[0], A.indices->shape[0]) <<
"Shape of edge weights does not match the number of edges.";
auto ctx = A.indptr->ctx;
auto idtype = A.indptr->dtype;
auto dtype = A_weights->dtype;
NDArray ret;
// TODO(BarclayII): change to ATEN_XPU_SWITCH_CUDA once the GPU kernels are implemented
ATEN_XPU_SWITCH(ctx.device_type, XPU, "CSRMask", {
ATEN_ID_TYPE_SWITCH(idtype, IdType, {
ATEN_FLOAT_TYPE_SWITCH(dtype, DType, "Edge weights", {
ret = CSRMask<XPU, IdType, DType>(A, A_weights, B);
});
});
});
return ret;
}
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSpMM") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSpMM")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0]; HeteroGraphRef graph = args[0];
...@@ -296,61 +278,83 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelGetEdgeMapping") ...@@ -296,61 +278,83 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelGetEdgeMapping")
*rv = GetEdgeMapping(graph); *rv = GetEdgeMapping(graph);
}); });
/*!
* \brief Sparse matrix multiplication with graph interface.
*
* \param A_ref The left operand.
* \param A_weights The edge weights of graph A.
* \param B_ref The right operand.
* \param B_weights The edge weights of graph B.
* \param num_vtypes The number of vertex types of the graph to be returned.
* \return A pair consisting of the new graph as well as its edge weights.
*/
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLCSRMM") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLCSRMM")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
int M = args[0]; const HeteroGraphRef A_ref = args[0];
int N = args[1]; NDArray A_weights = args[1];
int P = args[2]; const HeteroGraphRef B_ref = args[2];
NDArray A_indptr = args[3]; NDArray B_weights = args[3];
NDArray A_indices = args[4]; int num_vtypes = args[4];
NDArray A_data = args[5];
NDArray B_indptr = args[6]; const HeteroGraphPtr A = A_ref.sptr();
NDArray B_indices = args[7]; const HeteroGraphPtr B = B_ref.sptr();
NDArray B_data = args[8]; CHECK_EQ(A->NumEdgeTypes(), 1) << "The first graph must have only one edge type.";
auto result = CSRMM( CHECK_EQ(B->NumEdgeTypes(), 1) << "The second graph must have only one edge type.";
CSRMatrix(M, N, A_indptr, A_indices), const auto A_csr = A->GetCSRMatrix(0);
A_data, const auto B_csr = B->GetCSRMatrix(0);
CSRMatrix(N, P, B_indptr, B_indices), auto result = CSRMM(A_csr, A_weights, B_csr, B_weights);
B_data);
List<Value> ret; List<ObjectRef> ret;
ret.push_back(Value(MakeValue(result.first.indptr))); ret.push_back(HeteroGraphRef(CreateFromCSR(num_vtypes, result.first, ALL_CODE)));
ret.push_back(Value(MakeValue(result.first.indices)));
ret.push_back(Value(MakeValue(result.second))); ret.push_back(Value(MakeValue(result.second)));
*rv = ret; *rv = ret;
}); });
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLCSRSum") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLCSRSum")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
int M = args[0]; List<HeteroGraphRef> A_refs = args[0];
int N = args[1]; List<Value> A_weights = args[1];
List<Value> A_indptr = args[2];
List<Value> A_indices = args[3]; std::vector<NDArray> weights = ListValueToVector<NDArray>(A_weights);
List<Value> A_data = args[4]; std::vector<CSRMatrix> mats;
std::vector<NDArray> weights = ListValueToVector<NDArray>(A_data); mats.reserve(A_refs.size());
std::vector<CSRMatrix> mats(A_indptr.size()); int num_vtypes = 0;
for (int i = 0; i < A_indptr.size(); ++i) for (auto A_ref : A_refs) {
mats[i] = CSRMatrix(M, N, A_indptr[i]->data, A_indices[i]->data); const HeteroGraphPtr A = A_ref.sptr();
CHECK_EQ(A->NumEdgeTypes(), 1) << "Graphs must have only one edge type.";
mats.push_back(A->GetCSRMatrix(0));
if (num_vtypes == 0)
num_vtypes = A->NumVertexTypes();
}
auto result = CSRSum(mats, weights); auto result = CSRSum(mats, weights);
List<Value> ret;
ret.push_back(Value(MakeValue(result.first.indptr))); List<ObjectRef> ret;
ret.push_back(Value(MakeValue(result.first.indices))); ret.push_back(HeteroGraphRef(CreateFromCSR(num_vtypes, result.first, ALL_CODE)));
ret.push_back(Value(MakeValue(result.second))); ret.push_back(Value(MakeValue(result.second)));
*rv = ret; *rv = ret;
}); });
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLCSRMask") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLCSRMask")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
int M = args[0]; const HeteroGraphRef A_ref = args[0];
int N = args[1]; NDArray A_weights = args[1];
NDArray A_indptr = args[2]; const HeteroGraphRef B_ref = args[2];
NDArray A_indices = args[3];
NDArray A_data = args[4]; const HeteroGraphPtr A = A_ref.sptr();
NDArray B_indptr = args[5]; const HeteroGraphPtr B = B_ref.sptr();
NDArray B_indices = args[6]; CHECK_EQ(A->NumEdgeTypes(), 1) << "Both graphs must have only one edge type.";
auto result = CSRMask( CHECK_EQ(B->NumEdgeTypes(), 1) << "Both graphs must have only one edge type.";
CSRMatrix(M, N, A_indptr, A_indices), const CSRMatrix& A_csr = A->GetCSRMatrix(0);
A_data, const COOMatrix& B_coo = B->GetCOOMatrix(0);
CSRMatrix(M, N, B_indptr, B_indices)); CHECK_EQ(A_csr.num_rows, B_coo.num_rows) <<
"Both graphs must have the same number of nodes.";
CHECK_EQ(A_csr.num_cols, B_coo.num_cols) <<
"Both graphs must have the same number of nodes.";
NDArray result;
ATEN_FLOAT_TYPE_SWITCH(A_weights->dtype, DType, "Edge weights", {
result = aten::CSRGetData<DType>(A_csr, B_coo.row, B_coo.col, A_weights, 0.);
});
*rv = result; *rv = result;
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment