Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
8091b3e2
"...text-generation-inference.git" did not exist on "ebecc06161d3399aa1dace7be1a7a86efec85f8d"
Commit
8091b3e2
authored
Oct 19, 2021
by
Hubert Lu
Browse files
Fix the hipification issues for cublasGemmEx by adding rocblas_gemm_ex
parent
f79993d9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
91 additions
and
3 deletions
+91
-3
csrc/fused_dense_cuda.cu
csrc/fused_dense_cuda.cu
+91
-3
No files found.
csrc/fused_dense_cuda.cu
View file @
8091b3e2
...
@@ -30,6 +30,33 @@ cublasStatus_t gemm_bias(
...
@@ -30,6 +30,33 @@ cublasStatus_t gemm_bias(
const
float
*
beta
,
const
float
*
beta
,
double
*
C
,
double
*
C
,
int
ldc
)
{
int
ldc
)
{
#ifdef __HIP_PLATFORM_HCC__
return
rocblas_gemm_ex
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
rocblas_datatype_f64_r
,
lda
,
B
,
rocblas_datatype_f64_r
,
ldb
,
beta
,
C
,
rocblas_datatype_f64_r
,
ldc
,
C
,
rocblas_datatype_f64_r
,
ldc
,
rocblas_datatype_f64_r
,
rocblas_gemm_algo_standard
,
0
,
0
);
#else
return
cublasGemmEx
(
return
cublasGemmEx
(
handle
,
handle
,
transa
,
transa
,
...
@@ -50,6 +77,7 @@ cublasStatus_t gemm_bias(
...
@@ -50,6 +77,7 @@ cublasStatus_t gemm_bias(
ldc
,
ldc
,
CUDA_R_64F
,
CUDA_R_64F
,
CUBLAS_GEMM_DEFAULT
);
CUBLAS_GEMM_DEFAULT
);
#endif
}
}
// FP32 Wrapper around cublas GEMMEx
// FP32 Wrapper around cublas GEMMEx
...
@@ -68,6 +96,34 @@ cublasStatus_t gemm_bias(
...
@@ -68,6 +96,34 @@ cublasStatus_t gemm_bias(
const
float
*
beta
,
const
float
*
beta
,
float
*
C
,
float
*
C
,
int
ldc
)
{
int
ldc
)
{
#ifdef __HIP_PLATFORM_HCC__
return
rocblas_gemm_ex
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
rocblas_datatype_f32_r
,
lda
,
B
,
rocblas_datatype_f32_r
,
ldb
,
beta
,
C
,
rocblas_datatype_f32_r
,
ldc
,
C
,
rocblas_datatype_f32_r
,
ldc
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
,
0
,
0
);
#else
return
cublasGemmEx
(
return
cublasGemmEx
(
handle
,
handle
,
transa
,
transa
,
...
@@ -88,6 +144,7 @@ cublasStatus_t gemm_bias(
...
@@ -88,6 +144,7 @@ cublasStatus_t gemm_bias(
ldc
,
ldc
,
CUDA_R_32F
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT
);
CUBLAS_GEMM_DEFAULT
);
#endif
}
}
// FP16 Tensor core wrapper around cublas GEMMEx
// FP16 Tensor core wrapper around cublas GEMMEx
...
@@ -106,6 +163,33 @@ cublasStatus_t gemm_bias(
...
@@ -106,6 +163,33 @@ cublasStatus_t gemm_bias(
const
float
*
beta
,
const
float
*
beta
,
at
::
Half
*
C
,
at
::
Half
*
C
,
int
ldc
)
{
int
ldc
)
{
#ifdef __HIP_PLATFORM_HCC__
return
rocblas_gemm_ex
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
rocblas_datatype_f16_r
,
lda
,
B
,
rocblas_datatype_f16_r
,
ldb
,
beta
,
C
,
rocblas_datatype_f16_r
,
ldc
,
C
,
rocblas_datatype_f16_r
,
ldc
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
,
0
,
0
);
#else
return
cublasGemmEx
(
return
cublasGemmEx
(
handle
,
handle
,
transa
,
transa
,
...
@@ -126,6 +210,7 @@ cublasStatus_t gemm_bias(
...
@@ -126,6 +210,7 @@ cublasStatus_t gemm_bias(
ldc
,
ldc
,
CUDA_R_32F
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
#endif
}
}
...
@@ -1148,7 +1233,7 @@ int linear_bias_forward_cuda(at::Tensor input, T *weight, at::Tensor bias, int i
...
@@ -1148,7 +1233,7 @@ int linear_bias_forward_cuda(at::Tensor input, T *weight, at::Tensor bias, int i
const
float
beta_zero
=
0.0
;
const
float
beta_zero
=
0.0
;
const
float
beta_one
=
1.0
;
const
float
beta_one
=
1.0
;
int
status
=
1
;
int
status
=
1
;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11
6
00
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11
0
00
status
=
gemm_bias_lt
(
status
=
gemm_bias_lt
(
(
cublasLtHandle_t
)
handle
,
(
cublasLtHandle_t
)
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
...
@@ -1200,6 +1285,7 @@ int linear_bias_backward_cuda(T *input, T *weight, T *d_output, int in_features,
...
@@ -1200,6 +1285,7 @@ int linear_bias_backward_cuda(T *input, T *weight, T *d_output, int in_features,
cublasGetStream
(
handle
,
&
stream
);
cublasGetStream
(
handle
,
&
stream
);
const
float
alpha
=
1.0
;
const
float
alpha
=
1.0
;
const
float
beta_zero
=
0.0
;
const
float
beta_zero
=
0.0
;
const
float
beta_one
=
1.0
;
int
status
=
1
;
int
status
=
1
;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
status
=
gemm_bgradb_lt
(
status
=
gemm_bgradb_lt
(
...
@@ -1272,7 +1358,7 @@ int linear_gelu_linear_forward_cuda(T *input, T *weight1, T *bias1, T *weight2,
...
@@ -1272,7 +1358,7 @@ int linear_gelu_linear_forward_cuda(T *input, T *weight1, T *bias1, T *weight2,
const
float
alpha
=
1.0
;
const
float
alpha
=
1.0
;
const
float
beta_zero
=
0.0
;
const
float
beta_zero
=
0.0
;
int
status
=
1
;
int
status
=
1
;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11
6
00
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11
0
00
status
=
gemm_bias_gelu_lt
(
status
=
gemm_bias_gelu_lt
(
(
cublasLtHandle_t
)
handle
,
(
cublasLtHandle_t
)
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
...
@@ -1328,8 +1414,9 @@ int linear_gelu_linear_backward_cuda(T *input, T *gelu_in, T *output1, T *weight
...
@@ -1328,8 +1414,9 @@ int linear_gelu_linear_backward_cuda(T *input, T *gelu_in, T *output1, T *weight
cublasGetStream
(
handle
,
&
stream
);
cublasGetStream
(
handle
,
&
stream
);
const
float
alpha
=
1.0
;
const
float
alpha
=
1.0
;
const
float
beta_zero
=
0.0
;
const
float
beta_zero
=
0.0
;
const
float
beta_one
=
1.0
;
int
status
=
1
;
int
status
=
1
;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11
6
00
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11
0
00
//wgrad for first gemm
//wgrad for first gemm
status
=
gemm_bgradb_lt
(
status
=
gemm_bgradb_lt
(
(
cublasLtHandle_t
)
handle
,
(
cublasLtHandle_t
)
handle
,
...
@@ -1435,3 +1522,4 @@ template int linear_gelu_linear_backward_cuda<at::Half>(at::Half *input, at::Hal
...
@@ -1435,3 +1522,4 @@ template int linear_gelu_linear_backward_cuda<at::Half>(at::Half *input, at::Hal
template
int
linear_gelu_linear_backward_cuda
<
float
>(
float
*
input
,
float
*
gelu_in
,
float
*
output1
,
float
*
weight1
,
float
*
weight2
,
float
*
d_output1
,
float
*
d_output2
,
int
in_features
,
int
batch_size
,
int
hidden_features
,
int
out_features
,
float
*
d_weight1
,
float
*
d_weight2
,
float
*
d_bias1
,
float
*
d_bias2
,
float
*
d_input
,
void
*
lt_workspace
);
template
int
linear_gelu_linear_backward_cuda
<
float
>(
float
*
input
,
float
*
gelu_in
,
float
*
output1
,
float
*
weight1
,
float
*
weight2
,
float
*
d_output1
,
float
*
d_output2
,
int
in_features
,
int
batch_size
,
int
hidden_features
,
int
out_features
,
float
*
d_weight1
,
float
*
d_weight2
,
float
*
d_bias1
,
float
*
d_bias2
,
float
*
d_input
,
void
*
lt_workspace
);
template
int
linear_gelu_linear_backward_cuda
<
double
>(
double
*
input
,
double
*
gelu_in
,
double
*
output1
,
double
*
weight1
,
double
*
weight2
,
double
*
d_output1
,
double
*
d_output2
,
int
in_features
,
int
batch_size
,
int
hidden_features
,
int
out_features
,
double
*
d_weight1
,
double
*
d_weight2
,
double
*
d_bias1
,
double
*
d_bias2
,
double
*
d_input
,
void
*
lt_workspace
);
template
int
linear_gelu_linear_backward_cuda
<
double
>(
double
*
input
,
double
*
gelu_in
,
double
*
output1
,
double
*
weight1
,
double
*
weight2
,
double
*
d_output1
,
double
*
d_output2
,
int
in_features
,
int
batch_size
,
int
hidden_features
,
int
out_features
,
double
*
d_weight1
,
double
*
d_weight2
,
double
*
d_bias1
,
double
*
d_bias2
,
double
*
d_input
,
void
*
lt_workspace
);
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment