Commit 4e7a2a8e authored by flyingdown's avatar flyingdown
Browse files

fix up for torch2.1

parent 2a4864d5
...@@ -85,12 +85,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -85,12 +85,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
rocblas_int flags = 0; int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Q Fwd // Input Linear Q Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_q_dim, output_lin_q_dim,
...@@ -98,25 +98,21 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -98,25 +98,21 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()), static_cast<const void*>(input_weights_q.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(inputs_q.data_ptr()), static_cast<const void*>(inputs_q.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
q_lin_results_ptr, q_lin_results_ptr,
rocblas_datatype_f16_r, HIPBLAS_R_16F,
output_lin_q_dim, output_lin_q_dim,
q_lin_results_ptr, HIPBLAS_R_32F,
rocblas_datatype_f16_r, HIPBLAS_GEMM_DEFAULT /*algo*/
output_lin_q_dim, ));
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear KV Fwd // Input Linear KV Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_kv_dim, output_lin_kv_dim,
...@@ -124,22 +120,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -124,22 +120,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()), static_cast<const void*>(input_weights_kv.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(inputs_kv.data_ptr()), static_cast<const void*>(inputs_kv.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
k_lin_results_ptr, k_lin_results_ptr,
rocblas_datatype_f16_r, HIPBLAS_R_16F,
output_lin_kv_dim, output_lin_kv_dim,
k_lin_results_ptr, HIPBLAS_R_32F,
rocblas_datatype_f16_r, HIPBLAS_GEMM_DEFAULT /*algo*/
output_lin_kv_dim, ));
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -219,7 +211,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -219,7 +211,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags); flags);
// Output Linear // Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -227,22 +219,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -227,22 +219,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(outputs.data_ptr()), static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, HIPBLAS_R_32F,
rocblas_gemm_algo_standard /*algo*/, HIPBLAS_GEMM_DEFAULT /*algo*/
0 /*solution_index*/, ));
flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_lin_q_results, return {input_lin_q_results,
...@@ -318,7 +306,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -318,7 +306,7 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
rocblas_int flags = 0; int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
...@@ -332,7 +320,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -332,7 +320,7 @@ std::vector<torch::Tensor> bwd_cuda(
#endif #endif
// Output Linear Dgrad // Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -340,25 +328,21 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -340,25 +328,21 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(output_grads.data_ptr()), static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()), static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()), HIPBLAS_R_32F,
rocblas_datatype_f16_r, HIPBLAS_GEMM_DEFAULT /*algo*/
embed_dim, ));
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Output Linear Wgrad // Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -366,22 +350,18 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -366,22 +350,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q, batches_q,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(output_grads.data_ptr()), static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()), static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, HIPBLAS_R_32F,
rocblas_gemm_algo_standard /*algo*/, HIPBLAS_GEMM_DEFAULT /*algo*/
0 /*solution_index*/, ));
flags));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -493,7 +473,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -493,7 +473,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags); flags);
// Input Linear Q Dgrad // Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -501,25 +481,21 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -501,25 +481,21 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_q_dim, output_lin_q_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()), static_cast<const void*>(input_weights_q.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(q_lin_grads_ptr), static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
output_lin_q_dim, output_lin_q_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_q_grads.data_ptr()), static_cast<void*>(input_q_grads.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<void*>(input_q_grads.data_ptr()), HIPBLAS_R_32F,
rocblas_datatype_f16_r, HIPBLAS_GEMM_DEFAULT /*algo*/
embed_dim, ));
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear Q Wgrad // Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -527,25 +503,21 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -527,25 +503,21 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q, batches_q,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_q.data_ptr()), static_cast<const void*>(inputs_q.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(q_lin_grads_ptr), static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
output_lin_q_dim, output_lin_q_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_weight_q_grads.data_ptr()), static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim,
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, HIPBLAS_R_32F,
rocblas_gemm_algo_standard /*algo*/, HIPBLAS_GEMM_DEFAULT /*algo*/
0 /*solution_index*/, ));
flags));
// Input Linear KV Dgrad // Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -553,25 +525,21 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -553,25 +525,21 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_kv_dim, output_lin_kv_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()), static_cast<const void*>(input_weights_kv.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(k_lin_grads_ptr), static_cast<const void*>(k_lin_grads_ptr),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
output_lin_kv_dim, output_lin_kv_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_kv_grads.data_ptr()), static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<void*>(input_kv_grads.data_ptr()), HIPBLAS_R_32F,
rocblas_datatype_f16_r, HIPBLAS_GEMM_DEFAULT /*algo*/
embed_dim, ));
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear KV Wgrad // Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -579,22 +547,18 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -579,22 +547,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches_kv, batches_kv,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_kv.data_ptr()), static_cast<const void*>(inputs_kv.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(k_lin_grads_ptr), static_cast<const void*>(k_lin_grads_ptr),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
output_lin_kv_dim, output_lin_kv_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_weight_kv_grads.data_ptr()), static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim,
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, HIPBLAS_R_32F,
rocblas_gemm_algo_standard /*algo*/, HIPBLAS_GEMM_DEFAULT /*algo*/
0 /*solution_index*/, ));
flags));
// TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); // TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_q_grads, input_q_grads,
......
...@@ -101,7 +101,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -101,7 +101,7 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
rocblas_int flags = 0; int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm // Layer Norm
...@@ -116,7 +116,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -116,7 +116,7 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr())); static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr()));
// Input Linear Q Fwd // Input Linear Q Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_q_dim, output_lin_q_dim,
...@@ -124,26 +124,22 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -124,26 +124,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()), static_cast<const void*>(input_weights_q.data_ptr()),
rocblas_datatype_f16_r /*a_type*/, HIPBLAS_R_16F /*a_type*/,
embed_dim, embed_dim,
//static_cast<const void*>(inputs_q.data_ptr()), //static_cast<const void*>(inputs_q.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()), static_cast<const void*>(lyr_nrm_results.data_ptr()),
rocblas_datatype_f16_r /*b_type*/, HIPBLAS_R_16F /*b_type*/,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
q_lin_results_ptr, q_lin_results_ptr,
rocblas_datatype_f16_r /*c_type*/, HIPBLAS_R_16F /*c_type*/,
output_lin_q_dim, output_lin_q_dim,
q_lin_results_ptr, HIPBLAS_R_32F /*compute_type*/,
rocblas_datatype_f16_r /*d_type*/, HIPBLAS_GEMM_DEFAULT /*algo*/
output_lin_q_dim, ));
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear KV Fwd // Input Linear KV Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_kv_dim, output_lin_kv_dim,
...@@ -151,22 +147,18 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -151,22 +147,18 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()), static_cast<const void*>(input_weights_kv.data_ptr()),
rocblas_datatype_f16_r /*a_type*/, HIPBLAS_R_16F /*a_type*/,
embed_dim, embed_dim,
static_cast<const void*>(inputs_kv.data_ptr()), static_cast<const void*>(inputs_kv.data_ptr()),
rocblas_datatype_f16_r /*b_type*/, HIPBLAS_R_16F /*b_type*/,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
k_lin_results_ptr, k_lin_results_ptr,
rocblas_datatype_f16_r /*c_type*/, HIPBLAS_R_16F /*c_type*/,
output_lin_kv_dim, output_lin_kv_dim,
k_lin_results_ptr, HIPBLAS_R_32F /*compute_type*/,
rocblas_datatype_f16_r /*d_type*/, HIPBLAS_GEMM_DEFAULT /*algo*/
output_lin_kv_dim, ));
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
b_layout_n, b_layout_n,
...@@ -246,7 +238,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -246,7 +238,7 @@ std::vector<torch::Tensor> fwd_cuda(
flags); flags);
// Output Linear // Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -254,22 +246,18 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -254,22 +246,18 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/, HIPBLAS_R_16F /*a_type*/,
embed_dim, embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r /*b_type*/, HIPBLAS_R_16F /*b_type*/,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_lin_results.data_ptr()), static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*c_type*/, HIPBLAS_R_16F /*c_type*/,
embed_dim, embed_dim,
static_cast<void*>(output_lin_results.data_ptr()), HIPBLAS_R_32F /*compute_type*/,
rocblas_datatype_f16_r /*d_type*/, HIPBLAS_GEMM_DEFAULT /*algo*/
embed_dim, ));
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// End-of-block Dropout-Add // End-of-block Dropout-Add
if (is_training) { if (is_training) {
...@@ -374,7 +362,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -374,7 +362,7 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
rocblas_int flags = 0; int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
...@@ -396,7 +384,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -396,7 +384,7 @@ std::vector<torch::Tensor> bwd_cuda(
(1.0 / (1.0 - dropout_prob))); (1.0 / (1.0 - dropout_prob)));
// Output Linear Dgrad // Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -404,25 +392,21 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -404,25 +392,21 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/, HIPBLAS_R_16F /*a_type*/,
embed_dim, embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()), static_cast<const void*>(dropout_add_grads.data_ptr()),
rocblas_datatype_f16_r /*b_type*/, HIPBLAS_R_16F /*b_type*/,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()), static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/, HIPBLAS_R_16F /*c_type*/,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim, embed_dim,
rocblas_datatype_f32_r /*compute_type*/, HIPBLAS_R_32F /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, HIPBLAS_GEMM_DEFAULT /*algo*/
0 /*solution_index*/, ));
flags));
// Output Linear Wgrad // Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -430,22 +414,18 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -430,22 +414,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q, batches_q,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r /*a_type*/, HIPBLAS_R_16F /*a_type*/,
embed_dim, embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()), static_cast<const void*>(dropout_add_grads.data_ptr()),
rocblas_datatype_f16_r /*b_type*/, HIPBLAS_R_16F /*b_type*/,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()), static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/, HIPBLAS_R_16F /*c_type*/,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim, embed_dim,
rocblas_datatype_f32_r /*compute_type*/, HIPBLAS_R_32F /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, HIPBLAS_GEMM_DEFAULT /*algo*/
0 /*solution_index*/, ));
flags));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -557,7 +537,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -557,7 +537,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags); flags);
// Input Linear Q Dgrad // Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -565,26 +545,22 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -565,26 +545,22 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_q_dim, output_lin_q_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()), static_cast<const void*>(input_weights_q.data_ptr()),
rocblas_datatype_f16_r /*a_type*/, HIPBLAS_R_16F /*a_type*/,
embed_dim, embed_dim,
static_cast<const void*>(q_lin_grads_ptr), static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/, HIPBLAS_R_16F /*b_type*/,
output_lin_q_dim, output_lin_q_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
//static_cast<void*>(input_q_grads.data_ptr()), //static_cast<void*>(input_q_grads.data_ptr()),
static_cast<void*>(input_lin_q_grads.data_ptr()), static_cast<void*>(input_lin_q_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/, HIPBLAS_R_16F /*c_type*/,
embed_dim,
static_cast<void*>(input_lin_q_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim, embed_dim,
rocblas_datatype_f32_r /*compute_type*/, HIPBLAS_R_32F /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, HIPBLAS_GEMM_DEFAULT /*algo*/
0 /*solution_index*/, ));
flags));
// Input Linear Q Wgrad // Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -592,25 +568,21 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -592,25 +568,21 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q, batches_q,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_q.data_ptr()), static_cast<const void*>(inputs_q.data_ptr()),
rocblas_datatype_f16_r /*a_type*/, HIPBLAS_R_16F /*a_type*/,
embed_dim, embed_dim,
static_cast<const void*>(q_lin_grads_ptr), static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/, HIPBLAS_R_16F /*b_type*/,
output_lin_q_dim, output_lin_q_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_weight_q_grads.data_ptr()), static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/, HIPBLAS_R_16F /*c_type*/,
embed_dim,
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim, embed_dim,
rocblas_datatype_f32_r /*compute_type*/, HIPBLAS_R_32F /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, HIPBLAS_GEMM_DEFAULT /*algo*/
0 /*solution_index*/, ));
flags));
// Input Linear KV Dgrad // Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -618,25 +590,21 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -618,25 +590,21 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_kv_dim, output_lin_kv_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()), static_cast<const void*>(input_weights_kv.data_ptr()),
rocblas_datatype_f16_r /*a_type*/, HIPBLAS_R_16F /*a_type*/,
embed_dim, embed_dim,
static_cast<const void*>(k_lin_grads_ptr), static_cast<const void*>(k_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/, HIPBLAS_R_16F /*b_type*/,
output_lin_kv_dim, output_lin_kv_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_kv_grads.data_ptr()), static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/, HIPBLAS_R_16F /*c_type*/,
embed_dim, embed_dim,
static_cast<void*>(input_kv_grads.data_ptr()), HIPBLAS_R_32F /*compute_type*/,
rocblas_datatype_f16_r /*d_type*/, HIPBLAS_GEMM_DEFAULT /*algo*/
embed_dim, ));
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear KV Wgrad // Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -644,22 +612,18 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -644,22 +612,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches_kv, batches_kv,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_kv.data_ptr()), static_cast<const void*>(inputs_kv.data_ptr()),
rocblas_datatype_f16_r /*a_type*/, HIPBLAS_R_16F /*a_type*/,
embed_dim, embed_dim,
static_cast<const void*>(k_lin_grads_ptr), static_cast<const void*>(k_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/, HIPBLAS_R_16F /*b_type*/,
output_lin_kv_dim, output_lin_kv_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_weight_kv_grads.data_ptr()), static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/, HIPBLAS_R_16F /*c_type*/,
embed_dim,
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim, embed_dim,
rocblas_datatype_f32_r /*compute_type*/, HIPBLAS_R_32F /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, HIPBLAS_GEMM_DEFAULT /*algo*/
0 /*solution_index*/, ));
flags));
// Fused Layer Norm Bwd with Residual Add // Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient<half,float>( HostLayerNormGradient<half,float>(
......
...@@ -80,13 +80,13 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -80,13 +80,13 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
rocblas_int flags = 0; int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd // Input Linear Fwd
input_lin_results.copy_(input_biases); input_lin_results.copy_(input_biases);
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_dim, output_lin_dim,
...@@ -94,22 +94,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -94,22 +94,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()), static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(inputs.data_ptr()), static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(&beta_one), static_cast<const void*>(&beta_one),
q_lin_results_ptr, q_lin_results_ptr,
rocblas_datatype_f16_r, HIPBLAS_R_16F,
output_lin_dim, output_lin_dim,
q_lin_results_ptr, HIPBLAS_R_32F,
rocblas_datatype_f16_r, HIPBLAS_GEMM_DEFAULT /*algo*/
output_lin_dim, ));
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -183,7 +179,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -183,7 +179,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
outputs.copy_(output_biases); outputs.copy_(output_biases);
// Output Linear // Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -191,22 +187,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -191,22 +187,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(&beta_one), static_cast<const void*>(&beta_one),
static_cast<void*>(outputs.data_ptr()), static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, HIPBLAS_R_32F,
rocblas_gemm_algo_standard /*algo*/, HIPBLAS_GEMM_DEFAULT /*algo*/
0 /*solution_index*/, ));
flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_lin_results, bmm1_results, dropout_results, return {input_lin_results, bmm1_results, dropout_results,
...@@ -267,7 +259,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -267,7 +259,7 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
rocblas_int flags = 0; int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
...@@ -281,7 +273,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -281,7 +273,7 @@ std::vector<torch::Tensor> bwd_cuda(
#endif #endif
// Output Linear Dgrad // Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -289,25 +281,21 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -289,25 +281,21 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(output_grads.data_ptr()), static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()), static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, HIPBLAS_R_32F,
rocblas_gemm_algo_standard /*algo*/, HIPBLAS_GEMM_DEFAULT /*algo*/
0 /*solution_index*/, ));
flags));
// Output Linear Wgrad // Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -315,22 +303,18 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -315,22 +303,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches, batches,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(output_grads.data_ptr()), static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()), static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()), HIPBLAS_R_32F,
rocblas_datatype_f16_r, HIPBLAS_GEMM_DEFAULT /*algo*/
embed_dim, ));
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false); auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1 // MatMul2 Dgrad1
...@@ -441,7 +425,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -441,7 +425,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags); flags);
// Input Linear Dgrad // Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -449,25 +433,21 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -449,25 +433,21 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim, output_lin_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()), static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(input_lin_output_grads.data_ptr()), static_cast<const void*>(input_lin_output_grads.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
output_lin_dim, output_lin_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_grads.data_ptr()), static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim,
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, HIPBLAS_R_32F,
rocblas_gemm_algo_standard /*algo*/, HIPBLAS_GEMM_DEFAULT /*algo*/
0 /*solution_index*/, ));
flags));
// Input Linear Wgrad // Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -475,22 +455,18 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -475,22 +455,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches, batches,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(inputs.data_ptr()), static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(q_lin_grads_ptr), static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
output_lin_dim, output_lin_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()), static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, HIPBLAS_R_32F,
rocblas_gemm_algo_standard /*algo*/, HIPBLAS_GEMM_DEFAULT /*algo*/
0 /*solution_index*/, ));
flags));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
......
...@@ -78,13 +78,13 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, ...@@ -78,13 +78,13 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
rocblas_int flags = 0; int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd // Input Linear Fwd
input_lin_results.copy_(input_biases); input_lin_results.copy_(input_biases);
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_dim, output_lin_dim,
...@@ -92,22 +92,18 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, ...@@ -92,22 +92,18 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()), static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(inputs.data_ptr()), static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(&beta_one), static_cast<const void*>(&beta_one),
q_lin_results_ptr, q_lin_results_ptr,
rocblas_datatype_f16_r, HIPBLAS_R_16F,
output_lin_dim, output_lin_dim,
q_lin_results_ptr, HIPBLAS_R_32F,
rocblas_datatype_f16_r, HIPBLAS_GEMM_DEFAULT /*algo*/
output_lin_dim, ));
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -189,7 +185,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, ...@@ -189,7 +185,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
outputs.copy_(output_biases); outputs.copy_(output_biases);
// Output Linear // Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -197,22 +193,18 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, ...@@ -197,22 +193,18 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(&beta_one), static_cast<const void*>(&beta_one),
static_cast<void*>(outputs.data_ptr()), static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, HIPBLAS_R_32F,
rocblas_gemm_algo_standard /*algo*/, HIPBLAS_GEMM_DEFAULT /*algo*/
0 /*solution_index*/, ));
flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_lin_results, softmax_results, dropout_results, return {input_lin_results, softmax_results, dropout_results,
...@@ -273,7 +265,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -273,7 +265,7 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
rocblas_int flags = 0; int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
...@@ -287,7 +279,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -287,7 +279,7 @@ std::vector<torch::Tensor> bwd_cuda(
#endif #endif
// Output Linear Dgrad // Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -295,25 +287,21 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -295,25 +287,21 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(output_grads.data_ptr()), static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()), static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()), HIPBLAS_R_32F,
rocblas_datatype_f16_r, HIPBLAS_GEMM_DEFAULT /*algo*/
embed_dim, ));
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Output Linear Wgrad // Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -321,22 +309,18 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -321,22 +309,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches, batches,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(output_grads.data_ptr()), static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()), static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()), HIPBLAS_R_32F,
rocblas_datatype_f16_r, HIPBLAS_GEMM_DEFAULT /*algo*/
embed_dim, ));
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false); auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1 // MatMul2 Dgrad1
...@@ -441,7 +425,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -441,7 +425,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches, attn_batches,
flags); flags);
// Input Linear Dgrad // Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -449,25 +433,21 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -449,25 +433,21 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim, output_lin_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()), static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(input_lin_output_grads.data_ptr()), static_cast<const void*>(input_lin_output_grads.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
output_lin_dim, output_lin_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_grads.data_ptr()), static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim,
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, HIPBLAS_R_32F,
rocblas_gemm_algo_standard /*algo*/, HIPBLAS_GEMM_DEFAULT /*algo*/
0 /*solution_index*/, ));
flags));
// Input Linear Wgrad // Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -475,22 +455,18 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -475,22 +455,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches, batches,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(inputs.data_ptr()), static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(q_lin_grads_ptr), static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
output_lin_dim, output_lin_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()), static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, HIPBLAS_R_32F,
rocblas_gemm_algo_standard /*algo*/, HIPBLAS_GEMM_DEFAULT /*algo*/
0 /*solution_index*/, ));
flags));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
......
...@@ -77,12 +77,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -77,12 +77,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
rocblas_int flags = 0; int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd // Input Linear Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_dim, output_lin_dim,
...@@ -90,22 +90,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -90,22 +90,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()), static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(inputs.data_ptr()), static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
q_lin_results_ptr, q_lin_results_ptr,
rocblas_datatype_f16_r, HIPBLAS_R_16F,
output_lin_dim, output_lin_dim,
q_lin_results_ptr, HIPBLAS_R_32F,
rocblas_datatype_f16_r, HIPBLAS_GEMM_DEFAULT /*algo*/
output_lin_dim, ));
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -185,7 +181,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -185,7 +181,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags); flags);
// Output Linear // Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -193,22 +189,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -193,22 +189,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(outputs.data_ptr()), static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, HIPBLAS_R_32F,
rocblas_gemm_algo_standard /*algo*/, HIPBLAS_GEMM_DEFAULT /*algo*/
0 /*solution_index*/, ));
flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_lin_results, softmax_results, dropout_results, return {input_lin_results, softmax_results, dropout_results,
...@@ -269,7 +261,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -269,7 +261,7 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
rocblas_int flags = 0; int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
...@@ -283,7 +275,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -283,7 +275,7 @@ std::vector<torch::Tensor> bwd_cuda(
#endif #endif
// Output Linear Dgrad // Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -291,25 +283,21 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -291,25 +283,21 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(output_grads.data_ptr()), static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()), static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, HIPBLAS_R_32F,
rocblas_gemm_algo_standard /*algo*/, HIPBLAS_GEMM_DEFAULT /*algo*/
0 /*solution_index*/, ));
flags));
// Output Linear Wgrad // Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -317,22 +305,18 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -317,22 +305,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches, batches,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(output_grads.data_ptr()), static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()), static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()), HIPBLAS_R_32F,
rocblas_datatype_f16_r, HIPBLAS_GEMM_DEFAULT /*algo*/
embed_dim, ));
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -444,7 +428,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -444,7 +428,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags); flags);
// Input Linear Dgrad // Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -452,25 +436,21 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -452,25 +436,21 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim, output_lin_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()), static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(q_lin_grads_ptr), static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
output_lin_dim, output_lin_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_grads.data_ptr()), static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim,
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, HIPBLAS_R_32F,
rocblas_gemm_algo_standard /*algo*/, HIPBLAS_GEMM_DEFAULT /*algo*/
0 /*solution_index*/, ));
flags));
// Input Linear Wgrad // Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -478,22 +458,18 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -478,22 +458,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches, batches,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(inputs.data_ptr()), static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(q_lin_grads_ptr), static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
output_lin_dim, output_lin_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()), static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r, HIPBLAS_R_16F,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim, embed_dim,
rocblas_datatype_f32_r, HIPBLAS_R_32F,
rocblas_gemm_algo_standard /*algo*/, HIPBLAS_GEMM_DEFAULT /*algo*/
0 /*solution_index*/, ));
flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
......
...@@ -88,7 +88,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -88,7 +88,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
rocblas_int flags = 0; int flags = 0;
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm // Layer Norm
...@@ -103,7 +103,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -103,7 +103,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr())); static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr()));
// Input Linear Fwd // Input Linear Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_dim, output_lin_dim,
...@@ -111,23 +111,19 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -111,23 +111,19 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()), static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/, HIPBLAS_R_16F /*a_type*/,
embed_dim, embed_dim,
//static_cast<const void*>(inputs.data_ptr()), //static_cast<const void*>(inputs.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()), static_cast<const void*>(lyr_nrm_results.data_ptr()),
rocblas_datatype_f16_r /*b_type*/, HIPBLAS_R_16F /*b_type*/,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
q_lin_results_ptr, q_lin_results_ptr,
rocblas_datatype_f16_r /*c_type*/, HIPBLAS_R_16F /*c_type*/,
output_lin_dim, output_lin_dim,
q_lin_results_ptr, HIPBLAS_R_32F /*compute_type*/,
rocblas_datatype_f16_r /*d_type*/, HIPBLAS_GEMM_DEFAULT /*algo*/
output_lin_dim, ));
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -208,7 +204,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -208,7 +204,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags); flags);
// Output Linear // Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -216,22 +212,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -216,22 +212,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/, HIPBLAS_R_16F /*a_type*/,
embed_dim, embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r /*b_type*/, HIPBLAS_R_16F /*b_type*/,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_lin_results.data_ptr()), static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*c_type*/, HIPBLAS_R_16F /*c_type*/,
embed_dim,
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim, embed_dim,
rocblas_datatype_f32_r /*compute_type*/, HIPBLAS_R_32F /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, HIPBLAS_GEMM_DEFAULT /*algo*/
0 /*solution_index*/, ));
flags));
// End-of-block Dropout-Add // End-of-block Dropout-Add
...@@ -320,7 +312,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -320,7 +312,7 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
rocblas_int flags = 0; int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
...@@ -341,7 +333,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -341,7 +333,7 @@ std::vector<torch::Tensor> bwd_cuda(
(1.0 / (1.0 - dropout_prob))); (1.0 / (1.0 - dropout_prob)));
// Output Linear Dgrad // Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -349,25 +341,21 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -349,25 +341,21 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/, HIPBLAS_R_16F /*a_type*/,
embed_dim, embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()), static_cast<const void*>(dropout_add_grads.data_ptr()),
rocblas_datatype_f16_r /*b_type*/, HIPBLAS_R_16F /*b_type*/,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()), static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/, HIPBLAS_R_16F /*c_type*/,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim, embed_dim,
rocblas_datatype_f32_r /*compute_type*/, HIPBLAS_R_32F /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, HIPBLAS_GEMM_DEFAULT /*algo*/
0 /*solution_index*/, ));
flags));
// Output Linear Wgrad // Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -375,22 +363,18 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -375,22 +363,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches, batches,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r /*a_type*/, HIPBLAS_R_16F /*a_type*/,
embed_dim, embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()), static_cast<const void*>(dropout_add_grads.data_ptr()),
rocblas_datatype_f16_r /*b_type*/, HIPBLAS_R_16F /*b_type*/,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()), static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/, HIPBLAS_R_16F /*c_type*/,
embed_dim, embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()), HIPBLAS_R_32F /*compute_type*/,
rocblas_datatype_f16_r /*d_type*/, HIPBLAS_GEMM_DEFAULT /*algo*/
embed_dim, ));
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -502,7 +486,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -502,7 +486,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags); flags);
// Input Linear Dgrad // Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -510,26 +494,22 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -510,26 +494,22 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim, output_lin_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()), static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/, HIPBLAS_R_16F /*a_type*/,
embed_dim, embed_dim,
static_cast<const void*>(q_lin_grads_ptr), static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/, HIPBLAS_R_16F /*b_type*/,
output_lin_dim, output_lin_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
//static_cast<void*>(input_grads.data_ptr()), //static_cast<void*>(input_grads.data_ptr()),
static_cast<void*>(input_lin_grads.data_ptr()), static_cast<void*>(input_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/, HIPBLAS_R_16F /*c_type*/,
embed_dim,
static_cast<void*>(input_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim, embed_dim,
rocblas_datatype_f32_r /*compute_type*/, HIPBLAS_R_32F /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, HIPBLAS_GEMM_DEFAULT /*algo*/
0 /*solution_index*/, ));
flags));
// Input Linear Wgrad // Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -538,22 +518,18 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -538,22 +518,18 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
//static_cast<const void*>(inputs.data_ptr()), //static_cast<const void*>(inputs.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()), static_cast<const void*>(lyr_nrm_results.data_ptr()),
rocblas_datatype_f16_r /*a_type*/, HIPBLAS_R_16F /*a_type*/,
embed_dim, embed_dim,
static_cast<const void*>(q_lin_grads_ptr), static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/, HIPBLAS_R_16F /*b_type*/,
output_lin_dim, output_lin_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()), static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/, HIPBLAS_R_16F /*c_type*/,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim, embed_dim,
rocblas_datatype_f32_r /*compute_type*/, HIPBLAS_R_32F /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, HIPBLAS_GEMM_DEFAULT /*algo*/
0 /*solution_index*/, ));
flags));
// Fused Layer Norm Bwd with Residual Add // Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient<half, float>( HostLayerNormGradient<half, float>(
......
...@@ -17,15 +17,15 @@ ...@@ -17,15 +17,15 @@
// symbol to be automatically resolved by PyTorch libs // symbol to be automatically resolved by PyTorch libs
/* /*
rocblas_datatype a_type = rocblas_datatype_f16_r; // OK rocblas_datatype a_type = HIPBLAS_R_16F; // OK
rocblas_datatype b_type = rocblas_datatype_f16_r; // OK rocblas_datatype b_type = HIPBLAS_R_16F; // OK
rocblas_datatype c_type = rocblas_datatype_f16_r; // OK rocblas_datatype c_type = HIPBLAS_R_16F; // OK
rocblas_datatype d_type = rocblas_datatype_f16_r; rocblas_datatype d_type = HIPBLAS_R_16F;
rocblas_datatype compute_type = rocblas_datatype_f32_r; rocblas_datatype compute_type = HIPBLAS_R_32F;
rocblas_gemm_algo algo = rocblas_gemm_algo_standard; rocblas_gemm_algo algo = HIPBLAS_GEMM_DEFAULT;
int32_t solution_index = 0; int32_t solution_index = 0;
rocblas_int flags = 0; int flags = 0;
*/ */
namespace { namespace {
...@@ -44,38 +44,37 @@ cublasOperation_t convertTransToCublasOperation(char trans) { ...@@ -44,38 +44,37 @@ cublasOperation_t convertTransToCublasOperation(char trans) {
void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k, void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB, float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo, rocblas_int flags) { float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, hipblasGemmAlgo_t algo, int flags) {
cublasOperation_t opa = convertTransToCublasOperation(transa); hipblasOperation_t opa = convertTransToCublasOperation(transa);
cublasOperation_t opb = convertTransToCublasOperation(transb); hipblasOperation_t opb = convertTransToCublasOperation(transb);
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); hipblasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
cublasSetStream(handle, stream); hipblasSetStream(handle, stream);
float fAlpha = alpha; float fAlpha = alpha;
float fBeta = beta; float fBeta = beta;
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //THCublasCheck(hipblasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, TORCH_CUDABLAS_CHECK(hipblasGemmStridedBatchedEx(handle,
opa, opb, (int)m, (int)n, (int)k, opa, opb, (int)m, (int)n, (int)k,
(void*)&fAlpha, a, rocblas_datatype_f16_r /*a_type*/, (int)lda, strideA, (void*)&fAlpha, (const void*)a, HIPBLAS_R_16F /*a_type*/, (int)lda, strideA,
b, rocblas_datatype_f16_r /*b_type*/, (int)ldb, strideB, (const void*)b, HIPBLAS_R_16F /*b_type*/, (int)ldb, strideB,
(void*)&fBeta, c, rocblas_datatype_f16_r /*c_type*/, (int)ldc, strideC, (void*)&fBeta, (void*)c, HIPBLAS_R_16F /*c_type*/, (int)ldc, strideC,
d, rocblas_datatype_f16_r /*d_type*/, int(ldd), strideD, (int)batchCount, HIPBLAS_R_32F /*compute_type*/, algo));
(int)batchCount, rocblas_datatype_f32_r /*compute_type*/, algo, 0 /*solution_index*/, flags));
} }
void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k, void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB, float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_int flags) { float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, int flags) {
auto stream = c10::cuda::getCurrentCUDAStream(); auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA();
if ( (transa == 't') && (transb == 'n') ) { if ( (transa == 't') && (transb == 'n') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); } if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, HIPBLAS_GEMM_DEFAULT, flags); }
else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); } else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, HIPBLAS_GEMM_DEFAULT, flags); }
} else if ( (transa == 'n') && (transb == 'n') ) { } else if ( (transa == 'n') && (transb == 'n') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); } if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, HIPBLAS_GEMM_DEFAULT, flags); }
else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); } else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, HIPBLAS_GEMM_DEFAULT, flags); }
} else if ( (transa == 'n') && (transb == 't') ) { } else if ( (transa == 'n') && (transb == 't') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); } if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, HIPBLAS_GEMM_DEFAULT, flags); }
else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); } else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, HIPBLAS_GEMM_DEFAULT, flags); }
} else { } else {
AT_ASSERTM(false, "TransA and TransB are invalid"); AT_ASSERTM(false, "TransA and TransB are invalid");
} }
......
...@@ -4,6 +4,7 @@ import sys ...@@ -4,6 +4,7 @@ import sys
test_dirs = ["groupbn", "fused_dense", "layer_norm", "multihead_attn", "transducer", "focal_loss", "index_mul_2d", "."] # "." for test_label_smoothing.py test_dirs = ["groupbn", "fused_dense", "layer_norm", "multihead_attn", "transducer", "focal_loss", "index_mul_2d", "."] # "." for test_label_smoothing.py
ROCM_BLACKLIST = [ ROCM_BLACKLIST = [
"groupbn",
"layer_norm" "layer_norm"
] ]
......
...@@ -31,7 +31,7 @@ cublasStatus_t gemm_bias( ...@@ -31,7 +31,7 @@ cublasStatus_t gemm_bias(
double* C, double* C,
int ldc) { int ldc) {
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex( return hipblasGemmEx(
handle, handle,
transa, transa,
transb, transb,
...@@ -40,22 +40,18 @@ cublasStatus_t gemm_bias( ...@@ -40,22 +40,18 @@ cublasStatus_t gemm_bias(
k, k,
alpha, alpha,
A, A,
rocblas_datatype_f64_r, HIPBLAS_R_64F,
lda, lda,
B, B,
rocblas_datatype_f64_r, HIPBLAS_R_64F,
ldb, ldb,
beta, beta,
C, C,
rocblas_datatype_f64_r, HIPBLAS_R_64F,
ldc, ldc,
C, HIPBLAS_R_64F,
rocblas_datatype_f64_r, HIPBLAS_GEMM_DEFAULT
ldc, );
rocblas_datatype_f64_r,
rocblas_gemm_algo_standard,
0,
0);
#else #else
return cublasGemmEx( return cublasGemmEx(
handle, handle,
...@@ -97,7 +93,7 @@ cublasStatus_t gemm_bias( ...@@ -97,7 +93,7 @@ cublasStatus_t gemm_bias(
float* C, float* C,
int ldc) { int ldc) {
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex( return hipblasGemmEx(
handle, handle,
transa, transa,
transb, transb,
...@@ -106,22 +102,18 @@ cublasStatus_t gemm_bias( ...@@ -106,22 +102,18 @@ cublasStatus_t gemm_bias(
k, k,
alpha, alpha,
A, A,
rocblas_datatype_f32_r, HIPBLAS_R_32F,
lda, lda,
B, B,
rocblas_datatype_f32_r, HIPBLAS_R_32F,
ldb, ldb,
beta, beta,
C, C,
rocblas_datatype_f32_r, HIPBLAS_R_32F,
ldc,
C,
rocblas_datatype_f32_r,
ldc, ldc,
rocblas_datatype_f32_r, HIPBLAS_R_32F,
rocblas_gemm_algo_standard, HIPBLAS_GEMM_DEFAULT
0, );
0);
#else #else
return cublasGemmEx( return cublasGemmEx(
...@@ -164,7 +156,7 @@ cublasStatus_t gemm_bias( ...@@ -164,7 +156,7 @@ cublasStatus_t gemm_bias(
at::Half* C, at::Half* C,
int ldc) { int ldc) {
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex( return hipblasGemmEx(
handle, handle,
transa, transa,
transb, transb,
...@@ -173,22 +165,18 @@ cublasStatus_t gemm_bias( ...@@ -173,22 +165,18 @@ cublasStatus_t gemm_bias(
k, k,
alpha, alpha,
A, A,
rocblas_datatype_f16_r, HIPBLAS_R_16F,
lda, lda,
B, B,
rocblas_datatype_f16_r, HIPBLAS_R_16F,
ldb, ldb,
beta, beta,
C, C,
rocblas_datatype_f16_r, HIPBLAS_R_16F,
ldc,
C,
rocblas_datatype_f16_r,
ldc, ldc,
rocblas_datatype_f32_r, HIPBLAS_R_32F,
rocblas_gemm_algo_standard, HIPBLAS_GEMM_DEFAULT
0, );
0);
#else #else
return cublasGemmEx( return cublasGemmEx(
handle, handle,
......
...@@ -78,7 +78,7 @@ cublasStatus_t mlp_gemm( ...@@ -78,7 +78,7 @@ cublasStatus_t mlp_gemm(
int ldc, int ldc,
int flag) { int flag) {
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex( return hipblasGemmEx(
handle, handle,
transa, transa,
transb, transb,
...@@ -87,22 +87,18 @@ cublasStatus_t mlp_gemm( ...@@ -87,22 +87,18 @@ cublasStatus_t mlp_gemm(
k, k,
alpha, alpha,
A, A,
rocblas_datatype_f64_r, HIPBLAS_R_64F,
lda, lda,
B, B,
rocblas_datatype_f64_r, HIPBLAS_R_64F,
ldb, ldb,
beta, beta,
C, C,
rocblas_datatype_f64_r, HIPBLAS_R_64F,
ldc, ldc,
C, HIPBLAS_R_64F,
rocblas_datatype_f64_r, HIPBLAS_GEMM_DEFAULT
ldc, );
rocblas_datatype_f64_r,
rocblas_gemm_algo_standard,
0,
flag);
#else #else
return cublasGemmEx( return cublasGemmEx(
handle, handle,
...@@ -145,7 +141,7 @@ cublasStatus_t mlp_gemm( ...@@ -145,7 +141,7 @@ cublasStatus_t mlp_gemm(
int ldc, int ldc,
int flag) { int flag) {
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex( return hipblasGemmEx(
handle, handle,
transa, transa,
transb, transb,
...@@ -154,22 +150,18 @@ cublasStatus_t mlp_gemm( ...@@ -154,22 +150,18 @@ cublasStatus_t mlp_gemm(
k, k,
alpha, alpha,
A, A,
rocblas_datatype_f32_r, HIPBLAS_R_32F,
lda, lda,
B, B,
rocblas_datatype_f32_r, HIPBLAS_R_32F,
ldb, ldb,
beta, beta,
C, C,
rocblas_datatype_f32_r, HIPBLAS_R_32F,
ldc,
C,
rocblas_datatype_f32_r,
ldc, ldc,
rocblas_datatype_f32_r, HIPBLAS_R_32F,
rocblas_gemm_algo_standard, HIPBLAS_GEMM_DEFAULT
0, );
flag);
#else #else
return cublasGemmEx( return cublasGemmEx(
...@@ -213,7 +205,7 @@ cublasStatus_t mlp_gemm( ...@@ -213,7 +205,7 @@ cublasStatus_t mlp_gemm(
int ldc, int ldc,
int flag) { int flag) {
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex( return hipblasGemmEx(
handle, handle,
transa, transa,
transb, transb,
...@@ -222,22 +214,18 @@ cublasStatus_t mlp_gemm( ...@@ -222,22 +214,18 @@ cublasStatus_t mlp_gemm(
k, k,
alpha, alpha,
A, A,
rocblas_datatype_f16_r, HIPBLAS_R_16F,
lda, lda,
B, B,
rocblas_datatype_f16_r, HIPBLAS_R_16F,
ldb, ldb,
beta, beta,
C, C,
rocblas_datatype_f16_r, HIPBLAS_R_16F,
ldc,
C,
rocblas_datatype_f16_r,
ldc, ldc,
rocblas_datatype_f32_r, HIPBLAS_R_32F,
rocblas_gemm_algo_standard, HIPBLAS_GEMM_DEFAULT
0, );
flag);
#else #else
return cublasGemmEx( return cublasGemmEx(
handle, handle,
......
...@@ -8,7 +8,7 @@ import os ...@@ -8,7 +8,7 @@ import os
parser = argparse.ArgumentParser(description='allreduce hook example') parser = argparse.ArgumentParser(description='allreduce hook example')
parser.add_argument("--local_rank", default=0, type=int) parser.add_argument("--local-rank", default=0, type=int)
args = parser.parse_args() args = parser.parse_args()
args.distributed = False args.distributed = False
......
...@@ -8,7 +8,7 @@ from apex.parallel import DistributedDataParallel ...@@ -8,7 +8,7 @@ from apex.parallel import DistributedDataParallel
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# FOR DISTRIBUTED: Parse for the local_rank argument, which will be supplied # FOR DISTRIBUTED: Parse for the local_rank argument, which will be supplied
# automatically by torch.distributed.launch. # automatically by torch.distributed.launch.
parser.add_argument("--local_rank", default=0, type=int) parser.add_argument("--local-rank", default=0, type=int)
parser.add_argument("--opt_level", default="O2", type=str) parser.add_argument("--opt_level", default="O2", type=str)
args = parser.parse_args() args = parser.parse_args()
......
...@@ -26,7 +26,7 @@ batch_size = 32 ...@@ -26,7 +26,7 @@ batch_size = 32
from apex.parallel import DistributedDataParallel as DDP from apex.parallel import DistributedDataParallel as DDP
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", default=0, type=int) parser.add_argument("--local-rank", default=0, type=int)
parser.add_argument("--fp16", action='store_true', default=False) parser.add_argument("--fp16", action='store_true', default=False)
parser.add_argument("--fp64", action='store_true', default=False) parser.add_argument("--fp64", action='store_true', default=False)
parser.add_argument("--group_size", default=0, type=int) parser.add_argument("--group_size", default=0, type=int)
......
...@@ -23,7 +23,7 @@ def compare(desc, inp1, inp2, error= 1e-5): ...@@ -23,7 +23,7 @@ def compare(desc, inp1, inp2, error= 1e-5):
return close return close
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local-rank', type=int, default=0)
parser.add_argument('--apex', action='store_true') parser.add_argument('--apex', action='store_true')
args = parser.parse_args() args = parser.parse_args()
......
...@@ -26,7 +26,7 @@ batch_size = 32 ...@@ -26,7 +26,7 @@ batch_size = 32
from apex.parallel import DistributedDataParallel as DDP from apex.parallel import DistributedDataParallel as DDP
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", default=0, type=int) parser.add_argument("--local-rank", default=0, type=int)
parser.add_argument("--fp16", action='store_true', default=False) parser.add_argument("--fp16", action='store_true', default=False)
parser.add_argument("--fp64", action='store_true', default=False) parser.add_argument("--fp64", action='store_true', default=False)
args = parser.parse_args() args = parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment