Unverified Commit c59000ac authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Cleanup] Remove duplicated _IndexSelect (#4874)

parent 0cb5f0fd
...@@ -45,7 +45,7 @@ void SpMMCsr( ...@@ -45,7 +45,7 @@ void SpMMCsr(
int64_t x_length = 1; int64_t x_length = 1;
for (int i = 1; i < ufeat->ndim; ++i) x_length *= ufeat->shape[i]; for (int i = 1; i < ufeat->ndim; ++i) x_length *= ufeat->shape[i];
if (!IsNullArray(csr.data)) { if (!IsNullArray(csr.data)) {
efeat = _IndexSelect<DType, IdType>(efeat, csr.data); efeat = IndexSelect(efeat, csr.data);
} }
CusparseCsrmm2<DType, IdType>( CusparseCsrmm2<DType, IdType>(
ufeat->ctx, csr, static_cast<DType*>(ufeat->data), ufeat->ctx, csr, static_cast<DType*>(ufeat->data),
......
...@@ -98,19 +98,6 @@ cublasStatus_t Xgeam<double>( ...@@ -98,19 +98,6 @@ cublasStatus_t Xgeam<double>(
handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc);
} }
/**
* @brief IndexSelect operator kernel implementation.
* @note duplicate of IndexSelectKernel defined in array_index_select.cu
*/
template <typename DType, typename IdType>
__global__ void _IndexSelectKernel(
const DType* __restrict__ in, const IdType* __restrict__ idx,
DType* __restrict__ out, int n, int m) {
int i = blockIdx.x;
for (int j = threadIdx.x; j < m; j += blockDim.x)
out[i * m + j] = in[idx[i] * m + j];
}
/** /**
* @brief Transpose operator kernel implementation. * @brief Transpose operator kernel implementation.
* @note not efficient but it's not a bottleneck, used for float16 dtype. * @note not efficient but it's not a bottleneck, used for float16 dtype.
...@@ -168,42 +155,6 @@ void _Transpose<__nv_bfloat16>( ...@@ -168,42 +155,6 @@ void _Transpose<__nv_bfloat16>(
} }
#endif // BF16_ENABLED #endif // BF16_ENABLED
/**
* @brief
*/
template <typename DType, typename IdType>
__global__ void _IndexSelectKernel(
const DType* array, const IdType* index, int64_t length, DType* out) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = gridDim.x * blockDim.x;
while (tx < length) {
out[tx] = array[index[tx]];
tx += stride_x;
}
}
/* @brief IndexSelect operator.
* @note duplicate of IndexSelect defined in array_op.h but it can
* not be applied to float16 dtype.
*/
template <typename DType, typename IdType>
NDArray _IndexSelect(NDArray array, NDArray index) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
const DType* array_data = static_cast<DType*>(array->data);
const IdType* idx_data = static_cast<IdType*>(index->data);
const int64_t arr_len = array->shape[0];
const int64_t len = index->shape[0];
NDArray ret = NDArray::Empty({len}, array->dtype, array->ctx);
if (len == 0) return ret;
DType* ret_data = static_cast<DType*>(ret->data);
const int nt = FindNumThreads(len);
const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL(
_IndexSelectKernel, nb, nt, 0, stream, array_data, idx_data, len,
ret_data);
return ret;
}
#if CUDART_VERSION < 11000 #if CUDART_VERSION < 11000
template <typename DType> template <typename DType>
cusparseStatus_t Xcsrmm2( cusparseStatus_t Xcsrmm2(
......
...@@ -134,7 +134,7 @@ void SpMMCsrHetero( ...@@ -134,7 +134,7 @@ void SpMMCsrHetero(
cusparse_available<DType, IdType>(more_nnz)) { // cusparse cusparse_available<DType, IdType>(more_nnz)) { // cusparse
NDArray efeat = vec_efeat[etype]; NDArray efeat = vec_efeat[etype];
if (!IsNullArray(csr.data)) if (!IsNullArray(csr.data))
efeat = _IndexSelect<DType, IdType>(efeat, csr.data); efeat = IndexSelect(efeat, csr.data);
CusparseCsrmm2Hetero<DType, IdType>( CusparseCsrmm2Hetero<DType, IdType>(
csr.indptr->ctx, csr, static_cast<DType*>(vec_ufeat[src_id]->data), csr.indptr->ctx, csr, static_cast<DType*>(vec_ufeat[src_id]->data),
static_cast<DType*>(efeat->data), static_cast<DType*>(efeat->data),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment