Unverified Commit 54b93919 authored by Burc Eryilmaz's avatar Burc Eryilmaz Committed by GitHub
Browse files

fix CUBLAS guards (#1162)



* support for fused dense layer with cublasLt, fusion in both fprop and bprop

* fix typo causing syntax error

* add fused GEMM+gelu+GEMM modue

* fix typo for workspace size

* update cublas check for 11600

* add tests for fused dense layer

* fix CUDA 10.x path

* safer guard around CUBLAS constants, remove unreferenced variable

* more guard changes

* guard against cublas version instead of cuda
Co-authored-by: default avatarSukru Eryilmaz <seryilmaz@computelab-dgx1v-32.nvidia.com>
parent ae1cdd64
...@@ -62,7 +62,7 @@ std::vector<at::Tensor> linear_bias_backward(at::Tensor input, at::Tensor weight ...@@ -62,7 +62,7 @@ std::vector<at::Tensor> linear_bias_backward(at::Tensor input, at::Tensor weight
// create output/workspace tensor // create output/workspace tensor
auto d_weight = at::empty({out_features, in_features}, input.type()); auto d_weight = at::empty({out_features, in_features}, input.type());
#if defined(CUDA_VERSION) && CUDA_VERSION < 11000 #if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
auto d_bias = d_output.view({-1, out_features}).sum(0, false); auto d_bias = d_output.view({-1, out_features}).sum(0, false);
#else #else
auto d_bias = at::empty({out_features}, input.type()); auto d_bias = at::empty({out_features}, input.type());
......
...@@ -129,7 +129,7 @@ cublasStatus_t gemm_bias( ...@@ -129,7 +129,7 @@ cublasStatus_t gemm_bias(
} }
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000 #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
int gemm_bias_lt( int gemm_bias_lt(
...@@ -1148,7 +1148,7 @@ int linear_bias_forward_cuda(at::Tensor input, T *weight, at::Tensor bias, int i ...@@ -1148,7 +1148,7 @@ int linear_bias_forward_cuda(at::Tensor input, T *weight, at::Tensor bias, int i
const float beta_zero = 0.0; const float beta_zero = 0.0;
const float beta_one = 1.0; const float beta_one = 1.0;
int status = 1; int status = 1;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000 #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
status = gemm_bias_lt( status = gemm_bias_lt(
(cublasLtHandle_t)handle, (cublasLtHandle_t)handle,
CUBLAS_OP_T, CUBLAS_OP_T,
...@@ -1200,7 +1200,6 @@ int linear_bias_backward_cuda(T *input, T *weight, T *d_output, int in_features, ...@@ -1200,7 +1200,6 @@ int linear_bias_backward_cuda(T *input, T *weight, T *d_output, int in_features,
cublasGetStream(handle, &stream); cublasGetStream(handle, &stream);
const float alpha = 1.0; const float alpha = 1.0;
const float beta_zero = 0.0; const float beta_zero = 0.0;
const float beta_one = 1.0;
int status = 1; int status = 1;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
status = gemm_bgradb_lt( status = gemm_bgradb_lt(
...@@ -1273,7 +1272,7 @@ int linear_gelu_linear_forward_cuda(T *input, T *weight1, T *bias1, T *weight2, ...@@ -1273,7 +1272,7 @@ int linear_gelu_linear_forward_cuda(T *input, T *weight1, T *bias1, T *weight2,
const float alpha = 1.0; const float alpha = 1.0;
const float beta_zero = 0.0; const float beta_zero = 0.0;
int status = 1; int status = 1;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000 #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
status = gemm_bias_gelu_lt( status = gemm_bias_gelu_lt(
(cublasLtHandle_t)handle, (cublasLtHandle_t)handle,
CUBLAS_OP_T, CUBLAS_OP_T,
...@@ -1329,9 +1328,8 @@ int linear_gelu_linear_backward_cuda(T *input, T *gelu_in, T *output1, T *weight ...@@ -1329,9 +1328,8 @@ int linear_gelu_linear_backward_cuda(T *input, T *gelu_in, T *output1, T *weight
cublasGetStream(handle, &stream); cublasGetStream(handle, &stream);
const float alpha = 1.0; const float alpha = 1.0;
const float beta_zero = 0.0; const float beta_zero = 0.0;
const float beta_one = 1.0;
int status = 1; int status = 1;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000 #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
//wgrad for first gemm //wgrad for first gemm
status = gemm_bgradb_lt( status = gemm_bgradb_lt(
(cublasLtHandle_t)handle, (cublasLtHandle_t)handle,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment