Unverified Commit 00c1e56d authored by Burc Eryilmaz's avatar Burc Eryilmaz Committed by GitHub
Browse files

compile cublasLt code only for cublas >= 11.0 (#1108)


Co-authored-by: default avatarSukru Eryilmaz <seryilmaz@computelab-dgx1v-32.nvidia.com>
parent 082f999a
...@@ -10,9 +10,10 @@ ...@@ -10,9 +10,10 @@
#include <cublas_v2.h> #include <cublas_v2.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
// includes cublaslt // includes cublaslt
#include <cublasLt.h> #include <cublasLt.h>
#endif
// constants for fused bias+relu kernel // constants for fused bias+relu kernel
#define BIAS_RELU_FW_NTHREADS 128 // forward number of thread per block #define BIAS_RELU_FW_NTHREADS 128 // forward number of thread per block
#define BIAS_RELU_BW_NTHREADS_X 32 // backward number of thread in feature dim #define BIAS_RELU_BW_NTHREADS_X 32 // backward number of thread in feature dim
...@@ -167,7 +168,7 @@ cublasStatus_t mlp_gemm( ...@@ -167,7 +168,7 @@ cublasStatus_t mlp_gemm(
CUDA_R_32F, CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP); CUBLAS_GEMM_DEFAULT_TENSOR_OP);
} }
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
int mlp_gemm_lt( int mlp_gemm_lt(
cublasLtHandle_t ltHandle, cublasLtHandle_t ltHandle,
cublasOperation_t transa, cublasOperation_t transa,
...@@ -428,7 +429,7 @@ CLEANUP: ...@@ -428,7 +429,7 @@ CLEANUP:
// enqueued. // enqueued.
return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
} }
#endif
// Bias ADD. Assume input X is [features x batch size], column major. // Bias ADD. Assume input X is [features x batch size], column major.
// Bias is one 'features' long vector, with implicit broadcast. // Bias is one 'features' long vector, with implicit broadcast.
...@@ -1271,6 +1272,7 @@ int mlp_fp( ...@@ -1271,6 +1272,7 @@ int mlp_fp(
// try with cublaslt first for supported case with valid handle // try with cublaslt first for supported case with valid handle
int cublaslt_status = 1; int cublaslt_status = 1;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
if(activation < 1){ if(activation < 1){
cublaslt_status = mlp_gemm_lt( cublaslt_status = mlp_gemm_lt(
//ltHandle, //ltHandle,
...@@ -1295,6 +1297,7 @@ int mlp_fp( ...@@ -1295,6 +1297,7 @@ int mlp_fp(
activation == 1, activation == 1,
bias); bias);
} }
#endif
// if cublaslt failed or not executed, fallback to cublas // if cublaslt failed or not executed, fallback to cublas
if (cublaslt_status != 0) { if (cublaslt_status != 0) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment