"...text-generation-inference.git" did not exist on "ebecc06161d3399aa1dace7be1a7a86efec85f8d"
Commit 8091b3e2 authored by Hubert Lu's avatar Hubert Lu
Browse files

Fix the hipification issues for cublasGemmEx by adding rocblas_gemm_ex

parent f79993d9
...@@ -30,6 +30,33 @@ cublasStatus_t gemm_bias( ...@@ -30,6 +30,33 @@ cublasStatus_t gemm_bias(
const float* beta, const float* beta,
double* C, double* C,
int ldc) { int ldc) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f64_r,
lda,
B,
rocblas_datatype_f64_r,
ldb,
beta,
C,
rocblas_datatype_f64_r,
ldc,
C,
rocblas_datatype_f64_r,
ldc,
rocblas_datatype_f64_r,
rocblas_gemm_algo_standard,
0,
0);
#else
return cublasGemmEx( return cublasGemmEx(
handle, handle,
transa, transa,
...@@ -50,6 +77,7 @@ cublasStatus_t gemm_bias( ...@@ -50,6 +77,7 @@ cublasStatus_t gemm_bias(
ldc, ldc,
CUDA_R_64F, CUDA_R_64F,
CUBLAS_GEMM_DEFAULT); CUBLAS_GEMM_DEFAULT);
#endif
} }
// FP32 Wrapper around cublas GEMMEx // FP32 Wrapper around cublas GEMMEx
...@@ -68,6 +96,34 @@ cublasStatus_t gemm_bias( ...@@ -68,6 +96,34 @@ cublasStatus_t gemm_bias(
const float* beta, const float* beta,
float* C, float* C,
int ldc) { int ldc) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f32_r,
lda,
B,
rocblas_datatype_f32_r,
ldb,
beta,
C,
rocblas_datatype_f32_r,
ldc,
C,
rocblas_datatype_f32_r,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
0);
#else
return cublasGemmEx( return cublasGemmEx(
handle, handle,
transa, transa,
...@@ -88,6 +144,7 @@ cublasStatus_t gemm_bias( ...@@ -88,6 +144,7 @@ cublasStatus_t gemm_bias(
ldc, ldc,
CUDA_R_32F, CUDA_R_32F,
CUBLAS_GEMM_DEFAULT); CUBLAS_GEMM_DEFAULT);
#endif
} }
// FP16 Tensor core wrapper around cublas GEMMEx // FP16 Tensor core wrapper around cublas GEMMEx
...@@ -106,6 +163,33 @@ cublasStatus_t gemm_bias( ...@@ -106,6 +163,33 @@ cublasStatus_t gemm_bias(
const float* beta, const float* beta,
at::Half* C, at::Half* C,
int ldc) { int ldc) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f16_r,
lda,
B,
rocblas_datatype_f16_r,
ldb,
beta,
C,
rocblas_datatype_f16_r,
ldc,
C,
rocblas_datatype_f16_r,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
0);
#else
return cublasGemmEx( return cublasGemmEx(
handle, handle,
transa, transa,
...@@ -126,6 +210,7 @@ cublasStatus_t gemm_bias( ...@@ -126,6 +210,7 @@ cublasStatus_t gemm_bias(
ldc, ldc,
CUDA_R_32F, CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP); CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
} }
...@@ -1148,7 +1233,7 @@ int linear_bias_forward_cuda(at::Tensor input, T *weight, at::Tensor bias, int i ...@@ -1148,7 +1233,7 @@ int linear_bias_forward_cuda(at::Tensor input, T *weight, at::Tensor bias, int i
const float beta_zero = 0.0; const float beta_zero = 0.0;
const float beta_one = 1.0; const float beta_one = 1.0;
int status = 1; int status = 1;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
status = gemm_bias_lt( status = gemm_bias_lt(
(cublasLtHandle_t)handle, (cublasLtHandle_t)handle,
CUBLAS_OP_T, CUBLAS_OP_T,
...@@ -1200,6 +1285,7 @@ int linear_bias_backward_cuda(T *input, T *weight, T *d_output, int in_features, ...@@ -1200,6 +1285,7 @@ int linear_bias_backward_cuda(T *input, T *weight, T *d_output, int in_features,
cublasGetStream(handle, &stream); cublasGetStream(handle, &stream);
const float alpha = 1.0; const float alpha = 1.0;
const float beta_zero = 0.0; const float beta_zero = 0.0;
const float beta_one = 1.0;
int status = 1; int status = 1;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
status = gemm_bgradb_lt( status = gemm_bgradb_lt(
...@@ -1272,7 +1358,7 @@ int linear_gelu_linear_forward_cuda(T *input, T *weight1, T *bias1, T *weight2, ...@@ -1272,7 +1358,7 @@ int linear_gelu_linear_forward_cuda(T *input, T *weight1, T *bias1, T *weight2,
const float alpha = 1.0; const float alpha = 1.0;
const float beta_zero = 0.0; const float beta_zero = 0.0;
int status = 1; int status = 1;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
status = gemm_bias_gelu_lt( status = gemm_bias_gelu_lt(
(cublasLtHandle_t)handle, (cublasLtHandle_t)handle,
CUBLAS_OP_T, CUBLAS_OP_T,
...@@ -1328,8 +1414,9 @@ int linear_gelu_linear_backward_cuda(T *input, T *gelu_in, T *output1, T *weight ...@@ -1328,8 +1414,9 @@ int linear_gelu_linear_backward_cuda(T *input, T *gelu_in, T *output1, T *weight
cublasGetStream(handle, &stream); cublasGetStream(handle, &stream);
const float alpha = 1.0; const float alpha = 1.0;
const float beta_zero = 0.0; const float beta_zero = 0.0;
const float beta_one = 1.0;
int status = 1; int status = 1;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
//wgrad for first gemm //wgrad for first gemm
status = gemm_bgradb_lt( status = gemm_bgradb_lt(
(cublasLtHandle_t)handle, (cublasLtHandle_t)handle,
...@@ -1435,3 +1522,4 @@ template int linear_gelu_linear_backward_cuda<at::Half>(at::Half *input, at::Hal ...@@ -1435,3 +1522,4 @@ template int linear_gelu_linear_backward_cuda<at::Half>(at::Half *input, at::Hal
template int linear_gelu_linear_backward_cuda<float>(float *input, float *gelu_in, float *output1, float *weight1, float *weight2, float *d_output1, float *d_output2, int in_features, int batch_size, int hidden_features, int out_features, float *d_weight1, float *d_weight2, float *d_bias1, float *d_bias2, float *d_input, void *lt_workspace); template int linear_gelu_linear_backward_cuda<float>(float *input, float *gelu_in, float *output1, float *weight1, float *weight2, float *d_output1, float *d_output2, int in_features, int batch_size, int hidden_features, int out_features, float *d_weight1, float *d_weight2, float *d_bias1, float *d_bias2, float *d_input, void *lt_workspace);
template int linear_gelu_linear_backward_cuda<double>(double *input, double *gelu_in, double *output1, double *weight1, double *weight2, double *d_output1, double *d_output2, int in_features, int batch_size, int hidden_features, int out_features, double *d_weight1, double *d_weight2, double *d_bias1, double *d_bias2, double *d_input, void *lt_workspace); template int linear_gelu_linear_backward_cuda<double>(double *input, double *gelu_in, double *output1, double *weight1, double *weight2, double *d_output1, double *d_output2, int in_features, int batch_size, int hidden_features, int out_features, double *d_weight1, double *d_weight2, double *d_bias1, double *d_bias2, double *d_input, void *lt_workspace);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment