Unverified Commit c59000ac authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Cleanup] Remove duplicated _IndexSelect (#4874)

parent 0cb5f0fd
......@@ -45,7 +45,7 @@ void SpMMCsr(
int64_t x_length = 1;
for (int i = 1; i < ufeat->ndim; ++i) x_length *= ufeat->shape[i];
if (!IsNullArray(csr.data)) {
efeat = _IndexSelect<DType, IdType>(efeat, csr.data);
efeat = IndexSelect(efeat, csr.data);
}
CusparseCsrmm2<DType, IdType>(
ufeat->ctx, csr, static_cast<DType*>(ufeat->data),
......
......@@ -98,19 +98,6 @@ cublasStatus_t Xgeam<double>(
handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc);
}
/**
* @brief IndexSelect operator kernel implementation.
* @note duplicate of IndexSelectKernel defined in array_index_select.cu
*/
template <typename DType, typename IdType>
__global__ void _IndexSelectKernel(
const DType* __restrict__ in, const IdType* __restrict__ idx,
DType* __restrict__ out, int n, int m) {
int i = blockIdx.x;
for (int j = threadIdx.x; j < m; j += blockDim.x)
out[i * m + j] = in[idx[i] * m + j];
}
/**
* @brief Transpose operator kernel implementation.
* @note not efficient but it's not a bottleneck, used for float16 dtype.
......@@ -168,42 +155,6 @@ void _Transpose<__nv_bfloat16>(
}
#endif // BF16_ENABLED
/**
* @brief
*/
template <typename DType, typename IdType>
__global__ void _IndexSelectKernel(
const DType* array, const IdType* index, int64_t length, DType* out) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = gridDim.x * blockDim.x;
while (tx < length) {
out[tx] = array[index[tx]];
tx += stride_x;
}
}
/* @brief IndexSelect operator.
* @note duplicate of IndexSelect defined in array_op.h but it can
* not be applied to float16 dtype.
*/
template <typename DType, typename IdType>
NDArray _IndexSelect(NDArray array, NDArray index) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
const DType* array_data = static_cast<DType*>(array->data);
const IdType* idx_data = static_cast<IdType*>(index->data);
const int64_t arr_len = array->shape[0];
const int64_t len = index->shape[0];
NDArray ret = NDArray::Empty({len}, array->dtype, array->ctx);
if (len == 0) return ret;
DType* ret_data = static_cast<DType*>(ret->data);
const int nt = FindNumThreads(len);
const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL(
_IndexSelectKernel, nb, nt, 0, stream, array_data, idx_data, len,
ret_data);
return ret;
}
#if CUDART_VERSION < 11000
template <typename DType>
cusparseStatus_t Xcsrmm2(
......
......@@ -134,7 +134,7 @@ void SpMMCsrHetero(
cusparse_available<DType, IdType>(more_nnz)) { // cusparse
NDArray efeat = vec_efeat[etype];
if (!IsNullArray(csr.data))
efeat = _IndexSelect<DType, IdType>(efeat, csr.data);
efeat = IndexSelect(efeat, csr.data);
CusparseCsrmm2Hetero<DType, IdType>(
csr.indptr->ctx, csr, static_cast<DType*>(vec_ufeat[src_id]->data),
static_cast<DType*>(efeat->data),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment