Unverified Commit 082f999a authored by Burc Eryilmaz's avatar Burc Eryilmaz Committed by GitHub
Browse files

Fix cublasLt context create/destroy overhead in MLP extension (#1083)

* don't create cublasLt handle, fix zero block size case

* cleanup
parent b8be1bc7
......@@ -718,7 +718,7 @@ void get_biasAddRelu_bprop_grid_size(
// Get number of SMs for efficient reduction.
int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
// can switch to occupancy calculation. use 4 below now for sm_70
int max_blocks_y = num_SMs * 4 / (*grid_x);
int max_blocks_y = (num_SMs * 4+(*grid_x)-1) / (*grid_x);
// block_y should be from minimal work per thread
int nRedSplits = (batch_size + block_y - 1) / block_y;
// increase number of elem per thread redcution to not launch more than enough
......@@ -1252,9 +1252,6 @@ int mlp_fp(
// Get cublas handle from Pytorch
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasLtHandle_t ltHandle;
cublasStatus_t lthandle_status;
lthandle_status = cublasLtCreate(&ltHandle);
// Get the stream from cublas handle to reuse for biasReLU kernel.
cudaStream_t stream;
cublasGetStream(handle, &stream);
......@@ -1274,9 +1271,10 @@ int mlp_fp(
// try with cublaslt first for supported case with valid handle
int cublaslt_status = 1;
if(lthandle_status == CUBLAS_STATUS_SUCCESS && activation < 2){
if(activation < 1){
cublaslt_status = mlp_gemm_lt(
ltHandle,
//ltHandle,
(cublasLtHandle_t)handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
ofeat,
......@@ -1357,9 +1355,6 @@ int mlp_fp(
reserved_space_y += ofeat * batch_size;
}
if(lthandle_status == CUBLAS_STATUS_SUCCESS) cublasLtDestroy(ltHandle);
return 0;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment