Unverified Commit 2523bc7a authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Performance] Improve the performance of SpMMCsr by reconfiguration (#4363)

* Change CUDA_MAX_NUM_THREADS to 256

* change the configuration of grid
parent 18d89b5d
...@@ -548,11 +548,11 @@ __global__ void SpMMCsrKernel( ...@@ -548,11 +548,11 @@ __global__ void SpMMCsrKernel(
const int64_t* __restrict__ ebcast_off, const int64_t* __restrict__ ebcast_off,
int64_t ufeat_len, int64_t efeat_len, int64_t out_len) { int64_t ufeat_len, int64_t efeat_len, int64_t out_len) {
// SPMM with CSR. // SPMM with CSR.
int ty = blockIdx.y * blockDim.y + threadIdx.y; int ty = blockIdx.x * blockDim.y + threadIdx.y;
const Idx stride_y = blockDim.y * gridDim.y; const Idx stride_y = blockDim.y * gridDim.x;
const int stride_x = blockDim.x * gridDim.x; const int stride_x = blockDim.x * gridDim.y;
while (ty < num_rows) { while (ty < num_rows) {
int tx = blockIdx.x * blockDim.x + threadIdx.x; int tx = blockIdx.y * blockDim.x + threadIdx.x;
while (tx < out_len) { while (tx < out_len) {
DType local_accum = ReduceOp::zero(); DType local_accum = ReduceOp::zero();
Idx local_argu = 0, local_arge = 0; Idx local_argu = 0, local_arge = 0;
...@@ -759,8 +759,8 @@ void SpMMCsr( ...@@ -759,8 +759,8 @@ void SpMMCsr(
rhs_len = bcast.rhs_len; rhs_len = bcast.rhs_len;
const int ntx = FindNumThreads(len); const int ntx = FindNumThreads(len);
const int nty = CUDA_MAX_NUM_THREADS / ntx; const int nty = CUDA_MAX_NUM_THREADS / ntx;
const int nbx = (len + ntx - 1) / ntx; const int nby= (len + ntx - 1) / ntx;
const int nby = FindNumBlocks<'y'>((csr.num_rows + nty - 1) / nty); const int nbx = FindNumBlocks<'x'>((csr.num_rows + nty - 1) / nty);
//LOG(INFO) << "nblks=(" << nbx << ", " << nby << ") nthrs=(" << ntx << ", " << nty << ")"; //LOG(INFO) << "nblks=(" << nbx << ", " << nby << ") nthrs=(" << ntx << ", " << nty << ")";
const dim3 nblks(nbx, nby); const dim3 nblks(nbx, nby);
const dim3 nthrs(ntx, nty); const dim3 nthrs(ntx, nty);
......
...@@ -18,7 +18,8 @@ namespace cuda { ...@@ -18,7 +18,8 @@ namespace cuda {
#define CUDA_MAX_NUM_BLOCKS_X 0x7FFFFFFF #define CUDA_MAX_NUM_BLOCKS_X 0x7FFFFFFF
#define CUDA_MAX_NUM_BLOCKS_Y 0xFFFF #define CUDA_MAX_NUM_BLOCKS_Y 0xFFFF
#define CUDA_MAX_NUM_BLOCKS_Z 0xFFFF #define CUDA_MAX_NUM_BLOCKS_Z 0xFFFF
#define CUDA_MAX_NUM_THREADS 1024 // The max number of threads per block
#define CUDA_MAX_NUM_THREADS 256
#ifdef USE_FP16 #ifdef USE_FP16
#define SWITCH_BITS(bits, DType, ...) \ #define SWITCH_BITS(bits, DType, ...) \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment