Commit b5d7745d authored by flyingdown's avatar flyingdown
Browse files

merge mirror master

parents 03204b84 3ba7192d
...@@ -94,155 +94,79 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -94,155 +94,79 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Q Fwd // Input Linear Q Fwd
if (use_fp16) { TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), output_lin_q_dim,
output_lin_q_dim, batches_q,
batches_q, embed_dim,
embed_dim, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(input_weights_q.data_ptr()),
static_cast<const void*>(input_weights_q.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(inputs_q.data_ptr()),
static_cast<const void*>(inputs_q.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), q_lin_results_ptr,
q_lin_results_ptr, rocblas_datatype_f16_r,
rocblas_datatype_f16_r, output_lin_q_dim,
output_lin_q_dim, q_lin_results_ptr,
q_lin_results_ptr, rocblas_datatype_f16_r,
rocblas_datatype_f16_r, output_lin_q_dim,
output_lin_q_dim, rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
// Input Linear KV Fwd
// Input Linear KV Fwd TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), output_lin_kv_dim,
output_lin_kv_dim, batches_kv,
batches_kv, embed_dim,
embed_dim, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(input_weights_kv.data_ptr()),
static_cast<const void*>(input_weights_kv.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(inputs_kv.data_ptr()),
static_cast<const void*>(inputs_kv.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), k_lin_results_ptr,
k_lin_results_ptr, rocblas_datatype_f16_r,
rocblas_datatype_f16_r, output_lin_kv_dim,
output_lin_kv_dim, k_lin_results_ptr,
k_lin_results_ptr, rocblas_datatype_f16_r,
rocblas_datatype_f16_r, output_lin_kv_dim,
output_lin_kv_dim, rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) gemm_switch_fp32accum( a_layout_t,
gemm_switch_fp32accum( a_layout_t, b_layout_n,
b_layout_n, k_seq_len,
k_seq_len, q_seq_len,
q_seq_len, head_dim,
head_dim, scale,
h_scale, static_cast<const half*>(k_lin_results_ptr),
static_cast<const half*>(k_lin_results_ptr), lead_dim_kv,
lead_dim_kv, batch_stride_kv,
batch_stride_kv, static_cast<const half*>(q_lin_results_ptr),
static_cast<const half*>(q_lin_results_ptr), lead_dim_q,
lead_dim_q, batch_stride_q,
batch_stride_q, beta,
h_beta, static_cast<half*>(softmax_results_ptr),
static_cast<half*>(softmax_results_ptr), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, static_cast<half*>(softmax_results_ptr),
static_cast<half*>(softmax_results_ptr), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, attn_batches,
attn_batches, flags);
flags);
} else {
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_q_dim,
batches_q,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(inputs_q.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_q_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_q_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear KV Fwd
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_kv_dim,
batches_kv,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(inputs_kv.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
k_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_kv_dim,
k_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_kv_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
scale,
static_cast<const half*>(k_lin_results_ptr),
lead_dim_kv,
batch_stride_kv,
static_cast<const half*>(q_lin_results_ptr),
lead_dim_q,
batch_stride_q,
beta,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
}
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
...@@ -276,104 +200,53 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -276,104 +200,53 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
} }
// Matmul2 // Matmul2
if (use_fp16) { gemm_switch_fp32accum( a_layout_n,
gemm_switch_fp32accum( a_layout_n, b_layout_n,
b_layout_n, head_dim,
head_dim, q_seq_len,
q_seq_len, k_seq_len,
k_seq_len, alpha,
h_alpha, static_cast<const half*>(v_lin_results_ptr),
static_cast<const half*>(v_lin_results_ptr), lead_dim_kv,
lead_dim_kv, batch_stride_kv,
batch_stride_kv, (is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) , k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, beta,
h_beta, static_cast<half*>(matmul2_results.data_ptr()),
static_cast<half*>(matmul2_results.data_ptr()), head_dim*attn_batches,
head_dim*attn_batches, head_dim,
head_dim, static_cast<half*>(matmul2_results.data_ptr()),
static_cast<half*>(matmul2_results.data_ptr()), head_dim*attn_batches,
head_dim*attn_batches, head_dim,
head_dim, attn_batches,
attn_batches, flags);
flags);
// Output Linear
// Output Linear TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), embed_dim,
embed_dim, batches_q,
batches_q, embed_dim,
embed_dim, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(output_weights.data_ptr()),
static_cast<const void*>(output_weights.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(matmul2_results.data_ptr()),
static_cast<const void*>(matmul2_results.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), static_cast<void*>(outputs.data_ptr()),
static_cast<void*>(outputs.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
} else {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim_kv,
batch_stride_kv,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
k_seq_len,
k_seq_len*q_seq_len,
beta,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches,
flags);
// Output Linear
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_q,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(outputs.data_ptr()), static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
static_cast<void*>(outputs.data_ptr()), rocblas_datatype_f32_r,
rocblas_datatype_f16_r, rocblas_gemm_algo_standard /*algo*/,
embed_dim, 0 /*solution_index*/,
rocblas_datatype_f32_r, flags));
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
}
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_lin_q_results, return {input_lin_q_results,
...@@ -465,32 +338,57 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -465,32 +338,57 @@ std::vector<torch::Tensor> bwd_cuda(
#endif #endif
#endif #endif
if (use_fp16) { // Output Linear Dgrad
// Output Linear Dgrad TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), embed_dim,
embed_dim, batches_q,
batches_q, embed_dim,
embed_dim, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(output_weights.data_ptr()),
static_cast<const void*>(output_weights.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(output_grads.data_ptr()),
static_cast<const void*>(output_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), static_cast<void*>(output_lin_grads.data_ptr()),
static_cast<void*>(output_lin_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<void*>(output_lin_grads.data_ptr()),
static_cast<void*>(output_lin_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
embed_dim,
batches_q,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Output Linear Wgrad // Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
...@@ -680,308 +578,155 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -680,308 +578,155 @@ std::vector<torch::Tensor> bwd_cuda(
k_seq_len, attn_batches * q_seq_len); k_seq_len, attn_batches * q_seq_len);
assert(softmax_success); assert(softmax_success);
if (use_fp16) { // Matmul1 Dgrad1
// Matmul1 Dgrad1 gemm_switch_fp32accum( a_layout_n,
gemm_switch_fp32accum( a_layout_n, b_layout_n,
b_layout_n, head_dim,
head_dim, q_seq_len,
q_seq_len, k_seq_len,
k_seq_len, scale,
h_scale, k_lin_results_ptr,
k_lin_results_ptr, lead_dim_kv,
lead_dim_kv, batch_stride_kv,
batch_stride_kv, static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, beta,
h_beta, q_lin_grads_ptr,
q_lin_grads_ptr, lead_dim_q,
lead_dim_q, batch_stride_q,
batch_stride_q, q_lin_grads_ptr,
q_lin_grads_ptr, lead_dim_q,
lead_dim_q, batch_stride_q,
batch_stride_q, attn_batches,
attn_batches, flags);
flags);
// Matmul1 Dgrad2
// Matmul1 Dgrad2 gemm_switch_fp32accum( a_layout_n,
gemm_switch_fp32accum( a_layout_n, b_layout_t,
b_layout_t, head_dim,
head_dim, k_seq_len,
k_seq_len, q_seq_len,
q_seq_len, scale,
h_scale, q_lin_results_ptr,
q_lin_results_ptr, lead_dim_q,
lead_dim_q, batch_stride_q,
batch_stride_q, static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, beta,
h_beta, k_lin_grads_ptr,
k_lin_grads_ptr, lead_dim_kv,
lead_dim_kv, batch_stride_kv,
batch_stride_kv, k_lin_grads_ptr,
k_lin_grads_ptr, lead_dim_kv,
lead_dim_kv, batch_stride_kv,
batch_stride_kv, attn_batches,
attn_batches, flags);
flags);
// Input Linear Q Dgrad
// Input Linear Q Dgrad TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), embed_dim,
embed_dim, batches_q,
batches_q, output_lin_q_dim,
output_lin_q_dim, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(input_weights_q.data_ptr()),
static_cast<const void*>(input_weights_q.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(q_lin_grads_ptr),
static_cast<const void*>(q_lin_grads_ptr), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, output_lin_q_dim,
output_lin_q_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), static_cast<void*>(input_q_grads.data_ptr()),
static_cast<void*>(input_q_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<void*>(input_q_grads.data_ptr()),
static_cast<void*>(input_q_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
// Input Linear Q Wgrad
// Input Linear Q Wgrad TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_T), embed_dim,
embed_dim, output_lin_q_dim,
output_lin_q_dim, batches_q,
batches_q, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(inputs_q.data_ptr()),
static_cast<const void*>(inputs_q.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(q_lin_grads_ptr),
static_cast<const void*>(q_lin_grads_ptr), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, output_lin_q_dim,
output_lin_q_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), static_cast<void*>(input_weight_q_grads.data_ptr()),
static_cast<void*>(input_weight_q_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<void*>(input_weight_q_grads.data_ptr()),
static_cast<void*>(input_weight_q_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
// Input Linear KV Dgrad
// Input Linear KV Dgrad TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), embed_dim,
embed_dim, batches_kv,
batches_kv, output_lin_kv_dim,
output_lin_kv_dim, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(input_weights_kv.data_ptr()),
static_cast<const void*>(input_weights_kv.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(k_lin_grads_ptr),
static_cast<const void*>(k_lin_grads_ptr), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, output_lin_kv_dim,
output_lin_kv_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), static_cast<void*>(input_kv_grads.data_ptr()),
static_cast<void*>(input_kv_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<void*>(input_kv_grads.data_ptr()),
static_cast<void*>(input_kv_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
// Input Linear KV Wgrad
// Input Linear KV Wgrad TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_T), embed_dim,
embed_dim, output_lin_kv_dim,
output_lin_kv_dim, batches_kv,
batches_kv, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(inputs_kv.data_ptr()),
static_cast<const void*>(inputs_kv.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(k_lin_grads_ptr),
static_cast<const void*>(k_lin_grads_ptr), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, output_lin_kv_dim,
output_lin_kv_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), static_cast<void*>(input_weight_kv_grads.data_ptr()),
static_cast<void*>(input_weight_kv_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<void*>(input_weight_kv_grads.data_ptr()),
static_cast<void*>(input_weight_kv_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
} else {
// Matmul1 Dgrad1
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
scale,
k_lin_results_ptr,
lead_dim_kv,
batch_stride_kv,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
q_lin_grads_ptr,
lead_dim_q,
batch_stride_q,
q_lin_grads_ptr,
lead_dim_q,
batch_stride_q,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
scale,
q_lin_results_ptr,
lead_dim_q,
batch_stride_q,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
k_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
k_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
attn_batches,
flags);
// Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_q,
output_lin_q_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_q_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_q_dim,
batches_q,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_q.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_q_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_kv,
output_lin_kv_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(k_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_kv_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_kv_dim,
batches_kv,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_kv.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(k_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_kv_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
}
// TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); // TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_q_grads, input_q_grads,
......
...@@ -119,158 +119,80 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -119,158 +119,80 @@ std::vector<torch::Tensor> fwd_cuda(
1.0e-5, static_cast<const at::Half *>(lyr_nrm_gamma_weights.data_ptr()), 1.0e-5, static_cast<const at::Half *>(lyr_nrm_gamma_weights.data_ptr()),
static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr())); static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr()));
if (use_fp16) { // Input Linear Q Fwd
// Input Linear Q Fwd TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), output_lin_q_dim,
output_lin_q_dim, batches_q,
batches_q, embed_dim,
embed_dim, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(input_weights_q.data_ptr()),
static_cast<const void*>(input_weights_q.data_ptr()), rocblas_datatype_f16_r /*a_type*/,
rocblas_datatype_f16_r /*a_type*/, embed_dim,
embed_dim, //static_cast<const void*>(inputs_q.data_ptr()),
//static_cast<const void*>(inputs_q.data_ptr()), static_cast<const void*>(lyr_nrm_results.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()), rocblas_datatype_f16_r /*b_type*/,
rocblas_datatype_f16_r /*b_type*/, embed_dim,
embed_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), q_lin_results_ptr,
q_lin_results_ptr, rocblas_datatype_f16_r /*c_type*/,
rocblas_datatype_f16_r /*c_type*/, output_lin_q_dim,
output_lin_q_dim, q_lin_results_ptr,
q_lin_results_ptr, rocblas_datatype_f16_r /*d_type*/,
rocblas_datatype_f16_r /*d_type*/, output_lin_q_dim,
output_lin_q_dim, rocblas_datatype_f32_r /*compute_type*/,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
// Input Linear KV Fwd
// Input Linear KV Fwd TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), output_lin_kv_dim,
output_lin_kv_dim, batches_kv,
batches_kv, embed_dim,
embed_dim, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(input_weights_kv.data_ptr()),
static_cast<const void*>(input_weights_kv.data_ptr()), rocblas_datatype_f16_r /*a_type*/,
rocblas_datatype_f16_r /*a_type*/, embed_dim,
embed_dim, static_cast<const void*>(inputs_kv.data_ptr()),
static_cast<const void*>(inputs_kv.data_ptr()), rocblas_datatype_f16_r /*b_type*/,
rocblas_datatype_f16_r /*b_type*/, embed_dim,
embed_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), k_lin_results_ptr,
k_lin_results_ptr, rocblas_datatype_f16_r /*c_type*/,
rocblas_datatype_f16_r /*c_type*/, output_lin_kv_dim,
output_lin_kv_dim, k_lin_results_ptr,
k_lin_results_ptr, rocblas_datatype_f16_r /*d_type*/,
rocblas_datatype_f16_r /*d_type*/, output_lin_kv_dim,
output_lin_kv_dim, rocblas_datatype_f32_r /*compute_type*/,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags))); // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) gemm_switch_fp32accum( a_layout_t,
gemm_switch_fp32accum( a_layout_t, b_layout_n,
b_layout_n, k_seq_len,
k_seq_len, q_seq_len,
q_seq_len, head_dim,
head_dim, scale,
h_scale, static_cast<const half*>(k_lin_results_ptr),
static_cast<const half*>(k_lin_results_ptr), lead_dim_kv,
lead_dim_kv, batch_stride_kv,
batch_stride_kv, static_cast<const half*>(q_lin_results_ptr),
static_cast<const half*>(q_lin_results_ptr), lead_dim_q,
lead_dim_q, batch_stride_q,
batch_stride_q, beta,
h_beta, static_cast<half*>(softmax_results_ptr),
static_cast<half*>(softmax_results_ptr), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, static_cast<half*>(softmax_results_ptr),
static_cast<half*>(softmax_results_ptr), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, attn_batches,
attn_batches, flags);
flags);
} else {
// Input Linear Q Fwd
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_q_dim,
batches_q,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
//static_cast<const void*>(inputs_q.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
q_lin_results_ptr,
rocblas_datatype_f16_r /*c_type*/,
output_lin_q_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r /*d_type*/,
output_lin_q_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear KV Fwd
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_kv_dim,
batches_kv,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(inputs_kv.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
k_lin_results_ptr,
rocblas_datatype_f16_r /*c_type*/,
output_lin_kv_dim,
k_lin_results_ptr,
rocblas_datatype_f16_r /*d_type*/,
output_lin_kv_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
scale,
static_cast<const half*>(k_lin_results_ptr),
lead_dim_kv,
batch_stride_kv,
static_cast<const half*>(q_lin_results_ptr),
lead_dim_q,
batch_stride_q,
beta,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
}
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
...@@ -303,108 +225,55 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -303,108 +225,55 @@ std::vector<torch::Tensor> fwd_cuda(
(1.0f - dropout_prob)); (1.0f - dropout_prob));
} }
if (use_fp16) { // Matmul2
// Matmul2 gemm_switch_fp32accum( a_layout_n,
gemm_switch_fp32accum( a_layout_n, b_layout_n,
b_layout_n, head_dim,
head_dim, q_seq_len,
q_seq_len, k_seq_len,
k_seq_len, alpha,
h_alpha, static_cast<const half*>(v_lin_results_ptr),
static_cast<const half*>(v_lin_results_ptr), lead_dim_kv,
lead_dim_kv, batch_stride_kv,
batch_stride_kv, (is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()),
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()), //static_cast<const half*>(dropout_results.data_ptr()),
//static_cast<const half*>(dropout_results.data_ptr()), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, beta,
h_beta, static_cast<half*>(matmul2_results.data_ptr()),
static_cast<half*>(matmul2_results.data_ptr()), head_dim*attn_batches,
head_dim*attn_batches, head_dim,
head_dim, static_cast<half*>(matmul2_results.data_ptr()),
static_cast<half*>(matmul2_results.data_ptr()), head_dim*attn_batches,
head_dim*attn_batches, head_dim,
head_dim, attn_batches,
attn_batches, flags);
flags);
// Output Linear
// Output Linear TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), embed_dim,
embed_dim, batches_q,
batches_q, embed_dim,
embed_dim, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(output_weights.data_ptr()),
static_cast<const void*>(output_weights.data_ptr()), rocblas_datatype_f16_r /*a_type*/,
rocblas_datatype_f16_r /*a_type*/, embed_dim,
embed_dim, static_cast<const void*>(matmul2_results.data_ptr()),
static_cast<const void*>(matmul2_results.data_ptr()), rocblas_datatype_f16_r /*b_type*/,
rocblas_datatype_f16_r /*b_type*/, embed_dim,
embed_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), static_cast<void*>(output_lin_results.data_ptr()),
static_cast<void*>(output_lin_results.data_ptr()), rocblas_datatype_f16_r /*c_type*/,
rocblas_datatype_f16_r /*c_type*/, embed_dim,
embed_dim, static_cast<void*>(output_lin_results.data_ptr()),
static_cast<void*>(output_lin_results.data_ptr()), rocblas_datatype_f16_r /*d_type*/,
rocblas_datatype_f16_r /*d_type*/, embed_dim,
embed_dim, rocblas_datatype_f32_r /*compute_type*/,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
} else {
// Matmul2
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim_kv,
batch_stride_kv,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()),
//static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches,
flags);
// Output Linear
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_q,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
}
// End-of-block Dropout-Add // End-of-block Dropout-Add
if (is_training) { if (is_training) {
...@@ -533,32 +402,57 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -533,32 +402,57 @@ std::vector<torch::Tensor> bwd_cuda(
total_tokens_q, total_tokens_q,
(1.0 / (1.0 - dropout_prob))); (1.0 / (1.0 - dropout_prob)));
if (use_fp16) { // Output Linear Dgrad
// Output Linear Dgrad TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), embed_dim,
embed_dim, batches_q,
batches_q, embed_dim,
embed_dim, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(output_weights.data_ptr()),
static_cast<const void*>(output_weights.data_ptr()), rocblas_datatype_f16_r /*a_type*/,
rocblas_datatype_f16_r /*a_type*/, embed_dim,
embed_dim, static_cast<const void*>(dropout_add_grads.data_ptr()),
static_cast<const void*>(dropout_add_grads.data_ptr()), rocblas_datatype_f16_r /*b_type*/,
rocblas_datatype_f16_r /*b_type*/, embed_dim,
embed_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), static_cast<void*>(output_lin_grads.data_ptr()),
static_cast<void*>(output_lin_grads.data_ptr()), rocblas_datatype_f16_r /*c_type*/,
rocblas_datatype_f16_r /*c_type*/, embed_dim,
embed_dim, static_cast<void*>(output_lin_grads.data_ptr()),
static_cast<void*>(output_lin_grads.data_ptr()), rocblas_datatype_f16_r /*d_type*/,
rocblas_datatype_f16_r /*d_type*/, embed_dim,
embed_dim, rocblas_datatype_f32_r /*compute_type*/,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
embed_dim,
batches_q,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Output Linear Wgrad // Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
...@@ -749,310 +643,156 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -749,310 +643,156 @@ std::vector<torch::Tensor> bwd_cuda(
k_seq_len, attn_batches * q_seq_len); k_seq_len, attn_batches * q_seq_len);
assert(softmax_success); assert(softmax_success);
if (use_fp16) { // Matmul1 Dgrad1
// Matmul1 Dgrad1 gemm_switch_fp32accum( a_layout_n,
gemm_switch_fp32accum( a_layout_n, b_layout_n,
b_layout_n, head_dim,
head_dim, q_seq_len,
q_seq_len, k_seq_len,
k_seq_len, scale,
h_scale, k_lin_results_ptr,
k_lin_results_ptr, lead_dim_kv,
lead_dim_kv, batch_stride_kv,
batch_stride_kv, static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, beta,
h_beta, q_lin_grads_ptr,
q_lin_grads_ptr, lead_dim_q,
lead_dim_q, batch_stride_q,
batch_stride_q, q_lin_grads_ptr,
q_lin_grads_ptr, lead_dim_q,
lead_dim_q, batch_stride_q,
batch_stride_q, attn_batches,
attn_batches, flags);
flags);
// Matmul1 Dgrad2
// Matmul1 Dgrad2 gemm_switch_fp32accum( a_layout_n,
gemm_switch_fp32accum( a_layout_n, b_layout_t,
b_layout_t, head_dim,
head_dim, k_seq_len,
k_seq_len, q_seq_len,
q_seq_len, scale,
h_scale, q_lin_results_ptr,
q_lin_results_ptr, lead_dim_q,
lead_dim_q, batch_stride_q,
batch_stride_q, static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, beta,
h_beta, k_lin_grads_ptr,
k_lin_grads_ptr, lead_dim_kv,
lead_dim_kv, batch_stride_kv,
batch_stride_kv, k_lin_grads_ptr,
k_lin_grads_ptr, lead_dim_kv,
lead_dim_kv, batch_stride_kv,
batch_stride_kv, attn_batches,
attn_batches, flags);
flags);
// Input Linear Q Dgrad
// Input Linear Q Dgrad TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), embed_dim,
embed_dim, batches_q,
batches_q, output_lin_q_dim,
output_lin_q_dim, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(input_weights_q.data_ptr()),
static_cast<const void*>(input_weights_q.data_ptr()), rocblas_datatype_f16_r /*a_type*/,
rocblas_datatype_f16_r /*a_type*/, embed_dim,
embed_dim, static_cast<const void*>(q_lin_grads_ptr),
static_cast<const void*>(q_lin_grads_ptr), rocblas_datatype_f16_r /*b_type*/,
rocblas_datatype_f16_r /*b_type*/, output_lin_q_dim,
output_lin_q_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), //static_cast<void*>(input_q_grads.data_ptr()),
//static_cast<void*>(input_q_grads.data_ptr()), static_cast<void*>(input_lin_q_grads.data_ptr()),
static_cast<void*>(input_lin_q_grads.data_ptr()), rocblas_datatype_f16_r /*c_type*/,
rocblas_datatype_f16_r /*c_type*/, embed_dim,
embed_dim, static_cast<void*>(input_lin_q_grads.data_ptr()),
static_cast<void*>(input_lin_q_grads.data_ptr()), rocblas_datatype_f16_r /*d_type*/,
rocblas_datatype_f16_r /*d_type*/, embed_dim,
embed_dim, rocblas_datatype_f32_r /*compute_type*/,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
// Input Linear Q Wgrad
// Input Linear Q Wgrad TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_T), embed_dim,
embed_dim, output_lin_q_dim,
output_lin_q_dim, batches_q,
batches_q, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(inputs_q.data_ptr()),
static_cast<const void*>(inputs_q.data_ptr()), rocblas_datatype_f16_r /*a_type*/,
rocblas_datatype_f16_r /*a_type*/, embed_dim,
embed_dim, static_cast<const void*>(q_lin_grads_ptr),
static_cast<const void*>(q_lin_grads_ptr), rocblas_datatype_f16_r /*b_type*/,
rocblas_datatype_f16_r /*b_type*/, output_lin_q_dim,
output_lin_q_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), static_cast<void*>(input_weight_q_grads.data_ptr()),
static_cast<void*>(input_weight_q_grads.data_ptr()), rocblas_datatype_f16_r /*c_type*/,
rocblas_datatype_f16_r /*c_type*/, embed_dim,
embed_dim, static_cast<void*>(input_weight_q_grads.data_ptr()),
static_cast<void*>(input_weight_q_grads.data_ptr()), rocblas_datatype_f16_r /*d_type*/,
rocblas_datatype_f16_r /*d_type*/, embed_dim,
embed_dim, rocblas_datatype_f32_r /*compute_type*/,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
// Input Linear KV Dgrad
// Input Linear KV Dgrad TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), embed_dim,
embed_dim, batches_kv,
batches_kv, output_lin_kv_dim,
output_lin_kv_dim, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(input_weights_kv.data_ptr()),
static_cast<const void*>(input_weights_kv.data_ptr()), rocblas_datatype_f16_r /*a_type*/,
rocblas_datatype_f16_r /*a_type*/, embed_dim,
embed_dim, static_cast<const void*>(k_lin_grads_ptr),
static_cast<const void*>(k_lin_grads_ptr), rocblas_datatype_f16_r /*b_type*/,
rocblas_datatype_f16_r /*b_type*/, output_lin_kv_dim,
output_lin_kv_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), static_cast<void*>(input_kv_grads.data_ptr()),
static_cast<void*>(input_kv_grads.data_ptr()), rocblas_datatype_f16_r /*c_type*/,
rocblas_datatype_f16_r /*c_type*/, embed_dim,
embed_dim, static_cast<void*>(input_kv_grads.data_ptr()),
static_cast<void*>(input_kv_grads.data_ptr()), rocblas_datatype_f16_r /*d_type*/,
rocblas_datatype_f16_r /*d_type*/, embed_dim,
embed_dim, rocblas_datatype_f32_r /*compute_type*/,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
// Input Linear KV Wgrad
// Input Linear KV Wgrad TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_T), embed_dim,
embed_dim, output_lin_kv_dim,
output_lin_kv_dim, batches_kv,
batches_kv, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(inputs_kv.data_ptr()),
static_cast<const void*>(inputs_kv.data_ptr()), rocblas_datatype_f16_r /*a_type*/,
rocblas_datatype_f16_r /*a_type*/, embed_dim,
embed_dim, static_cast<const void*>(k_lin_grads_ptr),
static_cast<const void*>(k_lin_grads_ptr), rocblas_datatype_f16_r /*b_type*/,
rocblas_datatype_f16_r /*b_type*/, output_lin_kv_dim,
output_lin_kv_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), static_cast<void*>(input_weight_kv_grads.data_ptr()),
static_cast<void*>(input_weight_kv_grads.data_ptr()), rocblas_datatype_f16_r /*c_type*/,
rocblas_datatype_f16_r /*c_type*/, embed_dim,
embed_dim, static_cast<void*>(input_weight_kv_grads.data_ptr()),
static_cast<void*>(input_weight_kv_grads.data_ptr()), rocblas_datatype_f16_r /*d_type*/,
rocblas_datatype_f16_r /*d_type*/, embed_dim,
embed_dim, rocblas_datatype_f32_r /*compute_type*/,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
} else {
// Matmul1 Dgrad1
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
scale,
k_lin_results_ptr,
lead_dim_kv,
batch_stride_kv,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
q_lin_grads_ptr,
lead_dim_q,
batch_stride_q,
q_lin_grads_ptr,
lead_dim_q,
batch_stride_q,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
scale,
q_lin_results_ptr,
lead_dim_q,
batch_stride_q,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
k_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
k_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
attn_batches,
flags);
// Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_q,
output_lin_q_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_q_dim,
static_cast<const void*>(&beta),
//static_cast<void*>(input_q_grads.data_ptr()),
static_cast<void*>(input_lin_q_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_lin_q_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_q_dim,
batches_q,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_q.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_q_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_kv,
output_lin_kv_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(k_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_kv_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_kv_dim,
batches_kv,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_kv.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(k_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_kv_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
}
// Fused Layer Norm Bwd with Residual Add // Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient<half,float>( HostLayerNormGradient<half,float>(
...@@ -1080,4 +820,4 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -1080,4 +820,4 @@ std::vector<torch::Tensor> bwd_cuda(
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
} // end namespace encdec_norm_add } // end namespace encdec_norm_add
} // end namespace multihead_attn } // end namespace multihead_attn
\ No newline at end of file
...@@ -90,104 +90,53 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -90,104 +90,53 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
// Input Linear Fwd // Input Linear Fwd
input_lin_results.copy_(input_biases); input_lin_results.copy_(input_biases);
if (use_fp16) { TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), output_lin_dim,
output_lin_dim, batches,
batches, embed_dim,
embed_dim, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(input_weights.data_ptr()),
static_cast<const void*>(input_weights.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(inputs.data_ptr()),
static_cast<const void*>(inputs.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(&beta_one),
static_cast<const void*>(&h_beta_one), q_lin_results_ptr,
q_lin_results_ptr, rocblas_datatype_f16_r,
rocblas_datatype_f16_r, output_lin_dim,
output_lin_dim, q_lin_results_ptr,
q_lin_results_ptr, rocblas_datatype_f16_r,
rocblas_datatype_f16_r, output_lin_dim,
output_lin_dim, rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) gemm_switch_fp32accum( a_layout_t,
gemm_switch_fp32accum( a_layout_t, b_layout_n,
b_layout_n, k_seq_len,
k_seq_len, q_seq_len,
q_seq_len, head_dim,
head_dim, scale,
h_scale, static_cast<const half*>(k_lin_results_ptr),
static_cast<const half*>(k_lin_results_ptr), lead_dim,
lead_dim, batch_stride,
batch_stride, static_cast<const half*>(q_lin_results_ptr),
static_cast<const half*>(q_lin_results_ptr), lead_dim,
lead_dim, batch_stride,
batch_stride, beta_zero,
h_beta_zero, static_cast<half*>(bmm1_results_ptr),
static_cast<half*>(bmm1_results_ptr), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, static_cast<half*>(bmm1_results_ptr),
static_cast<half*>(bmm1_results_ptr), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, attn_batches,
attn_batches, flags);
flags);
} else {
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta_one),
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
scale,
static_cast<const half*>(k_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(q_lin_results_ptr),
lead_dim,
batch_stride,
beta_zero,
static_cast<half*>(bmm1_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(bmm1_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
}
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
...@@ -213,108 +162,55 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -213,108 +162,55 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
} }
// Matmul2 // Matmul2
if (use_fp16) { gemm_switch_fp32accum( a_layout_n,
gemm_switch_fp32accum( a_layout_n, b_layout_n,
b_layout_n, head_dim,
head_dim, q_seq_len,
q_seq_len, k_seq_len,
k_seq_len, alpha,
h_alpha, static_cast<const half*>(v_lin_results_ptr),
static_cast<const half*>(v_lin_results_ptr), lead_dim,
lead_dim, batch_stride,
batch_stride, static_cast<const half*>(dropout_results.data_ptr()),
static_cast<const half*>(dropout_results.data_ptr()), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, beta_zero,
h_beta_zero, static_cast<half*>(matmul2_results.data_ptr()),
static_cast<half*>(matmul2_results.data_ptr()), head_dim*attn_batches,
head_dim*attn_batches, head_dim,
head_dim, static_cast<half*>(matmul2_results.data_ptr()),
static_cast<half*>(matmul2_results.data_ptr()), head_dim*attn_batches,
head_dim*attn_batches, head_dim,
head_dim, attn_batches,
attn_batches, flags);
flags);
outputs.copy_(output_biases);
outputs.copy_(output_biases);
// Output Linear
// Output Linear TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), embed_dim,
embed_dim, batches,
batches, embed_dim,
embed_dim, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(output_weights.data_ptr()),
static_cast<const void*>(output_weights.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(matmul2_results.data_ptr()),
static_cast<const void*>(matmul2_results.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(&beta_one),
static_cast<const void*>(&h_beta_one), static_cast<void*>(outputs.data_ptr()),
static_cast<void*>(outputs.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<void*>(outputs.data_ptr()),
static_cast<void*>(outputs.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
} else {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta_zero,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches,
flags);
outputs.copy_(output_biases);
// Output Linear
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta_one),
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
}
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_lin_results, bmm1_results, dropout_results, return {input_lin_results, bmm1_results, dropout_results,
...@@ -392,442 +288,222 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -392,442 +288,222 @@ std::vector<torch::Tensor> bwd_cuda(
#endif #endif
// Output Linear Dgrad // Output Linear Dgrad
if (use_fp16) { TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), embed_dim,
embed_dim, batches,
batches, embed_dim,
embed_dim, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(output_weights.data_ptr()),
static_cast<const void*>(output_weights.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(output_grads.data_ptr()),
static_cast<const void*>(output_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), static_cast<void*>(output_lin_grads.data_ptr()),
static_cast<void*>(output_lin_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<void*>(output_lin_grads.data_ptr()),
static_cast<void*>(output_lin_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
// Output Linear Wgrad
// Output Linear Wgrad TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_T), embed_dim,
embed_dim, embed_dim,
embed_dim, batches,
batches, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(matmul2_results.data_ptr()),
static_cast<const void*>(matmul2_results.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(output_grads.data_ptr()),
static_cast<const void*>(output_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), static_cast<void*>(output_weight_grads.data_ptr()),
static_cast<void*>(output_weight_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<void*>(output_weight_grads.data_ptr()),
static_cast<void*>(output_weight_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false); // MatMul2 Dgrad1
// MatMul2 Dgrad1 gemm_switch_fp32accum( a_layout_t,
gemm_switch_fp32accum( a_layout_t, b_layout_n,
b_layout_n, k_seq_len,
k_seq_len, q_seq_len,
q_seq_len, head_dim,
head_dim, alpha,
h_alpha, static_cast<const half*>(v_lin_results_ptr),
static_cast<const half*>(v_lin_results_ptr), lead_dim,
lead_dim, batch_stride,
batch_stride, static_cast<const half*>(output_lin_grads.data_ptr()),
static_cast<const half*>(output_lin_grads.data_ptr()), head_dim*attn_batches,
head_dim*attn_batches, head_dim,
head_dim, beta,
h_beta, static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, attn_batches,
attn_batches, flags);
flags);
// Matmul2 Dgrad2
// Matmul2 Dgrad2 gemm_switch_fp32accum( a_layout_n,
gemm_switch_fp32accum( a_layout_n, b_layout_t,
b_layout_t, head_dim,
head_dim, k_seq_len,
k_seq_len, q_seq_len,
q_seq_len, alpha,
h_alpha, static_cast<const half*>(output_lin_grads.data_ptr()),
static_cast<const half*>(output_lin_grads.data_ptr()), head_dim*attn_batches,
head_dim*attn_batches, head_dim,
head_dim, static_cast<const half*>(dropout_results.data_ptr()),
static_cast<const half*>(dropout_results.data_ptr()), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, beta,
h_beta, v_lin_grads_ptr,
v_lin_grads_ptr, lead_dim,
lead_dim, batch_stride,
batch_stride, v_lin_grads_ptr,
v_lin_grads_ptr, lead_dim,
lead_dim, batch_stride,
batch_stride, attn_batches,
attn_batches, flags);
flags);
// Apply Dropout Mask and Scale by Dropout Probability
// Apply Dropout Mask and Scale by Dropout Probability // Softmax Grad
// Softmax Grad dispatch_masked_scale_softmax_backward_recompute<half, half, float, false>(
dispatch_masked_scale_softmax_backward_recompute<half, half, float, false>( static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half* const>(matmul2_grads.data_ptr()),
static_cast<half* const>(matmul2_grads.data_ptr()), reinterpret_cast<half const*>(bmm1_results.data_ptr()),
reinterpret_cast<half const*>(bmm1_results.data_ptr()), reinterpret_cast<half const*>(pad_mask.data_ptr()),
reinterpret_cast<half const*>(pad_mask.data_ptr()), static_cast<uint8_t const*>(dropout_mask.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()), 1.0/(1.0-dropout_prob),
1.0/(1.0-dropout_prob), k_seq_len,
k_seq_len, k_seq_len,
k_seq_len, attn_batches*q_seq_len/sequences,
attn_batches*q_seq_len/sequences, attn_batches*q_seq_len,
attn_batches*q_seq_len, stream);
stream);
// Matmul1 Dgrad1
// Matmul1 Dgrad1 gemm_switch_fp32accum( a_layout_n,
gemm_switch_fp32accum( a_layout_n, b_layout_n,
b_layout_n, head_dim,
head_dim, q_seq_len,
q_seq_len, k_seq_len,
k_seq_len, scale,
h_scale, k_lin_results_ptr,
k_lin_results_ptr, lead_dim,
lead_dim, batch_stride,
batch_stride, static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, beta,
h_beta, q_lin_grads_ptr,
q_lin_grads_ptr, lead_dim,
lead_dim, batch_stride,
batch_stride, q_lin_grads_ptr,
q_lin_grads_ptr, lead_dim,
lead_dim, batch_stride,
batch_stride, attn_batches,
attn_batches, flags);
flags);
// Matmul1 Dgrad2
// Matmul1 Dgrad2 gemm_switch_fp32accum( a_layout_n,
gemm_switch_fp32accum( a_layout_n, b_layout_t,
b_layout_t, head_dim,
head_dim, k_seq_len,
k_seq_len, q_seq_len,
q_seq_len, scale,
h_scale, q_lin_results_ptr,
q_lin_results_ptr, lead_dim,
lead_dim, batch_stride,
batch_stride, static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, beta,
h_beta, k_lin_grads_ptr,
k_lin_grads_ptr, lead_dim,
lead_dim, batch_stride,
batch_stride, k_lin_grads_ptr,
k_lin_grads_ptr, lead_dim,
lead_dim, batch_stride,
batch_stride, attn_batches,
attn_batches, flags);
flags);
// Input Linear Dgrad
// Input Linear Dgrad TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), embed_dim,
embed_dim, batches,
batches, output_lin_dim,
output_lin_dim, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(input_weights.data_ptr()),
static_cast<const void*>(input_weights.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(input_lin_output_grads.data_ptr()),
static_cast<const void*>(input_lin_output_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, output_lin_dim,
output_lin_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), static_cast<void*>(input_grads.data_ptr()),
static_cast<void*>(input_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<void*>(input_grads.data_ptr()),
static_cast<void*>(input_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
// Input Linear Wgrad
// Input Linear Wgrad TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_T), embed_dim,
embed_dim, output_lin_dim,
output_lin_dim, batches,
batches, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(inputs.data_ptr()),
static_cast<const void*>(inputs.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(q_lin_grads_ptr),
static_cast<const void*>(q_lin_grads_ptr), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, output_lin_dim,
output_lin_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), static_cast<void*>(input_weight_grads.data_ptr()),
static_cast<void*>(input_weight_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<void*>(input_weight_grads.data_ptr()),
static_cast<void*>(input_weight_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_grads, input_weight_grads, output_weight_grads,
return {input_grads, input_weight_grads, output_weight_grads, input_bias_grads, output_bias_grads};
input_bias_grads, output_bias_grads};
} else {
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
embed_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
v_lin_grads_ptr,
lead_dim,
batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward_recompute<half, half, float, false>(
static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half* const>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(bmm1_results.data_ptr()),
reinterpret_cast<half const*>(pad_mask.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len/sequences,
attn_batches*q_seq_len,
stream);
// Matmul1 Dgrad1
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
scale,
k_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
scale,
q_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
k_lin_grads_ptr,
lead_dim,
batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
output_lin_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(input_lin_output_grads.data_ptr()),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_grads, input_weight_grads, output_weight_grads,
input_bias_grads, output_bias_grads};
}
} }
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
......
...@@ -88,104 +88,53 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, ...@@ -88,104 +88,53 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
// Input Linear Fwd // Input Linear Fwd
input_lin_results.copy_(input_biases); input_lin_results.copy_(input_biases);
if (use_fp16) { TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), output_lin_dim,
output_lin_dim, batches,
batches, embed_dim,
embed_dim, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(input_weights.data_ptr()),
static_cast<const void*>(input_weights.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(inputs.data_ptr()),
static_cast<const void*>(inputs.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(&beta_one),
static_cast<const void*>(&h_beta_one), q_lin_results_ptr,
q_lin_results_ptr, rocblas_datatype_f16_r,
rocblas_datatype_f16_r, output_lin_dim,
output_lin_dim, q_lin_results_ptr,
q_lin_results_ptr, rocblas_datatype_f16_r,
rocblas_datatype_f16_r, output_lin_dim,
output_lin_dim, rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) gemm_switch_fp32accum( a_layout_t,
gemm_switch_fp32accum( a_layout_t, b_layout_n,
b_layout_n, k_seq_len,
k_seq_len, q_seq_len,
q_seq_len, head_dim,
head_dim, scale,
h_scale, static_cast<const half*>(k_lin_results_ptr),
static_cast<const half*>(k_lin_results_ptr), lead_dim,
lead_dim, batch_stride,
batch_stride, static_cast<const half*>(q_lin_results_ptr),
static_cast<const half*>(q_lin_results_ptr), lead_dim,
lead_dim, batch_stride,
batch_stride, beta_zero,
h_beta_zero, static_cast<half*>(softmax_results_ptr),
static_cast<half*>(softmax_results_ptr), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, static_cast<half*>(softmax_results_ptr),
static_cast<half*>(softmax_results_ptr), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, attn_batches,
attn_batches, flags);
flags);
} else {
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta_one),
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
scale,
static_cast<const half*>(k_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(q_lin_results_ptr),
lead_dim,
batch_stride,
beta_zero,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
}
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
...@@ -219,108 +168,55 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, ...@@ -219,108 +168,55 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
} }
// Matmul2 // Matmul2
if (use_fp16) { gemm_switch_fp32accum( a_layout_n,
gemm_switch_fp32accum( a_layout_n, b_layout_n,
b_layout_n, head_dim,
head_dim, q_seq_len,
q_seq_len, k_seq_len,
k_seq_len, alpha,
h_alpha, static_cast<const half*>(v_lin_results_ptr),
static_cast<const half*>(v_lin_results_ptr), lead_dim,
lead_dim, batch_stride,
batch_stride, (is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) , k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, beta_zero,
h_beta_zero, static_cast<half*>(matmul2_results.data_ptr()),
static_cast<half*>(matmul2_results.data_ptr()), head_dim*attn_batches,
head_dim*attn_batches, head_dim,
head_dim, static_cast<half*>(matmul2_results.data_ptr()),
static_cast<half*>(matmul2_results.data_ptr()), head_dim*attn_batches,
head_dim*attn_batches, head_dim,
head_dim, attn_batches,
attn_batches, flags);
flags);
outputs.copy_(output_biases);
outputs.copy_(output_biases);
// Output Linear
// Output Linear TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), embed_dim,
embed_dim, batches,
batches, embed_dim,
embed_dim, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(output_weights.data_ptr()),
static_cast<const void*>(output_weights.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(matmul2_results.data_ptr()),
static_cast<const void*>(matmul2_results.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(&beta_one),
static_cast<const void*>(&h_beta_one), static_cast<void*>(outputs.data_ptr()),
static_cast<void*>(outputs.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<void*>(outputs.data_ptr()),
static_cast<void*>(outputs.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
} else {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
k_seq_len,
k_seq_len*q_seq_len,
beta_zero,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches,
flags);
outputs.copy_(output_biases);
// Output Linear
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta_one),
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
}
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_lin_results, softmax_results, dropout_results, return {input_lin_results, softmax_results, dropout_results,
...@@ -398,432 +294,218 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -398,432 +294,218 @@ std::vector<torch::Tensor> bwd_cuda(
#endif #endif
// Output Linear Dgrad // Output Linear Dgrad
if (use_fp16) { TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), embed_dim,
embed_dim, batches,
batches, embed_dim,
embed_dim, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(output_weights.data_ptr()),
static_cast<const void*>(output_weights.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(output_grads.data_ptr()),
static_cast<const void*>(output_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), static_cast<void*>(output_lin_grads.data_ptr()),
static_cast<void*>(output_lin_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<void*>(output_lin_grads.data_ptr()),
static_cast<void*>(output_lin_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
// Output Linear Wgrad
// Output Linear Wgrad TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_T), embed_dim,
embed_dim, embed_dim,
embed_dim, batches,
batches, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(matmul2_results.data_ptr()),
static_cast<const void*>(matmul2_results.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(output_grads.data_ptr()),
static_cast<const void*>(output_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), static_cast<void*>(output_weight_grads.data_ptr()),
static_cast<void*>(output_weight_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<void*>(output_weight_grads.data_ptr()),
static_cast<void*>(output_weight_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false); // MatMul2 Dgrad1
// MatMul2 Dgrad1 gemm_switch_fp32accum( a_layout_t,
gemm_switch_fp32accum( a_layout_t, b_layout_n,
b_layout_n, k_seq_len,
k_seq_len, q_seq_len,
q_seq_len, head_dim,
head_dim, alpha,
h_alpha, static_cast<const half*>(v_lin_results_ptr),
static_cast<const half*>(v_lin_results_ptr), lead_dim,
lead_dim, batch_stride,
batch_stride, static_cast<const half*>(output_lin_grads.data_ptr()),
static_cast<const half*>(output_lin_grads.data_ptr()), head_dim*attn_batches,
head_dim*attn_batches, head_dim,
head_dim, beta,
h_beta, static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, attn_batches,
attn_batches, flags);
flags);
// Matmul2 Dgrad2
// Matmul2 Dgrad2 gemm_switch_fp32accum( a_layout_n,
gemm_switch_fp32accum( a_layout_n, b_layout_t,
b_layout_t, head_dim,
head_dim, k_seq_len,
k_seq_len, q_seq_len,
q_seq_len, alpha,
h_alpha, static_cast<const half*>(output_lin_grads.data_ptr()),
static_cast<const half*>(output_lin_grads.data_ptr()), head_dim*attn_batches,
head_dim*attn_batches, head_dim,
head_dim, static_cast<const half*>(dropout_results.data_ptr()),
static_cast<const half*>(dropout_results.data_ptr()), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, beta,
h_beta, v_lin_grads_ptr,
v_lin_grads_ptr, lead_dim,
lead_dim, batch_stride,
batch_stride, v_lin_grads_ptr,
v_lin_grads_ptr, lead_dim,
lead_dim, batch_stride,
batch_stride, attn_batches,
attn_batches, flags);
flags);
// Apply Dropout Mask and Scale by Dropout Probability
// Apply Dropout Mask and Scale by Dropout Probability // Softmax Grad
// Softmax Grad dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(
dispatch_masked_scale_softmax_backward_stream<half, half, float, false>( static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half *>(matmul2_grads.data_ptr()), static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half *>(matmul2_grads.data_ptr()), reinterpret_cast<half const *>(softmax_results.data_ptr()),
reinterpret_cast<half const *>(softmax_results.data_ptr()), static_cast<uint8_t const *>(dropout_mask.data_ptr()),
static_cast<uint8_t const *>(dropout_mask.data_ptr()), 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,
1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len, attn_batches * q_seq_len, stream);
attn_batches * q_seq_len, stream);
// Matmul1 Dgrad1
// Matmul1 Dgrad1 gemm_switch_fp32accum( a_layout_n,
gemm_switch_fp32accum( a_layout_n, b_layout_n,
b_layout_n, head_dim,
head_dim, q_seq_len,
q_seq_len, k_seq_len,
k_seq_len, scale,
h_scale, k_lin_results_ptr,
k_lin_results_ptr, lead_dim,
lead_dim, batch_stride,
batch_stride, static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, beta,
h_beta, q_lin_grads_ptr,
q_lin_grads_ptr, lead_dim,
lead_dim, batch_stride,
batch_stride, q_lin_grads_ptr,
q_lin_grads_ptr, lead_dim,
lead_dim, batch_stride,
batch_stride, attn_batches,
attn_batches, flags);
flags);
// Matmul1 Dgrad2
// Matmul1 Dgrad2 gemm_switch_fp32accum( a_layout_n,
gemm_switch_fp32accum( a_layout_n, b_layout_t,
b_layout_t, head_dim,
head_dim, k_seq_len,
k_seq_len, q_seq_len,
q_seq_len, scale,
h_scale, q_lin_results_ptr,
q_lin_results_ptr, lead_dim,
lead_dim, batch_stride,
batch_stride, static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, beta,
h_beta, k_lin_grads_ptr,
k_lin_grads_ptr, lead_dim,
lead_dim, batch_stride,
batch_stride, k_lin_grads_ptr,
k_lin_grads_ptr, lead_dim,
lead_dim, batch_stride,
batch_stride, attn_batches,
attn_batches, flags);
flags); // Input Linear Dgrad
// Input Linear Dgrad TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), embed_dim,
embed_dim, batches,
batches, output_lin_dim,
output_lin_dim, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(input_weights.data_ptr()),
static_cast<const void*>(input_weights.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(input_lin_output_grads.data_ptr()),
static_cast<const void*>(input_lin_output_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, output_lin_dim,
output_lin_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), static_cast<void*>(input_grads.data_ptr()),
static_cast<void*>(input_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<void*>(input_grads.data_ptr()),
static_cast<void*>(input_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
// Input Linear Wgrad
// Input Linear Wgrad TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_T), embed_dim,
embed_dim, output_lin_dim,
output_lin_dim, batches,
batches, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(inputs.data_ptr()),
static_cast<const void*>(inputs.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(q_lin_grads_ptr),
static_cast<const void*>(q_lin_grads_ptr), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, output_lin_dim,
output_lin_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), static_cast<void*>(input_weight_grads.data_ptr()),
static_cast<void*>(input_weight_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<void*>(input_weight_grads.data_ptr()),
static_cast<void*>(input_weight_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_grads, input_weight_grads, output_weight_grads,
return {input_grads, input_weight_grads, output_weight_grads, input_bias_grads, output_bias_grads};
input_bias_grads, output_bias_grads};
} else {
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
embed_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
v_lin_grads_ptr,
lead_dim,
batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(
static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const *>(dropout_mask.data_ptr()),
1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,
attn_batches * q_seq_len, stream);
// Matmul1 Dgrad1
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
scale,
k_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
scale,
q_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
k_lin_grads_ptr,
lead_dim,
batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
output_lin_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(input_lin_output_grads.data_ptr()),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_grads, input_weight_grads, output_weight_grads,
input_bias_grads, output_bias_grads};
}
} }
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
} // end namespace self } // end namespace self
} // end namespace multihead_attn } // end namespace multihead_attn
\ No newline at end of file
...@@ -85,9 +85,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -85,9 +85,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd // Input Linear Fwd
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
output_lin_dim, output_lin_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -108,7 +108,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -108,7 +108,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -188,9 +188,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -188,9 +188,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags); flags);
// Output Linear // Output Linear
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
embed_dim, embed_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -211,7 +211,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -211,7 +211,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_lin_results, softmax_results, dropout_results, return {input_lin_results, softmax_results, dropout_results,
...@@ -289,202 +289,102 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -289,202 +289,102 @@ std::vector<torch::Tensor> bwd_cuda(
#endif #endif
// Output Linear Dgrad // Output Linear Dgrad
if (use_fp16) { TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), embed_dim,
embed_dim, batches,
batches, embed_dim,
embed_dim, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(output_weights.data_ptr()),
static_cast<const void*>(output_weights.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(output_grads.data_ptr()),
static_cast<const void*>(output_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), static_cast<void*>(output_lin_grads.data_ptr()),
static_cast<void*>(output_lin_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<void*>(output_lin_grads.data_ptr()),
static_cast<void*>(output_lin_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
// Output Linear Wgrad
// Output Linear Wgrad TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_T), embed_dim,
embed_dim, embed_dim,
embed_dim, batches,
batches, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(matmul2_results.data_ptr()),
static_cast<const void*>(matmul2_results.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(output_grads.data_ptr()),
static_cast<const void*>(output_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), static_cast<void*>(output_weight_grads.data_ptr()),
static_cast<void*>(output_weight_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<void*>(output_weight_grads.data_ptr()),
static_cast<void*>(output_weight_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
// MatMul2 Dgrad1
// MatMul2 Dgrad1 gemm_switch_fp32accum( a_layout_t,
gemm_switch_fp32accum( a_layout_t, b_layout_n,
b_layout_n, k_seq_len,
k_seq_len, q_seq_len,
q_seq_len, head_dim,
head_dim, alpha,
h_alpha, static_cast<const half*>(v_lin_results_ptr),
static_cast<const half*>(v_lin_results_ptr), lead_dim,
lead_dim, batch_stride,
batch_stride, static_cast<const half*>(output_lin_grads.data_ptr()),
static_cast<const half*>(output_lin_grads.data_ptr()), head_dim*attn_batches,
head_dim*attn_batches, head_dim,
head_dim, beta,
h_beta, static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, attn_batches,
attn_batches, flags);
flags);
// Matmul2 Dgrad2
// Matmul2 Dgrad2 gemm_switch_fp32accum( a_layout_n,
gemm_switch_fp32accum( a_layout_n, b_layout_t,
b_layout_t, head_dim,
head_dim, k_seq_len,
k_seq_len, q_seq_len,
q_seq_len, alpha,
h_alpha, static_cast<const half*>(output_lin_grads.data_ptr()),
static_cast<const half*>(output_lin_grads.data_ptr()), head_dim*attn_batches,
head_dim*attn_batches, head_dim,
head_dim, static_cast<const half*>(dropout_results.data_ptr()),
static_cast<const half*>(dropout_results.data_ptr()), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, beta,
h_beta, v_lin_grads_ptr,
v_lin_grads_ptr, lead_dim,
lead_dim, batch_stride,
batch_stride, v_lin_grads_ptr,
v_lin_grads_ptr, lead_dim,
lead_dim, batch_stride,
batch_stride, attn_batches,
attn_batches, flags);
flags);
} else {
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
embed_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
v_lin_grads_ptr,
lead_dim,
batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
}
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<at::Half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
...@@ -504,202 +404,102 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -504,202 +404,102 @@ std::vector<torch::Tensor> bwd_cuda(
assert(softmax_success); assert(softmax_success);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
if (use_fp16) { gemm_switch_fp32accum( a_layout_n,
gemm_switch_fp32accum( a_layout_n, b_layout_n,
b_layout_n, head_dim,
head_dim, q_seq_len,
q_seq_len, k_seq_len,
k_seq_len, scale,
h_scale, k_lin_results_ptr,
k_lin_results_ptr, lead_dim,
lead_dim, batch_stride,
batch_stride, static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, beta,
h_beta, q_lin_grads_ptr,
q_lin_grads_ptr, lead_dim,
lead_dim, batch_stride,
batch_stride, q_lin_grads_ptr,
q_lin_grads_ptr, lead_dim,
lead_dim, batch_stride,
batch_stride, attn_batches,
attn_batches, flags);
flags);
// Matmul1 Dgrad2
// Matmul1 Dgrad2 gemm_switch_fp32accum( a_layout_n,
gemm_switch_fp32accum( a_layout_n, b_layout_t,
b_layout_t, head_dim,
head_dim, k_seq_len,
k_seq_len, q_seq_len,
q_seq_len, scale,
h_scale, q_lin_results_ptr,
q_lin_results_ptr, lead_dim,
lead_dim, batch_stride,
batch_stride, static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, beta,
h_beta, k_lin_grads_ptr,
k_lin_grads_ptr, lead_dim,
lead_dim, batch_stride,
batch_stride, k_lin_grads_ptr,
k_lin_grads_ptr, lead_dim,
lead_dim, batch_stride,
batch_stride, attn_batches,
attn_batches, flags);
flags);
// Input Linear Dgrad
// Input Linear Dgrad TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), embed_dim,
embed_dim, batches,
batches, output_lin_dim,
output_lin_dim, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(input_weights.data_ptr()),
static_cast<const void*>(input_weights.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(q_lin_grads_ptr),
static_cast<const void*>(q_lin_grads_ptr), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, output_lin_dim,
output_lin_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), static_cast<void*>(input_grads.data_ptr()),
static_cast<void*>(input_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<void*>(input_grads.data_ptr()),
static_cast<void*>(input_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
// Input Linear Wgrad
// Input Linear Wgrad TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_T), embed_dim,
embed_dim, output_lin_dim,
output_lin_dim, batches,
batches, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(inputs.data_ptr()),
static_cast<const void*>(inputs.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<const void*>(q_lin_grads_ptr),
static_cast<const void*>(q_lin_grads_ptr), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, output_lin_dim,
output_lin_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), static_cast<void*>(input_weight_grads.data_ptr()),
static_cast<void*>(input_weight_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, static_cast<void*>(input_weight_grads.data_ptr()),
static_cast<void*>(input_weight_grads.data_ptr()), rocblas_datatype_f16_r,
rocblas_datatype_f16_r, embed_dim,
embed_dim, rocblas_datatype_f32_r,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
} else {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
scale,
k_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
scale,
q_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
k_lin_grads_ptr,
lead_dim,
batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
output_lin_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
}
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
......
...@@ -106,106 +106,54 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -106,106 +106,54 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr())); static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr()));
// Input Linear Fwd // Input Linear Fwd
if (use_fp16) { TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), output_lin_dim,
output_lin_dim, batches,
batches, embed_dim,
embed_dim, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(input_weights.data_ptr()),
static_cast<const void*>(input_weights.data_ptr()), rocblas_datatype_f16_r /*a_type*/,
rocblas_datatype_f16_r /*a_type*/, embed_dim,
embed_dim, //static_cast<const void*>(inputs.data_ptr()),
//static_cast<const void*>(inputs.data_ptr()), static_cast<const void*>(lyr_nrm_results.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()), rocblas_datatype_f16_r /*b_type*/,
rocblas_datatype_f16_r /*b_type*/, embed_dim,
embed_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), q_lin_results_ptr,
q_lin_results_ptr, rocblas_datatype_f16_r /*c_type*/,
rocblas_datatype_f16_r /*c_type*/, output_lin_dim,
output_lin_dim, q_lin_results_ptr,
q_lin_results_ptr, rocblas_datatype_f16_r /*d_type*/,
rocblas_datatype_f16_r /*d_type*/, output_lin_dim,
output_lin_dim, rocblas_datatype_f32_r /*compute_type*/,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) gemm_switch_fp32accum( a_layout_t,
gemm_switch_fp32accum( a_layout_t, b_layout_n,
b_layout_n, k_seq_len,
k_seq_len, q_seq_len,
q_seq_len, head_dim,
head_dim, scale,
h_scale, static_cast<const half*>(k_lin_results_ptr),
static_cast<const half*>(k_lin_results_ptr), lead_dim,
lead_dim, batch_stride,
batch_stride, static_cast<const half*>(q_lin_results_ptr),
static_cast<const half*>(q_lin_results_ptr), lead_dim,
lead_dim, batch_stride,
batch_stride, beta,
h_beta, static_cast<half*>(softmax_results_ptr),
static_cast<half*>(softmax_results_ptr), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, static_cast<half*>(softmax_results_ptr),
static_cast<half*>(softmax_results_ptr), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, attn_batches,
attn_batches, flags);
flags);
} else {
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
//static_cast<const void*>(inputs.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
q_lin_results_ptr,
rocblas_datatype_f16_r /*c_type*/,
output_lin_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r /*d_type*/,
output_lin_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
scale,
static_cast<const half*>(k_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(q_lin_results_ptr),
lead_dim,
batch_stride,
beta,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
}
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
...@@ -239,106 +187,54 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -239,106 +187,54 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
} }
// Matmul2 // Matmul2
if (use_fp16) { gemm_switch_fp32accum( a_layout_n,
gemm_switch_fp32accum( a_layout_n, b_layout_n,
b_layout_n, head_dim,
head_dim, q_seq_len,
q_seq_len, k_seq_len,
k_seq_len, alpha,
h_alpha, static_cast<const half*>(v_lin_results_ptr),
static_cast<const half*>(v_lin_results_ptr), lead_dim,
lead_dim, batch_stride,
batch_stride, (is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) , //static_cast<const half*>(dropout_results.data_ptr()),
//static_cast<const half*>(dropout_results.data_ptr()), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, beta,
h_beta, static_cast<half*>(matmul2_results.data_ptr()),
static_cast<half*>(matmul2_results.data_ptr()), head_dim*attn_batches,
head_dim*attn_batches, head_dim,
head_dim, static_cast<half*>(matmul2_results.data_ptr()),
static_cast<half*>(matmul2_results.data_ptr()), head_dim*attn_batches,
head_dim*attn_batches, head_dim,
head_dim, attn_batches,
attn_batches, flags);
flags);
// Output Linear
// Output Linear TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), embed_dim,
embed_dim, batches,
batches, embed_dim,
embed_dim, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(output_weights.data_ptr()),
static_cast<const void*>(output_weights.data_ptr()), rocblas_datatype_f16_r /*a_type*/,
rocblas_datatype_f16_r /*a_type*/, embed_dim,
embed_dim, static_cast<const void*>(matmul2_results.data_ptr()),
static_cast<const void*>(matmul2_results.data_ptr()), rocblas_datatype_f16_r /*b_type*/,
rocblas_datatype_f16_r /*b_type*/, embed_dim,
embed_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), static_cast<void*>(output_lin_results.data_ptr()),
static_cast<void*>(output_lin_results.data_ptr()), rocblas_datatype_f16_r /*c_type*/,
rocblas_datatype_f16_r /*c_type*/, embed_dim,
embed_dim, static_cast<void*>(output_lin_results.data_ptr()),
static_cast<void*>(output_lin_results.data_ptr()), rocblas_datatype_f16_r /*d_type*/,
rocblas_datatype_f16_r /*d_type*/, embed_dim,
embed_dim, rocblas_datatype_f32_r /*compute_type*/,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
} else {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
//static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches,
flags);
// Output Linear
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
}
// End-of-block Dropout-Add // End-of-block Dropout-Add
...@@ -451,202 +347,102 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -451,202 +347,102 @@ std::vector<torch::Tensor> bwd_cuda(
(1.0 / (1.0 - dropout_prob))); (1.0 / (1.0 - dropout_prob)));
// Output Linear Dgrad // Output Linear Dgrad
if (use_fp16) { TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), embed_dim,
embed_dim, batches,
batches, embed_dim,
embed_dim, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(output_weights.data_ptr()),
static_cast<const void*>(output_weights.data_ptr()), rocblas_datatype_f16_r /*a_type*/,
rocblas_datatype_f16_r /*a_type*/, embed_dim,
embed_dim, static_cast<const void*>(dropout_add_grads.data_ptr()),
static_cast<const void*>(dropout_add_grads.data_ptr()), rocblas_datatype_f16_r /*b_type*/,
rocblas_datatype_f16_r /*b_type*/, embed_dim,
embed_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), static_cast<void*>(output_lin_grads.data_ptr()),
static_cast<void*>(output_lin_grads.data_ptr()), rocblas_datatype_f16_r /*c_type*/,
rocblas_datatype_f16_r /*c_type*/, embed_dim,
embed_dim, static_cast<void*>(output_lin_grads.data_ptr()),
static_cast<void*>(output_lin_grads.data_ptr()), rocblas_datatype_f16_r /*d_type*/,
rocblas_datatype_f16_r /*d_type*/, embed_dim,
embed_dim, rocblas_datatype_f32_r /*compute_type*/,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
// Output Linear Wgrad
// Output Linear Wgrad TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_T), embed_dim,
embed_dim, embed_dim,
embed_dim, batches,
batches, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(matmul2_results.data_ptr()),
static_cast<const void*>(matmul2_results.data_ptr()), rocblas_datatype_f16_r /*a_type*/,
rocblas_datatype_f16_r /*a_type*/, embed_dim,
embed_dim, static_cast<const void*>(dropout_add_grads.data_ptr()),
static_cast<const void*>(dropout_add_grads.data_ptr()), rocblas_datatype_f16_r /*b_type*/,
rocblas_datatype_f16_r /*b_type*/, embed_dim,
embed_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), static_cast<void*>(output_weight_grads.data_ptr()),
static_cast<void*>(output_weight_grads.data_ptr()), rocblas_datatype_f16_r /*c_type*/,
rocblas_datatype_f16_r /*c_type*/, embed_dim,
embed_dim, static_cast<void*>(output_weight_grads.data_ptr()),
static_cast<void*>(output_weight_grads.data_ptr()), rocblas_datatype_f16_r /*d_type*/,
rocblas_datatype_f16_r /*d_type*/, embed_dim,
embed_dim, rocblas_datatype_f32_r /*compute_type*/,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
// MatMul2 Dgrad1
// MatMul2 Dgrad1 gemm_switch_fp32accum( a_layout_t,
gemm_switch_fp32accum( a_layout_t, b_layout_n,
b_layout_n, k_seq_len,
k_seq_len, q_seq_len,
q_seq_len, head_dim,
head_dim, alpha,
h_alpha, static_cast<const half*>(v_lin_results_ptr),
static_cast<const half*>(v_lin_results_ptr), lead_dim,
lead_dim, batch_stride,
batch_stride, static_cast<const half*>(output_lin_grads.data_ptr()),
static_cast<const half*>(output_lin_grads.data_ptr()), head_dim*attn_batches,
head_dim*attn_batches, head_dim,
head_dim, beta,
h_beta, static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, attn_batches,
attn_batches, flags);
flags);
// Matmul2 Dgrad2
// Matmul2 Dgrad2 gemm_switch_fp32accum( a_layout_n,
gemm_switch_fp32accum( a_layout_n, b_layout_t,
b_layout_t, head_dim,
head_dim, k_seq_len,
k_seq_len, q_seq_len,
q_seq_len, alpha,
h_alpha, static_cast<const half*>(output_lin_grads.data_ptr()),
static_cast<const half*>(output_lin_grads.data_ptr()), head_dim*attn_batches,
head_dim*attn_batches, head_dim,
head_dim, static_cast<const half*>(dropout_results.data_ptr()),
static_cast<const half*>(dropout_results.data_ptr()), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, beta,
h_beta, v_lin_grads_ptr,
v_lin_grads_ptr, lead_dim,
lead_dim, batch_stride,
batch_stride, v_lin_grads_ptr,
v_lin_grads_ptr, lead_dim,
lead_dim, batch_stride,
batch_stride, attn_batches,
attn_batches, flags);
flags);
} else {
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
embed_dim,
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
v_lin_grads_ptr,
lead_dim,
batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
}
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<at::Half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
...@@ -666,206 +462,104 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -666,206 +462,104 @@ std::vector<torch::Tensor> bwd_cuda(
assert(softmax_success); assert(softmax_success);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
if (use_fp16) { gemm_switch_fp32accum( a_layout_n,
gemm_switch_fp32accum( a_layout_n, b_layout_n,
b_layout_n, head_dim,
head_dim, q_seq_len,
q_seq_len, k_seq_len,
k_seq_len, scale,
h_scale, k_lin_results_ptr,
k_lin_results_ptr, lead_dim,
lead_dim, batch_stride,
batch_stride, static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, beta,
h_beta, q_lin_grads_ptr,
q_lin_grads_ptr, lead_dim,
lead_dim, batch_stride,
batch_stride, q_lin_grads_ptr,
q_lin_grads_ptr, lead_dim,
lead_dim, batch_stride,
batch_stride, attn_batches,
attn_batches, flags);
flags);
// Matmul1 Dgrad2
// Matmul1 Dgrad2 gemm_switch_fp32accum( a_layout_n,
gemm_switch_fp32accum( a_layout_n, b_layout_t,
b_layout_t, head_dim,
head_dim, k_seq_len,
k_seq_len, q_seq_len,
q_seq_len, scale,
h_scale, q_lin_results_ptr,
q_lin_results_ptr, lead_dim,
lead_dim, batch_stride,
batch_stride, static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len,
k_seq_len, k_seq_len*q_seq_len,
k_seq_len*q_seq_len, beta,
h_beta, k_lin_grads_ptr,
k_lin_grads_ptr, lead_dim,
lead_dim, batch_stride,
batch_stride, k_lin_grads_ptr,
k_lin_grads_ptr, lead_dim,
lead_dim, batch_stride,
batch_stride, attn_batches,
attn_batches, flags);
flags);
// Input Linear Dgrad
// Input Linear Dgrad TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), embed_dim,
embed_dim, batches,
batches, output_lin_dim,
output_lin_dim, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), static_cast<const void*>(input_weights.data_ptr()),
static_cast<const void*>(input_weights.data_ptr()), rocblas_datatype_f16_r /*a_type*/,
rocblas_datatype_f16_r /*a_type*/, embed_dim,
embed_dim, static_cast<const void*>(q_lin_grads_ptr),
static_cast<const void*>(q_lin_grads_ptr), rocblas_datatype_f16_r /*b_type*/,
rocblas_datatype_f16_r /*b_type*/, output_lin_dim,
output_lin_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), //static_cast<void*>(input_grads.data_ptr()),
//static_cast<void*>(input_grads.data_ptr()), static_cast<void*>(input_lin_grads.data_ptr()),
static_cast<void*>(input_lin_grads.data_ptr()), rocblas_datatype_f16_r /*c_type*/,
rocblas_datatype_f16_r /*c_type*/, embed_dim,
embed_dim, static_cast<void*>(input_lin_grads.data_ptr()),
static_cast<void*>(input_lin_grads.data_ptr()), rocblas_datatype_f16_r /*d_type*/,
rocblas_datatype_f16_r /*d_type*/, embed_dim,
embed_dim, rocblas_datatype_f32_r /*compute_type*/,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
// Input Linear Wgrad
// Input Linear Wgrad TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_T), embed_dim,
embed_dim, output_lin_dim,
output_lin_dim, batches,
batches, static_cast<const void*>(&alpha),
static_cast<const void*>(&h_alpha), //static_cast<const void*>(inputs.data_ptr()),
//static_cast<const void*>(inputs.data_ptr()), static_cast<const void*>(lyr_nrm_results.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()), rocblas_datatype_f16_r /*a_type*/,
rocblas_datatype_f16_r /*a_type*/, embed_dim,
embed_dim, static_cast<const void*>(q_lin_grads_ptr),
static_cast<const void*>(q_lin_grads_ptr), rocblas_datatype_f16_r /*b_type*/,
rocblas_datatype_f16_r /*b_type*/, output_lin_dim,
output_lin_dim, static_cast<const void*>(&beta),
static_cast<const void*>(&h_beta), static_cast<void*>(input_weight_grads.data_ptr()),
static_cast<void*>(input_weight_grads.data_ptr()), rocblas_datatype_f16_r /*c_type*/,
rocblas_datatype_f16_r /*c_type*/, embed_dim,
embed_dim, static_cast<void*>(input_weight_grads.data_ptr()),
static_cast<void*>(input_weight_grads.data_ptr()), rocblas_datatype_f16_r /*d_type*/,
rocblas_datatype_f16_r /*d_type*/, embed_dim,
embed_dim, rocblas_datatype_f32_r /*compute_type*/,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/, rocblas_gemm_algo_standard /*algo*/,
rocblas_gemm_algo_standard /*algo*/, 0 /*solution_index*/,
0 /*solution_index*/, flags));
flags)));
} else {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
scale,
k_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
scale,
q_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
beta,
k_lin_grads_ptr,
lead_dim,
batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
output_lin_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_dim,
static_cast<const void*>(&beta),
//static_cast<void*>(input_grads.data_ptr()),
static_cast<void*>(input_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_dim,
batches,
static_cast<const void*>(&alpha),
//static_cast<const void*>(inputs.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
}
// Fused Layer Norm Bwd with Residual Add // Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient<half, float>( HostLayerNormGradient<half, float>(
...@@ -889,4 +583,4 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -889,4 +583,4 @@ std::vector<torch::Tensor> bwd_cuda(
} // end namespace rocblas_gemmex } // end namespace rocblas_gemmex
} // end namespace self_norm_add } // end namespace self_norm_add
} // end namespace multihead_attn } // end namespace multihead_attn
\ No newline at end of file
...@@ -7,8 +7,6 @@ ...@@ -7,8 +7,6 @@
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <rocblas/rocblas.h>
//#include <ATen/ATen.h> //#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
...@@ -47,52 +45,6 @@ cublasOperation_t convertTransToCublasOperation(char trans) { ...@@ -47,52 +45,6 @@ cublasOperation_t convertTransToCublasOperation(char trans) {
} }
} }
// needed to work around calling rocblas API instead of hipblas API
static rocblas_operation hipOperationToRocOperation(hipblasOperation_t op)
{
switch(op)
{
case HIPBLAS_OP_N:
return rocblas_operation_none;
case HIPBLAS_OP_T:
return rocblas_operation_transpose;
case HIPBLAS_OP_C:
return rocblas_operation_conjugate_transpose;
}
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
}
static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error)
{
switch(error)
{
case rocblas_status_size_unchanged:
case rocblas_status_size_increased:
case rocblas_status_success:
case rocblas_status_continue:
return HIPBLAS_STATUS_SUCCESS;
case rocblas_status_invalid_handle:
return HIPBLAS_STATUS_NOT_INITIALIZED;
case rocblas_status_not_implemented:
case rocblas_status_excluded_from_build:
return HIPBLAS_STATUS_NOT_SUPPORTED;
case rocblas_status_invalid_pointer:
case rocblas_status_invalid_size:
case rocblas_status_invalid_value:
case rocblas_status_size_query_mismatch:
return HIPBLAS_STATUS_INVALID_VALUE;
case rocblas_status_memory_error:
return HIPBLAS_STATUS_ALLOC_FAILED;
case rocblas_status_internal_error:
case rocblas_status_perf_degraded:
case rocblas_status_check_numerics_fail:
return HIPBLAS_STATUS_INTERNAL_ERROR;
case rocblas_status_arch_mismatch:
return HIPBLAS_STATUS_ARCH_MISMATCH;
}
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
}
void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k, void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB, float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo, rocblas_int flags) { float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo, rocblas_int flags) {
...@@ -105,13 +57,13 @@ void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k, ...@@ -105,13 +57,13 @@ void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
float fAlpha = alpha; float fAlpha = alpha;
float fBeta = beta; float fBeta = beta;
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_strided_batched_ex((rocblas_handle)handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle,
hipOperationToRocOperation(opa), hipOperationToRocOperation(opb), (int)m, (int)n, (int)k, opa, opb, (int)m, (int)n, (int)k,
(void*)&fAlpha, a, rocblas_datatype_f16_r /*a_type*/, (int)lda, strideA, (void*)&fAlpha, a, rocblas_datatype_f16_r /*a_type*/, (int)lda, strideA,
b, rocblas_datatype_f16_r /*b_type*/, (int)ldb, strideB, b, rocblas_datatype_f16_r /*b_type*/, (int)ldb, strideB,
(void*)&fBeta, c, rocblas_datatype_f16_r /*c_type*/, (int)ldc, strideC, (void*)&fBeta, c, rocblas_datatype_f16_r /*c_type*/, (int)ldc, strideC,
d, rocblas_datatype_f16_r /*d_type*/, int(ldd), strideD, d, rocblas_datatype_f16_r /*d_type*/, int(ldd), strideD,
(int)batchCount, rocblas_datatype_f32_r /*compute_type*/, algo, 0 /*solution_index*/, flags))); (int)batchCount, rocblas_datatype_f32_r /*compute_type*/, algo, 0 /*solution_index*/, flags));
} }
void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k, void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k,
......
...@@ -10,22 +10,10 @@ ...@@ -10,22 +10,10 @@
#include <cublas_v2.h> #include <cublas_v2.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include "utils.h"
#include <rocblas/rocblas.h>
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000 #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
// includes cublaslt // includes cublaslt
#include <cublasLt.h> #include <cublasLt.h>
#endif #endif
// until we use hipblas v2
// hipify correctly maps things like CUDA_R_16F to HIP_R_16F,
// however hipblas v1 is still using its custom type
#define HIP_R_64F HIPBLAS_R_64F
#define HIP_R_32F HIPBLAS_R_32F
#define HIP_R_16F HIPBLAS_R_16F
// FP64 Wrapper around cublas GEMMEx // FP64 Wrapper around cublas GEMMEx
cublasStatus_t gemm_bias( cublasStatus_t gemm_bias(
cublasHandle_t handle, cublasHandle_t handle,
...@@ -42,6 +30,33 @@ cublasStatus_t gemm_bias( ...@@ -42,6 +30,33 @@ cublasStatus_t gemm_bias(
const float* beta, const float* beta,
double* C, double* C,
int ldc) { int ldc) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f64_r,
lda,
B,
rocblas_datatype_f64_r,
ldb,
beta,
C,
rocblas_datatype_f64_r,
ldc,
C,
rocblas_datatype_f64_r,
ldc,
rocblas_datatype_f64_r,
rocblas_gemm_algo_standard,
0,
0);
#else
return cublasGemmEx( return cublasGemmEx(
handle, handle,
transa, transa,
...@@ -62,6 +77,7 @@ cublasStatus_t gemm_bias( ...@@ -62,6 +77,7 @@ cublasStatus_t gemm_bias(
ldc, ldc,
CUDA_R_64F, CUDA_R_64F,
CUBLAS_GEMM_DEFAULT); CUBLAS_GEMM_DEFAULT);
#endif
} }
// FP32 Wrapper around cublas GEMMEx // FP32 Wrapper around cublas GEMMEx
...@@ -80,6 +96,34 @@ cublasStatus_t gemm_bias( ...@@ -80,6 +96,34 @@ cublasStatus_t gemm_bias(
const float* beta, const float* beta,
float* C, float* C,
int ldc) { int ldc) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f32_r,
lda,
B,
rocblas_datatype_f32_r,
ldb,
beta,
C,
rocblas_datatype_f32_r,
ldc,
C,
rocblas_datatype_f32_r,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
0);
#else
return cublasGemmEx( return cublasGemmEx(
handle, handle,
transa, transa,
...@@ -100,6 +144,7 @@ cublasStatus_t gemm_bias( ...@@ -100,6 +144,7 @@ cublasStatus_t gemm_bias(
ldc, ldc,
CUDA_R_32F, CUDA_R_32F,
CUBLAS_GEMM_DEFAULT); CUBLAS_GEMM_DEFAULT);
#endif
} }
// FP16 Tensor core wrapper around cublas GEMMEx // FP16 Tensor core wrapper around cublas GEMMEx
...@@ -118,6 +163,7 @@ cublasStatus_t gemm_bias( ...@@ -118,6 +163,7 @@ cublasStatus_t gemm_bias(
const float* beta, const float* beta,
at::Half* C, at::Half* C,
int ldc) { int ldc) {
<<<<<<< HEAD
if (parseEnvVarFlag("APEX_ROCBLAS_GEMM_ALLOW_HALF")) { if (parseEnvVarFlag("APEX_ROCBLAS_GEMM_ALLOW_HALF")) {
half h_alpha = __float2half(*alpha); half h_alpha = __float2half(*alpha);
half h_beta = __float2half(*beta); half h_beta = __float2half(*beta);
...@@ -163,6 +209,56 @@ cublasStatus_t gemm_bias( ...@@ -163,6 +209,56 @@ cublasStatus_t gemm_bias(
CUDA_R_32F, CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP); CUBLAS_GEMM_DEFAULT_TENSOR_OP);
} }
=======
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f16_r,
lda,
B,
rocblas_datatype_f16_r,
ldb,
beta,
C,
rocblas_datatype_f16_r,
ldc,
C,
rocblas_datatype_f16_r,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
0);
#else
return cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_16F,
lda,
B,
CUDA_R_16F,
ldb,
beta,
C,
CUDA_R_16F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
>>>>>>> mirror/master
} }
......
...@@ -13,8 +13,6 @@ ...@@ -13,8 +13,6 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include "utils.h" #include "utils.h"
#include <rocblas/rocblas.h>
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000 #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
// includes cublaslt // includes cublaslt
#include <cublasLt.h> #include <cublasLt.h>
...@@ -62,52 +60,6 @@ __device__ __inline__ float sigmoid(float a) { ...@@ -62,52 +60,6 @@ __device__ __inline__ float sigmoid(float a) {
return (retf); return (retf);
} }
// needed to work around calling rocblas API instead of hipblas API
static rocblas_operation hipOperationToRocOperation(hipblasOperation_t op)
{
switch(op)
{
case HIPBLAS_OP_N:
return rocblas_operation_none;
case HIPBLAS_OP_T:
return rocblas_operation_transpose;
case HIPBLAS_OP_C:
return rocblas_operation_conjugate_transpose;
}
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
}
static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error)
{
switch(error)
{
case rocblas_status_size_unchanged:
case rocblas_status_size_increased:
case rocblas_status_success:
case rocblas_status_continue:
return HIPBLAS_STATUS_SUCCESS;
case rocblas_status_invalid_handle:
return HIPBLAS_STATUS_NOT_INITIALIZED;
case rocblas_status_not_implemented:
case rocblas_status_excluded_from_build:
return HIPBLAS_STATUS_NOT_SUPPORTED;
case rocblas_status_invalid_pointer:
case rocblas_status_invalid_size:
case rocblas_status_invalid_value:
case rocblas_status_size_query_mismatch:
return HIPBLAS_STATUS_INVALID_VALUE;
case rocblas_status_memory_error:
return HIPBLAS_STATUS_ALLOC_FAILED;
case rocblas_status_internal_error:
case rocblas_status_perf_degraded:
case rocblas_status_check_numerics_fail:
return HIPBLAS_STATUS_INTERNAL_ERROR;
case rocblas_status_arch_mismatch:
return HIPBLAS_STATUS_ARCH_MISMATCH;
}
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
}
// FP64 Wrapper around cublas GEMMEx // FP64 Wrapper around cublas GEMMEx
cublasStatus_t mlp_gemm( cublasStatus_t mlp_gemm(
cublasHandle_t handle, cublasHandle_t handle,
...@@ -126,10 +78,10 @@ cublasStatus_t mlp_gemm( ...@@ -126,10 +78,10 @@ cublasStatus_t mlp_gemm(
int ldc, int ldc,
int flag) { int flag) {
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
return rocBLASStatusToHIPStatus(rocblas_gemm_ex( return rocblas_gemm_ex(
(rocblas_handle) handle, handle,
hipOperationToRocOperation(transa), transa,
hipOperationToRocOperation(transb), transb,
m, m,
n, n,
k, k,
...@@ -150,7 +102,7 @@ cublasStatus_t mlp_gemm( ...@@ -150,7 +102,7 @@ cublasStatus_t mlp_gemm(
rocblas_datatype_f64_r, rocblas_datatype_f64_r,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0,
flag)); flag);
#else #else
return cublasGemmEx( return cublasGemmEx(
handle, handle,
...@@ -193,10 +145,10 @@ cublasStatus_t mlp_gemm( ...@@ -193,10 +145,10 @@ cublasStatus_t mlp_gemm(
int ldc, int ldc,
int flag) { int flag) {
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
return rocBLASStatusToHIPStatus(rocblas_gemm_ex( return rocblas_gemm_ex(
(rocblas_handle) handle, handle,
hipOperationToRocOperation(transa), transa,
hipOperationToRocOperation(transb), transb,
m, m,
n, n,
k, k,
...@@ -217,7 +169,7 @@ cublasStatus_t mlp_gemm( ...@@ -217,7 +169,7 @@ cublasStatus_t mlp_gemm(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0,
flag)); flag);
#else #else
return cublasGemmEx( return cublasGemmEx(
...@@ -261,61 +213,31 @@ cublasStatus_t mlp_gemm( ...@@ -261,61 +213,31 @@ cublasStatus_t mlp_gemm(
int ldc, int ldc,
int flag) { int flag) {
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
if (parseEnvVarFlag("APEX_ROCBLAS_GEMM_ALLOW_HALF")) { return rocblas_gemm_ex(
half h_alpha = __float2half(*alpha); handle,
half h_beta = __float2half(*beta); transa,
return rocBLASStatusToHIPStatus(rocblas_gemm_ex( transb,
(rocblas_handle) handle, m,
hipOperationToRocOperation(transa), n,
hipOperationToRocOperation(transb), k,
m, alpha,
n, A,
k, rocblas_datatype_f16_r,
/* alpha */ &h_alpha, lda,
A, B,
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
lda, ldb,
B, beta,
rocblas_datatype_f16_r, C,
ldb, rocblas_datatype_f16_r,
/* beta */ &h_beta, ldc,
C, C,
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
ldc, ldc,
C, rocblas_datatype_f32_r,
rocblas_datatype_f16_r, rocblas_gemm_algo_standard,
ldc, 0,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, flag);
rocblas_gemm_algo_standard,
0,
flag);
} else {
return rocBLASStatusToHIPStatus(rocblas_gemm_ex(
(rocblas_handle) handle,
hipOperationToRocOperation(transa),
hipOperationToRocOperation(transb),
m,
n,
k,
alpha,
A,
rocblas_datatype_f16_r,
lda,
B,
rocblas_datatype_f16_r,
ldb,
beta,
C,
rocblas_datatype_f16_r,
ldc,
C,
rocblas_datatype_f16_r,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
flag);
}
#else #else
return cublasGemmEx( return cublasGemmEx(
handle, handle,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment