Unverified Commit 1113f674 authored by Israt Nisa's avatar Israt Nisa Committed by GitHub
Browse files

[Kernel] Add heterograph support in CUDA kernels (SpMM, SDDMM) (#2925)



* Added heterograph support SpMM, SDDMM

* bug fix cuda stream

* add cudaStrm destroy and fix whitespace

* Added heterograph support SpMM, SDDMM

* bug fix cuda stream

* add cudaStrm destroy and fix whitespace

* changed max stream = 1

* Fixed ctx

* using default stream

* Added heterograph support SpMM, SDDMM

* bug fix cuda stream

* add cudaStrm destroy and fix whitespace

* changed max stream = 1

* Fixed ctx

* using default stream

* fix bug in copy_rhs

* changed by mistake

* minor datatype change

* added datatype check
Co-authored-by: default avatarIsrat Nisa <nisisrat@amazon.com>
parent ff519f98
...@@ -91,6 +91,42 @@ void SDDMMCsr(const std::string& op, ...@@ -91,6 +91,42 @@ void SDDMMCsr(const std::string& op,
}); });
} }
/*!
* \brief CUDA implementation of g-SDDMM on heterograph using
Csr format.
*/
template <int XPU, typename IdType, int bits>
void SDDMMCsrHetero(const std::string& op,
const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& vec_lhs,
const std::vector<NDArray>& vec_rhs,
std::vector<NDArray> vec_out,
int lhs_target,
int rhs_target,
const std::vector<dgl_type_t>& lhs_eid,
const std::vector<dgl_type_t>& rhs_eid) {
// TODO(Israt): Resolve PR - https://github.com/dmlc/dgl/issues/2995
// to use maxstream > 1
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
/* Call SDDMM for each relation type */
for (dgl_type_t etype = 0; etype < lhs_eid.size(); ++etype) {
CSRMatrix csr = vec_csr[etype];
NDArray lhs = vec_lhs[lhs_eid[etype]];
NDArray rhs = vec_rhs[rhs_eid[etype]];
NDArray out = vec_out[etype];
cuda::SDDMMCsrHetero<IdType, DType, Op, LhsTarget, RhsTarget>(
bcast, csr, lhs, rhs, out, thr_entry->stream);
}
});
});
});
}
/*! /*!
* \brief CUDA implementation of g-SDDMM on Coo format. * \brief CUDA implementation of g-SDDMM on Coo format.
*/ */
...@@ -137,6 +173,50 @@ template void SDDMMCsr<kDLGPU, int64_t, 64>( ...@@ -137,6 +173,50 @@ template void SDDMMCsr<kDLGPU, int64_t, 64>(
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCsrHetero<kDLGPU, int32_t, 16>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLGPU, int64_t, 16>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLGPU, int32_t, 32>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLGPU, int64_t, 32>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLGPU, int32_t, 64>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLGPU, int64_t, 64>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCoo<kDLGPU, int32_t, 16>( template void SDDMMCoo<kDLGPU, int32_t, 16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
......
...@@ -310,6 +310,58 @@ void SDDMMCsr( ...@@ -310,6 +310,58 @@ void SDDMMCsr(
}); });
} }
/*!
* \brief CUDA implementation of g-SDDMM on heterograph using Csr format.
* \param bcast Broadcast information.
* \param csr The Csr matrix.
* \param lhs The left hand side operand feature.
* \param rhs The right hand size operand feature.
* \param out The result feature on edges.
* \param stream cudaStream id.
*/
template <typename Idx, typename DType, typename Op,
int LhsTarget = 0, int RhsTarget = 2>
void SDDMMCsrHetero(
const BcastOff& bcast,
const CSRMatrix& csr,
NDArray lhs,
NDArray rhs,
NDArray out,
cudaStream_t strm_id) {
const Idx *indptr = csr.indptr.Ptr<Idx>();
const Idx *indices = csr.indices.Ptr<Idx>();
const Idx *edge_map = csr.data.Ptr<Idx>();
const DType *lhs_data = lhs.Ptr<DType>();
const DType *rhs_data = rhs.Ptr<DType>();
DType *out_data = out.Ptr<DType>();
int64_t N = csr.num_rows, M = csr.num_cols, E = csr.indices->shape[0];
int64_t *lhs_off = nullptr, *rhs_off = nullptr;
int64_t len = bcast.out_len,
lhs_len = bcast.lhs_len,
rhs_len = bcast.rhs_len;
int64_t reduce_dim = bcast.reduce_size;
const int ntx = FindNumThreads(len);
const int nty = CUDA_MAX_NUM_THREADS / ntx;
const int nbx = (len + ntx - 1) / ntx;
const int nby = FindNumBlocks<'y'>((E + nty - 1) / nty);
const dim3 nblks(nbx, nby);
const dim3 nthrs(ntx, nty);
const bool use_idx = !IsNullArray(csr.data);
BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, {
CUDA_KERNEL_CALL((SDDMMCsrKernel<Idx, DType, Op, UseBcast, UseIdx, LhsTarget, RhsTarget>),
nblks, nthrs, 0, strm_id,
lhs_data, rhs_data, out_data,
indptr, indices, edge_map,
N, M, E, reduce_dim,
lhs_off, rhs_off,
lhs_len, rhs_len, len);
});
}
} // namespace cuda } // namespace cuda
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
......
...@@ -282,6 +282,112 @@ void CusparseCsrmm2( ...@@ -282,6 +282,112 @@ void CusparseCsrmm2(
if (valptr) if (valptr)
device->FreeWorkspace(ctx, valptr); device->FreeWorkspace(ctx, valptr);
} }
/*! Cusparse implementation of SpMM on Csr format. */
template <typename DType, typename IdType>
void CusparseCsrmm2Hetero(
const DLContext& ctx,
const CSRMatrix& csr,
const DType* B_data, const DType* A_data,
DType* C_data,
int64_t x_length,
cudaStream_t strm_id) {
// We use csrmm2 to perform following operation:
// C = A x B, where A is a sparse matrix in csr format, B is the dense matrix for node
// feature tensor. However, since cusparse only supports column-major, while our tensor
// is stored in row-major, the actual computation is:
// C = trans(A x trans(B)).
// Currently, we use cublasXgeam to implement transposition and allocate intermediate
// workspace memory for this.
int int_maxlimit = std::numeric_limits<int>::max();
CHECK_GE(int_maxlimit, (csr.num_rows));
CHECK_GE(int_maxlimit, csr.num_cols);
CHECK_GE(int_maxlimit, csr.indices->shape[0]);
const int m = csr.num_rows;
const int n = x_length;
const int k = csr.num_cols;
const int nnz = csr.indices->shape[0];
const DType alpha = 1.0;
const DType beta = 1.0;
// device
auto device = runtime::DeviceAPI::Get(ctx);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
// allocate cusparse handle if needed
if (!thr_entry->cusparse_handle) {
CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
}
CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, strm_id));
// all one data array
DType* valptr = nullptr;
if (!A_data) {
valptr = static_cast<DType*>(device->AllocWorkspace(ctx, nnz * sizeof(DType)));
_Fill(valptr, nnz, static_cast<DType>(1.));
}
#if CUDART_VERSION >= 11000
cusparseSpMatDescr_t matA;
cusparseDnMatDescr_t matB, matC;
constexpr auto dtype = cuda_dtype<DType>::value;
constexpr auto idtype = cusparse_idtype<IdType>::value;
CUSPARSE_CALL(cusparseCreateCsr(&matA,
m, k, nnz,
static_cast<IdType*>(csr.indptr->data),
static_cast<IdType*>(csr.indices->data),
const_cast<DType*>(valptr? valptr : A_data),
idtype, idtype,
CUSPARSE_INDEX_BASE_ZERO, dtype));
CUSPARSE_CALL(cusparseCreateDnMat(&matB,
k, n, n,
const_cast<DType*>(B_data), dtype, CUSPARSE_ORDER_ROW));
CUSPARSE_CALL(cusparseCreateDnMat(&matC,
m, n, n,
C_data, dtype, CUSPARSE_ORDER_ROW));
auto transA = CUSPARSE_OPERATION_NON_TRANSPOSE;
auto transB = CUSPARSE_OPERATION_NON_TRANSPOSE;
size_t workspace_size;
CUSPARSE_CALL(cusparseSpMM_bufferSize(
thr_entry->cusparse_handle, transA, transB,
&alpha, matA, matB, &beta, matC,
dtype, CUSPARSE_SPMM_CSR_ALG2,
&workspace_size));
void* workspace = device->AllocWorkspace(ctx, workspace_size);
CUSPARSE_CALL(cusparseSpMM(
thr_entry->cusparse_handle, transA, transB,
&alpha, matA, matB, &beta, matC,
dtype, CUSPARSE_SPMM_CSR_ALG2,
workspace));
device->FreeWorkspace(ctx, workspace);
CUSPARSE_CALL(cusparseDestroySpMat(matA));
CUSPARSE_CALL(cusparseDestroyDnMat(matB));
CUSPARSE_CALL(cusparseDestroyDnMat(matC));
#else
// allocate matrix for temporary transposed output
DType* trans_out = static_cast<DType*>(device->AllocWorkspace(ctx, m * n * sizeof(DType)));
cusparseMatDescr_t descr;
CUSPARSE_CALL(cusparseCreateMatDescr(&descr));
CUSPARSE_CALL(cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL));
CUSPARSE_CALL(cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO));
CHECK_EQ(sizeof(IdType), sizeof(int32_t));
CUSPARSE_CALL(Xcsrmm2<DType>(
thr_entry->cusparse_handle,
CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_TRANSPOSE,
m, n, k, nnz, &alpha,
descr, (valptr)? valptr : A_data,
static_cast<int32_t*>(csr.indptr->data),
static_cast<int32_t*>(csr.indices->data),
B_data, n, &beta, trans_out, m));
CUSPARSE_CALL(cusparseDestroyMatDescr(descr));
// transpose the output matrix
_Transpose(trans_out, C_data, n, m);
device->FreeWorkspace(ctx, trans_out);
#endif
if (valptr)
device->FreeWorkspace(ctx, valptr);
}
} // namespace cusparse } // namespace cusparse
#define SWITCH_OP(op, Op, ...) \ #define SWITCH_OP(op, Op, ...) \
...@@ -400,6 +506,110 @@ void SpMMCsr(const std::string& op, const std::string& reduce, ...@@ -400,6 +506,110 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
} }
} }
/*!
* \brief CUDA implementation of g-SpMM on Csr format.
* \note use cusparse if the reduce operator is `sum` and there is
* no broadcast, use dgl's kernel in other cases.
*/
template <int XPU, typename IdType, int bits>
void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& vec_ufeat,
const std::vector<NDArray>& vec_efeat,
std::vector<NDArray> vec_out,
const std::vector<NDArray>& out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, // ufeat node type id
const std::vector<dgl_type_t>& out_ntids) { // output node type id
int64_t feat_len = bcast.out_len;
bool is_scalar_efeat = vec_efeat.size() != 0;
bool use_efeat = op != "copy_lhs";
// TODO(Israt): 1:Resolve PR-https://github.com/dmlc/dgl/issues/2995
// to use maxstream > 1
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) {
const dgl_type_t src_id = ufeat_ntids[etype];
const dgl_type_t dst_id = out_ntids[etype];
CSRMatrix csr = vec_csr[etype];
if (reduce == "sum") {
SWITCH_BITS(bits, DType, {
/* Call SpMM for each relation type */
if (op == "copy_lhs" && cusparse_available<bits, IdType>()) { // cusparse
int64_t x_length = 1;
NDArray nd_ufeat = vec_ufeat[ufeat_ntids[0]];
for (int i = 1; i < nd_ufeat->ndim; ++i) {
x_length *= nd_ufeat->shape[i];
}
cusparse::CusparseCsrmm2Hetero<DType, IdType>(
csr.indptr->ctx, csr,
static_cast<DType*>(vec_ufeat[src_id]->data),
nullptr,
static_cast<DType*>(vec_out[dst_id]->data),
x_length,
thr_entry->stream);
} else if (op == "mul" && is_scalar_efeat &&
cusparse_available<bits, IdType>()) { // cusparse
NDArray efeat = vec_efeat[etype];
int64_t x_length = 1;
NDArray nd_ufeat = vec_ufeat[ufeat_ntids[0]];
for (int i = 1; i < nd_ufeat->ndim; ++i) {
x_length *= nd_ufeat->shape[i];
}
if (!IsNullArray(csr.data)) {
SWITCH_BITS(bits, DType, {
efeat = _IndexSelect<DType, IdType>(vec_efeat[etype], csr.data);
});
}
SWITCH_BITS(bits, DType, {
cusparse::CusparseCsrmm2Hetero<DType, IdType>(
csr.indptr->ctx, csr,
static_cast<DType*>(vec_ufeat[src_id]->data),
static_cast<DType*>(efeat->data),
static_cast<DType*>(vec_out[dst_id]->data),
x_length,
thr_entry->stream);
});
} else { // general kernel
NDArray ufeat = (vec_ufeat.size() == 0) ?
NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ?
NullArray() : vec_efeat[etype];
SWITCH_OP(op, Op, {
cuda::SpMMCsrHetero<IdType, DType, Op, cuda::reduce::Sum<IdType, DType> >(
bcast, csr, ufeat, efeat, vec_out[dst_id],
NullArray(), NullArray(), thr_entry->stream);
});
}
});
} else if (reduce == "max") {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
NDArray ufeat = (vec_ufeat.size() == 0) ?
NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ?
NullArray() : vec_efeat[etype];
cuda::SpMMCsrHetero<IdType, DType, Op, cuda::reduce::Max<IdType, DType> >(
bcast, csr, ufeat, efeat, vec_out[dst_id],
out_aux[0], out_aux[1], thr_entry->stream);
});
});
} else if (reduce == "min") {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
NDArray ufeat = (vec_ufeat.size() == 0) ?
NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ?
NullArray() : vec_efeat[etype];
cuda::SpMMCsrHetero<IdType, DType, Op, cuda::reduce::Min<IdType, DType> >(
bcast, csr, ufeat, efeat, vec_out[dst_id],
out_aux[0], out_aux[1], thr_entry->stream);
});
});
} else {
LOG(FATAL) << "Not implemented";
}
}
}
/*! /*!
* \brief CUDA implementation of g-SpMM on Coo format. * \brief CUDA implementation of g-SpMM on Coo format.
...@@ -463,6 +673,43 @@ template void SpMMCsr<kDLGPU, int64_t, 64>( ...@@ -463,6 +673,43 @@ template void SpMMCsr<kDLGPU, int64_t, 64>(
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsrHetero<kDLGPU, int32_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDLGPU, int64_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDLGPU, int32_t, 32>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDLGPU, int64_t, 32>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDLGPU, int32_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDLGPU, int64_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCoo<kDLGPU, int32_t, 16>( template void SpMMCoo<kDLGPU, int32_t, 16>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo, const BcastOff& bcast, const COOMatrix& coo,
......
...@@ -160,7 +160,7 @@ __global__ void SpMMCsrKernel( ...@@ -160,7 +160,7 @@ __global__ void SpMMCsrKernel(
DType out = BinaryOp::Call(uoff + lhs_add, eoff + rhs_add); DType out = BinaryOp::Call(uoff + lhs_add, eoff + rhs_add);
ReduceOp::Call(&local_accum, &local_argu, &local_arge, out, cid, eid); ReduceOp::Call(&local_accum, &local_argu, &local_arge, out, cid, eid);
} }
out[ty * out_len + tx] = local_accum; out[ty * out_len + tx] += local_accum;
if (ReduceOp::require_arg && BinaryOp::use_lhs) if (ReduceOp::require_arg && BinaryOp::use_lhs)
arg_u[ty * out_len + tx] = local_argu; arg_u[ty * out_len + tx] = local_argu;
if (ReduceOp::require_arg && BinaryOp::use_rhs) if (ReduceOp::require_arg && BinaryOp::use_rhs)
...@@ -305,6 +305,65 @@ void SpMMCsr( ...@@ -305,6 +305,65 @@ void SpMMCsr(
}); });
} }
/*!
* \brief CUDA implementation of g-SpMM on Csr format.
* \param bcast Broadcast information.
* \param csr The Csr matrix.
* \param ufeat The feature on source nodes.
* \param efeat The feature on edges.
* \param out The result feature on destination nodes.
* \param argu Arg-Min/Max on source nodes, which refers the source node indices
* correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max reducer.
* \param arge Arg-Min/Max on edges. which refers the source node indices
* correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max reducer.
* \param stream cudaStream id.
*/
template <typename Idx, typename DType,
typename BinaryOp, typename ReduceOp>
void SpMMCsrHetero(
const BcastOff& bcast,
const CSRMatrix& csr,
NDArray ufeat, NDArray efeat,
NDArray out, NDArray argu, NDArray arge,
cudaStream_t strm_id) {
const Idx *indptr = csr.indptr.Ptr<Idx>();
const Idx *indices = csr.indices.Ptr<Idx>();
const Idx *edge_map = csr.data.Ptr<Idx>();
const DType *ufeat_data = ufeat.Ptr<DType>();
const DType *efeat_data = efeat.Ptr<DType>();
DType *out_data = out.Ptr<DType>();
Idx* argu_data = argu.Ptr<Idx>();
Idx* arge_data = arge.Ptr<Idx>();
int64_t *ubcast_off = nullptr, *ebcast_off = nullptr;
int64_t len = bcast.out_len,
lhs_len = bcast.lhs_len,
rhs_len = bcast.rhs_len;
const int ntx = FindNumThreads(len);
const int nty = CUDA_MAX_NUM_THREADS / ntx;
const int nbx = (len + ntx - 1) / ntx;
const int nby = FindNumBlocks<'y'>((csr.num_rows + nty - 1) / nty);
//LOG(INFO) << "nblks=(" << nbx << ", " << nby << ") nthrs=(" << ntx << ", " << nty << ")";
const dim3 nblks(nbx, nby);
const dim3 nthrs(ntx, nty);
const bool use_idx = !IsNullArray(csr.data);
BCAST_IDX_CTX_SWITCH(bcast, use_idx, ufeat->ctx, ubcast_off, ebcast_off, {
CUDA_KERNEL_CALL((SpMMCsrKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),
nblks, nthrs, 0, strm_id,
ufeat_data, efeat_data, out_data, argu_data, arge_data,
indptr, indices, edge_map,
csr.num_rows, csr.num_cols,
ubcast_off, ebcast_off,
lhs_len, rhs_len, len)
});
}
} // namespace cuda } // namespace cuda
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment