Unverified Commit 082f999a authored by Burc Eryilmaz's avatar Burc Eryilmaz Committed by GitHub
Browse files

Fix cublasLt context create/destroy overhead in MLP extension (#1083)

* don't create cublasLt handle, fix zero block size case

* cleanup
parent b8be1bc7
...@@ -718,7 +718,7 @@ void get_biasAddRelu_bprop_grid_size( ...@@ -718,7 +718,7 @@ void get_biasAddRelu_bprop_grid_size(
// Get number of SMs for efficient reduction. // Get number of SMs for efficient reduction.
int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
// can switch to occupancy calculation. use 4 below now for sm_70 // can switch to occupancy calculation. use 4 below now for sm_70
int max_blocks_y = num_SMs * 4 / (*grid_x); int max_blocks_y = (num_SMs * 4+(*grid_x)-1) / (*grid_x);
// block_y should be from minimal work per thread // block_y should be from minimal work per thread
int nRedSplits = (batch_size + block_y - 1) / block_y; int nRedSplits = (batch_size + block_y - 1) / block_y;
// increase number of elem per thread redcution to not launch more than enough // increase number of elem per thread redcution to not launch more than enough
...@@ -1252,9 +1252,6 @@ int mlp_fp( ...@@ -1252,9 +1252,6 @@ int mlp_fp(
// Get cublas handle from Pytorch // Get cublas handle from Pytorch
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasLtHandle_t ltHandle;
cublasStatus_t lthandle_status;
lthandle_status = cublasLtCreate(&ltHandle);
// Get the stream from cublas handle to reuse for biasReLU kernel. // Get the stream from cublas handle to reuse for biasReLU kernel.
cudaStream_t stream; cudaStream_t stream;
cublasGetStream(handle, &stream); cublasGetStream(handle, &stream);
...@@ -1274,28 +1271,29 @@ int mlp_fp( ...@@ -1274,28 +1271,29 @@ int mlp_fp(
// try with cublaslt first for supported case with valid handle // try with cublaslt first for supported case with valid handle
int cublaslt_status = 1; int cublaslt_status = 1;
if(lthandle_status == CUBLAS_STATUS_SUCCESS && activation < 2){ if(activation < 1){
cublaslt_status = mlp_gemm_lt( cublaslt_status = mlp_gemm_lt(
ltHandle, //ltHandle,
CUBLAS_OP_T, (cublasLtHandle_t)handle,
CUBLAS_OP_N, CUBLAS_OP_T,
ofeat, CUBLAS_OP_N,
batch_size, ofeat,
ifeat, batch_size,
&one, ifeat,
weight, &one,
ifeat, weight,
input, ifeat,
ifeat, input,
&zero, ifeat,
output, &zero,
ofeat, output,
lt_workspace, ofeat,
1 << 22, lt_workspace,
stream, 1 << 22,
use_bias == 1, stream,
activation == 1, use_bias == 1,
bias); activation == 1,
bias);
} }
// if cublaslt failed or not executed, fallback to cublas // if cublaslt failed or not executed, fallback to cublas
...@@ -1357,9 +1355,6 @@ int mlp_fp( ...@@ -1357,9 +1355,6 @@ int mlp_fp(
reserved_space_y += ofeat * batch_size; reserved_space_y += ofeat * batch_size;
} }
if(lthandle_status == CUBLAS_STATUS_SUCCESS) cublasLtDestroy(ltHandle);
return 0; return 0;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment