/*! * Copyright (c) 2021 by Contributors * \file array/cuda/csr_get_data.cu * \brief Retrieve entries of a CSR matrix */ #include #include #include #include #include "../../runtime/cuda/cuda_common.h" #include "./utils.h" namespace dgl { using runtime::NDArray; namespace aten { namespace impl { template NDArray CSRGetData( CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, DType filler) { const int64_t rowlen = rows->shape[0]; const int64_t collen = cols->shape[0]; CHECK((rowlen == collen) || (rowlen == 1) || (collen == 1)) << "Invalid row and col id array."; const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1; const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1; const int64_t rstlen = std::max(rowlen, collen); IdArray rst = NDArray::Empty({rstlen}, weights->dtype, rows->ctx); if (rstlen == 0) return rst; hipStream_t stream = runtime::getCurrentCUDAStream(); const int nt = cuda::FindNumThreads(rstlen); const int nb = (rstlen + nt - 1) / nt; if (return_eids) BUG_IF_FAIL(DLDataTypeTraits::dtype == rows->dtype) << "DType does not match row's dtype."; // TODO(minjie): use binary search for sorted csr CUDA_KERNEL_CALL(cuda::_LinearSearchKernel, nb, nt, 0, stream, csr.indptr.Ptr(), csr.indices.Ptr(), CSRHasData(csr)? csr.data.Ptr() : nullptr, rows.Ptr(), cols.Ptr(), row_stride, col_stride, rstlen, return_eids ? nullptr : weights.Ptr(), filler, rst.Ptr()); return rst; } #ifdef USE_FP16 template NDArray CSRGetData( CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler); template NDArray CSRGetData( CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler); #endif template NDArray CSRGetData( CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler); template NDArray CSRGetData( CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler); template NDArray CSRGetData( CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler); template NDArray CSRGetData( CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler); // For CSRGetData(CSRMatrix, NDArray, NDArray) template NDArray CSRGetData( CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int32_t filler); template NDArray CSRGetData( CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int64_t filler); } // namespace impl } // namespace aten } // namespace dgl