Commit b5d7745d authored by flyingdown's avatar flyingdown
Browse files

merge mirror master

parents 03204b84 3ba7192d
......@@ -94,155 +94,79 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Q Fwd
if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_q_dim,
batches_q,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(input_weights_q.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(inputs_q.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&h_beta),
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_q_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_q_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear KV Fwd
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_kv_dim,
batches_kv,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(input_weights_kv.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(inputs_kv.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&h_beta),
k_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_kv_dim,
k_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_kv_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
h_scale,
static_cast<const half*>(k_lin_results_ptr),
lead_dim_kv,
batch_stride_kv,
static_cast<const half*>(q_lin_results_ptr),
lead_dim_q,
batch_stride_q,
h_beta,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
} else {
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_q_dim,
batches_q,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(inputs_q.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_q_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_q_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear KV Fwd
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_kv_dim,
batches_kv,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(inputs_kv.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
k_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_kv_dim,
k_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_kv_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
scale,
static_cast<const half*>(k_lin_results_ptr),
lead_dim_kv,
batch_stride_kv,
static_cast<const half*>(q_lin_results_ptr),
lead_dim_q,
batch_stride_q,
beta,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
}
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_q_dim,
batches_q,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(inputs_q.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_q_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_q_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear KV Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_kv_dim,
batches_kv,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(inputs_kv.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
k_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_kv_dim,
k_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_kv_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
scale,
static_cast<const half*>(k_lin_results_ptr),
lead_dim_kv,
batch_stride_kv,
static_cast<const half*>(q_lin_results_ptr),
lead_dim_q,
batch_stride_q,
beta,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Padded Softmax
bool softmax_success = false;
......@@ -276,104 +200,53 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
}
// Matmul2
if (use_fp16) {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
h_alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim_kv,
batch_stride_kv,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches,
flags);
// Output Linear
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_q,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
} else {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim_kv,
batch_stride_kv,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
k_seq_len,
k_seq_len*q_seq_len,
beta,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches,
flags);
// Output Linear
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_q,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim_kv,
batch_stride_kv,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
k_seq_len,
k_seq_len*q_seq_len,
beta,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches,
flags);
// Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
batches_q,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
}
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_lin_q_results,
......@@ -465,32 +338,57 @@ std::vector<torch::Tensor> bwd_cuda(
#endif
#endif
if (use_fp16) {
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_q,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches_q,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
embed_dim,
batches_q,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
......@@ -680,308 +578,155 @@ std::vector<torch::Tensor> bwd_cuda(
k_seq_len, attn_batches * q_seq_len);
assert(softmax_success);
if (use_fp16) {
// Matmul1 Dgrad1
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
h_scale,
k_lin_results_ptr,
lead_dim_kv,
batch_stride_kv,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
q_lin_grads_ptr,
lead_dim_q,
batch_stride_q,
q_lin_grads_ptr,
lead_dim_q,
batch_stride_q,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
h_scale,
q_lin_results_ptr,
lead_dim_q,
batch_stride_q,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
k_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
k_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
attn_batches,
flags);
// Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_q,
output_lin_q_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(input_weights_q.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_q_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(input_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_q_dim,
batches_q,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(inputs_q.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_q_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_kv,
output_lin_kv_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(input_weights_kv.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(k_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_kv_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_kv_dim,
batches_kv,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(inputs_kv.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(k_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_kv_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
} else {
// Matmul1 Dgrad1
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
scale,
k_lin_results_ptr,
lead_dim_kv,
batch_stride_kv,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
q_lin_grads_ptr,
lead_dim_q,
batch_stride_q,
q_lin_grads_ptr,
lead_dim_q,
batch_stride_q,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
scale,
q_lin_results_ptr,
lead_dim_q,
batch_stride_q,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
k_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
k_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
attn_batches,
flags);
// Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_q,
output_lin_q_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_q_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_q_dim,
batches_q,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_q.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_q_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_kv,
output_lin_kv_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(k_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_kv_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_kv_dim,
batches_kv,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_kv.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(k_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_kv_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
}
// Matmul1 Dgrad1
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
scale,
k_lin_results_ptr,
lead_dim_kv,
batch_stride_kv,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
q_lin_grads_ptr,
lead_dim_q,
batch_stride_q,
q_lin_grads_ptr,
lead_dim_q,
batch_stride_q,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
scale,
q_lin_results_ptr,
lead_dim_q,
batch_stride_q,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
k_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
k_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
attn_batches,
flags);
// Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches_q,
output_lin_q_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_q_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
output_lin_q_dim,
batches_q,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_q.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_q_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches_kv,
output_lin_kv_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(k_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_kv_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
output_lin_kv_dim,
batches_kv,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_kv.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(k_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_kv_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_q_grads,
......
......@@ -119,158 +119,80 @@ std::vector<torch::Tensor> fwd_cuda(
1.0e-5, static_cast<const at::Half *>(lyr_nrm_gamma_weights.data_ptr()),
static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr()));
if (use_fp16) {
// Input Linear Q Fwd
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_q_dim,
batches_q,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(input_weights_q.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
//static_cast<const void*>(inputs_q.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&h_beta),
q_lin_results_ptr,
rocblas_datatype_f16_r /*c_type*/,
output_lin_q_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r /*d_type*/,
output_lin_q_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear KV Fwd
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_kv_dim,
batches_kv,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(input_weights_kv.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(inputs_kv.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&h_beta),
k_lin_results_ptr,
rocblas_datatype_f16_r /*c_type*/,
output_lin_kv_dim,
k_lin_results_ptr,
rocblas_datatype_f16_r /*d_type*/,
output_lin_kv_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
h_scale,
static_cast<const half*>(k_lin_results_ptr),
lead_dim_kv,
batch_stride_kv,
static_cast<const half*>(q_lin_results_ptr),
lead_dim_q,
batch_stride_q,
h_beta,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
} else {
// Input Linear Q Fwd
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_q_dim,
batches_q,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
//static_cast<const void*>(inputs_q.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
q_lin_results_ptr,
rocblas_datatype_f16_r /*c_type*/,
output_lin_q_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r /*d_type*/,
output_lin_q_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear KV Fwd
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_kv_dim,
batches_kv,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(inputs_kv.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
k_lin_results_ptr,
rocblas_datatype_f16_r /*c_type*/,
output_lin_kv_dim,
k_lin_results_ptr,
rocblas_datatype_f16_r /*d_type*/,
output_lin_kv_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
scale,
static_cast<const half*>(k_lin_results_ptr),
lead_dim_kv,
batch_stride_kv,
static_cast<const half*>(q_lin_results_ptr),
lead_dim_q,
batch_stride_q,
beta,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
}
// Input Linear Q Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_q_dim,
batches_q,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
//static_cast<const void*>(inputs_q.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
q_lin_results_ptr,
rocblas_datatype_f16_r /*c_type*/,
output_lin_q_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r /*d_type*/,
output_lin_q_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear KV Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_kv_dim,
batches_kv,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(inputs_kv.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
k_lin_results_ptr,
rocblas_datatype_f16_r /*c_type*/,
output_lin_kv_dim,
k_lin_results_ptr,
rocblas_datatype_f16_r /*d_type*/,
output_lin_kv_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
scale,
static_cast<const half*>(k_lin_results_ptr),
lead_dim_kv,
batch_stride_kv,
static_cast<const half*>(q_lin_results_ptr),
lead_dim_q,
batch_stride_q,
beta,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Padded Softmax
bool softmax_success = false;
......@@ -303,108 +225,55 @@ std::vector<torch::Tensor> fwd_cuda(
(1.0f - dropout_prob));
}
if (use_fp16) {
// Matmul2
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
h_alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim_kv,
batch_stride_kv,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()),
//static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches,
flags);
// Output Linear
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_q,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
} else {
// Matmul2
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim_kv,
batch_stride_kv,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()),
//static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches,
flags);
// Output Linear
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_q,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
}
// Matmul2
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim_kv,
batch_stride_kv,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()),
//static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches,
flags);
// Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
batches_q,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// End-of-block Dropout-Add
if (is_training) {
......@@ -533,32 +402,57 @@ std::vector<torch::Tensor> bwd_cuda(
total_tokens_q,
(1.0 / (1.0 - dropout_prob)));
if (use_fp16) {
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_q,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches_q,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
embed_dim,
batches_q,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
......@@ -749,310 +643,156 @@ std::vector<torch::Tensor> bwd_cuda(
k_seq_len, attn_batches * q_seq_len);
assert(softmax_success);
if (use_fp16) {
// Matmul1 Dgrad1
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
h_scale,
k_lin_results_ptr,
lead_dim_kv,
batch_stride_kv,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
q_lin_grads_ptr,
lead_dim_q,
batch_stride_q,
q_lin_grads_ptr,
lead_dim_q,
batch_stride_q,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
h_scale,
q_lin_results_ptr,
lead_dim_q,
batch_stride_q,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
k_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
k_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
attn_batches,
flags);
// Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_q,
output_lin_q_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(input_weights_q.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_q_dim,
static_cast<const void*>(&h_beta),
//static_cast<void*>(input_q_grads.data_ptr()),
static_cast<void*>(input_lin_q_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_lin_q_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_q_dim,
batches_q,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(inputs_q.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_q_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_kv,
output_lin_kv_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(input_weights_kv.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(k_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_kv_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_kv_dim,
batches_kv,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(inputs_kv.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(k_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_kv_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
} else {
// Matmul1 Dgrad1
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
scale,
k_lin_results_ptr,
lead_dim_kv,
batch_stride_kv,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
q_lin_grads_ptr,
lead_dim_q,
batch_stride_q,
q_lin_grads_ptr,
lead_dim_q,
batch_stride_q,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
scale,
q_lin_results_ptr,
lead_dim_q,
batch_stride_q,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
k_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
k_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
attn_batches,
flags);
// Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_q,
output_lin_q_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_q_dim,
static_cast<const void*>(&beta),
//static_cast<void*>(input_q_grads.data_ptr()),
static_cast<void*>(input_lin_q_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_lin_q_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_q_dim,
batches_q,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_q.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_q_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_kv,
output_lin_kv_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(k_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_kv_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_kv_dim,
batches_kv,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_kv.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(k_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_kv_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
}
// Matmul1 Dgrad1
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
scale,
k_lin_results_ptr,
lead_dim_kv,
batch_stride_kv,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
q_lin_grads_ptr,
lead_dim_q,
batch_stride_q,
q_lin_grads_ptr,
lead_dim_q,
batch_stride_q,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
scale,
q_lin_results_ptr,
lead_dim_q,
batch_stride_q,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
k_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
k_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
attn_batches,
flags);
// Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches_q,
output_lin_q_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_q_dim,
static_cast<const void*>(&beta),
//static_cast<void*>(input_q_grads.data_ptr()),
static_cast<void*>(input_lin_q_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_lin_q_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
output_lin_q_dim,
batches_q,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_q.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_q_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches_kv,
output_lin_kv_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(k_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_kv_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
output_lin_kv_dim,
batches_kv,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_kv.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(k_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_kv_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient<half,float>(
......@@ -1080,4 +820,4 @@ std::vector<torch::Tensor> bwd_cuda(
} // end namespace rocblas_gemmex
} // end namespace encdec_norm_add
} // end namespace multihead_attn
} // end namespace multihead_attn
\ No newline at end of file
......@@ -90,104 +90,53 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
// Input Linear Fwd
input_lin_results.copy_(input_biases);
if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_dim,
batches,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&h_beta_one),
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
h_scale,
static_cast<const half*>(k_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(q_lin_results_ptr),
lead_dim,
batch_stride,
h_beta_zero,
static_cast<half*>(bmm1_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(bmm1_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
} else {
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta_one),
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
scale,
static_cast<const half*>(k_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(q_lin_results_ptr),
lead_dim,
batch_stride,
beta_zero,
static_cast<half*>(bmm1_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(bmm1_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
}
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta_one),
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
scale,
static_cast<const half*>(k_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(q_lin_results_ptr),
lead_dim,
batch_stride,
beta_zero,
static_cast<half*>(bmm1_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(bmm1_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Padded Softmax
bool softmax_success = false;
......@@ -213,108 +162,55 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
}
// Matmul2
if (use_fp16) {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
h_alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta_zero,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches,
flags);
outputs.copy_(output_biases);
// Output Linear
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&h_beta_one),
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
} else {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta_zero,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches,
flags);
outputs.copy_(output_biases);
// Output Linear
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta_one),
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
}
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta_zero,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches,
flags);
outputs.copy_(output_biases);
// Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta_one),
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_lin_results, bmm1_results, dropout_results,
......@@ -392,442 +288,222 @@ std::vector<torch::Tensor> bwd_cuda(
#endif
// Output Linear Dgrad
if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
embed_dim,
batches,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
h_alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
h_beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
h_alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
v_lin_grads_ptr,
lead_dim,
batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward_recompute<half, half, float, false>(
static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half* const>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(bmm1_results.data_ptr()),
reinterpret_cast<half const*>(pad_mask.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len/sequences,
attn_batches*q_seq_len,
stream);
// Matmul1 Dgrad1
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
h_scale,
k_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
h_scale,
q_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
k_lin_grads_ptr,
lead_dim,
batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
output_lin_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(input_lin_output_grads.data_ptr()),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_dim,
batches,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_grads, input_weight_grads, output_weight_grads,
input_bias_grads, output_bias_grads};
} else {
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
embed_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
v_lin_grads_ptr,
lead_dim,
batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward_recompute<half, half, float, false>(
static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half* const>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(bmm1_results.data_ptr()),
reinterpret_cast<half const*>(pad_mask.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len/sequences,
attn_batches*q_seq_len,
stream);
// Matmul1 Dgrad1
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
scale,
k_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
scale,
q_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
k_lin_grads_ptr,
lead_dim,
batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
output_lin_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(input_lin_output_grads.data_ptr()),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_grads, input_weight_grads, output_weight_grads,
input_bias_grads, output_bias_grads};
}
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
embed_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
v_lin_grads_ptr,
lead_dim,
batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward_recompute<half, half, float, false>(
static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half* const>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(bmm1_results.data_ptr()),
reinterpret_cast<half const*>(pad_mask.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len/sequences,
attn_batches*q_seq_len,
stream);
// Matmul1 Dgrad1
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
scale,
k_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
scale,
q_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
k_lin_grads_ptr,
lead_dim,
batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches,
output_lin_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(input_lin_output_grads.data_ptr()),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
output_lin_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_grads, input_weight_grads, output_weight_grads,
input_bias_grads, output_bias_grads};
}
} // end namespace rocblas_gemmex
......
......@@ -88,104 +88,53 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
// Input Linear Fwd
input_lin_results.copy_(input_biases);
if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_dim,
batches,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&h_beta_one),
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
h_scale,
static_cast<const half*>(k_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(q_lin_results_ptr),
lead_dim,
batch_stride,
h_beta_zero,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
} else {
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta_one),
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
scale,
static_cast<const half*>(k_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(q_lin_results_ptr),
lead_dim,
batch_stride,
beta_zero,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
}
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta_one),
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
scale,
static_cast<const half*>(k_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(q_lin_results_ptr),
lead_dim,
batch_stride,
beta_zero,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Padded Softmax
bool softmax_success = false;
......@@ -219,108 +168,55 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
}
// Matmul2
if (use_fp16) {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
h_alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
k_seq_len,
k_seq_len*q_seq_len,
h_beta_zero,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches,
flags);
outputs.copy_(output_biases);
// Output Linear
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&h_beta_one),
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
} else {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
k_seq_len,
k_seq_len*q_seq_len,
beta_zero,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches,
flags);
outputs.copy_(output_biases);
// Output Linear
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta_one),
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
}
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
k_seq_len,
k_seq_len*q_seq_len,
beta_zero,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches,
flags);
outputs.copy_(output_biases);
// Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta_one),
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_lin_results, softmax_results, dropout_results,
......@@ -398,432 +294,218 @@ std::vector<torch::Tensor> bwd_cuda(
#endif
// Output Linear Dgrad
if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
embed_dim,
batches,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
h_alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
h_beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
h_alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
v_lin_grads_ptr,
lead_dim,
batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(
static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const *>(dropout_mask.data_ptr()),
1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,
attn_batches * q_seq_len, stream);
// Matmul1 Dgrad1
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
h_scale,
k_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
h_scale,
q_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
k_lin_grads_ptr,
lead_dim,
batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
output_lin_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(input_lin_output_grads.data_ptr()),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_dim,
batches,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_grads, input_weight_grads, output_weight_grads,
input_bias_grads, output_bias_grads};
} else {
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
embed_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
v_lin_grads_ptr,
lead_dim,
batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(
static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const *>(dropout_mask.data_ptr()),
1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,
attn_batches * q_seq_len, stream);
// Matmul1 Dgrad1
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
scale,
k_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
scale,
q_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
k_lin_grads_ptr,
lead_dim,
batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
output_lin_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(input_lin_output_grads.data_ptr()),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_grads, input_weight_grads, output_weight_grads,
input_bias_grads, output_bias_grads};
}
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
embed_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
v_lin_grads_ptr,
lead_dim,
batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(
static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const *>(dropout_mask.data_ptr()),
1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,
attn_batches * q_seq_len, stream);
// Matmul1 Dgrad1
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
scale,
k_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
scale,
q_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
k_lin_grads_ptr,
lead_dim,
batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches,
output_lin_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(input_lin_output_grads.data_ptr()),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
output_lin_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_grads, input_weight_grads, output_weight_grads,
input_bias_grads, output_bias_grads};
}
} // end namespace rocblas_gemmex
} // end namespace self
} // end namespace multihead_attn
} // end namespace multihead_attn
\ No newline at end of file
......@@ -85,9 +85,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_dim,
batches,
embed_dim,
......@@ -108,7 +108,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
......@@ -188,9 +188,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags);
// Output Linear
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
batches,
embed_dim,
......@@ -211,7 +211,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_lin_results, softmax_results, dropout_results,
......@@ -289,202 +289,102 @@ std::vector<torch::Tensor> bwd_cuda(
#endif
// Output Linear Dgrad
if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
embed_dim,
batches,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
h_alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
h_beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
h_alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
v_lin_grads_ptr,
lead_dim,
batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
} else {
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
embed_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
v_lin_grads_ptr,
lead_dim,
batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
}
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
embed_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
v_lin_grads_ptr,
lead_dim,
batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<at::Half,float,uint32_t>(
......@@ -504,202 +404,102 @@ std::vector<torch::Tensor> bwd_cuda(
assert(softmax_success);
// Matmul1 Dgrad1
if (use_fp16) {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
h_scale,
k_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
h_scale,
q_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
k_lin_grads_ptr,
lead_dim,
batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
output_lin_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_dim,
batches,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
} else {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
scale,
k_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
scale,
q_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
k_lin_grads_ptr,
lead_dim,
batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
output_lin_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
}
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
scale,
k_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
scale,
q_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
k_lin_grads_ptr,
lead_dim,
batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches,
output_lin_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
output_lin_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
......
......@@ -106,106 +106,54 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr()));
// Input Linear Fwd
if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_dim,
batches,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
//static_cast<const void*>(inputs.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&h_beta),
q_lin_results_ptr,
rocblas_datatype_f16_r /*c_type*/,
output_lin_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r /*d_type*/,
output_lin_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
h_scale,
static_cast<const half*>(k_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(q_lin_results_ptr),
lead_dim,
batch_stride,
h_beta,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
} else {
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
//static_cast<const void*>(inputs.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
q_lin_results_ptr,
rocblas_datatype_f16_r /*c_type*/,
output_lin_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r /*d_type*/,
output_lin_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
scale,
static_cast<const half*>(k_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(q_lin_results_ptr),
lead_dim,
batch_stride,
beta,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
}
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
//static_cast<const void*>(inputs.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
q_lin_results_ptr,
rocblas_datatype_f16_r /*c_type*/,
output_lin_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r /*d_type*/,
output_lin_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
scale,
static_cast<const half*>(k_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(q_lin_results_ptr),
lead_dim,
batch_stride,
beta,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Padded Softmax
bool softmax_success = false;
......@@ -239,106 +187,54 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
}
// Matmul2
if (use_fp16) {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
h_alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
//static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches,
flags);
// Output Linear
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
} else {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
//static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches,
flags);
// Output Linear
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
}
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
//static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches,
flags);
// Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// End-of-block Dropout-Add
......@@ -451,202 +347,102 @@ std::vector<torch::Tensor> bwd_cuda(
(1.0 / (1.0 - dropout_prob)));
// Output Linear Dgrad
if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
embed_dim,
batches,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
h_alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
h_beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
h_alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
v_lin_grads_ptr,
lead_dim,
batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
} else {
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
embed_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
v_lin_grads_ptr,
lead_dim,
batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
}
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
embed_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
v_lin_grads_ptr,
lead_dim,
batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<at::Half,float,uint32_t>(
......@@ -666,206 +462,104 @@ std::vector<torch::Tensor> bwd_cuda(
assert(softmax_success);
// Matmul1 Dgrad1
if (use_fp16) {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
h_scale,
k_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
h_scale,
q_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
k_lin_grads_ptr,
lead_dim,
batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
output_lin_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_dim,
static_cast<const void*>(&h_beta),
//static_cast<void*>(input_grads.data_ptr()),
static_cast<void*>(input_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_dim,
batches,
static_cast<const void*>(&h_alpha),
//static_cast<const void*>(inputs.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
} else {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
scale,
k_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
scale,
q_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
k_lin_grads_ptr,
lead_dim,
batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
output_lin_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_dim,
static_cast<const void*>(&beta),
//static_cast<void*>(input_grads.data_ptr()),
static_cast<void*>(input_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_dim,
batches,
static_cast<const void*>(&alpha),
//static_cast<const void*>(inputs.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
}
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
scale,
k_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
scale,
q_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
k_lin_grads_ptr,
lead_dim,
batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches,
output_lin_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_dim,
static_cast<const void*>(&beta),
//static_cast<void*>(input_grads.data_ptr()),
static_cast<void*>(input_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
output_lin_dim,
batches,
static_cast<const void*>(&alpha),
//static_cast<const void*>(inputs.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient<half, float>(
......@@ -889,4 +583,4 @@ std::vector<torch::Tensor> bwd_cuda(
} // end namespace rocblas_gemmex
} // end namespace self_norm_add
} // end namespace multihead_attn
} // end namespace multihead_attn
\ No newline at end of file
......@@ -7,8 +7,6 @@
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <rocblas/rocblas.h>
//#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
......@@ -47,52 +45,6 @@ cublasOperation_t convertTransToCublasOperation(char trans) {
}
}
// needed to work around calling rocblas API instead of hipblas API
static rocblas_operation hipOperationToRocOperation(hipblasOperation_t op)
{
switch(op)
{
case HIPBLAS_OP_N:
return rocblas_operation_none;
case HIPBLAS_OP_T:
return rocblas_operation_transpose;
case HIPBLAS_OP_C:
return rocblas_operation_conjugate_transpose;
}
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
}
static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error)
{
switch(error)
{
case rocblas_status_size_unchanged:
case rocblas_status_size_increased:
case rocblas_status_success:
case rocblas_status_continue:
return HIPBLAS_STATUS_SUCCESS;
case rocblas_status_invalid_handle:
return HIPBLAS_STATUS_NOT_INITIALIZED;
case rocblas_status_not_implemented:
case rocblas_status_excluded_from_build:
return HIPBLAS_STATUS_NOT_SUPPORTED;
case rocblas_status_invalid_pointer:
case rocblas_status_invalid_size:
case rocblas_status_invalid_value:
case rocblas_status_size_query_mismatch:
return HIPBLAS_STATUS_INVALID_VALUE;
case rocblas_status_memory_error:
return HIPBLAS_STATUS_ALLOC_FAILED;
case rocblas_status_internal_error:
case rocblas_status_perf_degraded:
case rocblas_status_check_numerics_fail:
return HIPBLAS_STATUS_INTERNAL_ERROR;
case rocblas_status_arch_mismatch:
return HIPBLAS_STATUS_ARCH_MISMATCH;
}
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
}
void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo, rocblas_int flags) {
......@@ -105,13 +57,13 @@ void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
float fAlpha = alpha;
float fBeta = beta;
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_strided_batched_ex((rocblas_handle)handle,
hipOperationToRocOperation(opa), hipOperationToRocOperation(opb), (int)m, (int)n, (int)k,
TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle,
opa, opb, (int)m, (int)n, (int)k,
(void*)&fAlpha, a, rocblas_datatype_f16_r /*a_type*/, (int)lda, strideA,
b, rocblas_datatype_f16_r /*b_type*/, (int)ldb, strideB,
(void*)&fBeta, c, rocblas_datatype_f16_r /*c_type*/, (int)ldc, strideC,
d, rocblas_datatype_f16_r /*d_type*/, int(ldd), strideD,
(int)batchCount, rocblas_datatype_f32_r /*compute_type*/, algo, 0 /*solution_index*/, flags)));
(int)batchCount, rocblas_datatype_f32_r /*compute_type*/, algo, 0 /*solution_index*/, flags));
}
void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k,
......
......@@ -10,22 +10,10 @@
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include "utils.h"
#include <rocblas/rocblas.h>
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
// includes cublaslt
#include <cublasLt.h>
#endif
// until we use hipblas v2
// hipify correctly maps things like CUDA_R_16F to HIP_R_16F,
// however hipblas v1 is still using its custom type
#define HIP_R_64F HIPBLAS_R_64F
#define HIP_R_32F HIPBLAS_R_32F
#define HIP_R_16F HIPBLAS_R_16F
// FP64 Wrapper around cublas GEMMEx
cublasStatus_t gemm_bias(
cublasHandle_t handle,
......@@ -42,6 +30,33 @@ cublasStatus_t gemm_bias(
const float* beta,
double* C,
int ldc) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f64_r,
lda,
B,
rocblas_datatype_f64_r,
ldb,
beta,
C,
rocblas_datatype_f64_r,
ldc,
C,
rocblas_datatype_f64_r,
ldc,
rocblas_datatype_f64_r,
rocblas_gemm_algo_standard,
0,
0);
#else
return cublasGemmEx(
handle,
transa,
......@@ -62,6 +77,7 @@ cublasStatus_t gemm_bias(
ldc,
CUDA_R_64F,
CUBLAS_GEMM_DEFAULT);
#endif
}
// FP32 Wrapper around cublas GEMMEx
......@@ -80,6 +96,34 @@ cublasStatus_t gemm_bias(
const float* beta,
float* C,
int ldc) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f32_r,
lda,
B,
rocblas_datatype_f32_r,
ldb,
beta,
C,
rocblas_datatype_f32_r,
ldc,
C,
rocblas_datatype_f32_r,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
0);
#else
return cublasGemmEx(
handle,
transa,
......@@ -100,6 +144,7 @@ cublasStatus_t gemm_bias(
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT);
#endif
}
// FP16 Tensor core wrapper around cublas GEMMEx
......@@ -118,6 +163,7 @@ cublasStatus_t gemm_bias(
const float* beta,
at::Half* C,
int ldc) {
<<<<<<< HEAD
if (parseEnvVarFlag("APEX_ROCBLAS_GEMM_ALLOW_HALF")) {
half h_alpha = __float2half(*alpha);
half h_beta = __float2half(*beta);
......@@ -163,6 +209,56 @@ cublasStatus_t gemm_bias(
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
=======
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f16_r,
lda,
B,
rocblas_datatype_f16_r,
ldb,
beta,
C,
rocblas_datatype_f16_r,
ldc,
C,
rocblas_datatype_f16_r,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
0);
#else
return cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_16F,
lda,
B,
CUDA_R_16F,
ldb,
beta,
C,
CUDA_R_16F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
>>>>>>> mirror/master
}
......
......@@ -13,8 +13,6 @@
#include <cuda_runtime.h>
#include "utils.h"
#include <rocblas/rocblas.h>
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
// includes cublaslt
#include <cublasLt.h>
......@@ -62,52 +60,6 @@ __device__ __inline__ float sigmoid(float a) {
return (retf);
}
// needed to work around calling rocblas API instead of hipblas API
static rocblas_operation hipOperationToRocOperation(hipblasOperation_t op)
{
switch(op)
{
case HIPBLAS_OP_N:
return rocblas_operation_none;
case HIPBLAS_OP_T:
return rocblas_operation_transpose;
case HIPBLAS_OP_C:
return rocblas_operation_conjugate_transpose;
}
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
}
static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error)
{
switch(error)
{
case rocblas_status_size_unchanged:
case rocblas_status_size_increased:
case rocblas_status_success:
case rocblas_status_continue:
return HIPBLAS_STATUS_SUCCESS;
case rocblas_status_invalid_handle:
return HIPBLAS_STATUS_NOT_INITIALIZED;
case rocblas_status_not_implemented:
case rocblas_status_excluded_from_build:
return HIPBLAS_STATUS_NOT_SUPPORTED;
case rocblas_status_invalid_pointer:
case rocblas_status_invalid_size:
case rocblas_status_invalid_value:
case rocblas_status_size_query_mismatch:
return HIPBLAS_STATUS_INVALID_VALUE;
case rocblas_status_memory_error:
return HIPBLAS_STATUS_ALLOC_FAILED;
case rocblas_status_internal_error:
case rocblas_status_perf_degraded:
case rocblas_status_check_numerics_fail:
return HIPBLAS_STATUS_INTERNAL_ERROR;
case rocblas_status_arch_mismatch:
return HIPBLAS_STATUS_ARCH_MISMATCH;
}
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
}
// FP64 Wrapper around cublas GEMMEx
cublasStatus_t mlp_gemm(
cublasHandle_t handle,
......@@ -126,10 +78,10 @@ cublasStatus_t mlp_gemm(
int ldc,
int flag) {
#ifdef __HIP_PLATFORM_HCC__
return rocBLASStatusToHIPStatus(rocblas_gemm_ex(
(rocblas_handle) handle,
hipOperationToRocOperation(transa),
hipOperationToRocOperation(transb),
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
......@@ -150,7 +102,7 @@ cublasStatus_t mlp_gemm(
rocblas_datatype_f64_r,
rocblas_gemm_algo_standard,
0,
flag));
flag);
#else
return cublasGemmEx(
handle,
......@@ -193,10 +145,10 @@ cublasStatus_t mlp_gemm(
int ldc,
int flag) {
#ifdef __HIP_PLATFORM_HCC__
return rocBLASStatusToHIPStatus(rocblas_gemm_ex(
(rocblas_handle) handle,
hipOperationToRocOperation(transa),
hipOperationToRocOperation(transb),
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
......@@ -217,7 +169,7 @@ cublasStatus_t mlp_gemm(
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
flag));
flag);
#else
return cublasGemmEx(
......@@ -261,61 +213,31 @@ cublasStatus_t mlp_gemm(
int ldc,
int flag) {
#ifdef __HIP_PLATFORM_HCC__
if (parseEnvVarFlag("APEX_ROCBLAS_GEMM_ALLOW_HALF")) {
half h_alpha = __float2half(*alpha);
half h_beta = __float2half(*beta);
return rocBLASStatusToHIPStatus(rocblas_gemm_ex(
(rocblas_handle) handle,
hipOperationToRocOperation(transa),
hipOperationToRocOperation(transb),
m,
n,
k,
/* alpha */ &h_alpha,
A,
rocblas_datatype_f16_r,
lda,
B,
rocblas_datatype_f16_r,
ldb,
/* beta */ &h_beta,
C,
rocblas_datatype_f16_r,
ldc,
C,
rocblas_datatype_f16_r,
ldc,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard,
0,
flag);
} else {
return rocBLASStatusToHIPStatus(rocblas_gemm_ex(
(rocblas_handle) handle,
hipOperationToRocOperation(transa),
hipOperationToRocOperation(transb),
m,
n,
k,
alpha,
A,
rocblas_datatype_f16_r,
lda,
B,
rocblas_datatype_f16_r,
ldb,
beta,
C,
rocblas_datatype_f16_r,
ldc,
C,
rocblas_datatype_f16_r,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
flag);
}
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f16_r,
lda,
B,
rocblas_datatype_f16_r,
ldb,
beta,
C,
rocblas_datatype_f16_r,
ldc,
C,
rocblas_datatype_f16_r,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
flag);
#else
return cublasGemmEx(
handle,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment