Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
8fc9b21f
Unverified
Commit
8fc9b21f
authored
Aug 11, 2023
by
Pruthvi Madugundu
Committed by
GitHub
Aug 11, 2023
Browse files
Changes to support hipblas migration (#113)
parent
10c74820
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
296 additions
and
274 deletions
+296
-274
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
...contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
+36
-36
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
...src/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
+37
-37
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
...ihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
+24
-24
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
...trib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
+25
-25
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
+24
-24
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
.../csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
+25
-25
apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh
apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh
+51
-3
csrc/fused_dense_cuda.cu
csrc/fused_dense_cuda.cu
+11
-85
csrc/mlp_cuda.cu
csrc/mlp_cuda.cu
+63
-15
No files found.
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
View file @
8fc9b21f
...
...
@@ -90,9 +90,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Q Fwd
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
output_lin_q_dim
,
batches_q
,
embed_dim
,
...
...
@@ -113,12 +113,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
// Input Linear KV Fwd
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
output_lin_kv_dim
,
batches_kv
,
embed_dim
,
...
...
@@ -139,7 +139,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
...
...
@@ -219,9 +219,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags
);
// Output Linear
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
embed_dim
,
batches_q
,
embed_dim
,
...
...
@@ -242,7 +242,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_lin_q_results
,
...
...
@@ -332,9 +332,9 @@ std::vector<torch::Tensor> bwd_cuda(
#endif
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
embed_dim
,
batches_q
,
embed_dim
,
...
...
@@ -355,12 +355,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
embed_dim
,
embed_dim
,
batches_q
,
...
...
@@ -381,7 +381,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
...
...
@@ -493,9 +493,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags
);
// Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
embed_dim
,
batches_q
,
output_lin_q_dim
,
...
...
@@ -516,12 +516,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
// Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
embed_dim
,
output_lin_q_dim
,
batches_q
,
...
...
@@ -542,12 +542,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
// Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
embed_dim
,
batches_kv
,
output_lin_kv_dim
,
...
...
@@ -568,12 +568,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
// Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
embed_dim
,
output_lin_kv_dim
,
batches_kv
,
...
...
@@ -594,7 +594,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
// TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_q_grads
,
...
...
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
View file @
8fc9b21f
...
...
@@ -116,9 +116,9 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_beta_weights
.
data_ptr
()));
// Input Linear Q Fwd
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
output_lin_q_dim
,
batches_q
,
embed_dim
,
...
...
@@ -140,12 +140,12 @@ std::vector<torch::Tensor> fwd_cuda(
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
// Input Linear KV Fwd
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
output_lin_kv_dim
,
batches_kv
,
embed_dim
,
...
...
@@ -166,7 +166,7 @@ std::vector<torch::Tensor> fwd_cuda(
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
...
...
@@ -246,9 +246,9 @@ std::vector<torch::Tensor> fwd_cuda(
flags
);
// Output Linear
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
embed_dim
,
batches_q
,
embed_dim
,
...
...
@@ -269,7 +269,7 @@ std::vector<torch::Tensor> fwd_cuda(
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
// End-of-block Dropout-Add
if
(
is_training
)
{
...
...
@@ -396,9 +396,9 @@ std::vector<torch::Tensor> bwd_cuda(
(
1.0
/
(
1.0
-
dropout_prob
)));
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
embed_dim
,
batches_q
,
embed_dim
,
...
...
@@ -419,12 +419,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
embed_dim
,
embed_dim
,
batches_q
,
...
...
@@ -445,7 +445,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
...
...
@@ -557,9 +557,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags
);
// Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
embed_dim
,
batches_q
,
output_lin_q_dim
,
...
...
@@ -581,12 +581,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
// Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
embed_dim
,
output_lin_q_dim
,
batches_q
,
...
...
@@ -607,12 +607,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
// Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
embed_dim
,
batches_kv
,
output_lin_kv_dim
,
...
...
@@ -633,12 +633,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
// Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
embed_dim
,
output_lin_kv_dim
,
batches_kv
,
...
...
@@ -659,7 +659,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient
<
half
,
float
>
(
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
View file @
8fc9b21f
...
...
@@ -86,9 +86,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
// Input Linear Fwd
input_lin_results
.
copy_
(
input_biases
);
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
output_lin_dim
,
batches
,
embed_dim
,
...
...
@@ -109,7 +109,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
...
...
@@ -183,9 +183,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
outputs
.
copy_
(
output_biases
);
// Output Linear
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
embed_dim
,
batches
,
embed_dim
,
...
...
@@ -206,7 +206,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_lin_results
,
bmm1_results
,
dropout_results
,
...
...
@@ -281,9 +281,9 @@ std::vector<torch::Tensor> bwd_cuda(
#endif
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
embed_dim
,
batches
,
embed_dim
,
...
...
@@ -304,12 +304,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
embed_dim
,
embed_dim
,
batches
,
...
...
@@ -330,7 +330,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
// MatMul2 Dgrad1
...
...
@@ -441,9 +441,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags
);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
embed_dim
,
batches
,
output_lin_dim
,
...
...
@@ -464,12 +464,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
embed_dim
,
output_lin_dim
,
batches
,
...
...
@@ -490,7 +490,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
View file @
8fc9b21f
...
...
@@ -84,9 +84,9 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
// Input Linear Fwd
input_lin_results
.
copy_
(
input_biases
);
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
output_lin_dim
,
batches
,
embed_dim
,
...
...
@@ -107,7 +107,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
...
...
@@ -189,9 +189,9 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
outputs
.
copy_
(
output_biases
);
// Output Linear
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
embed_dim
,
batches
,
embed_dim
,
...
...
@@ -212,7 +212,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_lin_results
,
softmax_results
,
dropout_results
,
...
...
@@ -287,9 +287,9 @@ std::vector<torch::Tensor> bwd_cuda(
#endif
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
embed_dim
,
batches
,
embed_dim
,
...
...
@@ -310,12 +310,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
embed_dim
,
embed_dim
,
batches
,
...
...
@@ -336,7 +336,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
// MatMul2 Dgrad1
...
...
@@ -441,9 +441,9 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches
,
flags
);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
embed_dim
,
batches
,
output_lin_dim
,
...
...
@@ -464,12 +464,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
embed_dim
,
output_lin_dim
,
batches
,
...
...
@@ -490,7 +490,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
View file @
8fc9b21f
...
...
@@ -82,9 +82,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
output_lin_dim
,
batches
,
embed_dim
,
...
...
@@ -105,7 +105,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
...
...
@@ -185,9 +185,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags
);
// Output Linear
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
embed_dim
,
batches
,
embed_dim
,
...
...
@@ -208,7 +208,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_lin_results
,
softmax_results
,
dropout_results
,
...
...
@@ -283,9 +283,9 @@ std::vector<torch::Tensor> bwd_cuda(
#endif
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
embed_dim
,
batches
,
embed_dim
,
...
...
@@ -306,12 +306,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
embed_dim
,
embed_dim
,
batches
,
...
...
@@ -332,7 +332,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
...
...
@@ -444,9 +444,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags
);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
embed_dim
,
batches
,
output_lin_dim
,
...
...
@@ -467,12 +467,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
embed_dim
,
output_lin_dim
,
batches
,
...
...
@@ -493,7 +493,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
View file @
8fc9b21f
...
...
@@ -103,9 +103,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_beta_weights
.
data_ptr
()));
// Input Linear Fwd
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
output_lin_dim
,
batches
,
embed_dim
,
...
...
@@ -127,7 +127,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
...
...
@@ -208,9 +208,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags
);
// Output Linear
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
embed_dim
,
batches
,
embed_dim
,
...
...
@@ -231,7 +231,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
// End-of-block Dropout-Add
...
...
@@ -341,9 +341,9 @@ std::vector<torch::Tensor> bwd_cuda(
(
1.0
/
(
1.0
-
dropout_prob
)));
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
embed_dim
,
batches
,
embed_dim
,
...
...
@@ -364,12 +364,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
embed_dim
,
embed_dim
,
batches
,
...
...
@@ -390,7 +390,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
...
...
@@ -502,9 +502,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags
);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
embed_dim
,
batches
,
output_lin_dim
,
...
...
@@ -526,12 +526,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
embed_dim
,
output_lin_dim
,
batches
,
...
...
@@ -553,7 +553,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
flags
))
)
;
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient
<
half
,
float
>
(
...
...
apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh
View file @
8fc9b21f
...
...
@@ -7,6 +7,8 @@
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <rocblas/rocblas.h>
//#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
...
...
@@ -42,6 +44,52 @@ cublasOperation_t convertTransToCublasOperation(char trans) {
}
}
// needed to work around calling rocblas API instead of hipblas API
static
rocblas_operation
hipOperationToRocOperation
(
hipblasOperation_t
op
)
{
switch
(
op
)
{
case
HIPBLAS_OP_N
:
return
rocblas_operation_none
;
case
HIPBLAS_OP_T
:
return
rocblas_operation_transpose
;
case
HIPBLAS_OP_C
:
return
rocblas_operation_conjugate_transpose
;
}
AT_ERROR
(
"HIPBLAS_STATUS_INVALID_ENUM"
);
}
static
hipblasStatus_t
rocBLASStatusToHIPStatus
(
rocblas_status
error
)
{
switch
(
error
)
{
case
rocblas_status_size_unchanged
:
case
rocblas_status_size_increased
:
case
rocblas_status_success
:
case
rocblas_status_continue
:
return
HIPBLAS_STATUS_SUCCESS
;
case
rocblas_status_invalid_handle
:
return
HIPBLAS_STATUS_NOT_INITIALIZED
;
case
rocblas_status_not_implemented
:
case
rocblas_status_excluded_from_build
:
return
HIPBLAS_STATUS_NOT_SUPPORTED
;
case
rocblas_status_invalid_pointer
:
case
rocblas_status_invalid_size
:
case
rocblas_status_invalid_value
:
case
rocblas_status_size_query_mismatch
:
return
HIPBLAS_STATUS_INVALID_VALUE
;
case
rocblas_status_memory_error
:
return
HIPBLAS_STATUS_ALLOC_FAILED
;
case
rocblas_status_internal_error
:
case
rocblas_status_perf_degraded
:
case
rocblas_status_check_numerics_fail
:
return
HIPBLAS_STATUS_INTERNAL_ERROR
;
case
rocblas_status_arch_mismatch
:
return
HIPBLAS_STATUS_ARCH_MISMATCH
;
}
AT_ERROR
(
"HIPBLAS_STATUS_INVALID_ENUM"
);
}
void
RocblasStridedBatchedGemm
(
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
half
*
d
,
long
ldd
,
long
strideD
,
long
batchCount
,
rocblas_gemm_algo
algo
,
rocblas_int
flags
)
{
...
...
@@ -54,13 +102,13 @@ void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
float
fAlpha
=
alpha
;
float
fBeta
=
beta
;
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_strided_batched_ex
(
handle
,
opa
,
opb
,
(
int
)
m
,
(
int
)
n
,
(
int
)
k
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_strided_batched_ex
(
(
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
opa
)
,
hipOperationToRocOperation
(
opb
)
,
(
int
)
m
,
(
int
)
n
,
(
int
)
k
,
(
void
*
)
&
fAlpha
,
a
,
rocblas_datatype_f16_r
/*a_type*/
,
(
int
)
lda
,
strideA
,
b
,
rocblas_datatype_f16_r
/*b_type*/
,
(
int
)
ldb
,
strideB
,
(
void
*
)
&
fBeta
,
c
,
rocblas_datatype_f16_r
/*c_type*/
,
(
int
)
ldc
,
strideC
,
d
,
rocblas_datatype_f16_r
/*d_type*/
,
int
(
ldd
),
strideD
,
(
int
)
batchCount
,
rocblas_datatype_f32_r
/*compute_type*/
,
algo
,
0
/*solution_index*/
,
flags
));
(
int
)
batchCount
,
rocblas_datatype_f32_r
/*compute_type*/
,
algo
,
0
/*solution_index*/
,
flags
))
)
;
}
void
gemm_switch_fp32accum
(
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
...
...
csrc/fused_dense_cuda.cu
View file @
8fc9b21f
...
...
@@ -10,10 +10,21 @@
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <rocblas/rocblas.h>
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
// includes cublaslt
#include <cublasLt.h>
#endif
// until we use hipblas v2
// hipify correctly maps things like CUDA_R_16F to HIP_R_16F,
// however hipblas v1 is still using its custom type
#define HIP_R_64F HIPBLAS_R_64F
#define HIP_R_32F HIPBLAS_R_32F
#define HIP_R_16F HIPBLAS_R_16F
// FP64 Wrapper around cublas GEMMEx
cublasStatus_t
gemm_bias
(
cublasHandle_t
handle
,
...
...
@@ -30,33 +41,6 @@ cublasStatus_t gemm_bias(
const
float
*
beta
,
double
*
C
,
int
ldc
)
{
#ifdef __HIP_PLATFORM_HCC__
return
rocblas_gemm_ex
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
rocblas_datatype_f64_r
,
lda
,
B
,
rocblas_datatype_f64_r
,
ldb
,
beta
,
C
,
rocblas_datatype_f64_r
,
ldc
,
C
,
rocblas_datatype_f64_r
,
ldc
,
rocblas_datatype_f64_r
,
rocblas_gemm_algo_standard
,
0
,
0
);
#else
return
cublasGemmEx
(
handle
,
transa
,
...
...
@@ -77,7 +61,6 @@ cublasStatus_t gemm_bias(
ldc
,
CUDA_R_64F
,
CUBLAS_GEMM_DEFAULT
);
#endif
}
// FP32 Wrapper around cublas GEMMEx
...
...
@@ -96,34 +79,6 @@ cublasStatus_t gemm_bias(
const
float
*
beta
,
float
*
C
,
int
ldc
)
{
#ifdef __HIP_PLATFORM_HCC__
return
rocblas_gemm_ex
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
rocblas_datatype_f32_r
,
lda
,
B
,
rocblas_datatype_f32_r
,
ldb
,
beta
,
C
,
rocblas_datatype_f32_r
,
ldc
,
C
,
rocblas_datatype_f32_r
,
ldc
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
,
0
,
0
);
#else
return
cublasGemmEx
(
handle
,
transa
,
...
...
@@ -144,7 +99,6 @@ cublasStatus_t gemm_bias(
ldc
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT
);
#endif
}
// FP16 Tensor core wrapper around cublas GEMMEx
...
...
@@ -163,33 +117,6 @@ cublasStatus_t gemm_bias(
const
float
*
beta
,
at
::
Half
*
C
,
int
ldc
)
{
#ifdef __HIP_PLATFORM_HCC__
return
rocblas_gemm_ex
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
rocblas_datatype_f16_r
,
lda
,
B
,
rocblas_datatype_f16_r
,
ldb
,
beta
,
C
,
rocblas_datatype_f16_r
,
ldc
,
C
,
rocblas_datatype_f16_r
,
ldc
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
,
0
,
0
);
#else
return
cublasGemmEx
(
handle
,
transa
,
...
...
@@ -210,7 +137,6 @@ cublasStatus_t gemm_bias(
ldc
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
#endif
}
...
...
csrc/mlp_cuda.cu
View file @
8fc9b21f
...
...
@@ -12,6 +12,8 @@
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <rocblas/rocblas.h>
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
// includes cublaslt
#include <cublasLt.h>
...
...
@@ -58,6 +60,52 @@ __device__ __inline__ float sigmoid(float a) {
return
(
retf
);
}
// needed to work around calling rocblas API instead of hipblas API
static
rocblas_operation
hipOperationToRocOperation
(
hipblasOperation_t
op
)
{
switch
(
op
)
{
case
HIPBLAS_OP_N
:
return
rocblas_operation_none
;
case
HIPBLAS_OP_T
:
return
rocblas_operation_transpose
;
case
HIPBLAS_OP_C
:
return
rocblas_operation_conjugate_transpose
;
}
AT_ERROR
(
"HIPBLAS_STATUS_INVALID_ENUM"
);
}
static
hipblasStatus_t
rocBLASStatusToHIPStatus
(
rocblas_status
error
)
{
switch
(
error
)
{
case
rocblas_status_size_unchanged
:
case
rocblas_status_size_increased
:
case
rocblas_status_success
:
case
rocblas_status_continue
:
return
HIPBLAS_STATUS_SUCCESS
;
case
rocblas_status_invalid_handle
:
return
HIPBLAS_STATUS_NOT_INITIALIZED
;
case
rocblas_status_not_implemented
:
case
rocblas_status_excluded_from_build
:
return
HIPBLAS_STATUS_NOT_SUPPORTED
;
case
rocblas_status_invalid_pointer
:
case
rocblas_status_invalid_size
:
case
rocblas_status_invalid_value
:
case
rocblas_status_size_query_mismatch
:
return
HIPBLAS_STATUS_INVALID_VALUE
;
case
rocblas_status_memory_error
:
return
HIPBLAS_STATUS_ALLOC_FAILED
;
case
rocblas_status_internal_error
:
case
rocblas_status_perf_degraded
:
case
rocblas_status_check_numerics_fail
:
return
HIPBLAS_STATUS_INTERNAL_ERROR
;
case
rocblas_status_arch_mismatch
:
return
HIPBLAS_STATUS_ARCH_MISMATCH
;
}
AT_ERROR
(
"HIPBLAS_STATUS_INVALID_ENUM"
);
}
// FP64 Wrapper around cublas GEMMEx
cublasStatus_t
mlp_gemm
(
cublasHandle_t
handle
,
...
...
@@ -76,10 +124,10 @@ cublasStatus_t mlp_gemm(
int
ldc
,
int
flag
)
{
#ifdef __HIP_PLATFORM_HCC__
return
rocblas_gemm_ex
(
handle
,
transa
,
transb
,
return
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
(
(
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
transa
)
,
hipOperationToRocOperation
(
transb
)
,
m
,
n
,
k
,
...
...
@@ -100,7 +148,7 @@ cublasStatus_t mlp_gemm(
rocblas_datatype_f64_r
,
rocblas_gemm_algo_standard
,
0
,
flag
);
flag
)
)
;
#else
return
cublasGemmEx
(
handle
,
...
...
@@ -143,10 +191,10 @@ cublasStatus_t mlp_gemm(
int
ldc
,
int
flag
)
{
#ifdef __HIP_PLATFORM_HCC__
return
rocblas_gemm_ex
(
handle
,
transa
,
transb
,
return
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
(
(
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
transa
)
,
hipOperationToRocOperation
(
transb
)
,
m
,
n
,
k
,
...
...
@@ -167,7 +215,7 @@ cublasStatus_t mlp_gemm(
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
,
0
,
flag
);
flag
)
)
;
#else
return
cublasGemmEx
(
...
...
@@ -211,10 +259,10 @@ cublasStatus_t mlp_gemm(
int
ldc
,
int
flag
)
{
#ifdef __HIP_PLATFORM_HCC__
return
rocblas_gemm_ex
(
handle
,
transa
,
transb
,
return
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
(
(
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
transa
)
,
hipOperationToRocOperation
(
transb
)
,
m
,
n
,
k
,
...
...
@@ -235,7 +283,7 @@ cublasStatus_t mlp_gemm(
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
,
0
,
flag
);
flag
)
)
;
#else
return
cublasGemmEx
(
handle
,
...
...
flyingdown
@flyingdown
mentioned in commit
e4d21865
·
Sep 18, 2023
mentioned in commit
e4d21865
mentioned in commit e4d218653b4143a7bd7cc11d88c88528be473aad
Toggle commit list
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment