Commit 4e7a2a8e authored by flyingdown's avatar flyingdown
Browse files

fix up for torch2.1

parent 2a4864d5
......@@ -85,12 +85,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
char a_layout_n{'n'};
char b_layout_n{'n'};
rocblas_int flags = 0;
int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Q Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_q_dim,
......@@ -98,25 +98,21 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(inputs_q.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(&beta),
q_lin_results_ptr,
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
output_lin_q_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_q_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
// Input Linear KV Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_kv_dim,
......@@ -124,22 +120,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(inputs_kv.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(&beta),
k_lin_results_ptr,
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
output_lin_kv_dim,
k_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_kv_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
......@@ -219,7 +211,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags);
// Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
......@@ -227,22 +219,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_lin_q_results,
......@@ -318,7 +306,7 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'};
char b_layout_t{'t'};
rocblas_int flags = 0;
int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
......@@ -332,7 +320,7 @@ std::vector<torch::Tensor> bwd_cuda(
#endif
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -340,25 +328,21 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -366,22 +350,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
......@@ -493,7 +473,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags);
// Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -501,25 +481,21 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_q_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
output_lin_q_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_q_grads.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<void*>(input_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
// Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -527,25 +503,21 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_q.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
output_lin_q_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
// Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -553,25 +525,21 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_kv_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(k_lin_grads_ptr),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
output_lin_kv_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
// Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -579,22 +547,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches_kv,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_kv.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(k_lin_grads_ptr),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
output_lin_kv_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
// TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_q_grads,
......
......@@ -101,7 +101,7 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'};
char b_layout_n{'n'};
rocblas_int flags = 0;
int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm
......@@ -116,7 +116,7 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr()));
// Input Linear Q Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_q_dim,
......@@ -124,26 +124,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
HIPBLAS_R_16F /*a_type*/,
embed_dim,
//static_cast<const void*>(inputs_q.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
HIPBLAS_R_16F /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
q_lin_results_ptr,
rocblas_datatype_f16_r /*c_type*/,
HIPBLAS_R_16F /*c_type*/,
output_lin_q_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r /*d_type*/,
output_lin_q_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F /*compute_type*/,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
// Input Linear KV Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_kv_dim,
......@@ -151,22 +147,18 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
HIPBLAS_R_16F /*a_type*/,
embed_dim,
static_cast<const void*>(inputs_kv.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
HIPBLAS_R_16F /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
k_lin_results_ptr,
rocblas_datatype_f16_r /*c_type*/,
HIPBLAS_R_16F /*c_type*/,
output_lin_kv_dim,
k_lin_results_ptr,
rocblas_datatype_f16_r /*d_type*/,
output_lin_kv_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F /*compute_type*/,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
......@@ -246,7 +238,7 @@ std::vector<torch::Tensor> fwd_cuda(
flags);
// Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
......@@ -254,22 +246,18 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
HIPBLAS_R_16F /*a_type*/,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
HIPBLAS_R_16F /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
HIPBLAS_R_16F /*c_type*/,
embed_dim,
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F /*compute_type*/,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
// End-of-block Dropout-Add
if (is_training) {
......@@ -374,7 +362,7 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'};
char b_layout_t{'t'};
rocblas_int flags = 0;
int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
......@@ -396,7 +384,7 @@ std::vector<torch::Tensor> bwd_cuda(
(1.0 / (1.0 - dropout_prob)));
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -404,25 +392,21 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
HIPBLAS_R_16F /*a_type*/,
embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
HIPBLAS_R_16F /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
HIPBLAS_R_16F /*c_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F /*compute_type*/,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -430,22 +414,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
HIPBLAS_R_16F /*a_type*/,
embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
HIPBLAS_R_16F /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
HIPBLAS_R_16F /*c_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F /*compute_type*/,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
......@@ -557,7 +537,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags);
// Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -565,26 +545,22 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_q_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_q.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
HIPBLAS_R_16F /*a_type*/,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
HIPBLAS_R_16F /*b_type*/,
output_lin_q_dim,
static_cast<const void*>(&beta),
//static_cast<void*>(input_q_grads.data_ptr()),
static_cast<void*>(input_lin_q_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_lin_q_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
HIPBLAS_R_16F /*c_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F /*compute_type*/,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
// Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -592,25 +568,21 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_q.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
HIPBLAS_R_16F /*a_type*/,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
HIPBLAS_R_16F /*b_type*/,
output_lin_q_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
HIPBLAS_R_16F /*c_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F /*compute_type*/,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
// Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -618,25 +590,21 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_kv_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights_kv.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
HIPBLAS_R_16F /*a_type*/,
embed_dim,
static_cast<const void*>(k_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
HIPBLAS_R_16F /*b_type*/,
output_lin_kv_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
HIPBLAS_R_16F /*c_type*/,
embed_dim,
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F /*compute_type*/,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
// Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -644,22 +612,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches_kv,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs_kv.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
HIPBLAS_R_16F /*a_type*/,
embed_dim,
static_cast<const void*>(k_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
HIPBLAS_R_16F /*b_type*/,
output_lin_kv_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
HIPBLAS_R_16F /*c_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F /*compute_type*/,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient<half,float>(
......
......@@ -80,13 +80,13 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
char a_layout_n{'n'};
char b_layout_n{'n'};
rocblas_int flags = 0;
int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
input_lin_results.copy_(input_biases);
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_dim,
......@@ -94,22 +94,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(&beta_one),
q_lin_results_ptr,
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
output_lin_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
......@@ -183,7 +179,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
outputs.copy_(output_biases);
// Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
......@@ -191,22 +187,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(&beta_one),
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_lin_results, bmm1_results, dropout_results,
......@@ -267,7 +259,7 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'};
char b_layout_t{'t'};
rocblas_int flags = 0;
int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
......@@ -281,7 +273,7 @@ std::vector<torch::Tensor> bwd_cuda(
#endif
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -289,25 +281,21 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -315,22 +303,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1
......@@ -441,7 +425,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -449,25 +433,21 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(input_lin_output_grads.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -475,22 +455,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
......
......@@ -78,13 +78,13 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
char a_layout_n{'n'};
char b_layout_n{'n'};
rocblas_int flags = 0;
int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
input_lin_results.copy_(input_biases);
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_dim,
......@@ -92,22 +92,18 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(&beta_one),
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_16F,
output_lin_dim,
HIPBLAS_R_32F,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
......@@ -189,7 +185,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
outputs.copy_(output_biases);
// Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
......@@ -197,22 +193,18 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(&beta_one),
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_lin_results, softmax_results, dropout_results,
......@@ -273,7 +265,7 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'};
char b_layout_t{'t'};
rocblas_int flags = 0;
int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
......@@ -287,7 +279,7 @@ std::vector<torch::Tensor> bwd_cuda(
#endif
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -295,25 +287,21 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -321,22 +309,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1
......@@ -441,7 +425,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches,
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -449,25 +433,21 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(input_lin_output_grads.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -475,22 +455,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
......
......@@ -77,12 +77,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
char a_layout_n{'n'};
char b_layout_n{'n'};
rocblas_int flags = 0;
int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_dim,
......@@ -90,22 +90,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(&beta),
q_lin_results_ptr,
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
output_lin_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
......@@ -185,7 +181,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags);
// Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
......@@ -193,22 +189,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_lin_results, softmax_results, dropout_results,
......@@ -269,7 +261,7 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'};
char b_layout_t{'t'};
rocblas_int flags = 0;
int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
......@@ -283,7 +275,7 @@ std::vector<torch::Tensor> bwd_cuda(
#endif
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -291,25 +283,21 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -317,22 +305,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
......@@ -444,7 +428,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -452,25 +436,21 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -478,22 +458,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
embed_dim,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
......
......@@ -88,7 +88,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
char a_layout_n{'n'};
char b_layout_n{'n'};
rocblas_int flags = 0;
int flags = 0;
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm
......@@ -103,7 +103,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr()));
// Input Linear Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_dim,
......@@ -111,23 +111,19 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
HIPBLAS_R_16F /*a_type*/,
embed_dim,
//static_cast<const void*>(inputs.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
HIPBLAS_R_16F /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
q_lin_results_ptr,
rocblas_datatype_f16_r /*c_type*/,
HIPBLAS_R_16F /*c_type*/,
output_lin_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r /*d_type*/,
output_lin_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F /*compute_type*/,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
......@@ -208,7 +204,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags);
// Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
......@@ -216,22 +212,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
HIPBLAS_R_16F /*a_type*/,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
HIPBLAS_R_16F /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
HIPBLAS_R_16F /*c_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F /*compute_type*/,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
// End-of-block Dropout-Add
......@@ -320,7 +312,7 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'};
char b_layout_t{'t'};
rocblas_int flags = 0;
int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
......@@ -341,7 +333,7 @@ std::vector<torch::Tensor> bwd_cuda(
(1.0 / (1.0 - dropout_prob)));
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -349,25 +341,21 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
HIPBLAS_R_16F /*a_type*/,
embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
HIPBLAS_R_16F /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
HIPBLAS_R_16F /*c_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F /*compute_type*/,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -375,22 +363,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches,
static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
HIPBLAS_R_16F /*a_type*/,
embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
HIPBLAS_R_16F /*b_type*/,
embed_dim,
static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
HIPBLAS_R_16F /*c_type*/,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F /*compute_type*/,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
......@@ -502,7 +486,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
......@@ -510,26 +494,22 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim,
static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
HIPBLAS_R_16F /*a_type*/,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
HIPBLAS_R_16F /*b_type*/,
output_lin_dim,
static_cast<const void*>(&beta),
//static_cast<void*>(input_grads.data_ptr()),
static_cast<void*>(input_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
HIPBLAS_R_16F /*c_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F /*compute_type*/,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
TORCH_CUDABLAS_CHECK(hipblasGemmEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
......@@ -538,22 +518,18 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<const void*>(&alpha),
//static_cast<const void*>(inputs.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
HIPBLAS_R_16F /*a_type*/,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
HIPBLAS_R_16F /*b_type*/,
output_lin_dim,
static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
HIPBLAS_R_16F /*c_type*/,
embed_dim,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
HIPBLAS_R_32F /*compute_type*/,
HIPBLAS_GEMM_DEFAULT /*algo*/
));
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient<half, float>(
......
......@@ -17,15 +17,15 @@
// symbol to be automatically resolved by PyTorch libs
/*
rocblas_datatype a_type = rocblas_datatype_f16_r; // OK
rocblas_datatype b_type = rocblas_datatype_f16_r; // OK
rocblas_datatype c_type = rocblas_datatype_f16_r; // OK
rocblas_datatype d_type = rocblas_datatype_f16_r;
rocblas_datatype compute_type = rocblas_datatype_f32_r;
rocblas_datatype a_type = HIPBLAS_R_16F; // OK
rocblas_datatype b_type = HIPBLAS_R_16F; // OK
rocblas_datatype c_type = HIPBLAS_R_16F; // OK
rocblas_datatype d_type = HIPBLAS_R_16F;
rocblas_datatype compute_type = HIPBLAS_R_32F;
rocblas_gemm_algo algo = rocblas_gemm_algo_standard;
rocblas_gemm_algo algo = HIPBLAS_GEMM_DEFAULT;
int32_t solution_index = 0;
rocblas_int flags = 0;
int flags = 0;
*/
namespace {
......@@ -44,38 +44,37 @@ cublasOperation_t convertTransToCublasOperation(char trans) {
void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo, rocblas_int flags) {
cublasOperation_t opa = convertTransToCublasOperation(transa);
cublasOperation_t opb = convertTransToCublasOperation(transb);
float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, hipblasGemmAlgo_t algo, int flags) {
hipblasOperation_t opa = convertTransToCublasOperation(transa);
hipblasOperation_t opb = convertTransToCublasOperation(transb);
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
hipblasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
hipblasSetStream(handle, stream);
float fAlpha = alpha;
float fBeta = beta;
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle,
//THCublasCheck(hipblasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK(hipblasGemmStridedBatchedEx(handle,
opa, opb, (int)m, (int)n, (int)k,
(void*)&fAlpha, a, rocblas_datatype_f16_r /*a_type*/, (int)lda, strideA,
b, rocblas_datatype_f16_r /*b_type*/, (int)ldb, strideB,
(void*)&fBeta, c, rocblas_datatype_f16_r /*c_type*/, (int)ldc, strideC,
d, rocblas_datatype_f16_r /*d_type*/, int(ldd), strideD,
(int)batchCount, rocblas_datatype_f32_r /*compute_type*/, algo, 0 /*solution_index*/, flags));
(void*)&fAlpha, (const void*)a, HIPBLAS_R_16F /*a_type*/, (int)lda, strideA,
(const void*)b, HIPBLAS_R_16F /*b_type*/, (int)ldb, strideB,
(void*)&fBeta, (void*)c, HIPBLAS_R_16F /*c_type*/, (int)ldc, strideC,
(int)batchCount, HIPBLAS_R_32F /*compute_type*/, algo));
}
void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_int flags) {
auto stream = c10::cuda::getCurrentCUDAStream();
float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, int flags) {
auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA();
if ( (transa == 't') && (transb == 'n') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, HIPBLAS_GEMM_DEFAULT, flags); }
else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, HIPBLAS_GEMM_DEFAULT, flags); }
} else if ( (transa == 'n') && (transb == 'n') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, HIPBLAS_GEMM_DEFAULT, flags); }
else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, HIPBLAS_GEMM_DEFAULT, flags); }
} else if ( (transa == 'n') && (transb == 't') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, HIPBLAS_GEMM_DEFAULT, flags); }
else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, HIPBLAS_GEMM_DEFAULT, flags); }
} else {
AT_ASSERTM(false, "TransA and TransB are invalid");
}
......
......@@ -4,6 +4,7 @@ import sys
test_dirs = ["groupbn", "fused_dense", "layer_norm", "multihead_attn", "transducer", "focal_loss", "index_mul_2d", "."] # "." for test_label_smoothing.py
ROCM_BLACKLIST = [
"groupbn",
"layer_norm"
]
......
......@@ -31,7 +31,7 @@ cublasStatus_t gemm_bias(
double* C,
int ldc) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
return hipblasGemmEx(
handle,
transa,
transb,
......@@ -40,22 +40,18 @@ cublasStatus_t gemm_bias(
k,
alpha,
A,
rocblas_datatype_f64_r,
HIPBLAS_R_64F,
lda,
B,
rocblas_datatype_f64_r,
HIPBLAS_R_64F,
ldb,
beta,
C,
rocblas_datatype_f64_r,
HIPBLAS_R_64F,
ldc,
C,
rocblas_datatype_f64_r,
ldc,
rocblas_datatype_f64_r,
rocblas_gemm_algo_standard,
0,
0);
HIPBLAS_R_64F,
HIPBLAS_GEMM_DEFAULT
);
#else
return cublasGemmEx(
handle,
......@@ -97,7 +93,7 @@ cublasStatus_t gemm_bias(
float* C,
int ldc) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
return hipblasGemmEx(
handle,
transa,
transb,
......@@ -106,22 +102,18 @@ cublasStatus_t gemm_bias(
k,
alpha,
A,
rocblas_datatype_f32_r,
HIPBLAS_R_32F,
lda,
B,
rocblas_datatype_f32_r,
HIPBLAS_R_32F,
ldb,
beta,
C,
rocblas_datatype_f32_r,
ldc,
C,
rocblas_datatype_f32_r,
HIPBLAS_R_32F,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
0);
HIPBLAS_R_32F,
HIPBLAS_GEMM_DEFAULT
);
#else
return cublasGemmEx(
......@@ -164,7 +156,7 @@ cublasStatus_t gemm_bias(
at::Half* C,
int ldc) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
return hipblasGemmEx(
handle,
transa,
transb,
......@@ -173,22 +165,18 @@ cublasStatus_t gemm_bias(
k,
alpha,
A,
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
lda,
B,
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
ldb,
beta,
C,
rocblas_datatype_f16_r,
ldc,
C,
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
0);
HIPBLAS_R_32F,
HIPBLAS_GEMM_DEFAULT
);
#else
return cublasGemmEx(
handle,
......
......@@ -78,7 +78,7 @@ cublasStatus_t mlp_gemm(
int ldc,
int flag) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
return hipblasGemmEx(
handle,
transa,
transb,
......@@ -87,22 +87,18 @@ cublasStatus_t mlp_gemm(
k,
alpha,
A,
rocblas_datatype_f64_r,
HIPBLAS_R_64F,
lda,
B,
rocblas_datatype_f64_r,
HIPBLAS_R_64F,
ldb,
beta,
C,
rocblas_datatype_f64_r,
HIPBLAS_R_64F,
ldc,
C,
rocblas_datatype_f64_r,
ldc,
rocblas_datatype_f64_r,
rocblas_gemm_algo_standard,
0,
flag);
HIPBLAS_R_64F,
HIPBLAS_GEMM_DEFAULT
);
#else
return cublasGemmEx(
handle,
......@@ -145,7 +141,7 @@ cublasStatus_t mlp_gemm(
int ldc,
int flag) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
return hipblasGemmEx(
handle,
transa,
transb,
......@@ -154,22 +150,18 @@ cublasStatus_t mlp_gemm(
k,
alpha,
A,
rocblas_datatype_f32_r,
HIPBLAS_R_32F,
lda,
B,
rocblas_datatype_f32_r,
HIPBLAS_R_32F,
ldb,
beta,
C,
rocblas_datatype_f32_r,
ldc,
C,
rocblas_datatype_f32_r,
HIPBLAS_R_32F,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
flag);
HIPBLAS_R_32F,
HIPBLAS_GEMM_DEFAULT
);
#else
return cublasGemmEx(
......@@ -213,7 +205,7 @@ cublasStatus_t mlp_gemm(
int ldc,
int flag) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
return hipblasGemmEx(
handle,
transa,
transb,
......@@ -222,22 +214,18 @@ cublasStatus_t mlp_gemm(
k,
alpha,
A,
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
lda,
B,
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
ldb,
beta,
C,
rocblas_datatype_f16_r,
ldc,
C,
rocblas_datatype_f16_r,
HIPBLAS_R_16F,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
flag);
HIPBLAS_R_32F,
HIPBLAS_GEMM_DEFAULT
);
#else
return cublasGemmEx(
handle,
......
......@@ -8,7 +8,7 @@ import os
parser = argparse.ArgumentParser(description='allreduce hook example')
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument("--local-rank", default=0, type=int)
args = parser.parse_args()
args.distributed = False
......
......@@ -8,7 +8,7 @@ from apex.parallel import DistributedDataParallel
parser = argparse.ArgumentParser()
# FOR DISTRIBUTED: Parse for the local_rank argument, which will be supplied
# automatically by torch.distributed.launch.
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument("--local-rank", default=0, type=int)
parser.add_argument("--opt_level", default="O2", type=str)
args = parser.parse_args()
......
......@@ -26,7 +26,7 @@ batch_size = 32
from apex.parallel import DistributedDataParallel as DDP
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument("--local-rank", default=0, type=int)
parser.add_argument("--fp16", action='store_true', default=False)
parser.add_argument("--fp64", action='store_true', default=False)
parser.add_argument("--group_size", default=0, type=int)
......
......@@ -23,7 +23,7 @@ def compare(desc, inp1, inp2, error= 1e-5):
return close
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--local-rank', type=int, default=0)
parser.add_argument('--apex', action='store_true')
args = parser.parse_args()
......
......@@ -26,7 +26,7 @@ batch_size = 32
from apex.parallel import DistributedDataParallel as DDP
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument("--local-rank", default=0, type=int)
parser.add_argument("--fp16", action='store_true', default=False)
parser.add_argument("--fp64", action='store_true', default=False)
args = parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment