#include "hip/hip_runtime.h" /*! * Copyright (c) 2020 by Contributors * \file array/cuda/spmat_op_impl_csr.cu * \brief CSR operator CPU implementation */ #include #include #include #include #include "../../runtime/cuda/cuda_common.h" #include "./utils.h" #include "./atomic.cuh" #include "./dgl_cub.cuh" namespace dgl { using runtime::NDArray; namespace aten { namespace impl { ///////////////////////////// CSRIsNonZero ///////////////////////////// template bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) { hipStream_t stream = runtime::getCurrentCUDAStream(); const auto& ctx = csr.indptr->ctx; IdArray rows = aten::VecToIdArray({row}, sizeof(IdType) * 8, ctx); IdArray cols = aten::VecToIdArray({col}, sizeof(IdType) * 8, ctx); rows = rows.CopyTo(ctx); cols = cols.CopyTo(ctx); IdArray out = aten::NewIdArray(1, ctx, sizeof(IdType) * 8); const IdType* data = nullptr; // TODO(minjie): use binary search for sorted csr CUDA_KERNEL_CALL(dgl::cuda::_LinearSearchKernel, 1, 1, 0, stream, csr.indptr.Ptr(), csr.indices.Ptr(), data, rows.Ptr(), cols.Ptr(), 1, 1, 1, static_cast(nullptr), static_cast(-1), out.Ptr()); out = out.CopyTo(DLContext{kDLCPU, 0}); return *out.Ptr() != -1; } template bool CSRIsNonZero(CSRMatrix, int64_t, int64_t); template bool CSRIsNonZero(CSRMatrix, int64_t, int64_t); template NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) { const auto rowlen = row->shape[0]; const auto collen = col->shape[0]; const auto rstlen = std::max(rowlen, collen); NDArray rst = NDArray::Empty({rstlen}, row->dtype, row->ctx); if (rstlen == 0) return rst; const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1; const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1; hipStream_t stream = runtime::getCurrentCUDAStream(); const int nt = dgl::cuda::FindNumThreads(rstlen); const int nb = (rstlen + nt - 1) / nt; const IdType* data = nullptr; // TODO(minjie): use binary search for sorted csr CUDA_KERNEL_CALL(dgl::cuda::_LinearSearchKernel, nb, nt, 0, stream, csr.indptr.Ptr(), csr.indices.Ptr(), data, row.Ptr(), col.Ptr(), row_stride, col_stride, rstlen, static_cast(nullptr), static_cast(-1), rst.Ptr()); return rst != -1; } template NDArray CSRIsNonZero(CSRMatrix, NDArray, NDArray); template NDArray CSRIsNonZero(CSRMatrix, NDArray, NDArray); ///////////////////////////// CSRHasDuplicate ///////////////////////////// /*! * \brief Check whether each row does not have any duplicate entries. * Assume the CSR is sorted. */ template __global__ void _SegmentHasNoDuplicate( const IdType* indptr, const IdType* indices, int64_t num_rows, int8_t* flags) { int tx = blockIdx.x * blockDim.x + threadIdx.x; const int stride_x = gridDim.x * blockDim.x; while (tx < num_rows) { bool f = true; for (IdType i = indptr[tx] + 1; f && i < indptr[tx + 1]; ++i) { f = (indices[i - 1] != indices[i]); } flags[tx] = static_cast(f); tx += stride_x; } } template bool CSRHasDuplicate(CSRMatrix csr) { if (!csr.sorted) csr = CSRSort(csr); const auto& ctx = csr.indptr->ctx; hipStream_t stream = runtime::getCurrentCUDAStream(); auto device = runtime::DeviceAPI::Get(ctx); // We allocate a workspace of num_rows bytes. It wastes a little bit memory but should // be fine. int8_t* flags = static_cast(device->AllocWorkspace(ctx, csr.num_rows)); const int nt = dgl::cuda::FindNumThreads(csr.num_rows); const int nb = (csr.num_rows + nt - 1) / nt; CUDA_KERNEL_CALL(_SegmentHasNoDuplicate, nb, nt, 0, stream, csr.indptr.Ptr(), csr.indices.Ptr(), csr.num_rows, flags); bool ret = dgl::cuda::AllTrue(flags, csr.num_rows, ctx); device->FreeWorkspace(ctx, flags); return !ret; } template bool CSRHasDuplicate(CSRMatrix csr); template bool CSRHasDuplicate(CSRMatrix csr); ///////////////////////////// CSRGetRowNNZ ///////////////////////////// template int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) { const IdType cur = aten::IndexSelect(csr.indptr, row); const IdType next = aten::IndexSelect(csr.indptr, row + 1); return next - cur; } template int64_t CSRGetRowNNZ(CSRMatrix, int64_t); template int64_t CSRGetRowNNZ(CSRMatrix, int64_t); template __global__ void _CSRGetRowNNZKernel( const IdType* vid, const IdType* indptr, IdType* out, int64_t length) { int tx = blockIdx.x * blockDim.x + threadIdx.x; int stride_x = gridDim.x * blockDim.x; while (tx < length) { const IdType vv = vid[tx]; out[tx] = indptr[vv + 1] - indptr[vv]; tx += stride_x; } } template NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) { hipStream_t stream = runtime::getCurrentCUDAStream(); const auto len = rows->shape[0]; const IdType* vid_data = static_cast(rows->data); const IdType* indptr_data = static_cast(csr.indptr->data); NDArray rst = NDArray::Empty({len}, rows->dtype, rows->ctx); IdType* rst_data = static_cast(rst->data); const int nt = dgl::cuda::FindNumThreads(len); const int nb = (len + nt - 1) / nt; CUDA_KERNEL_CALL(_CSRGetRowNNZKernel, nb, nt, 0, stream, vid_data, indptr_data, rst_data, len); return rst; } template NDArray CSRGetRowNNZ(CSRMatrix, NDArray); template NDArray CSRGetRowNNZ(CSRMatrix, NDArray); ///////////////////////////// CSRGetRowColumnIndices ///////////////////////////// template NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) { const int64_t len = impl::CSRGetRowNNZ(csr, row); const int64_t offset = aten::IndexSelect(csr.indptr, row) * sizeof(IdType); return csr.indices.CreateView({len}, csr.indices->dtype, offset); } template NDArray CSRGetRowColumnIndices(CSRMatrix, int64_t); template NDArray CSRGetRowColumnIndices(CSRMatrix, int64_t); ///////////////////////////// CSRGetRowData ///////////////////////////// template NDArray CSRGetRowData(CSRMatrix csr, int64_t row) { const int64_t len = impl::CSRGetRowNNZ(csr, row); const int64_t offset = aten::IndexSelect(csr.indptr, row) * sizeof(IdType); if (aten::CSRHasData(csr)) return csr.data.CreateView({len}, csr.data->dtype, offset); else return aten::Range(offset, offset + len, csr.indptr->dtype.bits, csr.indptr->ctx); } template NDArray CSRGetRowData(CSRMatrix, int64_t); template NDArray CSRGetRowData(CSRMatrix, int64_t); ///////////////////////////// CSRSliceRows ///////////////////////////// template CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) { const int64_t num_rows = end - start; const IdType st_pos = aten::IndexSelect(csr.indptr, start); const IdType ed_pos = aten::IndexSelect(csr.indptr, end); const IdType nnz = ed_pos - st_pos; IdArray ret_indptr = aten::IndexSelect(csr.indptr, start, end + 1) - st_pos; // indices and data can be view arrays IdArray ret_indices = csr.indices.CreateView( {nnz}, csr.indices->dtype, st_pos * sizeof(IdType)); IdArray ret_data; if (CSRHasData(csr)) ret_data = csr.data.CreateView({nnz}, csr.data->dtype, st_pos * sizeof(IdType)); else ret_data = aten::Range(st_pos, ed_pos, csr.indptr->dtype.bits, csr.indptr->ctx); return CSRMatrix(num_rows, csr.num_cols, ret_indptr, ret_indices, ret_data, csr.sorted); } template CSRMatrix CSRSliceRows(CSRMatrix, int64_t, int64_t); template CSRMatrix CSRSliceRows(CSRMatrix, int64_t, int64_t); /*! * \brief Copy data segment to output buffers * * For the i^th row r = row[i], copy the data from indptr[r] ~ indptr[r+1] * to the out_data from out_indptr[i] ~ out_indptr[i+1] * * If the provided `data` array is nullptr, write the read index to the out_data. * */ template __global__ void _SegmentCopyKernel( const IdType* indptr, const DType* data, const IdType* row, int64_t length, int64_t n_row, const IdType* out_indptr, DType* out_data) { IdType tx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; const int stride_x = gridDim.x * blockDim.x; while (tx < length) { IdType rpos = dgl::cuda::_UpperBound(out_indptr, n_row, tx) - 1; IdType rofs = tx - out_indptr[rpos]; const IdType u = row[rpos]; out_data[tx] = data? data[indptr[u]+rofs] : indptr[u]+rofs; tx += stride_x; } } template CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) { hipStream_t stream = runtime::getCurrentCUDAStream(); const int64_t len = rows->shape[0]; IdArray ret_indptr = aten::CumSum(aten::CSRGetRowNNZ(csr, rows), true); const int64_t nnz = aten::IndexSelect(ret_indptr, len); const int nt = 256; // for better GPU usage of small invocations const int nb = (nnz + nt - 1) / nt; // Copy indices. IdArray ret_indices = NDArray::Empty({nnz}, csr.indptr->dtype, rows->ctx); CUDA_KERNEL_CALL(_SegmentCopyKernel, nb, nt, 0, stream, csr.indptr.Ptr(), csr.indices.Ptr(), rows.Ptr(), nnz, len, ret_indptr.Ptr(), ret_indices.Ptr()); // Copy data. IdArray ret_data = NDArray::Empty({nnz}, csr.indptr->dtype, rows->ctx); CUDA_KERNEL_CALL(_SegmentCopyKernel, nb, nt, 0, stream, csr.indptr.Ptr(), CSRHasData(csr)? csr.data.Ptr() : nullptr, rows.Ptr(), nnz, len, ret_indptr.Ptr(), ret_data.Ptr()); return CSRMatrix(len, csr.num_cols, ret_indptr, ret_indices, ret_data, csr.sorted); } template CSRMatrix CSRSliceRows(CSRMatrix , NDArray); template CSRMatrix CSRSliceRows(CSRMatrix , NDArray); ///////////////////////////// CSRGetDataAndIndices ///////////////////////////// /*! * \brief Generate a 0-1 mask for each index that hits the provided (row, col) * index. * * Examples: * Given a CSR matrix (with duplicate entries) as follows: * [[0, 1, 2, 0, 0], * [1, 0, 0, 0, 0], * [0, 0, 1, 1, 0], * [0, 0, 0, 0, 0]] * Given rows: [0, 1], cols: [0, 2, 3] * The result mask is: [0, 1, 1, 1, 0, 0] */ template __global__ void _SegmentMaskKernel( const IdType* indptr, const IdType* indices, const IdType* row, const IdType* col, int64_t row_stride, int64_t col_stride, int64_t length, IdType* mask) { int tx = blockIdx.x * blockDim.x + threadIdx.x; const int stride_x = gridDim.x * blockDim.x; while (tx < length) { int rpos = tx * row_stride, cpos = tx * col_stride; const IdType r = row[rpos], c = col[cpos]; for (IdType i = indptr[r]; i < indptr[r + 1]; ++i) { if (indices[i] == c) { mask[i] = 1; } } tx += stride_x; } } /*! * \brief Search for the insertion positions for needle in the hay. * * The hay is a list of sorted elements and the result is the insertion position * of each needle so that the insertion still gives sorted order. * * It essentially perform binary search to find lower bound for each needle * elements. Require the largest elements in the hay is larger than the given * needle elements. Commonly used in searching for row IDs of a given set of * coordinates. */ template __global__ void _SortedSearchKernel( const IdType* hay, int64_t hay_size, const IdType* needles, int64_t num_needles, IdType* pos) { int tx = blockIdx.x * blockDim.x + threadIdx.x; const int stride_x = gridDim.x * blockDim.x; while (tx < num_needles) { const IdType ele = needles[tx]; // binary search IdType lo = 0, hi = hay_size - 1; while (lo < hi) { IdType mid = (lo + hi) >> 1; if (hay[mid] <= ele) { lo = mid + 1; } else { hi = mid; } } pos[tx] = (hay[hi] == ele)? hi : hi - 1; tx += stride_x; } } template std::vector CSRGetDataAndIndices(CSRMatrix csr, NDArray row, NDArray col) { const auto rowlen = row->shape[0]; const auto collen = col->shape[0]; const auto len = std::max(rowlen, collen); if (len == 0) return {NullArray(), NullArray(), NullArray()}; const auto& ctx = row->ctx; const auto nbits = row->dtype.bits; const int64_t nnz = csr.indices->shape[0]; const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1; const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1; hipStream_t stream = runtime::getCurrentCUDAStream(); // Generate a 0-1 mask for matched (row, col) positions. IdArray mask = Full(0, nnz, nbits, ctx); const int nt = dgl::cuda::FindNumThreads(len); const int nb = (len + nt - 1) / nt; CUDA_KERNEL_CALL(_SegmentMaskKernel, nb, nt, 0, stream, csr.indptr.Ptr(), csr.indices.Ptr(), row.Ptr(), col.Ptr(), row_stride, col_stride, len, mask.Ptr()); IdArray idx = AsNumBits(NonZero(mask), nbits); if (idx->shape[0] == 0) // No data. Return three empty arrays. return {idx, idx, idx}; // Search for row index IdArray ret_row = NewIdArray(idx->shape[0], ctx, nbits); const int nt2 = dgl::cuda::FindNumThreads(idx->shape[0]); const int nb2 = (idx->shape[0] + nt - 1) / nt; CUDA_KERNEL_CALL(_SortedSearchKernel, nb2, nt2, 0, stream, csr.indptr.Ptr(), csr.num_rows, idx.Ptr(), idx->shape[0], ret_row.Ptr()); // Column & data can be obtained by index select. IdArray ret_col = IndexSelect(csr.indices, idx); IdArray ret_data = CSRHasData(csr)? IndexSelect(csr.data, idx) : idx; return {ret_row, ret_col, ret_data}; } template std::vector CSRGetDataAndIndices( CSRMatrix csr, NDArray rows, NDArray cols); template std::vector CSRGetDataAndIndices( CSRMatrix csr, NDArray rows, NDArray cols); ///////////////////////////// CSRSliceMatrix ///////////////////////////// /*! * \brief Generate a 0-1 mask for each index whose column is in the provided set. * It also counts the number of masked values per row. */ template __global__ void _SegmentMaskColKernel( const IdType* indptr, const IdType* indices, int64_t num_rows, int64_t num_nnz, const IdType* col, int64_t col_len, IdType* mask, IdType* count) { IdType tx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; const int stride_x = gridDim.x * blockDim.x; while (tx < num_nnz) { IdType rpos = dgl::cuda::_UpperBound(indptr, num_rows, tx) - 1; IdType cur_c = indices[tx]; IdType i = dgl::cuda::_BinarySearch(col, col_len, cur_c); if (i < col_len) { mask[tx] = 1; cuda::AtomicAdd(count+rpos, IdType(1)); } tx += stride_x; } } template CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols) { hipStream_t stream = runtime::getCurrentCUDAStream(); const auto& ctx = rows->ctx; const auto& dtype = rows->dtype; const auto nbits = dtype.bits; const int64_t new_nrows = rows->shape[0]; const int64_t new_ncols = cols->shape[0]; if (new_nrows == 0 || new_ncols == 0) return CSRMatrix(new_nrows, new_ncols, Full(0, new_nrows + 1, nbits, ctx), NullArray(dtype, ctx), NullArray(dtype, ctx)); // First slice rows csr = CSRSliceRows(csr, rows); if (csr.indices->shape[0] == 0) return CSRMatrix(new_nrows, new_ncols, Full(0, new_nrows + 1, nbits, ctx), NullArray(dtype, ctx), NullArray(dtype, ctx)); // Generate a 0-1 mask for matched (row, col) positions. IdArray mask = Full(0, csr.indices->shape[0], nbits, ctx); // A count for how many masked values per row. IdArray count = NewIdArray(csr.num_rows, ctx, nbits); CUDA_CALL(hipMemset(count.Ptr(), 0, sizeof(IdType) * (csr.num_rows))); const int64_t nnz_csr = csr.indices->shape[0]; const int nt = 256; // In general ``cols'' array is sorted. But it is not guaranteed. // Hence checking and sorting array first. Sorting is not in place. auto device = runtime::DeviceAPI::Get(ctx); auto cols_size = cols->shape[0]; IdArray sorted_array = NewIdArray(cols->shape[0], ctx, cols->dtype.bits); auto ptr_sorted_cols = sorted_array.Ptr(); auto ptr_cols = cols.Ptr(); size_t workspace_size = 0; CUDA_CALL(hipcub::DeviceRadixSort::SortKeys( nullptr, workspace_size, ptr_cols, ptr_sorted_cols, cols->shape[0], 0, sizeof(IdType)*8, stream)); void *workspace = device->AllocWorkspace(ctx, workspace_size); CUDA_CALL(hipcub::DeviceRadixSort::SortKeys( workspace, workspace_size, ptr_cols, ptr_sorted_cols, cols->shape[0], 0, sizeof(IdType)*8, stream)); device->FreeWorkspace(ctx, workspace); // Execute SegmentMaskColKernel int nb = (nnz_csr + nt - 1) / nt; CUDA_KERNEL_CALL(_SegmentMaskColKernel, nb, nt, 0, stream, csr.indptr.Ptr(), csr.indices.Ptr(), csr.num_rows, nnz_csr, ptr_sorted_cols, cols_size, mask.Ptr(), count.Ptr()); IdArray idx = AsNumBits(NonZero(mask), nbits); if (idx->shape[0] == 0) return CSRMatrix(new_nrows, new_ncols, Full(0, new_nrows + 1, nbits, ctx), NullArray(dtype, ctx), NullArray(dtype, ctx)); // Indptr needs to be adjusted according to the new nnz per row. IdArray ret_indptr = CumSum(count, true); // Column & data can be obtained by index select. IdArray ret_col = IndexSelect(csr.indices, idx); IdArray ret_data = CSRHasData(csr)? IndexSelect(csr.data, idx) : idx; // Relabel column IdArray col_hash = NewIdArray(csr.num_cols, ctx, nbits); Scatter_(cols, Range(0, cols->shape[0], nbits, ctx), col_hash); ret_col = IndexSelect(col_hash, ret_col); return CSRMatrix(new_nrows, new_ncols, ret_indptr, ret_col, ret_data); } template CSRMatrix CSRSliceMatrix( CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols); template CSRMatrix CSRSliceMatrix( CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols); } // namespace impl } // namespace aten } // namespace dgl