Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
3ba7192d
Unverified
Commit
3ba7192d
authored
Sep 06, 2023
by
Peng
Committed by
GitHub
Sep 06, 2023
Browse files
Merge pull request #116 from ROCmSoftwarePlatform/revert_hipblas
Revert "Changes to support hipblas migration (#113)"
parents
8fc9b21f
e4d21865
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
274 additions
and
296 deletions
+274
-296
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
...contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
+36
-36
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
...src/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
+37
-37
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
...ihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
+24
-24
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
...trib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
+25
-25
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
+24
-24
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
.../csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
+25
-25
apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh
apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh
+3
-51
csrc/fused_dense_cuda.cu
csrc/fused_dense_cuda.cu
+85
-11
csrc/mlp_cuda.cu
csrc/mlp_cuda.cu
+15
-63
No files found.
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
View file @
3ba7192d
...
@@ -90,9 +90,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -90,9 +90,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Q Fwd
// Input Linear Q Fwd
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
output_lin_q_dim
,
output_lin_q_dim
,
batches_q
,
batches_q
,
embed_dim
,
embed_dim
,
...
@@ -113,12 +113,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -113,12 +113,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// Input Linear KV Fwd
// Input Linear KV Fwd
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
output_lin_kv_dim
,
output_lin_kv_dim
,
batches_kv
,
batches_kv
,
embed_dim
,
embed_dim
,
...
@@ -139,7 +139,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -139,7 +139,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
gemm_switch_fp32accum
(
a_layout_t
,
...
@@ -219,9 +219,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -219,9 +219,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags
);
flags
);
// Output Linear
// Output Linear
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
batches_q
,
batches_q
,
embed_dim
,
embed_dim
,
...
@@ -242,7 +242,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -242,7 +242,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_lin_q_results
,
return
{
input_lin_q_results
,
...
@@ -332,9 +332,9 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -332,9 +332,9 @@ std::vector<torch::Tensor> bwd_cuda(
#endif
#endif
// Output Linear Dgrad
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
batches_q
,
batches_q
,
embed_dim
,
embed_dim
,
...
@@ -355,12 +355,12 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -355,12 +355,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// Output Linear Wgrad
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
embed_dim
,
embed_dim
,
batches_q
,
batches_q
,
...
@@ -381,7 +381,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -381,7 +381,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// MatMul2 Dgrad1
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
gemm_switch_fp32accum
(
a_layout_t
,
...
@@ -493,9 +493,9 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -493,9 +493,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags
);
flags
);
// Input Linear Q Dgrad
// Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
batches_q
,
batches_q
,
output_lin_q_dim
,
output_lin_q_dim
,
...
@@ -516,12 +516,12 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -516,12 +516,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// Input Linear Q Wgrad
// Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
output_lin_q_dim
,
output_lin_q_dim
,
batches_q
,
batches_q
,
...
@@ -542,12 +542,12 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -542,12 +542,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// Input Linear KV Dgrad
// Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
batches_kv
,
batches_kv
,
output_lin_kv_dim
,
output_lin_kv_dim
,
...
@@ -568,12 +568,12 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -568,12 +568,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// Input Linear KV Wgrad
// Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
output_lin_kv_dim
,
output_lin_kv_dim
,
batches_kv
,
batches_kv
,
...
@@ -594,7 +594,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -594,7 +594,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
// TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
return
{
input_q_grads
,
input_q_grads
,
...
...
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
View file @
3ba7192d
...
@@ -116,9 +116,9 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -116,9 +116,9 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_beta_weights
.
data_ptr
()));
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_beta_weights
.
data_ptr
()));
// Input Linear Q Fwd
// Input Linear Q Fwd
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
output_lin_q_dim
,
output_lin_q_dim
,
batches_q
,
batches_q
,
embed_dim
,
embed_dim
,
...
@@ -140,12 +140,12 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -140,12 +140,12 @@ std::vector<torch::Tensor> fwd_cuda(
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// Input Linear KV Fwd
// Input Linear KV Fwd
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
output_lin_kv_dim
,
output_lin_kv_dim
,
batches_kv
,
batches_kv
,
embed_dim
,
embed_dim
,
...
@@ -166,7 +166,7 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -166,7 +166,7 @@ std::vector<torch::Tensor> fwd_cuda(
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
b_layout_n
,
...
@@ -246,9 +246,9 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -246,9 +246,9 @@ std::vector<torch::Tensor> fwd_cuda(
flags
);
flags
);
// Output Linear
// Output Linear
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
batches_q
,
batches_q
,
embed_dim
,
embed_dim
,
...
@@ -269,7 +269,7 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -269,7 +269,7 @@ std::vector<torch::Tensor> fwd_cuda(
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// End-of-block Dropout-Add
// End-of-block Dropout-Add
if
(
is_training
)
{
if
(
is_training
)
{
...
@@ -396,9 +396,9 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -396,9 +396,9 @@ std::vector<torch::Tensor> bwd_cuda(
(
1.0
/
(
1.0
-
dropout_prob
)));
(
1.0
/
(
1.0
-
dropout_prob
)));
// Output Linear Dgrad
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
batches_q
,
batches_q
,
embed_dim
,
embed_dim
,
...
@@ -419,12 +419,12 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -419,12 +419,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// Output Linear Wgrad
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
embed_dim
,
embed_dim
,
batches_q
,
batches_q
,
...
@@ -445,7 +445,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -445,7 +445,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// MatMul2 Dgrad1
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
gemm_switch_fp32accum
(
a_layout_t
,
...
@@ -557,9 +557,9 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -557,9 +557,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags
);
flags
);
// Input Linear Q Dgrad
// Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
batches_q
,
batches_q
,
output_lin_q_dim
,
output_lin_q_dim
,
...
@@ -581,12 +581,12 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -581,12 +581,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// Input Linear Q Wgrad
// Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
output_lin_q_dim
,
output_lin_q_dim
,
batches_q
,
batches_q
,
...
@@ -607,12 +607,12 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -607,12 +607,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// Input Linear KV Dgrad
// Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
batches_kv
,
batches_kv
,
output_lin_kv_dim
,
output_lin_kv_dim
,
...
@@ -633,12 +633,12 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -633,12 +633,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// Input Linear KV Wgrad
// Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
output_lin_kv_dim
,
output_lin_kv_dim
,
batches_kv
,
batches_kv
,
...
@@ -659,7 +659,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -659,7 +659,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// Fused Layer Norm Bwd with Residual Add
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient
<
half
,
float
>
(
HostLayerNormGradient
<
half
,
float
>
(
...
@@ -687,4 +687,4 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -687,4 +687,4 @@ std::vector<torch::Tensor> bwd_cuda(
}
// end namespace rocblas_gemmex
}
// end namespace rocblas_gemmex
}
// end namespace encdec_norm_add
}
// end namespace encdec_norm_add
}
// end namespace multihead_attn
}
// end namespace multihead_attn
\ No newline at end of file
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
View file @
3ba7192d
...
@@ -86,9 +86,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -86,9 +86,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
// Input Linear Fwd
// Input Linear Fwd
input_lin_results
.
copy_
(
input_biases
);
input_lin_results
.
copy_
(
input_biases
);
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
output_lin_dim
,
output_lin_dim
,
batches
,
batches
,
embed_dim
,
embed_dim
,
...
@@ -109,7 +109,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -109,7 +109,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
gemm_switch_fp32accum
(
a_layout_t
,
...
@@ -183,9 +183,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -183,9 +183,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
outputs
.
copy_
(
output_biases
);
outputs
.
copy_
(
output_biases
);
// Output Linear
// Output Linear
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
batches
,
batches
,
embed_dim
,
embed_dim
,
...
@@ -206,7 +206,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -206,7 +206,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_lin_results
,
bmm1_results
,
dropout_results
,
return
{
input_lin_results
,
bmm1_results
,
dropout_results
,
...
@@ -281,9 +281,9 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -281,9 +281,9 @@ std::vector<torch::Tensor> bwd_cuda(
#endif
#endif
// Output Linear Dgrad
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
batches
,
batches
,
embed_dim
,
embed_dim
,
...
@@ -304,12 +304,12 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -304,12 +304,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// Output Linear Wgrad
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
embed_dim
,
embed_dim
,
batches
,
batches
,
...
@@ -330,7 +330,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -330,7 +330,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
// MatMul2 Dgrad1
// MatMul2 Dgrad1
...
@@ -441,9 +441,9 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -441,9 +441,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags
);
flags
);
// Input Linear Dgrad
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
batches
,
batches
,
output_lin_dim
,
output_lin_dim
,
...
@@ -464,12 +464,12 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -464,12 +464,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// Input Linear Wgrad
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
output_lin_dim
,
output_lin_dim
,
batches
,
batches
,
...
@@ -490,7 +490,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -490,7 +490,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
View file @
3ba7192d
...
@@ -84,9 +84,9 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
...
@@ -84,9 +84,9 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
// Input Linear Fwd
// Input Linear Fwd
input_lin_results
.
copy_
(
input_biases
);
input_lin_results
.
copy_
(
input_biases
);
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
output_lin_dim
,
output_lin_dim
,
batches
,
batches
,
embed_dim
,
embed_dim
,
...
@@ -107,7 +107,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
...
@@ -107,7 +107,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
rocblas_datatype_f32_r
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
gemm_switch_fp32accum
(
a_layout_t
,
...
@@ -189,9 +189,9 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
...
@@ -189,9 +189,9 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
outputs
.
copy_
(
output_biases
);
outputs
.
copy_
(
output_biases
);
// Output Linear
// Output Linear
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
batches
,
batches
,
embed_dim
,
embed_dim
,
...
@@ -212,7 +212,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
...
@@ -212,7 +212,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
rocblas_datatype_f32_r
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_lin_results
,
softmax_results
,
dropout_results
,
return
{
input_lin_results
,
softmax_results
,
dropout_results
,
...
@@ -287,9 +287,9 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -287,9 +287,9 @@ std::vector<torch::Tensor> bwd_cuda(
#endif
#endif
// Output Linear Dgrad
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
batches
,
batches
,
embed_dim
,
embed_dim
,
...
@@ -310,12 +310,12 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -310,12 +310,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// Output Linear Wgrad
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
embed_dim
,
embed_dim
,
batches
,
batches
,
...
@@ -336,7 +336,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -336,7 +336,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
// MatMul2 Dgrad1
// MatMul2 Dgrad1
...
@@ -441,9 +441,9 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -441,9 +441,9 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches
,
attn_batches
,
flags
);
flags
);
// Input Linear Dgrad
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
batches
,
batches
,
output_lin_dim
,
output_lin_dim
,
...
@@ -464,12 +464,12 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -464,12 +464,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// Input Linear Wgrad
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
output_lin_dim
,
output_lin_dim
,
batches
,
batches
,
...
@@ -490,7 +490,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -490,7 +490,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
...
@@ -501,4 +501,4 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -501,4 +501,4 @@ std::vector<torch::Tensor> bwd_cuda(
}
// end namespace rocblas_gemmex
}
// end namespace rocblas_gemmex
}
// end namespace self
}
// end namespace self
}
// end namespace multihead_attn
}
// end namespace multihead_attn
\ No newline at end of file
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
View file @
3ba7192d
...
@@ -82,9 +82,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -82,9 +82,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
// Input Linear Fwd
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
output_lin_dim
,
output_lin_dim
,
batches
,
batches
,
embed_dim
,
embed_dim
,
...
@@ -105,7 +105,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -105,7 +105,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
gemm_switch_fp32accum
(
a_layout_t
,
...
@@ -185,9 +185,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -185,9 +185,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags
);
flags
);
// Output Linear
// Output Linear
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
batches
,
batches
,
embed_dim
,
embed_dim
,
...
@@ -208,7 +208,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -208,7 +208,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_lin_results
,
softmax_results
,
dropout_results
,
return
{
input_lin_results
,
softmax_results
,
dropout_results
,
...
@@ -283,9 +283,9 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -283,9 +283,9 @@ std::vector<torch::Tensor> bwd_cuda(
#endif
#endif
// Output Linear Dgrad
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
batches
,
batches
,
embed_dim
,
embed_dim
,
...
@@ -306,12 +306,12 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -306,12 +306,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// Output Linear Wgrad
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
embed_dim
,
embed_dim
,
batches
,
batches
,
...
@@ -332,7 +332,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -332,7 +332,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// MatMul2 Dgrad1
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
gemm_switch_fp32accum
(
a_layout_t
,
...
@@ -444,9 +444,9 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -444,9 +444,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags
);
flags
);
// Input Linear Dgrad
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
batches
,
batches
,
output_lin_dim
,
output_lin_dim
,
...
@@ -467,12 +467,12 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -467,12 +467,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// Input Linear Wgrad
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
output_lin_dim
,
output_lin_dim
,
batches
,
batches
,
...
@@ -493,7 +493,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -493,7 +493,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
return
{
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
View file @
3ba7192d
...
@@ -103,9 +103,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -103,9 +103,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_beta_weights
.
data_ptr
()));
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_beta_weights
.
data_ptr
()));
// Input Linear Fwd
// Input Linear Fwd
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
output_lin_dim
,
output_lin_dim
,
batches
,
batches
,
embed_dim
,
embed_dim
,
...
@@ -127,7 +127,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -127,7 +127,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
gemm_switch_fp32accum
(
a_layout_t
,
...
@@ -208,9 +208,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -208,9 +208,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags
);
flags
);
// Output Linear
// Output Linear
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
batches
,
batches
,
embed_dim
,
embed_dim
,
...
@@ -231,7 +231,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -231,7 +231,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// End-of-block Dropout-Add
// End-of-block Dropout-Add
...
@@ -341,9 +341,9 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -341,9 +341,9 @@ std::vector<torch::Tensor> bwd_cuda(
(
1.0
/
(
1.0
-
dropout_prob
)));
(
1.0
/
(
1.0
-
dropout_prob
)));
// Output Linear Dgrad
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
batches
,
batches
,
embed_dim
,
embed_dim
,
...
@@ -364,12 +364,12 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -364,12 +364,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// Output Linear Wgrad
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
embed_dim
,
embed_dim
,
batches
,
batches
,
...
@@ -390,7 +390,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -390,7 +390,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// MatMul2 Dgrad1
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
gemm_switch_fp32accum
(
a_layout_t
,
...
@@ -502,9 +502,9 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -502,9 +502,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags
);
flags
);
// Input Linear Dgrad
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
batches
,
batches
,
output_lin_dim
,
output_lin_dim
,
...
@@ -526,12 +526,12 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -526,12 +526,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// Input Linear Wgrad
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
output_lin_dim
,
output_lin_dim
,
batches
,
batches
,
...
@@ -553,7 +553,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -553,7 +553,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// Fused Layer Norm Bwd with Residual Add
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient
<
half
,
float
>
(
HostLayerNormGradient
<
half
,
float
>
(
...
@@ -577,4 +577,4 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -577,4 +577,4 @@ std::vector<torch::Tensor> bwd_cuda(
}
// end namespace rocblas_gemmex
}
// end namespace rocblas_gemmex
}
// end namespace self_norm_add
}
// end namespace self_norm_add
}
// end namespace multihead_attn
}
// end namespace multihead_attn
\ No newline at end of file
apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh
View file @
3ba7192d
...
@@ -7,8 +7,6 @@
...
@@ -7,8 +7,6 @@
//#include <cuda_profiler_api.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <rocblas/rocblas.h>
//#include <ATen/ATen.h>
//#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/Exceptions.h>
...
@@ -44,52 +42,6 @@ cublasOperation_t convertTransToCublasOperation(char trans) {
...
@@ -44,52 +42,6 @@ cublasOperation_t convertTransToCublasOperation(char trans) {
}
}
}
}
// needed to work around calling rocblas API instead of hipblas API
static
rocblas_operation
hipOperationToRocOperation
(
hipblasOperation_t
op
)
{
switch
(
op
)
{
case
HIPBLAS_OP_N
:
return
rocblas_operation_none
;
case
HIPBLAS_OP_T
:
return
rocblas_operation_transpose
;
case
HIPBLAS_OP_C
:
return
rocblas_operation_conjugate_transpose
;
}
AT_ERROR
(
"HIPBLAS_STATUS_INVALID_ENUM"
);
}
static
hipblasStatus_t
rocBLASStatusToHIPStatus
(
rocblas_status
error
)
{
switch
(
error
)
{
case
rocblas_status_size_unchanged
:
case
rocblas_status_size_increased
:
case
rocblas_status_success
:
case
rocblas_status_continue
:
return
HIPBLAS_STATUS_SUCCESS
;
case
rocblas_status_invalid_handle
:
return
HIPBLAS_STATUS_NOT_INITIALIZED
;
case
rocblas_status_not_implemented
:
case
rocblas_status_excluded_from_build
:
return
HIPBLAS_STATUS_NOT_SUPPORTED
;
case
rocblas_status_invalid_pointer
:
case
rocblas_status_invalid_size
:
case
rocblas_status_invalid_value
:
case
rocblas_status_size_query_mismatch
:
return
HIPBLAS_STATUS_INVALID_VALUE
;
case
rocblas_status_memory_error
:
return
HIPBLAS_STATUS_ALLOC_FAILED
;
case
rocblas_status_internal_error
:
case
rocblas_status_perf_degraded
:
case
rocblas_status_check_numerics_fail
:
return
HIPBLAS_STATUS_INTERNAL_ERROR
;
case
rocblas_status_arch_mismatch
:
return
HIPBLAS_STATUS_ARCH_MISMATCH
;
}
AT_ERROR
(
"HIPBLAS_STATUS_INVALID_ENUM"
);
}
void
RocblasStridedBatchedGemm
(
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
void
RocblasStridedBatchedGemm
(
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
half
*
d
,
long
ldd
,
long
strideD
,
long
batchCount
,
rocblas_gemm_algo
algo
,
rocblas_int
flags
)
{
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
half
*
d
,
long
ldd
,
long
strideD
,
long
batchCount
,
rocblas_gemm_algo
algo
,
rocblas_int
flags
)
{
...
@@ -102,13 +54,13 @@ void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
...
@@ -102,13 +54,13 @@ void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
float
fAlpha
=
alpha
;
float
fAlpha
=
alpha
;
float
fBeta
=
beta
;
float
fBeta
=
beta
;
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_strided_batched_ex
(
(
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_strided_batched_ex
(
handle
,
hipOperationToRocOperation
(
opa
)
,
hipOperationToRocOperation
(
opb
)
,
(
int
)
m
,
(
int
)
n
,
(
int
)
k
,
opa
,
opb
,
(
int
)
m
,
(
int
)
n
,
(
int
)
k
,
(
void
*
)
&
fAlpha
,
a
,
rocblas_datatype_f16_r
/*a_type*/
,
(
int
)
lda
,
strideA
,
(
void
*
)
&
fAlpha
,
a
,
rocblas_datatype_f16_r
/*a_type*/
,
(
int
)
lda
,
strideA
,
b
,
rocblas_datatype_f16_r
/*b_type*/
,
(
int
)
ldb
,
strideB
,
b
,
rocblas_datatype_f16_r
/*b_type*/
,
(
int
)
ldb
,
strideB
,
(
void
*
)
&
fBeta
,
c
,
rocblas_datatype_f16_r
/*c_type*/
,
(
int
)
ldc
,
strideC
,
(
void
*
)
&
fBeta
,
c
,
rocblas_datatype_f16_r
/*c_type*/
,
(
int
)
ldc
,
strideC
,
d
,
rocblas_datatype_f16_r
/*d_type*/
,
int
(
ldd
),
strideD
,
d
,
rocblas_datatype_f16_r
/*d_type*/
,
int
(
ldd
),
strideD
,
(
int
)
batchCount
,
rocblas_datatype_f32_r
/*compute_type*/
,
algo
,
0
/*solution_index*/
,
flags
))
)
;
(
int
)
batchCount
,
rocblas_datatype_f32_r
/*compute_type*/
,
algo
,
0
/*solution_index*/
,
flags
));
}
}
void
gemm_switch_fp32accum
(
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
void
gemm_switch_fp32accum
(
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
...
...
csrc/fused_dense_cuda.cu
View file @
3ba7192d
...
@@ -10,21 +10,10 @@
...
@@ -10,21 +10,10 @@
#include <cublas_v2.h>
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <rocblas/rocblas.h>
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
// includes cublaslt
// includes cublaslt
#include <cublasLt.h>
#include <cublasLt.h>
#endif
#endif
// until we use hipblas v2
// hipify correctly maps things like CUDA_R_16F to HIP_R_16F,
// however hipblas v1 is still using its custom type
#define HIP_R_64F HIPBLAS_R_64F
#define HIP_R_32F HIPBLAS_R_32F
#define HIP_R_16F HIPBLAS_R_16F
// FP64 Wrapper around cublas GEMMEx
// FP64 Wrapper around cublas GEMMEx
cublasStatus_t
gemm_bias
(
cublasStatus_t
gemm_bias
(
cublasHandle_t
handle
,
cublasHandle_t
handle
,
...
@@ -41,6 +30,33 @@ cublasStatus_t gemm_bias(
...
@@ -41,6 +30,33 @@ cublasStatus_t gemm_bias(
const
float
*
beta
,
const
float
*
beta
,
double
*
C
,
double
*
C
,
int
ldc
)
{
int
ldc
)
{
#ifdef __HIP_PLATFORM_HCC__
return
rocblas_gemm_ex
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
rocblas_datatype_f64_r
,
lda
,
B
,
rocblas_datatype_f64_r
,
ldb
,
beta
,
C
,
rocblas_datatype_f64_r
,
ldc
,
C
,
rocblas_datatype_f64_r
,
ldc
,
rocblas_datatype_f64_r
,
rocblas_gemm_algo_standard
,
0
,
0
);
#else
return
cublasGemmEx
(
return
cublasGemmEx
(
handle
,
handle
,
transa
,
transa
,
...
@@ -61,6 +77,7 @@ cublasStatus_t gemm_bias(
...
@@ -61,6 +77,7 @@ cublasStatus_t gemm_bias(
ldc
,
ldc
,
CUDA_R_64F
,
CUDA_R_64F
,
CUBLAS_GEMM_DEFAULT
);
CUBLAS_GEMM_DEFAULT
);
#endif
}
}
// FP32 Wrapper around cublas GEMMEx
// FP32 Wrapper around cublas GEMMEx
...
@@ -79,6 +96,34 @@ cublasStatus_t gemm_bias(
...
@@ -79,6 +96,34 @@ cublasStatus_t gemm_bias(
const
float
*
beta
,
const
float
*
beta
,
float
*
C
,
float
*
C
,
int
ldc
)
{
int
ldc
)
{
#ifdef __HIP_PLATFORM_HCC__
return
rocblas_gemm_ex
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
rocblas_datatype_f32_r
,
lda
,
B
,
rocblas_datatype_f32_r
,
ldb
,
beta
,
C
,
rocblas_datatype_f32_r
,
ldc
,
C
,
rocblas_datatype_f32_r
,
ldc
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
,
0
,
0
);
#else
return
cublasGemmEx
(
return
cublasGemmEx
(
handle
,
handle
,
transa
,
transa
,
...
@@ -99,6 +144,7 @@ cublasStatus_t gemm_bias(
...
@@ -99,6 +144,7 @@ cublasStatus_t gemm_bias(
ldc
,
ldc
,
CUDA_R_32F
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT
);
CUBLAS_GEMM_DEFAULT
);
#endif
}
}
// FP16 Tensor core wrapper around cublas GEMMEx
// FP16 Tensor core wrapper around cublas GEMMEx
...
@@ -117,6 +163,33 @@ cublasStatus_t gemm_bias(
...
@@ -117,6 +163,33 @@ cublasStatus_t gemm_bias(
const
float
*
beta
,
const
float
*
beta
,
at
::
Half
*
C
,
at
::
Half
*
C
,
int
ldc
)
{
int
ldc
)
{
#ifdef __HIP_PLATFORM_HCC__
return
rocblas_gemm_ex
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
rocblas_datatype_f16_r
,
lda
,
B
,
rocblas_datatype_f16_r
,
ldb
,
beta
,
C
,
rocblas_datatype_f16_r
,
ldc
,
C
,
rocblas_datatype_f16_r
,
ldc
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
,
0
,
0
);
#else
return
cublasGemmEx
(
return
cublasGemmEx
(
handle
,
handle
,
transa
,
transa
,
...
@@ -137,6 +210,7 @@ cublasStatus_t gemm_bias(
...
@@ -137,6 +210,7 @@ cublasStatus_t gemm_bias(
ldc
,
ldc
,
CUDA_R_32F
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
#endif
}
}
...
...
csrc/mlp_cuda.cu
View file @
3ba7192d
...
@@ -12,8 +12,6 @@
...
@@ -12,8 +12,6 @@
#include <cublas_v2.h>
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <rocblas/rocblas.h>
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
// includes cublaslt
// includes cublaslt
#include <cublasLt.h>
#include <cublasLt.h>
...
@@ -60,52 +58,6 @@ __device__ __inline__ float sigmoid(float a) {
...
@@ -60,52 +58,6 @@ __device__ __inline__ float sigmoid(float a) {
return
(
retf
);
return
(
retf
);
}
}
// needed to work around calling rocblas API instead of hipblas API
static
rocblas_operation
hipOperationToRocOperation
(
hipblasOperation_t
op
)
{
switch
(
op
)
{
case
HIPBLAS_OP_N
:
return
rocblas_operation_none
;
case
HIPBLAS_OP_T
:
return
rocblas_operation_transpose
;
case
HIPBLAS_OP_C
:
return
rocblas_operation_conjugate_transpose
;
}
AT_ERROR
(
"HIPBLAS_STATUS_INVALID_ENUM"
);
}
static
hipblasStatus_t
rocBLASStatusToHIPStatus
(
rocblas_status
error
)
{
switch
(
error
)
{
case
rocblas_status_size_unchanged
:
case
rocblas_status_size_increased
:
case
rocblas_status_success
:
case
rocblas_status_continue
:
return
HIPBLAS_STATUS_SUCCESS
;
case
rocblas_status_invalid_handle
:
return
HIPBLAS_STATUS_NOT_INITIALIZED
;
case
rocblas_status_not_implemented
:
case
rocblas_status_excluded_from_build
:
return
HIPBLAS_STATUS_NOT_SUPPORTED
;
case
rocblas_status_invalid_pointer
:
case
rocblas_status_invalid_size
:
case
rocblas_status_invalid_value
:
case
rocblas_status_size_query_mismatch
:
return
HIPBLAS_STATUS_INVALID_VALUE
;
case
rocblas_status_memory_error
:
return
HIPBLAS_STATUS_ALLOC_FAILED
;
case
rocblas_status_internal_error
:
case
rocblas_status_perf_degraded
:
case
rocblas_status_check_numerics_fail
:
return
HIPBLAS_STATUS_INTERNAL_ERROR
;
case
rocblas_status_arch_mismatch
:
return
HIPBLAS_STATUS_ARCH_MISMATCH
;
}
AT_ERROR
(
"HIPBLAS_STATUS_INVALID_ENUM"
);
}
// FP64 Wrapper around cublas GEMMEx
// FP64 Wrapper around cublas GEMMEx
cublasStatus_t
mlp_gemm
(
cublasStatus_t
mlp_gemm
(
cublasHandle_t
handle
,
cublasHandle_t
handle
,
...
@@ -124,10 +76,10 @@ cublasStatus_t mlp_gemm(
...
@@ -124,10 +76,10 @@ cublasStatus_t mlp_gemm(
int
ldc
,
int
ldc
,
int
flag
)
{
int
flag
)
{
#ifdef __HIP_PLATFORM_HCC__
#ifdef __HIP_PLATFORM_HCC__
return
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
(
return
rocblas_gemm_ex
(
(
rocblas_handle
)
handle
,
handle
,
hipOperationToRocOperation
(
transa
)
,
transa
,
hipOperationToRocOperation
(
transb
)
,
transb
,
m
,
m
,
n
,
n
,
k
,
k
,
...
@@ -148,7 +100,7 @@ cublasStatus_t mlp_gemm(
...
@@ -148,7 +100,7 @@ cublasStatus_t mlp_gemm(
rocblas_datatype_f64_r
,
rocblas_datatype_f64_r
,
rocblas_gemm_algo_standard
,
rocblas_gemm_algo_standard
,
0
,
0
,
flag
)
)
;
flag
);
#else
#else
return
cublasGemmEx
(
return
cublasGemmEx
(
handle
,
handle
,
...
@@ -191,10 +143,10 @@ cublasStatus_t mlp_gemm(
...
@@ -191,10 +143,10 @@ cublasStatus_t mlp_gemm(
int
ldc
,
int
ldc
,
int
flag
)
{
int
flag
)
{
#ifdef __HIP_PLATFORM_HCC__
#ifdef __HIP_PLATFORM_HCC__
return
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
(
return
rocblas_gemm_ex
(
(
rocblas_handle
)
handle
,
handle
,
hipOperationToRocOperation
(
transa
)
,
transa
,
hipOperationToRocOperation
(
transb
)
,
transb
,
m
,
m
,
n
,
n
,
k
,
k
,
...
@@ -215,7 +167,7 @@ cublasStatus_t mlp_gemm(
...
@@ -215,7 +167,7 @@ cublasStatus_t mlp_gemm(
rocblas_datatype_f32_r
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
,
rocblas_gemm_algo_standard
,
0
,
0
,
flag
)
)
;
flag
);
#else
#else
return
cublasGemmEx
(
return
cublasGemmEx
(
...
@@ -259,10 +211,10 @@ cublasStatus_t mlp_gemm(
...
@@ -259,10 +211,10 @@ cublasStatus_t mlp_gemm(
int
ldc
,
int
ldc
,
int
flag
)
{
int
flag
)
{
#ifdef __HIP_PLATFORM_HCC__
#ifdef __HIP_PLATFORM_HCC__
return
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
(
return
rocblas_gemm_ex
(
(
rocblas_handle
)
handle
,
handle
,
hipOperationToRocOperation
(
transa
)
,
transa
,
hipOperationToRocOperation
(
transb
)
,
transb
,
m
,
m
,
n
,
n
,
k
,
k
,
...
@@ -283,7 +235,7 @@ cublasStatus_t mlp_gemm(
...
@@ -283,7 +235,7 @@ cublasStatus_t mlp_gemm(
rocblas_datatype_f32_r
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
,
rocblas_gemm_algo_standard
,
0
,
0
,
flag
)
)
;
flag
);
#else
#else
return
cublasGemmEx
(
return
cublasGemmEx
(
handle
,
handle
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment