Unverified Commit 3ba7192d authored by Peng's avatar Peng Committed by GitHub
Browse files

Merge pull request #116 from ROCmSoftwarePlatform/revert_hipblas

Revert "Changes to support hipblas migration (#113)"
parents 8fc9b21f e4d21865
...@@ -90,9 +90,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -90,9 +90,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Q Fwd // Input Linear Q Fwd
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
output_lin_q_dim, output_lin_q_dim,
batches_q, batches_q,
embed_dim, embed_dim,
...@@ -113,12 +113,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -113,12 +113,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// Input Linear KV Fwd // Input Linear KV Fwd
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
output_lin_kv_dim, output_lin_kv_dim,
batches_kv, batches_kv,
embed_dim, embed_dim,
...@@ -139,7 +139,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -139,7 +139,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -219,9 +219,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -219,9 +219,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags); flags);
// Output Linear // Output Linear
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
embed_dim, embed_dim,
batches_q, batches_q,
embed_dim, embed_dim,
...@@ -242,7 +242,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -242,7 +242,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_lin_q_results, return {input_lin_q_results,
...@@ -332,9 +332,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -332,9 +332,9 @@ std::vector<torch::Tensor> bwd_cuda(
#endif #endif
// Output Linear Dgrad // Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
embed_dim, embed_dim,
batches_q, batches_q,
embed_dim, embed_dim,
...@@ -355,12 +355,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -355,12 +355,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// Output Linear Wgrad // Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_T,
embed_dim, embed_dim,
embed_dim, embed_dim,
batches_q, batches_q,
...@@ -381,7 +381,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -381,7 +381,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -493,9 +493,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -493,9 +493,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags); flags);
// Input Linear Q Dgrad // Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
embed_dim, embed_dim,
batches_q, batches_q,
output_lin_q_dim, output_lin_q_dim,
...@@ -516,12 +516,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -516,12 +516,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// Input Linear Q Wgrad // Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_T,
embed_dim, embed_dim,
output_lin_q_dim, output_lin_q_dim,
batches_q, batches_q,
...@@ -542,12 +542,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -542,12 +542,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// Input Linear KV Dgrad // Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
embed_dim, embed_dim,
batches_kv, batches_kv,
output_lin_kv_dim, output_lin_kv_dim,
...@@ -568,12 +568,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -568,12 +568,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// Input Linear KV Wgrad // Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_T,
embed_dim, embed_dim,
output_lin_kv_dim, output_lin_kv_dim,
batches_kv, batches_kv,
...@@ -594,7 +594,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -594,7 +594,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); // TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_q_grads, input_q_grads,
......
...@@ -116,9 +116,9 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -116,9 +116,9 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr())); static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr()));
// Input Linear Q Fwd // Input Linear Q Fwd
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
output_lin_q_dim, output_lin_q_dim,
batches_q, batches_q,
embed_dim, embed_dim,
...@@ -140,12 +140,12 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -140,12 +140,12 @@ std::vector<torch::Tensor> fwd_cuda(
rocblas_datatype_f32_r /*compute_type*/, rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// Input Linear KV Fwd // Input Linear KV Fwd
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
output_lin_kv_dim, output_lin_kv_dim,
batches_kv, batches_kv,
embed_dim, embed_dim,
...@@ -166,7 +166,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -166,7 +166,7 @@ std::vector<torch::Tensor> fwd_cuda(
rocblas_datatype_f32_r /*compute_type*/, rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
b_layout_n, b_layout_n,
...@@ -246,9 +246,9 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -246,9 +246,9 @@ std::vector<torch::Tensor> fwd_cuda(
flags); flags);
// Output Linear // Output Linear
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
embed_dim, embed_dim,
batches_q, batches_q,
embed_dim, embed_dim,
...@@ -269,7 +269,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -269,7 +269,7 @@ std::vector<torch::Tensor> fwd_cuda(
rocblas_datatype_f32_r /*compute_type*/, rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// End-of-block Dropout-Add // End-of-block Dropout-Add
if (is_training) { if (is_training) {
...@@ -396,9 +396,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -396,9 +396,9 @@ std::vector<torch::Tensor> bwd_cuda(
(1.0 / (1.0 - dropout_prob))); (1.0 / (1.0 - dropout_prob)));
// Output Linear Dgrad // Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
embed_dim, embed_dim,
batches_q, batches_q,
embed_dim, embed_dim,
...@@ -419,12 +419,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -419,12 +419,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r /*compute_type*/, rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// Output Linear Wgrad // Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_T,
embed_dim, embed_dim,
embed_dim, embed_dim,
batches_q, batches_q,
...@@ -445,7 +445,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -445,7 +445,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r /*compute_type*/, rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -557,9 +557,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -557,9 +557,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags); flags);
// Input Linear Q Dgrad // Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
embed_dim, embed_dim,
batches_q, batches_q,
output_lin_q_dim, output_lin_q_dim,
...@@ -581,12 +581,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -581,12 +581,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r /*compute_type*/, rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// Input Linear Q Wgrad // Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_T,
embed_dim, embed_dim,
output_lin_q_dim, output_lin_q_dim,
batches_q, batches_q,
...@@ -607,12 +607,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -607,12 +607,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r /*compute_type*/, rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// Input Linear KV Dgrad // Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
embed_dim, embed_dim,
batches_kv, batches_kv,
output_lin_kv_dim, output_lin_kv_dim,
...@@ -633,12 +633,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -633,12 +633,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r /*compute_type*/, rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// Input Linear KV Wgrad // Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_T,
embed_dim, embed_dim,
output_lin_kv_dim, output_lin_kv_dim,
batches_kv, batches_kv,
...@@ -659,7 +659,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -659,7 +659,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r /*compute_type*/, rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// Fused Layer Norm Bwd with Residual Add // Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient<half,float>( HostLayerNormGradient<half,float>(
......
...@@ -86,9 +86,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -86,9 +86,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
// Input Linear Fwd // Input Linear Fwd
input_lin_results.copy_(input_biases); input_lin_results.copy_(input_biases);
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
output_lin_dim, output_lin_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -109,7 +109,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -109,7 +109,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -183,9 +183,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -183,9 +183,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
outputs.copy_(output_biases); outputs.copy_(output_biases);
// Output Linear // Output Linear
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
embed_dim, embed_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -206,7 +206,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -206,7 +206,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_lin_results, bmm1_results, dropout_results, return {input_lin_results, bmm1_results, dropout_results,
...@@ -281,9 +281,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -281,9 +281,9 @@ std::vector<torch::Tensor> bwd_cuda(
#endif #endif
// Output Linear Dgrad // Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
embed_dim, embed_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -304,12 +304,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -304,12 +304,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// Output Linear Wgrad // Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_T,
embed_dim, embed_dim,
embed_dim, embed_dim,
batches, batches,
...@@ -330,7 +330,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -330,7 +330,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false); auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1 // MatMul2 Dgrad1
...@@ -441,9 +441,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -441,9 +441,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags); flags);
// Input Linear Dgrad // Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
embed_dim, embed_dim,
batches, batches,
output_lin_dim, output_lin_dim,
...@@ -464,12 +464,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -464,12 +464,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// Input Linear Wgrad // Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_T,
embed_dim, embed_dim,
output_lin_dim, output_lin_dim,
batches, batches,
...@@ -490,7 +490,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -490,7 +490,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
......
...@@ -84,9 +84,9 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, ...@@ -84,9 +84,9 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
// Input Linear Fwd // Input Linear Fwd
input_lin_results.copy_(input_biases); input_lin_results.copy_(input_biases);
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
output_lin_dim, output_lin_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -107,7 +107,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, ...@@ -107,7 +107,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -189,9 +189,9 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, ...@@ -189,9 +189,9 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
outputs.copy_(output_biases); outputs.copy_(output_biases);
// Output Linear // Output Linear
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
embed_dim, embed_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -212,7 +212,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, ...@@ -212,7 +212,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_lin_results, softmax_results, dropout_results, return {input_lin_results, softmax_results, dropout_results,
...@@ -287,9 +287,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -287,9 +287,9 @@ std::vector<torch::Tensor> bwd_cuda(
#endif #endif
// Output Linear Dgrad // Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
embed_dim, embed_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -310,12 +310,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -310,12 +310,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// Output Linear Wgrad // Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_T,
embed_dim, embed_dim,
embed_dim, embed_dim,
batches, batches,
...@@ -336,7 +336,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -336,7 +336,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false); auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1 // MatMul2 Dgrad1
...@@ -441,9 +441,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -441,9 +441,9 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches, attn_batches,
flags); flags);
// Input Linear Dgrad // Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
embed_dim, embed_dim,
batches, batches,
output_lin_dim, output_lin_dim,
...@@ -464,12 +464,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -464,12 +464,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// Input Linear Wgrad // Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_T,
embed_dim, embed_dim,
output_lin_dim, output_lin_dim,
batches, batches,
...@@ -490,7 +490,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -490,7 +490,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
......
...@@ -82,9 +82,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -82,9 +82,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd // Input Linear Fwd
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
output_lin_dim, output_lin_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -105,7 +105,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -105,7 +105,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -185,9 +185,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -185,9 +185,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags); flags);
// Output Linear // Output Linear
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
embed_dim, embed_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -208,7 +208,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -208,7 +208,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_lin_results, softmax_results, dropout_results, return {input_lin_results, softmax_results, dropout_results,
...@@ -283,9 +283,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -283,9 +283,9 @@ std::vector<torch::Tensor> bwd_cuda(
#endif #endif
// Output Linear Dgrad // Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
embed_dim, embed_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -306,12 +306,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -306,12 +306,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// Output Linear Wgrad // Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_T,
embed_dim, embed_dim,
embed_dim, embed_dim,
batches, batches,
...@@ -332,7 +332,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -332,7 +332,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -444,9 +444,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -444,9 +444,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags); flags);
// Input Linear Dgrad // Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
embed_dim, embed_dim,
batches, batches,
output_lin_dim, output_lin_dim,
...@@ -467,12 +467,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -467,12 +467,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// Input Linear Wgrad // Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_T,
embed_dim, embed_dim,
output_lin_dim, output_lin_dim,
batches, batches,
...@@ -493,7 +493,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -493,7 +493,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
......
...@@ -103,9 +103,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -103,9 +103,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr())); static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr()));
// Input Linear Fwd // Input Linear Fwd
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
output_lin_dim, output_lin_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -127,7 +127,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -127,7 +127,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r /*compute_type*/, rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -208,9 +208,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -208,9 +208,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags); flags);
// Output Linear // Output Linear
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_T,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
embed_dim, embed_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -231,7 +231,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -231,7 +231,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r /*compute_type*/, rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// End-of-block Dropout-Add // End-of-block Dropout-Add
...@@ -341,9 +341,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -341,9 +341,9 @@ std::vector<torch::Tensor> bwd_cuda(
(1.0 / (1.0 - dropout_prob))); (1.0 / (1.0 - dropout_prob)));
// Output Linear Dgrad // Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
embed_dim, embed_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -364,12 +364,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -364,12 +364,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r /*compute_type*/, rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// Output Linear Wgrad // Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_T,
embed_dim, embed_dim,
embed_dim, embed_dim,
batches, batches,
...@@ -390,7 +390,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -390,7 +390,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r /*compute_type*/, rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -502,9 +502,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -502,9 +502,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags); flags);
// Input Linear Dgrad // Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
embed_dim, embed_dim,
batches, batches,
output_lin_dim, output_lin_dim,
...@@ -526,12 +526,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -526,12 +526,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r /*compute_type*/, rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// Input Linear Wgrad // Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
hipOperationToRocOperation(CUBLAS_OP_N), CUBLAS_OP_N,
hipOperationToRocOperation(CUBLAS_OP_T), CUBLAS_OP_T,
embed_dim, embed_dim,
output_lin_dim, output_lin_dim,
batches, batches,
...@@ -553,7 +553,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -553,7 +553,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r /*compute_type*/, rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags))); flags));
// Fused Layer Norm Bwd with Residual Add // Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient<half, float>( HostLayerNormGradient<half, float>(
......
...@@ -7,8 +7,6 @@ ...@@ -7,8 +7,6 @@
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <rocblas/rocblas.h>
//#include <ATen/ATen.h> //#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
...@@ -44,52 +42,6 @@ cublasOperation_t convertTransToCublasOperation(char trans) { ...@@ -44,52 +42,6 @@ cublasOperation_t convertTransToCublasOperation(char trans) {
} }
} }
// needed to work around calling rocblas API instead of hipblas API
static rocblas_operation hipOperationToRocOperation(hipblasOperation_t op)
{
switch(op)
{
case HIPBLAS_OP_N:
return rocblas_operation_none;
case HIPBLAS_OP_T:
return rocblas_operation_transpose;
case HIPBLAS_OP_C:
return rocblas_operation_conjugate_transpose;
}
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
}
static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error)
{
switch(error)
{
case rocblas_status_size_unchanged:
case rocblas_status_size_increased:
case rocblas_status_success:
case rocblas_status_continue:
return HIPBLAS_STATUS_SUCCESS;
case rocblas_status_invalid_handle:
return HIPBLAS_STATUS_NOT_INITIALIZED;
case rocblas_status_not_implemented:
case rocblas_status_excluded_from_build:
return HIPBLAS_STATUS_NOT_SUPPORTED;
case rocblas_status_invalid_pointer:
case rocblas_status_invalid_size:
case rocblas_status_invalid_value:
case rocblas_status_size_query_mismatch:
return HIPBLAS_STATUS_INVALID_VALUE;
case rocblas_status_memory_error:
return HIPBLAS_STATUS_ALLOC_FAILED;
case rocblas_status_internal_error:
case rocblas_status_perf_degraded:
case rocblas_status_check_numerics_fail:
return HIPBLAS_STATUS_INTERNAL_ERROR;
case rocblas_status_arch_mismatch:
return HIPBLAS_STATUS_ARCH_MISMATCH;
}
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
}
void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k, void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB, float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo, rocblas_int flags) { float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo, rocblas_int flags) {
...@@ -102,13 +54,13 @@ void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k, ...@@ -102,13 +54,13 @@ void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
float fAlpha = alpha; float fAlpha = alpha;
float fBeta = beta; float fBeta = beta;
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_strided_batched_ex((rocblas_handle)handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle,
hipOperationToRocOperation(opa), hipOperationToRocOperation(opb), (int)m, (int)n, (int)k, opa, opb, (int)m, (int)n, (int)k,
(void*)&fAlpha, a, rocblas_datatype_f16_r /*a_type*/, (int)lda, strideA, (void*)&fAlpha, a, rocblas_datatype_f16_r /*a_type*/, (int)lda, strideA,
b, rocblas_datatype_f16_r /*b_type*/, (int)ldb, strideB, b, rocblas_datatype_f16_r /*b_type*/, (int)ldb, strideB,
(void*)&fBeta, c, rocblas_datatype_f16_r /*c_type*/, (int)ldc, strideC, (void*)&fBeta, c, rocblas_datatype_f16_r /*c_type*/, (int)ldc, strideC,
d, rocblas_datatype_f16_r /*d_type*/, int(ldd), strideD, d, rocblas_datatype_f16_r /*d_type*/, int(ldd), strideD,
(int)batchCount, rocblas_datatype_f32_r /*compute_type*/, algo, 0 /*solution_index*/, flags))); (int)batchCount, rocblas_datatype_f32_r /*compute_type*/, algo, 0 /*solution_index*/, flags));
} }
void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k, void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k,
......
...@@ -10,21 +10,10 @@ ...@@ -10,21 +10,10 @@
#include <cublas_v2.h> #include <cublas_v2.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <rocblas/rocblas.h>
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000 #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
// includes cublaslt // includes cublaslt
#include <cublasLt.h> #include <cublasLt.h>
#endif #endif
// until we use hipblas v2
// hipify correctly maps things like CUDA_R_16F to HIP_R_16F,
// however hipblas v1 is still using its custom type
#define HIP_R_64F HIPBLAS_R_64F
#define HIP_R_32F HIPBLAS_R_32F
#define HIP_R_16F HIPBLAS_R_16F
// FP64 Wrapper around cublas GEMMEx // FP64 Wrapper around cublas GEMMEx
cublasStatus_t gemm_bias( cublasStatus_t gemm_bias(
cublasHandle_t handle, cublasHandle_t handle,
...@@ -41,6 +30,33 @@ cublasStatus_t gemm_bias( ...@@ -41,6 +30,33 @@ cublasStatus_t gemm_bias(
const float* beta, const float* beta,
double* C, double* C,
int ldc) { int ldc) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f64_r,
lda,
B,
rocblas_datatype_f64_r,
ldb,
beta,
C,
rocblas_datatype_f64_r,
ldc,
C,
rocblas_datatype_f64_r,
ldc,
rocblas_datatype_f64_r,
rocblas_gemm_algo_standard,
0,
0);
#else
return cublasGemmEx( return cublasGemmEx(
handle, handle,
transa, transa,
...@@ -61,6 +77,7 @@ cublasStatus_t gemm_bias( ...@@ -61,6 +77,7 @@ cublasStatus_t gemm_bias(
ldc, ldc,
CUDA_R_64F, CUDA_R_64F,
CUBLAS_GEMM_DEFAULT); CUBLAS_GEMM_DEFAULT);
#endif
} }
// FP32 Wrapper around cublas GEMMEx // FP32 Wrapper around cublas GEMMEx
...@@ -79,6 +96,34 @@ cublasStatus_t gemm_bias( ...@@ -79,6 +96,34 @@ cublasStatus_t gemm_bias(
const float* beta, const float* beta,
float* C, float* C,
int ldc) { int ldc) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f32_r,
lda,
B,
rocblas_datatype_f32_r,
ldb,
beta,
C,
rocblas_datatype_f32_r,
ldc,
C,
rocblas_datatype_f32_r,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
0);
#else
return cublasGemmEx( return cublasGemmEx(
handle, handle,
transa, transa,
...@@ -99,6 +144,7 @@ cublasStatus_t gemm_bias( ...@@ -99,6 +144,7 @@ cublasStatus_t gemm_bias(
ldc, ldc,
CUDA_R_32F, CUDA_R_32F,
CUBLAS_GEMM_DEFAULT); CUBLAS_GEMM_DEFAULT);
#endif
} }
// FP16 Tensor core wrapper around cublas GEMMEx // FP16 Tensor core wrapper around cublas GEMMEx
...@@ -117,6 +163,33 @@ cublasStatus_t gemm_bias( ...@@ -117,6 +163,33 @@ cublasStatus_t gemm_bias(
const float* beta, const float* beta,
at::Half* C, at::Half* C,
int ldc) { int ldc) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f16_r,
lda,
B,
rocblas_datatype_f16_r,
ldb,
beta,
C,
rocblas_datatype_f16_r,
ldc,
C,
rocblas_datatype_f16_r,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
0);
#else
return cublasGemmEx( return cublasGemmEx(
handle, handle,
transa, transa,
...@@ -137,6 +210,7 @@ cublasStatus_t gemm_bias( ...@@ -137,6 +210,7 @@ cublasStatus_t gemm_bias(
ldc, ldc,
CUDA_R_32F, CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP); CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
} }
......
...@@ -12,8 +12,6 @@ ...@@ -12,8 +12,6 @@
#include <cublas_v2.h> #include <cublas_v2.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <rocblas/rocblas.h>
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000 #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
// includes cublaslt // includes cublaslt
#include <cublasLt.h> #include <cublasLt.h>
...@@ -60,52 +58,6 @@ __device__ __inline__ float sigmoid(float a) { ...@@ -60,52 +58,6 @@ __device__ __inline__ float sigmoid(float a) {
return (retf); return (retf);
} }
// needed to work around calling rocblas API instead of hipblas API
static rocblas_operation hipOperationToRocOperation(hipblasOperation_t op)
{
switch(op)
{
case HIPBLAS_OP_N:
return rocblas_operation_none;
case HIPBLAS_OP_T:
return rocblas_operation_transpose;
case HIPBLAS_OP_C:
return rocblas_operation_conjugate_transpose;
}
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
}
static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error)
{
switch(error)
{
case rocblas_status_size_unchanged:
case rocblas_status_size_increased:
case rocblas_status_success:
case rocblas_status_continue:
return HIPBLAS_STATUS_SUCCESS;
case rocblas_status_invalid_handle:
return HIPBLAS_STATUS_NOT_INITIALIZED;
case rocblas_status_not_implemented:
case rocblas_status_excluded_from_build:
return HIPBLAS_STATUS_NOT_SUPPORTED;
case rocblas_status_invalid_pointer:
case rocblas_status_invalid_size:
case rocblas_status_invalid_value:
case rocblas_status_size_query_mismatch:
return HIPBLAS_STATUS_INVALID_VALUE;
case rocblas_status_memory_error:
return HIPBLAS_STATUS_ALLOC_FAILED;
case rocblas_status_internal_error:
case rocblas_status_perf_degraded:
case rocblas_status_check_numerics_fail:
return HIPBLAS_STATUS_INTERNAL_ERROR;
case rocblas_status_arch_mismatch:
return HIPBLAS_STATUS_ARCH_MISMATCH;
}
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
}
// FP64 Wrapper around cublas GEMMEx // FP64 Wrapper around cublas GEMMEx
cublasStatus_t mlp_gemm( cublasStatus_t mlp_gemm(
cublasHandle_t handle, cublasHandle_t handle,
...@@ -124,10 +76,10 @@ cublasStatus_t mlp_gemm( ...@@ -124,10 +76,10 @@ cublasStatus_t mlp_gemm(
int ldc, int ldc,
int flag) { int flag) {
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
return rocBLASStatusToHIPStatus(rocblas_gemm_ex( return rocblas_gemm_ex(
(rocblas_handle) handle, handle,
hipOperationToRocOperation(transa), transa,
hipOperationToRocOperation(transb), transb,
m, m,
n, n,
k, k,
...@@ -148,7 +100,7 @@ cublasStatus_t mlp_gemm( ...@@ -148,7 +100,7 @@ cublasStatus_t mlp_gemm(
rocblas_datatype_f64_r, rocblas_datatype_f64_r,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0,
flag)); flag);
#else #else
return cublasGemmEx( return cublasGemmEx(
handle, handle,
...@@ -191,10 +143,10 @@ cublasStatus_t mlp_gemm( ...@@ -191,10 +143,10 @@ cublasStatus_t mlp_gemm(
int ldc, int ldc,
int flag) { int flag) {
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
return rocBLASStatusToHIPStatus(rocblas_gemm_ex( return rocblas_gemm_ex(
(rocblas_handle) handle, handle,
hipOperationToRocOperation(transa), transa,
hipOperationToRocOperation(transb), transb,
m, m,
n, n,
k, k,
...@@ -215,7 +167,7 @@ cublasStatus_t mlp_gemm( ...@@ -215,7 +167,7 @@ cublasStatus_t mlp_gemm(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0,
flag)); flag);
#else #else
return cublasGemmEx( return cublasGemmEx(
...@@ -259,10 +211,10 @@ cublasStatus_t mlp_gemm( ...@@ -259,10 +211,10 @@ cublasStatus_t mlp_gemm(
int ldc, int ldc,
int flag) { int flag) {
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
return rocBLASStatusToHIPStatus(rocblas_gemm_ex( return rocblas_gemm_ex(
(rocblas_handle) handle, handle,
hipOperationToRocOperation(transa), transa,
hipOperationToRocOperation(transb), transb,
m, m,
n, n,
k, k,
...@@ -283,7 +235,7 @@ cublasStatus_t mlp_gemm( ...@@ -283,7 +235,7 @@ cublasStatus_t mlp_gemm(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0,
flag)); flag);
#else #else
return cublasGemmEx( return cublasGemmEx(
handle, handle,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment