Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
b5d7745d
Commit
b5d7745d
authored
Sep 18, 2023
by
flyingdown
Browse files
merge mirror master
parents
03204b84
3ba7192d
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1900 additions
and
3593 deletions
+1900
-3593
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
...contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
+319
-574
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
...src/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
+325
-585
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
...ihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
+312
-636
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
...trib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
+307
-625
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
+200
-400
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
.../csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
+291
-597
apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh
apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh
+3
-51
csrc/fused_dense_cuda.cu
csrc/fused_dense_cuda.cu
+108
-12
csrc/mlp_cuda.cu
csrc/mlp_cuda.cu
+35
-113
No files found.
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
View file @
b5d7745d
...
@@ -94,155 +94,79 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -94,155 +94,79 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Q Fwd
// Input Linear Q Fwd
if
(
use_fp16
)
{
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
output_lin_q_dim
,
output_lin_q_dim
,
batches_q
,
batches_q
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
q_lin_results_ptr
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
output_lin_q_dim
,
output_lin_q_dim
,
q_lin_results_ptr
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
output_lin_q_dim
,
output_lin_q_dim
,
rocblas_datatype_f32_r
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
// Input Linear KV Fwd
// Input Linear KV Fwd
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
output_lin_kv_dim
,
output_lin_kv_dim
,
batches_kv
,
batches_kv
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
k_lin_results_ptr
,
k_lin_results_ptr
,
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
output_lin_kv_dim
,
output_lin_kv_dim
,
k_lin_results_ptr
,
k_lin_results_ptr
,
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
output_lin_kv_dim
,
output_lin_kv_dim
,
rocblas_datatype_f32_r
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
b_layout_n
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
head_dim
,
head_dim
,
scale
,
h_scale
,
static_cast
<
const
half
*>
(
k_lin_results_ptr
),
static_cast
<
const
half
*>
(
k_lin_results_ptr
),
lead_dim_kv
,
lead_dim_kv
,
batch_stride_kv
,
batch_stride_kv
,
static_cast
<
const
half
*>
(
q_lin_results_ptr
),
static_cast
<
const
half
*>
(
q_lin_results_ptr
),
lead_dim_q
,
lead_dim_q
,
batch_stride_q
,
batch_stride_q
,
beta
,
h_beta
,
static_cast
<
half
*>
(
softmax_results_ptr
),
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
softmax_results_ptr
),
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
attn_batches
,
flags
);
flags
);
}
else
{
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
hipOperationToRocOperation
(
CUBLAS_OP_N
),
output_lin_q_dim
,
batches_q
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
output_lin_q_dim
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
output_lin_q_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
// Input Linear KV Fwd
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
hipOperationToRocOperation
(
CUBLAS_OP_N
),
output_lin_kv_dim
,
batches_kv
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
k_lin_results_ptr
,
rocblas_datatype_f16_r
,
output_lin_kv_dim
,
k_lin_results_ptr
,
rocblas_datatype_f16_r
,
output_lin_kv_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
scale
,
static_cast
<
const
half
*>
(
k_lin_results_ptr
),
lead_dim_kv
,
batch_stride_kv
,
static_cast
<
const
half
*>
(
q_lin_results_ptr
),
lead_dim_q
,
batch_stride_q
,
beta
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
flags
);
}
// Padded Softmax
// Padded Softmax
bool
softmax_success
=
false
;
bool
softmax_success
=
false
;
...
@@ -276,104 +200,53 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -276,104 +200,53 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
}
}
// Matmul2
// Matmul2
if
(
use_fp16
)
{
gemm_switch_fp32accum
(
a_layout_n
,
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
b_layout_n
,
head_dim
,
head_dim
,
q_seq_len
,
q_seq_len
,
k_seq_len
,
k_seq_len
,
alpha
,
h_alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim_kv
,
lead_dim_kv
,
batch_stride_kv
,
batch_stride_kv
,
(
is_training
)
?
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
())
:
static_cast
<
const
half
*>
(
softmax_results
.
data_ptr
())
,
(
is_training
)
?
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
())
:
static_cast
<
const
half
*>
(
softmax_results
.
data_ptr
())
,
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
h_beta
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
*
attn_batches
,
head_dim
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
*
attn_batches
,
head_dim
,
head_dim
,
attn_batches
,
attn_batches
,
flags
);
flags
);
// Output Linear
// Output Linear
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
embed_dim
,
batches_q
,
batches_q
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
}
else
{
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
head_dim
,
q_seq_len
,
k_seq_len
,
alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim_kv
,
batch_stride_kv
,
(
is_training
)
?
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
())
:
static_cast
<
const
half
*>
(
softmax_results
.
data_ptr
())
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
attn_batches
,
flags
);
// Output Linear
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
batches_q
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f32_r
,
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
embed_dim
,
0
/*solution_index*/
,
rocblas_datatype_f32_r
,
flags
));
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
}
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_lin_q_results
,
return
{
input_lin_q_results
,
...
@@ -465,32 +338,57 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -465,32 +338,57 @@ std::vector<torch::Tensor> bwd_cuda(
#endif
#endif
#endif
#endif
if
(
use_fp16
)
{
// Output Linear Dgrad
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
embed_dim
,
batches_q
,
batches_q
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// Output Linear Wgrad
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
...
@@ -680,308 +578,155 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -680,308 +578,155 @@ std::vector<torch::Tensor> bwd_cuda(
k_seq_len
,
attn_batches
*
q_seq_len
);
k_seq_len
,
attn_batches
*
q_seq_len
);
assert
(
softmax_success
);
assert
(
softmax_success
);
if
(
use_fp16
)
{
// Matmul1 Dgrad1
// Matmul1 Dgrad1
gemm_switch_fp32accum
(
a_layout_n
,
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
b_layout_n
,
head_dim
,
head_dim
,
q_seq_len
,
q_seq_len
,
k_seq_len
,
k_seq_len
,
scale
,
h_scale
,
k_lin_results_ptr
,
k_lin_results_ptr
,
lead_dim_kv
,
lead_dim_kv
,
batch_stride_kv
,
batch_stride_kv
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
h_beta
,
q_lin_grads_ptr
,
q_lin_grads_ptr
,
lead_dim_q
,
lead_dim_q
,
batch_stride_q
,
batch_stride_q
,
q_lin_grads_ptr
,
q_lin_grads_ptr
,
lead_dim_q
,
lead_dim_q
,
batch_stride_q
,
batch_stride_q
,
attn_batches
,
attn_batches
,
flags
);
flags
);
// Matmul1 Dgrad2
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
b_layout_t
,
head_dim
,
head_dim
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
scale
,
h_scale
,
q_lin_results_ptr
,
q_lin_results_ptr
,
lead_dim_q
,
lead_dim_q
,
batch_stride_q
,
batch_stride_q
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
h_beta
,
k_lin_grads_ptr
,
k_lin_grads_ptr
,
lead_dim_kv
,
lead_dim_kv
,
batch_stride_kv
,
batch_stride_kv
,
k_lin_grads_ptr
,
k_lin_grads_ptr
,
lead_dim_kv
,
lead_dim_kv
,
batch_stride_kv
,
batch_stride_kv
,
attn_batches
,
attn_batches
,
flags
);
flags
);
// Input Linear Q Dgrad
// Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
embed_dim
,
batches_q
,
batches_q
,
output_lin_q_dim
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
output_lin_q_dim
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
static_cast
<
void
*>
(
input_q_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
input_q_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
// Input Linear Q Wgrad
// Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
embed_dim
,
embed_dim
,
output_lin_q_dim
,
output_lin_q_dim
,
batches_q
,
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
output_lin_q_dim
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
// Input Linear KV Dgrad
// Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
embed_dim
,
batches_kv
,
batches_kv
,
output_lin_kv_dim
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
output_lin_kv_dim
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
// Input Linear KV Wgrad
// Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
embed_dim
,
embed_dim
,
output_lin_kv_dim
,
output_lin_kv_dim
,
batches_kv
,
batches_kv
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
output_lin_kv_dim
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
}
else
{
// Matmul1 Dgrad1
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
head_dim
,
q_seq_len
,
k_seq_len
,
scale
,
k_lin_results_ptr
,
lead_dim_kv
,
batch_stride_kv
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
q_lin_grads_ptr
,
lead_dim_q
,
batch_stride_q
,
q_lin_grads_ptr
,
lead_dim_q
,
batch_stride_q
,
attn_batches
,
flags
);
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
scale
,
q_lin_results_ptr
,
lead_dim_q
,
batch_stride_q
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
k_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
k_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
attn_batches
,
flags
);
// Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
batches_q
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
// Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
hipOperationToRocOperation
(
CUBLAS_OP_T
),
embed_dim
,
output_lin_q_dim
,
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
// Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
batches_kv
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
rocblas_datatype_f16_r
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
// Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
hipOperationToRocOperation
(
CUBLAS_OP_T
),
embed_dim
,
output_lin_kv_dim
,
batches_kv
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
rocblas_datatype_f16_r
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
}
// TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
// TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
return
{
input_q_grads
,
input_q_grads
,
...
...
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
View file @
b5d7745d
...
@@ -119,158 +119,80 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -119,158 +119,80 @@ std::vector<torch::Tensor> fwd_cuda(
1.0e-5
,
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_gamma_weights
.
data_ptr
()),
1.0e-5
,
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_gamma_weights
.
data_ptr
()),
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_beta_weights
.
data_ptr
()));
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_beta_weights
.
data_ptr
()));
if
(
use_fp16
)
{
// Input Linear Q Fwd
// Input Linear Q Fwd
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
output_lin_q_dim
,
output_lin_q_dim
,
batches_q
,
batches_q
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
embed_dim
,
//static_cast<const void*>(inputs_q.data_ptr()),
//static_cast<const void*>(inputs_q.data_ptr()),
static_cast
<
const
void
*>
(
lyr_nrm_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
lyr_nrm_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
rocblas_datatype_f16_r
/*b_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
q_lin_results_ptr
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
/*c_type*/
,
rocblas_datatype_f16_r
/*c_type*/
,
output_lin_q_dim
,
output_lin_q_dim
,
q_lin_results_ptr
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
/*d_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
output_lin_q_dim
,
output_lin_q_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
// Input Linear KV Fwd
// Input Linear KV Fwd
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
output_lin_kv_dim
,
output_lin_kv_dim
,
batches_kv
,
batches_kv
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
rocblas_datatype_f16_r
/*b_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
k_lin_results_ptr
,
k_lin_results_ptr
,
rocblas_datatype_f16_r
/*c_type*/
,
rocblas_datatype_f16_r
/*c_type*/
,
output_lin_kv_dim
,
output_lin_kv_dim
,
k_lin_results_ptr
,
k_lin_results_ptr
,
rocblas_datatype_f16_r
/*d_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
output_lin_kv_dim
,
output_lin_kv_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
b_layout_n
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
head_dim
,
head_dim
,
scale
,
h_scale
,
static_cast
<
const
half
*>
(
k_lin_results_ptr
),
static_cast
<
const
half
*>
(
k_lin_results_ptr
),
lead_dim_kv
,
lead_dim_kv
,
batch_stride_kv
,
batch_stride_kv
,
static_cast
<
const
half
*>
(
q_lin_results_ptr
),
static_cast
<
const
half
*>
(
q_lin_results_ptr
),
lead_dim_q
,
lead_dim_q
,
batch_stride_q
,
batch_stride_q
,
beta
,
h_beta
,
static_cast
<
half
*>
(
softmax_results_ptr
),
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
softmax_results_ptr
),
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
attn_batches
,
flags
);
flags
);
}
else
{
// Input Linear Q Fwd
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
hipOperationToRocOperation
(
CUBLAS_OP_N
),
output_lin_q_dim
,
batches_q
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
//static_cast<const void*>(inputs_q.data_ptr()),
static_cast
<
const
void
*>
(
lyr_nrm_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
q_lin_results_ptr
,
rocblas_datatype_f16_r
/*c_type*/
,
output_lin_q_dim
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
/*d_type*/
,
output_lin_q_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
// Input Linear KV Fwd
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
hipOperationToRocOperation
(
CUBLAS_OP_N
),
output_lin_kv_dim
,
batches_kv
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
k_lin_results_ptr
,
rocblas_datatype_f16_r
/*c_type*/
,
output_lin_kv_dim
,
k_lin_results_ptr
,
rocblas_datatype_f16_r
/*d_type*/
,
output_lin_kv_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
scale
,
static_cast
<
const
half
*>
(
k_lin_results_ptr
),
lead_dim_kv
,
batch_stride_kv
,
static_cast
<
const
half
*>
(
q_lin_results_ptr
),
lead_dim_q
,
batch_stride_q
,
beta
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
flags
);
}
// Padded Softmax
// Padded Softmax
bool
softmax_success
=
false
;
bool
softmax_success
=
false
;
...
@@ -303,108 +225,55 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -303,108 +225,55 @@ std::vector<torch::Tensor> fwd_cuda(
(
1.0
f
-
dropout_prob
));
(
1.0
f
-
dropout_prob
));
}
}
if
(
use_fp16
)
{
// Matmul2
// Matmul2
gemm_switch_fp32accum
(
a_layout_n
,
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
b_layout_n
,
head_dim
,
head_dim
,
q_seq_len
,
q_seq_len
,
k_seq_len
,
k_seq_len
,
alpha
,
h_alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim_kv
,
lead_dim_kv
,
batch_stride_kv
,
batch_stride_kv
,
(
is_training
)
?
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
())
:
static_cast
<
const
half
*>
(
softmax_results
.
data_ptr
()),
(
is_training
)
?
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
())
:
static_cast
<
const
half
*>
(
softmax_results
.
data_ptr
()),
//static_cast<const half*>(dropout_results.data_ptr()),
//static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
h_beta
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
*
attn_batches
,
head_dim
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
*
attn_batches
,
head_dim
,
head_dim
,
attn_batches
,
attn_batches
,
flags
);
flags
);
// Output Linear
// Output Linear
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
embed_dim
,
batches_q
,
batches_q
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
rocblas_datatype_f16_r
/*b_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
}
else
{
// Matmul2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
head_dim
,
q_seq_len
,
k_seq_len
,
alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim_kv
,
batch_stride_kv
,
(
is_training
)
?
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
())
:
static_cast
<
const
half
*>
(
softmax_results
.
data_ptr
()),
//static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
attn_batches
,
flags
);
// Output Linear
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
batches_q
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
}
// End-of-block Dropout-Add
// End-of-block Dropout-Add
if
(
is_training
)
{
if
(
is_training
)
{
...
@@ -533,32 +402,57 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -533,32 +402,57 @@ std::vector<torch::Tensor> bwd_cuda(
total_tokens_q
,
total_tokens_q
,
(
1.0
/
(
1.0
-
dropout_prob
)));
(
1.0
/
(
1.0
-
dropout_prob
)));
if
(
use_fp16
)
{
// Output Linear Dgrad
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
embed_dim
,
batches_q
,
batches_q
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
rocblas_datatype_f16_r
/*b_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// Output Linear Wgrad
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
...
@@ -749,310 +643,156 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -749,310 +643,156 @@ std::vector<torch::Tensor> bwd_cuda(
k_seq_len
,
attn_batches
*
q_seq_len
);
k_seq_len
,
attn_batches
*
q_seq_len
);
assert
(
softmax_success
);
assert
(
softmax_success
);
if
(
use_fp16
)
{
// Matmul1 Dgrad1
// Matmul1 Dgrad1
gemm_switch_fp32accum
(
a_layout_n
,
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
b_layout_n
,
head_dim
,
head_dim
,
q_seq_len
,
q_seq_len
,
k_seq_len
,
k_seq_len
,
scale
,
h_scale
,
k_lin_results_ptr
,
k_lin_results_ptr
,
lead_dim_kv
,
lead_dim_kv
,
batch_stride_kv
,
batch_stride_kv
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
h_beta
,
q_lin_grads_ptr
,
q_lin_grads_ptr
,
lead_dim_q
,
lead_dim_q
,
batch_stride_q
,
batch_stride_q
,
q_lin_grads_ptr
,
q_lin_grads_ptr
,
lead_dim_q
,
lead_dim_q
,
batch_stride_q
,
batch_stride_q
,
attn_batches
,
attn_batches
,
flags
);
flags
);
// Matmul1 Dgrad2
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
b_layout_t
,
head_dim
,
head_dim
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
scale
,
h_scale
,
q_lin_results_ptr
,
q_lin_results_ptr
,
lead_dim_q
,
lead_dim_q
,
batch_stride_q
,
batch_stride_q
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
h_beta
,
k_lin_grads_ptr
,
k_lin_grads_ptr
,
lead_dim_kv
,
lead_dim_kv
,
batch_stride_kv
,
batch_stride_kv
,
k_lin_grads_ptr
,
k_lin_grads_ptr
,
lead_dim_kv
,
lead_dim_kv
,
batch_stride_kv
,
batch_stride_kv
,
attn_batches
,
attn_batches
,
flags
);
flags
);
// Input Linear Q Dgrad
// Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
embed_dim
,
batches_q
,
batches_q
,
output_lin_q_dim
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
/*b_type*/
,
rocblas_datatype_f16_r
/*b_type*/
,
output_lin_q_dim
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
//static_cast<void*>(input_q_grads.data_ptr()),
//static_cast<void*>(input_q_grads.data_ptr()),
static_cast
<
void
*>
(
input_lin_q_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_lin_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
input_lin_q_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_lin_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
// Input Linear Q Wgrad
// Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
embed_dim
,
embed_dim
,
output_lin_q_dim
,
output_lin_q_dim
,
batches_q
,
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
/*b_type*/
,
rocblas_datatype_f16_r
/*b_type*/
,
output_lin_q_dim
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
// Input Linear KV Dgrad
// Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
embed_dim
,
batches_kv
,
batches_kv
,
output_lin_kv_dim
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
rocblas_datatype_f16_r
/*b_type*/
,
rocblas_datatype_f16_r
/*b_type*/
,
output_lin_kv_dim
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
// Input Linear KV Wgrad
// Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
embed_dim
,
embed_dim
,
output_lin_kv_dim
,
output_lin_kv_dim
,
batches_kv
,
batches_kv
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
rocblas_datatype_f16_r
/*b_type*/
,
rocblas_datatype_f16_r
/*b_type*/
,
output_lin_kv_dim
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
}
else
{
// Matmul1 Dgrad1
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
head_dim
,
q_seq_len
,
k_seq_len
,
scale
,
k_lin_results_ptr
,
lead_dim_kv
,
batch_stride_kv
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
q_lin_grads_ptr
,
lead_dim_q
,
batch_stride_q
,
q_lin_grads_ptr
,
lead_dim_q
,
batch_stride_q
,
attn_batches
,
flags
);
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
scale
,
q_lin_results_ptr
,
lead_dim_q
,
batch_stride_q
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
k_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
k_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
attn_batches
,
flags
);
// Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
batches_q
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
/*b_type*/
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
//static_cast<void*>(input_q_grads.data_ptr()),
static_cast
<
void
*>
(
input_lin_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
input_lin_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
// Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
hipOperationToRocOperation
(
CUBLAS_OP_T
),
embed_dim
,
output_lin_q_dim
,
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
/*b_type*/
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
// Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
batches_kv
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
rocblas_datatype_f16_r
/*b_type*/
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
// Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
hipOperationToRocOperation
(
CUBLAS_OP_T
),
embed_dim
,
output_lin_kv_dim
,
batches_kv
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
rocblas_datatype_f16_r
/*b_type*/
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
}
// Fused Layer Norm Bwd with Residual Add
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient
<
half
,
float
>
(
HostLayerNormGradient
<
half
,
float
>
(
...
@@ -1080,4 +820,4 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -1080,4 +820,4 @@ std::vector<torch::Tensor> bwd_cuda(
}
// end namespace rocblas_gemmex
}
// end namespace rocblas_gemmex
}
// end namespace encdec_norm_add
}
// end namespace encdec_norm_add
}
// end namespace multihead_attn
}
// end namespace multihead_attn
\ No newline at end of file
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
View file @
b5d7745d
...
@@ -90,104 +90,53 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -90,104 +90,53 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
// Input Linear Fwd
// Input Linear Fwd
input_lin_results
.
copy_
(
input_biases
);
input_lin_results
.
copy_
(
input_biases
);
if
(
use_fp16
)
{
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
output_lin_dim
,
output_lin_dim
,
batches
,
batches
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
static_cast
<
const
void
*>
(
&
h_beta_one
),
q_lin_results_ptr
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
output_lin_dim
,
output_lin_dim
,
q_lin_results_ptr
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
output_lin_dim
,
output_lin_dim
,
rocblas_datatype_f32_r
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
b_layout_n
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
head_dim
,
head_dim
,
scale
,
h_scale
,
static_cast
<
const
half
*>
(
k_lin_results_ptr
),
static_cast
<
const
half
*>
(
k_lin_results_ptr
),
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
static_cast
<
const
half
*>
(
q_lin_results_ptr
),
static_cast
<
const
half
*>
(
q_lin_results_ptr
),
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
beta_zero
,
h_beta_zero
,
static_cast
<
half
*>
(
bmm1_results_ptr
),
static_cast
<
half
*>
(
bmm1_results_ptr
),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
bmm1_results_ptr
),
static_cast
<
half
*>
(
bmm1_results_ptr
),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
attn_batches
,
flags
);
flags
);
}
else
{
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
hipOperationToRocOperation
(
CUBLAS_OP_N
),
output_lin_dim
,
batches
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
output_lin_dim
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
output_lin_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
scale
,
static_cast
<
const
half
*>
(
k_lin_results_ptr
),
lead_dim
,
batch_stride
,
static_cast
<
const
half
*>
(
q_lin_results_ptr
),
lead_dim
,
batch_stride
,
beta_zero
,
static_cast
<
half
*>
(
bmm1_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
bmm1_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
flags
);
}
// Padded Softmax
// Padded Softmax
bool
softmax_success
=
false
;
bool
softmax_success
=
false
;
...
@@ -213,108 +162,55 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -213,108 +162,55 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
}
}
// Matmul2
// Matmul2
if
(
use_fp16
)
{
gemm_switch_fp32accum
(
a_layout_n
,
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
b_layout_n
,
head_dim
,
head_dim
,
q_seq_len
,
q_seq_len
,
k_seq_len
,
k_seq_len
,
alpha
,
h_alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
()),
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
beta_zero
,
h_beta_zero
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
*
attn_batches
,
head_dim
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
*
attn_batches
,
head_dim
,
head_dim
,
attn_batches
,
attn_batches
,
flags
);
flags
);
outputs
.
copy_
(
output_biases
);
outputs
.
copy_
(
output_biases
);
// Output Linear
// Output Linear
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
embed_dim
,
batches
,
batches
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
static_cast
<
const
void
*>
(
&
h_beta_one
),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
}
else
{
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
head_dim
,
q_seq_len
,
k_seq_len
,
alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim
,
batch_stride
,
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta_zero
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
attn_batches
,
flags
);
outputs
.
copy_
(
output_biases
);
// Output Linear
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
batches
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
}
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_lin_results
,
bmm1_results
,
dropout_results
,
return
{
input_lin_results
,
bmm1_results
,
dropout_results
,
...
@@ -392,442 +288,222 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -392,442 +288,222 @@ std::vector<torch::Tensor> bwd_cuda(
#endif
#endif
// Output Linear Dgrad
// Output Linear Dgrad
if
(
use_fp16
)
{
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
embed_dim
,
batches
,
batches
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
// Output Linear Wgrad
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
embed_dim
,
embed_dim
,
embed_dim
,
embed_dim
,
batches
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
// MatMul2 Dgrad1
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
b_layout_n
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
head_dim
,
head_dim
,
alpha
,
h_alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
*
attn_batches
,
head_dim
,
head_dim
,
beta
,
h_beta
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
attn_batches
,
flags
);
flags
);
// Matmul2 Dgrad2
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
b_layout_t
,
head_dim
,
head_dim
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
alpha
,
h_alpha
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
*
attn_batches
,
head_dim
,
head_dim
,
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
()),
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
h_beta
,
v_lin_grads_ptr
,
v_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
v_lin_grads_ptr
,
v_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
attn_batches
,
attn_batches
,
flags
);
flags
);
// Apply Dropout Mask and Scale by Dropout Probability
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
// Softmax Grad
dispatch_masked_scale_softmax_backward_recompute
<
half
,
half
,
float
,
false
>
(
dispatch_masked_scale_softmax_backward_recompute
<
half
,
half
,
float
,
false
>
(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*
const
>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*
const
>
(
matmul2_grads
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
bmm1_results
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
bmm1_results
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
pad_mask
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
pad_mask
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
dropout_mask
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
dropout_mask
.
data_ptr
()),
1.0
/
(
1.0
-
dropout_prob
),
1.0
/
(
1.0
-
dropout_prob
),
k_seq_len
,
k_seq_len
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
/
sequences
,
attn_batches
*
q_seq_len
/
sequences
,
attn_batches
*
q_seq_len
,
attn_batches
*
q_seq_len
,
stream
);
stream
);
// Matmul1 Dgrad1
// Matmul1 Dgrad1
gemm_switch_fp32accum
(
a_layout_n
,
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
b_layout_n
,
head_dim
,
head_dim
,
q_seq_len
,
q_seq_len
,
k_seq_len
,
k_seq_len
,
scale
,
h_scale
,
k_lin_results_ptr
,
k_lin_results_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
h_beta
,
q_lin_grads_ptr
,
q_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
q_lin_grads_ptr
,
q_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
attn_batches
,
attn_batches
,
flags
);
flags
);
// Matmul1 Dgrad2
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
b_layout_t
,
head_dim
,
head_dim
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
scale
,
h_scale
,
q_lin_results_ptr
,
q_lin_results_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
h_beta
,
k_lin_grads_ptr
,
k_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
k_lin_grads_ptr
,
k_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
attn_batches
,
attn_batches
,
flags
);
flags
);
// Input Linear Dgrad
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
embed_dim
,
batches
,
batches
,
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
input_lin_output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_lin_output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
// Input Linear Wgrad
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
embed_dim
,
embed_dim
,
output_lin_dim
,
output_lin_dim
,
batches
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_grads
,
input_weight_grads
,
output_weight_grads
,
return
{
input_grads
,
input_weight_grads
,
output_weight_grads
,
input_bias_grads
,
output_bias_grads
};
input_bias_grads
,
output_bias_grads
};
}
else
{
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
batches
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
hipOperationToRocOperation
(
CUBLAS_OP_T
),
embed_dim
,
embed_dim
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim
,
batch_stride
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
beta
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
flags
);
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
alpha
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
,
flags
);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward_recompute
<
half
,
half
,
float
,
false
>
(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*
const
>
(
matmul2_grads
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
bmm1_results
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
pad_mask
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
dropout_mask
.
data_ptr
()),
1.0
/
(
1.0
-
dropout_prob
),
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
/
sequences
,
attn_batches
*
q_seq_len
,
stream
);
// Matmul1 Dgrad1
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
head_dim
,
q_seq_len
,
k_seq_len
,
scale
,
k_lin_results_ptr
,
lead_dim
,
batch_stride
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
,
flags
);
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
scale
,
q_lin_results_ptr
,
lead_dim
,
batch_stride
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
,
flags
);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
batches
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
input_lin_output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
hipOperationToRocOperation
(
CUBLAS_OP_T
),
embed_dim
,
output_lin_dim
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_grads
,
input_weight_grads
,
output_weight_grads
,
input_bias_grads
,
output_bias_grads
};
}
}
}
}
// end namespace rocblas_gemmex
}
// end namespace rocblas_gemmex
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
View file @
b5d7745d
...
@@ -88,104 +88,53 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
...
@@ -88,104 +88,53 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
// Input Linear Fwd
// Input Linear Fwd
input_lin_results
.
copy_
(
input_biases
);
input_lin_results
.
copy_
(
input_biases
);
if
(
use_fp16
)
{
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
output_lin_dim
,
output_lin_dim
,
batches
,
batches
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
static_cast
<
const
void
*>
(
&
h_beta_one
),
q_lin_results_ptr
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
output_lin_dim
,
output_lin_dim
,
q_lin_results_ptr
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
output_lin_dim
,
output_lin_dim
,
rocblas_datatype_f32_r
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
b_layout_n
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
head_dim
,
head_dim
,
scale
,
h_scale
,
static_cast
<
const
half
*>
(
k_lin_results_ptr
),
static_cast
<
const
half
*>
(
k_lin_results_ptr
),
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
static_cast
<
const
half
*>
(
q_lin_results_ptr
),
static_cast
<
const
half
*>
(
q_lin_results_ptr
),
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
beta_zero
,
h_beta_zero
,
static_cast
<
half
*>
(
softmax_results_ptr
),
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
softmax_results_ptr
),
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
attn_batches
,
flags
);
flags
);
}
else
{
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
hipOperationToRocOperation
(
CUBLAS_OP_N
),
output_lin_dim
,
batches
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
output_lin_dim
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
output_lin_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
scale
,
static_cast
<
const
half
*>
(
k_lin_results_ptr
),
lead_dim
,
batch_stride
,
static_cast
<
const
half
*>
(
q_lin_results_ptr
),
lead_dim
,
batch_stride
,
beta_zero
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
flags
);
}
// Padded Softmax
// Padded Softmax
bool
softmax_success
=
false
;
bool
softmax_success
=
false
;
...
@@ -219,108 +168,55 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
...
@@ -219,108 +168,55 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
}
}
// Matmul2
// Matmul2
if
(
use_fp16
)
{
gemm_switch_fp32accum
(
a_layout_n
,
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
b_layout_n
,
head_dim
,
head_dim
,
q_seq_len
,
q_seq_len
,
k_seq_len
,
k_seq_len
,
alpha
,
h_alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
(
is_training
)
?
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
())
:
static_cast
<
const
half
*>
(
softmax_results
.
data_ptr
())
,
(
is_training
)
?
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
())
:
static_cast
<
const
half
*>
(
softmax_results
.
data_ptr
())
,
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
beta_zero
,
h_beta_zero
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
*
attn_batches
,
head_dim
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
*
attn_batches
,
head_dim
,
head_dim
,
attn_batches
,
attn_batches
,
flags
);
flags
);
outputs
.
copy_
(
output_biases
);
outputs
.
copy_
(
output_biases
);
// Output Linear
// Output Linear
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
embed_dim
,
batches
,
batches
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
static_cast
<
const
void
*>
(
&
h_beta_one
),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
}
else
{
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
head_dim
,
q_seq_len
,
k_seq_len
,
alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim
,
batch_stride
,
(
is_training
)
?
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
())
:
static_cast
<
const
half
*>
(
softmax_results
.
data_ptr
())
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta_zero
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
attn_batches
,
flags
);
outputs
.
copy_
(
output_biases
);
// Output Linear
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
batches
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
}
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_lin_results
,
softmax_results
,
dropout_results
,
return
{
input_lin_results
,
softmax_results
,
dropout_results
,
...
@@ -398,432 +294,218 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -398,432 +294,218 @@ std::vector<torch::Tensor> bwd_cuda(
#endif
#endif
// Output Linear Dgrad
// Output Linear Dgrad
if
(
use_fp16
)
{
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
embed_dim
,
batches
,
batches
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
// Output Linear Wgrad
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
embed_dim
,
embed_dim
,
embed_dim
,
embed_dim
,
batches
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
// MatMul2 Dgrad1
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
b_layout_n
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
head_dim
,
head_dim
,
alpha
,
h_alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
*
attn_batches
,
head_dim
,
head_dim
,
beta
,
h_beta
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
attn_batches
,
flags
);
flags
);
// Matmul2 Dgrad2
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
b_layout_t
,
head_dim
,
head_dim
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
alpha
,
h_alpha
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
*
attn_batches
,
head_dim
,
head_dim
,
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
()),
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
h_beta
,
v_lin_grads_ptr
,
v_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
v_lin_grads_ptr
,
v_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
attn_batches
,
attn_batches
,
flags
);
flags
);
// Apply Dropout Mask and Scale by Dropout Probability
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
// Softmax Grad
dispatch_masked_scale_softmax_backward_stream
<
half
,
half
,
float
,
false
>
(
dispatch_masked_scale_softmax_backward_stream
<
half
,
half
,
float
,
false
>
(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
dropout_mask
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
dropout_mask
.
data_ptr
()),
1.0
/
(
1.0
-
dropout_prob
),
k_seq_len
,
k_seq_len
,
1.0
/
(
1.0
-
dropout_prob
),
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
stream
);
attn_batches
*
q_seq_len
,
stream
);
// Matmul1 Dgrad1
// Matmul1 Dgrad1
gemm_switch_fp32accum
(
a_layout_n
,
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
b_layout_n
,
head_dim
,
head_dim
,
q_seq_len
,
q_seq_len
,
k_seq_len
,
k_seq_len
,
scale
,
h_scale
,
k_lin_results_ptr
,
k_lin_results_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
h_beta
,
q_lin_grads_ptr
,
q_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
q_lin_grads_ptr
,
q_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
attn_batches
,
attn_batches
,
flags
);
flags
);
// Matmul1 Dgrad2
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
b_layout_t
,
head_dim
,
head_dim
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
scale
,
h_scale
,
q_lin_results_ptr
,
q_lin_results_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
h_beta
,
k_lin_grads_ptr
,
k_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
k_lin_grads_ptr
,
k_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
attn_batches
,
attn_batches
,
flags
);
flags
);
// Input Linear Dgrad
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
embed_dim
,
batches
,
batches
,
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
input_lin_output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_lin_output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
// Input Linear Wgrad
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
embed_dim
,
embed_dim
,
output_lin_dim
,
output_lin_dim
,
batches
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_grads
,
input_weight_grads
,
output_weight_grads
,
return
{
input_grads
,
input_weight_grads
,
output_weight_grads
,
input_bias_grads
,
output_bias_grads
};
input_bias_grads
,
output_bias_grads
};
}
else
{
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
batches
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
hipOperationToRocOperation
(
CUBLAS_OP_T
),
embed_dim
,
embed_dim
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim
,
batch_stride
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
beta
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
flags
);
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
alpha
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
,
flags
);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward_stream
<
half
,
half
,
float
,
false
>
(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
dropout_mask
.
data_ptr
()),
1.0
/
(
1.0
-
dropout_prob
),
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
stream
);
// Matmul1 Dgrad1
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
head_dim
,
q_seq_len
,
k_seq_len
,
scale
,
k_lin_results_ptr
,
lead_dim
,
batch_stride
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
,
flags
);
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
scale
,
q_lin_results_ptr
,
lead_dim
,
batch_stride
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
,
flags
);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
batches
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
input_lin_output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
hipOperationToRocOperation
(
CUBLAS_OP_T
),
embed_dim
,
output_lin_dim
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_grads
,
input_weight_grads
,
output_weight_grads
,
input_bias_grads
,
output_bias_grads
};
}
}
}
}
// end namespace rocblas_gemmex
}
// end namespace rocblas_gemmex
}
// end namespace self
}
// end namespace self
}
// end namespace multihead_attn
}
// end namespace multihead_attn
\ No newline at end of file
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
View file @
b5d7745d
...
@@ -85,9 +85,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -85,9 +85,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
// Input Linear Fwd
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
output_lin_dim
,
output_lin_dim
,
batches
,
batches
,
embed_dim
,
embed_dim
,
...
@@ -108,7 +108,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -108,7 +108,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
gemm_switch_fp32accum
(
a_layout_t
,
...
@@ -188,9 +188,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -188,9 +188,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags
);
flags
);
// Output Linear
// Output Linear
TORCH_CUDABLAS_CHECK
(
roc
BLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
roc
blas_gemm_ex
(
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
)
,
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_N
)
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
batches
,
batches
,
embed_dim
,
embed_dim
,
...
@@ -211,7 +211,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -211,7 +211,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
))
)
;
flags
));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_lin_results
,
softmax_results
,
dropout_results
,
return
{
input_lin_results
,
softmax_results
,
dropout_results
,
...
@@ -289,202 +289,102 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -289,202 +289,102 @@ std::vector<torch::Tensor> bwd_cuda(
#endif
#endif
// Output Linear Dgrad
// Output Linear Dgrad
if
(
use_fp16
)
{
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
embed_dim
,
batches
,
batches
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
// Output Linear Wgrad
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
embed_dim
,
embed_dim
,
embed_dim
,
embed_dim
,
batches
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
// MatMul2 Dgrad1
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
b_layout_n
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
head_dim
,
head_dim
,
alpha
,
h_alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
*
attn_batches
,
head_dim
,
head_dim
,
beta
,
h_beta
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
attn_batches
,
flags
);
flags
);
// Matmul2 Dgrad2
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
b_layout_t
,
head_dim
,
head_dim
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
alpha
,
h_alpha
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
*
attn_batches
,
head_dim
,
head_dim
,
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
()),
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
h_beta
,
v_lin_grads_ptr
,
v_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
v_lin_grads_ptr
,
v_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
attn_batches
,
attn_batches
,
flags
);
flags
);
}
else
{
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
batches
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
hipOperationToRocOperation
(
CUBLAS_OP_T
),
embed_dim
,
embed_dim
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim
,
batch_stride
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
beta
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
flags
);
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
alpha
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
,
flags
);
}
// Apply Dropout Mask and Scale by Dropout Probability
// Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
apex_masked_scale_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
...
@@ -504,202 +404,102 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -504,202 +404,102 @@ std::vector<torch::Tensor> bwd_cuda(
assert
(
softmax_success
);
assert
(
softmax_success
);
// Matmul1 Dgrad1
// Matmul1 Dgrad1
if
(
use_fp16
)
{
gemm_switch_fp32accum
(
a_layout_n
,
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
b_layout_n
,
head_dim
,
head_dim
,
q_seq_len
,
q_seq_len
,
k_seq_len
,
k_seq_len
,
scale
,
h_scale
,
k_lin_results_ptr
,
k_lin_results_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
h_beta
,
q_lin_grads_ptr
,
q_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
q_lin_grads_ptr
,
q_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
attn_batches
,
attn_batches
,
flags
);
flags
);
// Matmul1 Dgrad2
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
b_layout_t
,
head_dim
,
head_dim
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
scale
,
h_scale
,
q_lin_results_ptr
,
q_lin_results_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
h_beta
,
k_lin_grads_ptr
,
k_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
k_lin_grads_ptr
,
k_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
attn_batches
,
attn_batches
,
flags
);
flags
);
// Input Linear Dgrad
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
embed_dim
,
batches
,
batches
,
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
// Input Linear Wgrad
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
embed_dim
,
embed_dim
,
output_lin_dim
,
output_lin_dim
,
batches
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
}
else
{
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
head_dim
,
q_seq_len
,
k_seq_len
,
scale
,
k_lin_results_ptr
,
lead_dim
,
batch_stride
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
,
flags
);
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
scale
,
q_lin_results_ptr
,
lead_dim
,
batch_stride
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
,
flags
);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
batches
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
hipOperationToRocOperation
(
CUBLAS_OP_T
),
embed_dim
,
output_lin_dim
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
}
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
return
{
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
View file @
b5d7745d
...
@@ -106,106 +106,54 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -106,106 +106,54 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_beta_weights
.
data_ptr
()));
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_beta_weights
.
data_ptr
()));
// Input Linear Fwd
// Input Linear Fwd
if
(
use_fp16
)
{
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
output_lin_dim
,
output_lin_dim
,
batches
,
batches
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
embed_dim
,
//static_cast<const void*>(inputs.data_ptr()),
//static_cast<const void*>(inputs.data_ptr()),
static_cast
<
const
void
*>
(
lyr_nrm_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
lyr_nrm_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
rocblas_datatype_f16_r
/*b_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
q_lin_results_ptr
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
/*c_type*/
,
rocblas_datatype_f16_r
/*c_type*/
,
output_lin_dim
,
output_lin_dim
,
q_lin_results_ptr
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
/*d_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
output_lin_dim
,
output_lin_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
b_layout_n
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
head_dim
,
head_dim
,
scale
,
h_scale
,
static_cast
<
const
half
*>
(
k_lin_results_ptr
),
static_cast
<
const
half
*>
(
k_lin_results_ptr
),
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
static_cast
<
const
half
*>
(
q_lin_results_ptr
),
static_cast
<
const
half
*>
(
q_lin_results_ptr
),
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
beta
,
h_beta
,
static_cast
<
half
*>
(
softmax_results_ptr
),
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
softmax_results_ptr
),
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
attn_batches
,
flags
);
flags
);
}
else
{
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
hipOperationToRocOperation
(
CUBLAS_OP_N
),
output_lin_dim
,
batches
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
//static_cast<const void*>(inputs.data_ptr()),
static_cast
<
const
void
*>
(
lyr_nrm_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
q_lin_results_ptr
,
rocblas_datatype_f16_r
/*c_type*/
,
output_lin_dim
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
/*d_type*/
,
output_lin_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
scale
,
static_cast
<
const
half
*>
(
k_lin_results_ptr
),
lead_dim
,
batch_stride
,
static_cast
<
const
half
*>
(
q_lin_results_ptr
),
lead_dim
,
batch_stride
,
beta
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
flags
);
}
// Padded Softmax
// Padded Softmax
bool
softmax_success
=
false
;
bool
softmax_success
=
false
;
...
@@ -239,106 +187,54 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -239,106 +187,54 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
}
}
// Matmul2
// Matmul2
if
(
use_fp16
)
{
gemm_switch_fp32accum
(
a_layout_n
,
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
b_layout_n
,
head_dim
,
head_dim
,
q_seq_len
,
q_seq_len
,
k_seq_len
,
k_seq_len
,
alpha
,
h_alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
(
is_training
)
?
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
())
:
static_cast
<
const
half
*>
(
softmax_results
.
data_ptr
())
,
(
is_training
)
?
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
())
:
static_cast
<
const
half
*>
(
softmax_results
.
data_ptr
())
,
//static_cast<const half*>(dropout_results.data_ptr()),
//static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
h_beta
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
*
attn_batches
,
head_dim
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
*
attn_batches
,
head_dim
,
head_dim
,
attn_batches
,
attn_batches
,
flags
);
flags
);
// Output Linear
// Output Linear
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
embed_dim
,
batches
,
batches
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
rocblas_datatype_f16_r
/*b_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
}
else
{
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
head_dim
,
q_seq_len
,
k_seq_len
,
alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim
,
batch_stride
,
(
is_training
)
?
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
())
:
static_cast
<
const
half
*>
(
softmax_results
.
data_ptr
())
,
//static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
attn_batches
,
flags
);
// Output Linear
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
batches
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
}
// End-of-block Dropout-Add
// End-of-block Dropout-Add
...
@@ -451,202 +347,102 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -451,202 +347,102 @@ std::vector<torch::Tensor> bwd_cuda(
(
1.0
/
(
1.0
-
dropout_prob
)));
(
1.0
/
(
1.0
-
dropout_prob
)));
// Output Linear Dgrad
// Output Linear Dgrad
if
(
use_fp16
)
{
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
embed_dim
,
batches
,
batches
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
rocblas_datatype_f16_r
/*b_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
// Output Linear Wgrad
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
embed_dim
,
embed_dim
,
embed_dim
,
embed_dim
,
batches
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
rocblas_datatype_f16_r
/*b_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
// MatMul2 Dgrad1
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
b_layout_n
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
head_dim
,
head_dim
,
alpha
,
h_alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
*
attn_batches
,
head_dim
,
head_dim
,
beta
,
h_beta
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
attn_batches
,
flags
);
flags
);
// Matmul2 Dgrad2
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
b_layout_t
,
head_dim
,
head_dim
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
alpha
,
h_alpha
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
*
attn_batches
,
head_dim
,
head_dim
,
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
()),
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
h_beta
,
v_lin_grads_ptr
,
v_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
v_lin_grads_ptr
,
v_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
attn_batches
,
attn_batches
,
flags
);
flags
);
}
else
{
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
batches
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
hipOperationToRocOperation
(
CUBLAS_OP_T
),
embed_dim
,
embed_dim
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim
,
batch_stride
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
beta
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
flags
);
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
alpha
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
,
flags
);
}
// Apply Dropout Mask and Scale by Dropout Probability
// Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
apex_masked_scale_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
...
@@ -666,206 +462,104 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -666,206 +462,104 @@ std::vector<torch::Tensor> bwd_cuda(
assert
(
softmax_success
);
assert
(
softmax_success
);
// Matmul1 Dgrad1
// Matmul1 Dgrad1
if
(
use_fp16
)
{
gemm_switch_fp32accum
(
a_layout_n
,
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
b_layout_n
,
head_dim
,
head_dim
,
q_seq_len
,
q_seq_len
,
k_seq_len
,
k_seq_len
,
scale
,
h_scale
,
k_lin_results_ptr
,
k_lin_results_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
h_beta
,
q_lin_grads_ptr
,
q_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
q_lin_grads_ptr
,
q_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
attn_batches
,
attn_batches
,
flags
);
flags
);
// Matmul1 Dgrad2
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
b_layout_t
,
head_dim
,
head_dim
,
k_seq_len
,
k_seq_len
,
q_seq_len
,
q_seq_len
,
scale
,
h_scale
,
q_lin_results_ptr
,
q_lin_results_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
h_beta
,
k_lin_grads_ptr
,
k_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
k_lin_grads_ptr
,
k_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
attn_batches
,
attn_batches
,
flags
);
flags
);
// Input Linear Dgrad
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
embed_dim
,
batches
,
batches
,
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
/*b_type*/
,
rocblas_datatype_f16_r
/*b_type*/
,
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
//static_cast<void*>(input_grads.data_ptr()),
//static_cast<void*>(input_grads.data_ptr()),
static_cast
<
void
*>
(
input_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
input_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
// Input Linear Wgrad
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
CUBLAS_OP_N
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
CUBLAS_OP_T
,
hipOperationToRocOperation
(
CUBLAS_OP_T
),
embed_dim
,
embed_dim
,
output_lin_dim
,
output_lin_dim
,
batches
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
h_alpha
),
//static_cast<const void*>(inputs.data_ptr()),
//static_cast<const void*>(inputs.data_ptr()),
static_cast
<
const
void
*>
(
lyr_nrm_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
lyr_nrm_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
/*b_type*/
,
rocblas_datatype_f16_r
/*b_type*/
,
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
h_beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
)));
}
else
{
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_n
,
head_dim
,
q_seq_len
,
k_seq_len
,
scale
,
k_lin_results_ptr
,
lead_dim
,
batch_stride
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
,
flags
);
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
scale
,
q_lin_results_ptr
,
lead_dim
,
batch_stride
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
,
flags
);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
batches
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
/*b_type*/
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
//static_cast<void*>(input_grads.data_ptr()),
static_cast
<
void
*>
(
input_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
input_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
hipOperationToRocOperation
(
CUBLAS_OP_T
),
embed_dim
,
output_lin_dim
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
//static_cast<const void*>(inputs.data_ptr()),
static_cast
<
const
void
*>
(
lyr_nrm_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
/*b_type*/
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
}
// Fused Layer Norm Bwd with Residual Add
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient
<
half
,
float
>
(
HostLayerNormGradient
<
half
,
float
>
(
...
@@ -889,4 +583,4 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -889,4 +583,4 @@ std::vector<torch::Tensor> bwd_cuda(
}
// end namespace rocblas_gemmex
}
// end namespace rocblas_gemmex
}
// end namespace self_norm_add
}
// end namespace self_norm_add
}
// end namespace multihead_attn
}
// end namespace multihead_attn
\ No newline at end of file
apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh
View file @
b5d7745d
...
@@ -7,8 +7,6 @@
...
@@ -7,8 +7,6 @@
//#include <cuda_profiler_api.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <rocblas/rocblas.h>
//#include <ATen/ATen.h>
//#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/Exceptions.h>
...
@@ -47,52 +45,6 @@ cublasOperation_t convertTransToCublasOperation(char trans) {
...
@@ -47,52 +45,6 @@ cublasOperation_t convertTransToCublasOperation(char trans) {
}
}
}
}
// needed to work around calling rocblas API instead of hipblas API
static
rocblas_operation
hipOperationToRocOperation
(
hipblasOperation_t
op
)
{
switch
(
op
)
{
case
HIPBLAS_OP_N
:
return
rocblas_operation_none
;
case
HIPBLAS_OP_T
:
return
rocblas_operation_transpose
;
case
HIPBLAS_OP_C
:
return
rocblas_operation_conjugate_transpose
;
}
AT_ERROR
(
"HIPBLAS_STATUS_INVALID_ENUM"
);
}
static
hipblasStatus_t
rocBLASStatusToHIPStatus
(
rocblas_status
error
)
{
switch
(
error
)
{
case
rocblas_status_size_unchanged
:
case
rocblas_status_size_increased
:
case
rocblas_status_success
:
case
rocblas_status_continue
:
return
HIPBLAS_STATUS_SUCCESS
;
case
rocblas_status_invalid_handle
:
return
HIPBLAS_STATUS_NOT_INITIALIZED
;
case
rocblas_status_not_implemented
:
case
rocblas_status_excluded_from_build
:
return
HIPBLAS_STATUS_NOT_SUPPORTED
;
case
rocblas_status_invalid_pointer
:
case
rocblas_status_invalid_size
:
case
rocblas_status_invalid_value
:
case
rocblas_status_size_query_mismatch
:
return
HIPBLAS_STATUS_INVALID_VALUE
;
case
rocblas_status_memory_error
:
return
HIPBLAS_STATUS_ALLOC_FAILED
;
case
rocblas_status_internal_error
:
case
rocblas_status_perf_degraded
:
case
rocblas_status_check_numerics_fail
:
return
HIPBLAS_STATUS_INTERNAL_ERROR
;
case
rocblas_status_arch_mismatch
:
return
HIPBLAS_STATUS_ARCH_MISMATCH
;
}
AT_ERROR
(
"HIPBLAS_STATUS_INVALID_ENUM"
);
}
void
RocblasStridedBatchedGemm
(
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
void
RocblasStridedBatchedGemm
(
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
half
*
d
,
long
ldd
,
long
strideD
,
long
batchCount
,
rocblas_gemm_algo
algo
,
rocblas_int
flags
)
{
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
half
*
d
,
long
ldd
,
long
strideD
,
long
batchCount
,
rocblas_gemm_algo
algo
,
rocblas_int
flags
)
{
...
@@ -105,13 +57,13 @@ void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
...
@@ -105,13 +57,13 @@ void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
float
fAlpha
=
alpha
;
float
fAlpha
=
alpha
;
float
fBeta
=
beta
;
float
fBeta
=
beta
;
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_strided_batched_ex
(
(
rocblas_handle
)
handle
,
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_strided_batched_ex
(
handle
,
hipOperationToRocOperation
(
opa
)
,
hipOperationToRocOperation
(
opb
)
,
(
int
)
m
,
(
int
)
n
,
(
int
)
k
,
opa
,
opb
,
(
int
)
m
,
(
int
)
n
,
(
int
)
k
,
(
void
*
)
&
fAlpha
,
a
,
rocblas_datatype_f16_r
/*a_type*/
,
(
int
)
lda
,
strideA
,
(
void
*
)
&
fAlpha
,
a
,
rocblas_datatype_f16_r
/*a_type*/
,
(
int
)
lda
,
strideA
,
b
,
rocblas_datatype_f16_r
/*b_type*/
,
(
int
)
ldb
,
strideB
,
b
,
rocblas_datatype_f16_r
/*b_type*/
,
(
int
)
ldb
,
strideB
,
(
void
*
)
&
fBeta
,
c
,
rocblas_datatype_f16_r
/*c_type*/
,
(
int
)
ldc
,
strideC
,
(
void
*
)
&
fBeta
,
c
,
rocblas_datatype_f16_r
/*c_type*/
,
(
int
)
ldc
,
strideC
,
d
,
rocblas_datatype_f16_r
/*d_type*/
,
int
(
ldd
),
strideD
,
d
,
rocblas_datatype_f16_r
/*d_type*/
,
int
(
ldd
),
strideD
,
(
int
)
batchCount
,
rocblas_datatype_f32_r
/*compute_type*/
,
algo
,
0
/*solution_index*/
,
flags
))
)
;
(
int
)
batchCount
,
rocblas_datatype_f32_r
/*compute_type*/
,
algo
,
0
/*solution_index*/
,
flags
));
}
}
void
gemm_switch_fp32accum
(
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
void
gemm_switch_fp32accum
(
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
...
...
csrc/fused_dense_cuda.cu
View file @
b5d7745d
...
@@ -10,22 +10,10 @@
...
@@ -10,22 +10,10 @@
#include <cublas_v2.h>
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include "utils.h"
#include <rocblas/rocblas.h>
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
// includes cublaslt
// includes cublaslt
#include <cublasLt.h>
#include <cublasLt.h>
#endif
#endif
// until we use hipblas v2
// hipify correctly maps things like CUDA_R_16F to HIP_R_16F,
// however hipblas v1 is still using its custom type
#define HIP_R_64F HIPBLAS_R_64F
#define HIP_R_32F HIPBLAS_R_32F
#define HIP_R_16F HIPBLAS_R_16F
// FP64 Wrapper around cublas GEMMEx
// FP64 Wrapper around cublas GEMMEx
cublasStatus_t
gemm_bias
(
cublasStatus_t
gemm_bias
(
cublasHandle_t
handle
,
cublasHandle_t
handle
,
...
@@ -42,6 +30,33 @@ cublasStatus_t gemm_bias(
...
@@ -42,6 +30,33 @@ cublasStatus_t gemm_bias(
const
float
*
beta
,
const
float
*
beta
,
double
*
C
,
double
*
C
,
int
ldc
)
{
int
ldc
)
{
#ifdef __HIP_PLATFORM_HCC__
return
rocblas_gemm_ex
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
rocblas_datatype_f64_r
,
lda
,
B
,
rocblas_datatype_f64_r
,
ldb
,
beta
,
C
,
rocblas_datatype_f64_r
,
ldc
,
C
,
rocblas_datatype_f64_r
,
ldc
,
rocblas_datatype_f64_r
,
rocblas_gemm_algo_standard
,
0
,
0
);
#else
return
cublasGemmEx
(
return
cublasGemmEx
(
handle
,
handle
,
transa
,
transa
,
...
@@ -62,6 +77,7 @@ cublasStatus_t gemm_bias(
...
@@ -62,6 +77,7 @@ cublasStatus_t gemm_bias(
ldc
,
ldc
,
CUDA_R_64F
,
CUDA_R_64F
,
CUBLAS_GEMM_DEFAULT
);
CUBLAS_GEMM_DEFAULT
);
#endif
}
}
// FP32 Wrapper around cublas GEMMEx
// FP32 Wrapper around cublas GEMMEx
...
@@ -80,6 +96,34 @@ cublasStatus_t gemm_bias(
...
@@ -80,6 +96,34 @@ cublasStatus_t gemm_bias(
const
float
*
beta
,
const
float
*
beta
,
float
*
C
,
float
*
C
,
int
ldc
)
{
int
ldc
)
{
#ifdef __HIP_PLATFORM_HCC__
return
rocblas_gemm_ex
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
rocblas_datatype_f32_r
,
lda
,
B
,
rocblas_datatype_f32_r
,
ldb
,
beta
,
C
,
rocblas_datatype_f32_r
,
ldc
,
C
,
rocblas_datatype_f32_r
,
ldc
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
,
0
,
0
);
#else
return
cublasGemmEx
(
return
cublasGemmEx
(
handle
,
handle
,
transa
,
transa
,
...
@@ -100,6 +144,7 @@ cublasStatus_t gemm_bias(
...
@@ -100,6 +144,7 @@ cublasStatus_t gemm_bias(
ldc
,
ldc
,
CUDA_R_32F
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT
);
CUBLAS_GEMM_DEFAULT
);
#endif
}
}
// FP16 Tensor core wrapper around cublas GEMMEx
// FP16 Tensor core wrapper around cublas GEMMEx
...
@@ -118,6 +163,7 @@ cublasStatus_t gemm_bias(
...
@@ -118,6 +163,7 @@ cublasStatus_t gemm_bias(
const
float
*
beta
,
const
float
*
beta
,
at
::
Half
*
C
,
at
::
Half
*
C
,
int
ldc
)
{
int
ldc
)
{
<<<<<<<
HEAD
if
(
parseEnvVarFlag
(
"APEX_ROCBLAS_GEMM_ALLOW_HALF"
))
{
if
(
parseEnvVarFlag
(
"APEX_ROCBLAS_GEMM_ALLOW_HALF"
))
{
half
h_alpha
=
__float2half
(
*
alpha
);
half
h_alpha
=
__float2half
(
*
alpha
);
half
h_beta
=
__float2half
(
*
beta
);
half
h_beta
=
__float2half
(
*
beta
);
...
@@ -163,6 +209,56 @@ cublasStatus_t gemm_bias(
...
@@ -163,6 +209,56 @@ cublasStatus_t gemm_bias(
CUDA_R_32F
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
}
}
=======
#ifdef __HIP_PLATFORM_HCC__
return
rocblas_gemm_ex
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
rocblas_datatype_f16_r
,
lda
,
B
,
rocblas_datatype_f16_r
,
ldb
,
beta
,
C
,
rocblas_datatype_f16_r
,
ldc
,
C
,
rocblas_datatype_f16_r
,
ldc
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
,
0
,
0
);
#else
return
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_16F
,
lda
,
B
,
CUDA_R_16F
,
ldb
,
beta
,
C
,
CUDA_R_16F
,
ldc
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
#endif
>>>>>>>
mirror
/
master
}
}
...
...
csrc/mlp_cuda.cu
View file @
b5d7745d
...
@@ -13,8 +13,6 @@
...
@@ -13,8 +13,6 @@
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include "utils.h"
#include "utils.h"
#include <rocblas/rocblas.h>
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
// includes cublaslt
// includes cublaslt
#include <cublasLt.h>
#include <cublasLt.h>
...
@@ -62,52 +60,6 @@ __device__ __inline__ float sigmoid(float a) {
...
@@ -62,52 +60,6 @@ __device__ __inline__ float sigmoid(float a) {
return
(
retf
);
return
(
retf
);
}
}
// needed to work around calling rocblas API instead of hipblas API
static
rocblas_operation
hipOperationToRocOperation
(
hipblasOperation_t
op
)
{
switch
(
op
)
{
case
HIPBLAS_OP_N
:
return
rocblas_operation_none
;
case
HIPBLAS_OP_T
:
return
rocblas_operation_transpose
;
case
HIPBLAS_OP_C
:
return
rocblas_operation_conjugate_transpose
;
}
AT_ERROR
(
"HIPBLAS_STATUS_INVALID_ENUM"
);
}
static
hipblasStatus_t
rocBLASStatusToHIPStatus
(
rocblas_status
error
)
{
switch
(
error
)
{
case
rocblas_status_size_unchanged
:
case
rocblas_status_size_increased
:
case
rocblas_status_success
:
case
rocblas_status_continue
:
return
HIPBLAS_STATUS_SUCCESS
;
case
rocblas_status_invalid_handle
:
return
HIPBLAS_STATUS_NOT_INITIALIZED
;
case
rocblas_status_not_implemented
:
case
rocblas_status_excluded_from_build
:
return
HIPBLAS_STATUS_NOT_SUPPORTED
;
case
rocblas_status_invalid_pointer
:
case
rocblas_status_invalid_size
:
case
rocblas_status_invalid_value
:
case
rocblas_status_size_query_mismatch
:
return
HIPBLAS_STATUS_INVALID_VALUE
;
case
rocblas_status_memory_error
:
return
HIPBLAS_STATUS_ALLOC_FAILED
;
case
rocblas_status_internal_error
:
case
rocblas_status_perf_degraded
:
case
rocblas_status_check_numerics_fail
:
return
HIPBLAS_STATUS_INTERNAL_ERROR
;
case
rocblas_status_arch_mismatch
:
return
HIPBLAS_STATUS_ARCH_MISMATCH
;
}
AT_ERROR
(
"HIPBLAS_STATUS_INVALID_ENUM"
);
}
// FP64 Wrapper around cublas GEMMEx
// FP64 Wrapper around cublas GEMMEx
cublasStatus_t
mlp_gemm
(
cublasStatus_t
mlp_gemm
(
cublasHandle_t
handle
,
cublasHandle_t
handle
,
...
@@ -126,10 +78,10 @@ cublasStatus_t mlp_gemm(
...
@@ -126,10 +78,10 @@ cublasStatus_t mlp_gemm(
int
ldc
,
int
ldc
,
int
flag
)
{
int
flag
)
{
#ifdef __HIP_PLATFORM_HCC__
#ifdef __HIP_PLATFORM_HCC__
return
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
(
return
rocblas_gemm_ex
(
(
rocblas_handle
)
handle
,
handle
,
hipOperationToRocOperation
(
transa
)
,
transa
,
hipOperationToRocOperation
(
transb
)
,
transb
,
m
,
m
,
n
,
n
,
k
,
k
,
...
@@ -150,7 +102,7 @@ cublasStatus_t mlp_gemm(
...
@@ -150,7 +102,7 @@ cublasStatus_t mlp_gemm(
rocblas_datatype_f64_r
,
rocblas_datatype_f64_r
,
rocblas_gemm_algo_standard
,
rocblas_gemm_algo_standard
,
0
,
0
,
flag
)
)
;
flag
);
#else
#else
return
cublasGemmEx
(
return
cublasGemmEx
(
handle
,
handle
,
...
@@ -193,10 +145,10 @@ cublasStatus_t mlp_gemm(
...
@@ -193,10 +145,10 @@ cublasStatus_t mlp_gemm(
int
ldc
,
int
ldc
,
int
flag
)
{
int
flag
)
{
#ifdef __HIP_PLATFORM_HCC__
#ifdef __HIP_PLATFORM_HCC__
return
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
(
return
rocblas_gemm_ex
(
(
rocblas_handle
)
handle
,
handle
,
hipOperationToRocOperation
(
transa
)
,
transa
,
hipOperationToRocOperation
(
transb
)
,
transb
,
m
,
m
,
n
,
n
,
k
,
k
,
...
@@ -217,7 +169,7 @@ cublasStatus_t mlp_gemm(
...
@@ -217,7 +169,7 @@ cublasStatus_t mlp_gemm(
rocblas_datatype_f32_r
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
,
rocblas_gemm_algo_standard
,
0
,
0
,
flag
)
)
;
flag
);
#else
#else
return
cublasGemmEx
(
return
cublasGemmEx
(
...
@@ -261,61 +213,31 @@ cublasStatus_t mlp_gemm(
...
@@ -261,61 +213,31 @@ cublasStatus_t mlp_gemm(
int
ldc
,
int
ldc
,
int
flag
)
{
int
flag
)
{
#ifdef __HIP_PLATFORM_HCC__
#ifdef __HIP_PLATFORM_HCC__
if
(
parseEnvVarFlag
(
"APEX_ROCBLAS_GEMM_ALLOW_HALF"
))
{
return
rocblas_gemm_ex
(
half
h_alpha
=
__float2half
(
*
alpha
);
handle
,
half
h_beta
=
__float2half
(
*
beta
);
transa
,
return
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
(
transb
,
(
rocblas_handle
)
handle
,
m
,
hipOperationToRocOperation
(
transa
),
n
,
hipOperationToRocOperation
(
transb
),
k
,
m
,
alpha
,
n
,
A
,
k
,
rocblas_datatype_f16_r
,
/* alpha */
&
h_alpha
,
lda
,
A
,
B
,
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
lda
,
ldb
,
B
,
beta
,
rocblas_datatype_f16_r
,
C
,
ldb
,
rocblas_datatype_f16_r
,
/* beta */
&
h_beta
,
ldc
,
C
,
C
,
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
ldc
,
ldc
,
C
,
rocblas_datatype_f32_r
,
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
,
ldc
,
0
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
,
flag
);
rocblas_gemm_algo_standard
,
0
,
flag
);
}
else
{
return
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
(
(
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
transa
),
hipOperationToRocOperation
(
transb
),
m
,
n
,
k
,
alpha
,
A
,
rocblas_datatype_f16_r
,
lda
,
B
,
rocblas_datatype_f16_r
,
ldb
,
beta
,
C
,
rocblas_datatype_f16_r
,
ldc
,
C
,
rocblas_datatype_f16_r
,
ldc
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
,
0
,
flag
);
}
#else
#else
return
cublasGemmEx
(
return
cublasGemmEx
(
handle
,
handle
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment