Unverified Commit b6a1f48b authored by athitten's avatar athitten Committed by GitHub
Browse files

Add rocblas_alt_impl falg for bwd rocblas calls in MHA (#70)



* Add missing flags arg in gemm_switch_fp32accum call

* Add rocblas_alt_impl flag in MHA

<rev> Add rocblas_alt_impl flag for all bwd gemms in MHA module

* Use ifdef for rocblas_gemm_flags_fp16_alt_impl to target at various AMD hardware
Co-authored-by: default avatarhubertlu-tw <hubertlu@amd.com>
parent 7bef81f7
...@@ -87,6 +87,10 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -87,6 +87,10 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
char b_layout_n{'n'}; char b_layout_n{'n'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
// Input Linear Q Fwd // Input Linear Q Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
...@@ -159,7 +163,8 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -159,7 +163,8 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast<half*>(softmax_results_ptr), static_cast<half*>(softmax_results_ptr),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
attn_batches); attn_batches,
flags);
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
...@@ -212,7 +217,8 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -212,7 +217,8 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast<half*>(matmul2_results.data_ptr()), static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches, head_dim*attn_batches,
head_dim, head_dim,
attn_batches); attn_batches,
flags);
// Output Linear // Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
...@@ -315,7 +321,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -315,7 +321,9 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_t{'t'}; char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
// Output Linear Dgrad // Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
...@@ -388,7 +396,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -388,7 +396,8 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
attn_batches); attn_batches,
flags);
// Matmul2 Dgrad2 // Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
...@@ -410,7 +419,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -410,7 +419,8 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr, v_lin_grads_ptr,
lead_dim_kv, lead_dim_kv,
batch_stride_kv, batch_stride_kv,
attn_batches); attn_batches,
flags);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<at::Half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
...@@ -449,7 +459,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -449,7 +459,8 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr, q_lin_grads_ptr,
lead_dim_q, lead_dim_q,
batch_stride_q, batch_stride_q,
attn_batches); attn_batches,
flags);
// Matmul1 Dgrad2 // Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
...@@ -471,7 +482,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -471,7 +482,8 @@ std::vector<torch::Tensor> bwd_cuda(
k_lin_grads_ptr, k_lin_grads_ptr,
lead_dim_kv, lead_dim_kv,
batch_stride_kv, batch_stride_kv,
attn_batches); attn_batches,
flags);
// Input Linear Q Dgrad // Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
......
...@@ -113,6 +113,10 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -113,6 +113,10 @@ std::vector<torch::Tensor> fwd_cuda(
1.0e-5, static_cast<const at::Half *>(lyr_nrm_gamma_weights.data_ptr()), 1.0e-5, static_cast<const at::Half *>(lyr_nrm_gamma_weights.data_ptr()),
static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr())); static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr()));
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
// Input Linear Q Fwd // Input Linear Q Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
...@@ -185,7 +189,8 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -185,7 +189,8 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(softmax_results_ptr), static_cast<half*>(softmax_results_ptr),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
attn_batches); attn_batches,
flags);
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
...@@ -239,7 +244,8 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -239,7 +244,8 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(matmul2_results.data_ptr()), static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches, head_dim*attn_batches,
head_dim, head_dim,
attn_batches); attn_batches,
flags);
// Output Linear // Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
...@@ -371,6 +377,10 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -371,6 +377,10 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_t{'t'}; char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
// Dropout Add Backward // Dropout Add Backward
apex_masked_scale_cuda<at::Half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
...@@ -452,7 +462,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -452,7 +462,8 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
attn_batches); attn_batches,
flags);
// Matmul2 Dgrad2 // Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
...@@ -474,7 +485,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -474,7 +485,8 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr, v_lin_grads_ptr,
lead_dim_kv, lead_dim_kv,
batch_stride_kv, batch_stride_kv,
attn_batches); attn_batches,
flags);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<at::Half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
...@@ -513,7 +525,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -513,7 +525,8 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr, q_lin_grads_ptr,
lead_dim_q, lead_dim_q,
batch_stride_q, batch_stride_q,
attn_batches); attn_batches,
flags);
// Matmul1 Dgrad2 // Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
...@@ -535,7 +548,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -535,7 +548,8 @@ std::vector<torch::Tensor> bwd_cuda(
k_lin_grads_ptr, k_lin_grads_ptr,
lead_dim_kv, lead_dim_kv,
batch_stride_kv, batch_stride_kv,
attn_batches); attn_batches,
flags);
// Input Linear Q Dgrad // Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
......
...@@ -88,6 +88,9 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -88,6 +88,9 @@ std::vector<torch::Tensor> fwd_cuda(
char b_layout_n{'n'}; char b_layout_n{'n'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
// Input Linear Fwd // Input Linear Fwd
input_lin_results.copy_(input_biases); input_lin_results.copy_(input_biases);
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
...@@ -135,7 +138,8 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -135,7 +138,8 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(bmm1_results_ptr), static_cast<half*>(bmm1_results_ptr),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
attn_batches); attn_batches,
flags);
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
...@@ -180,7 +184,8 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -180,7 +184,8 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<half*>(matmul2_results.data_ptr()), static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches, head_dim*attn_batches,
head_dim, head_dim,
attn_batches); attn_batches,
flags);
outputs.copy_(output_biases); outputs.copy_(output_biases);
...@@ -270,6 +275,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -270,6 +275,9 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_t{'t'}; char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
// Output Linear Dgrad // Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
...@@ -344,7 +352,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -344,7 +352,8 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
attn_batches); attn_batches,
flags);
// Matmul2 Dgrad2 // Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
...@@ -366,7 +375,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -366,7 +375,8 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr, v_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride, batch_stride,
attn_batches); attn_batches,
flags);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad // Softmax Grad
...@@ -403,7 +413,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -403,7 +413,8 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr, q_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride, batch_stride,
attn_batches); attn_batches,
flags);
// Matmul1 Dgrad2 // Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
...@@ -425,7 +436,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -425,7 +436,8 @@ std::vector<torch::Tensor> bwd_cuda(
k_lin_grads_ptr, k_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride, batch_stride,
attn_batches); attn_batches,
flags);
// Input Linear Dgrad // Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
......
...@@ -80,6 +80,10 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, ...@@ -80,6 +80,10 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
char b_layout_n{'n'}; char b_layout_n{'n'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
// Input Linear Fwd // Input Linear Fwd
input_lin_results.copy_(input_biases); input_lin_results.copy_(input_biases);
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
...@@ -127,7 +131,8 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, ...@@ -127,7 +131,8 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
static_cast<half*>(softmax_results_ptr), static_cast<half*>(softmax_results_ptr),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
attn_batches); attn_batches,
flags);
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
...@@ -180,7 +185,8 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, ...@@ -180,7 +185,8 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
static_cast<half*>(matmul2_results.data_ptr()), static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches, head_dim*attn_batches,
head_dim, head_dim,
attn_batches); attn_batches,
flags);
outputs.copy_(output_biases); outputs.copy_(output_biases);
...@@ -270,6 +276,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -270,6 +276,9 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_t{'t'}; char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
// Output Linear Dgrad // Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
...@@ -344,7 +353,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -344,7 +353,8 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
attn_batches); attn_batches,
flags);
// Matmul2 Dgrad2 // Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
...@@ -366,7 +376,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -366,7 +376,8 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr, v_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride, batch_stride,
attn_batches); attn_batches,
flags);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad // Softmax Grad
...@@ -398,7 +409,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -398,7 +409,8 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr, q_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride, batch_stride,
attn_batches); attn_batches,
flags);
// Matmul1 Dgrad2 // Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
...@@ -420,7 +432,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -420,7 +432,8 @@ std::vector<torch::Tensor> bwd_cuda(
k_lin_grads_ptr, k_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride, batch_stride,
attn_batches); attn_batches,
flags);
// Input Linear Dgrad // Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
......
...@@ -79,6 +79,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -79,6 +79,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
char b_layout_n{'n'}; char b_layout_n{'n'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
// Input Linear Fwd // Input Linear Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
...@@ -125,7 +128,8 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -125,7 +128,8 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast<half*>(softmax_results_ptr), static_cast<half*>(softmax_results_ptr),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
attn_batches); attn_batches,
flags);
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
...@@ -178,7 +182,8 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -178,7 +182,8 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast<half*>(matmul2_results.data_ptr()), static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches, head_dim*attn_batches,
head_dim, head_dim,
attn_batches); attn_batches,
flags);
// Output Linear // Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
...@@ -266,6 +271,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -266,6 +271,9 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_t{'t'}; char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
// Output Linear Dgrad // Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
...@@ -339,7 +347,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -339,7 +347,8 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
attn_batches); attn_batches,
flags);
// Matmul2 Dgrad2 // Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
...@@ -361,7 +370,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -361,7 +370,8 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr, v_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride, batch_stride,
attn_batches); attn_batches,
flags);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<at::Half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
...@@ -400,7 +410,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -400,7 +410,8 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr, q_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride, batch_stride,
attn_batches); attn_batches,
flags);
// Matmul1 Dgrad2 // Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
...@@ -422,7 +433,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -422,7 +433,8 @@ std::vector<torch::Tensor> bwd_cuda(
k_lin_grads_ptr, k_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride, batch_stride,
attn_batches); attn_batches,
flags);
// Input Linear Dgrad // Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
......
...@@ -100,6 +100,11 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -100,6 +100,11 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
1.0e-5, static_cast<const at::Half *>(lyr_nrm_gamma_weights.data_ptr()), 1.0e-5, static_cast<const at::Half *>(lyr_nrm_gamma_weights.data_ptr()),
static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr())); static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr()));
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
// Input Linear Fwd // Input Linear Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
...@@ -147,7 +152,8 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -147,7 +152,8 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast<half*>(softmax_results_ptr), static_cast<half*>(softmax_results_ptr),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
attn_batches); attn_batches,
flags);
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
...@@ -201,7 +207,8 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -201,7 +207,8 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast<half*>(matmul2_results.data_ptr()), static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches, head_dim*attn_batches,
head_dim, head_dim,
attn_batches); attn_batches,
flags);
// Output Linear // Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
...@@ -317,6 +324,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -317,6 +324,9 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_t{'t'}; char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
// Dropout Add Backward // Dropout Add Backward
apex_masked_scale_cuda<at::Half, float, uint32_t>( apex_masked_scale_cuda<at::Half, float, uint32_t>(
...@@ -397,7 +407,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -397,7 +407,8 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
attn_batches); attn_batches,
flags);
// Matmul2 Dgrad2 // Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
...@@ -419,7 +430,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -419,7 +430,8 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr, v_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride, batch_stride,
attn_batches); attn_batches,
flags);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<at::Half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
...@@ -458,7 +470,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -458,7 +470,8 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr, q_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride, batch_stride,
attn_batches); attn_batches,
flags);
// Matmul1 Dgrad2 // Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
...@@ -480,7 +493,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -480,7 +493,8 @@ std::vector<torch::Tensor> bwd_cuda(
k_lin_grads_ptr, k_lin_grads_ptr,
lead_dim, lead_dim,
batch_stride, batch_stride,
attn_batches); attn_batches,
flags);
// Input Linear Dgrad // Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
......
...@@ -42,7 +42,7 @@ cublasOperation_t convertTransToCublasOperation(char trans) { ...@@ -42,7 +42,7 @@ cublasOperation_t convertTransToCublasOperation(char trans) {
void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k, void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB, float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo) { float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo, rocblas_int flags) {
cublasOperation_t opa = convertTransToCublasOperation(transa); cublasOperation_t opa = convertTransToCublasOperation(transa);
cublasOperation_t opb = convertTransToCublasOperation(transb); cublasOperation_t opb = convertTransToCublasOperation(transb);
...@@ -63,17 +63,17 @@ void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k, ...@@ -63,17 +63,17 @@ void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k, void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB, float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount) { float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_int flags) {
auto stream = c10::cuda::getCurrentCUDAStream(); auto stream = c10::cuda::getCurrentCUDAStream();
if ( (transa == 't') && (transb == 'n') ) { if ( (transa == 't') && (transb == 'n') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); } if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo, flags); }
else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); } else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo, flags); }
} else if ( (transa == 'n') && (transb == 'n') ) { } else if ( (transa == 'n') && (transb == 'n') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); } if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo, flags); }
else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); } else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo, flags); }
} else if ( (transa == 'n') && (transb == 't') ) { } else if ( (transa == 'n') && (transb == 't') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); } if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo, flags); }
else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo); } else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo, flags); }
} else { } else {
AT_ASSERTM(false, "TransA and TransB are invalid"); AT_ASSERTM(false, "TransA and TransB are invalid");
} }
...@@ -127,7 +127,7 @@ void HgemmStridedBatched(char transa, char transb, long m, ...@@ -127,7 +127,7 @@ void HgemmStridedBatched(char transa, char transb, long m,
// gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA, // gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA,
// b, ldb, strideB, beta, c, ldc, strideC, batchCount); // b, ldb, strideB, beta, c, ldc, strideC, batchCount);
gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA, gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA,
b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount); b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, flags);
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment