Unverified Commit 5ecad142 authored by Hubert Lu's avatar Hubert Lu Committed by GitHub
Browse files

Make rocblas_gemm_flags_fp16_alt_impl in MHA and MLP backward compatible with...

Make rocblas_gemm_flags_fp16_alt_impl in MHA and MLP backward compatible with old PyTorch versions (#74)

* First attempt to make rocblas flag backward compatible

* Fix some bugs

* Fix some bugs

* Make rocblas_gemm_flags_fp16_alt_impl in MHA backward compatible with old PyTorch versions

* Add groupbn extension unit tests for ROCm

* Fix some bugs
parent 063d720f
...@@ -87,9 +87,6 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -87,9 +87,6 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
char b_layout_n{'n'}; char b_layout_n{'n'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
// Input Linear Q Fwd // Input Linear Q Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
...@@ -321,9 +318,16 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -321,9 +318,16 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_t{'t'}; char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#if USE_GEMM_FLAGS_FP16_ALT_IMPL #ifdef __HIP_PLATFORM_HCC__
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#endif #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef ROCM_BACKWARD_PASS_GUARD
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
#endif
#endif
// Output Linear Dgrad // Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
......
...@@ -113,10 +113,6 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -113,10 +113,6 @@ std::vector<torch::Tensor> fwd_cuda(
1.0e-5, static_cast<const at::Half *>(lyr_nrm_gamma_weights.data_ptr()), 1.0e-5, static_cast<const at::Half *>(lyr_nrm_gamma_weights.data_ptr()),
static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr())); static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr()));
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
// Input Linear Q Fwd // Input Linear Q Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
...@@ -377,10 +373,15 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -377,10 +373,15 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_t{'t'}; char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#if USE_GEMM_FLAGS_FP16_ALT_IMPL #ifdef __HIP_PLATFORM_HCC__
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef ROCM_BACKWARD_PASS_GUARD
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
#endif
#endif #endif
// Dropout Add Backward // Dropout Add Backward
apex_masked_scale_cuda<at::Half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
......
...@@ -88,9 +88,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -88,9 +88,7 @@ std::vector<torch::Tensor> fwd_cuda(
char b_layout_n{'n'}; char b_layout_n{'n'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
// Input Linear Fwd // Input Linear Fwd
input_lin_results.copy_(input_biases); input_lin_results.copy_(input_biases);
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
...@@ -275,8 +273,14 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -275,8 +273,14 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_t{'t'}; char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#if USE_GEMM_FLAGS_FP16_ALT_IMPL #ifdef __HIP_PLATFORM_HCC__
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef ROCM_BACKWARD_PASS_GUARD
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
#endif
#endif #endif
// Output Linear Dgrad // Output Linear Dgrad
......
...@@ -80,9 +80,6 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, ...@@ -80,9 +80,6 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
char b_layout_n{'n'}; char b_layout_n{'n'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
// Input Linear Fwd // Input Linear Fwd
input_lin_results.copy_(input_biases); input_lin_results.copy_(input_biases);
...@@ -276,8 +273,14 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -276,8 +273,14 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_t{'t'}; char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#if USE_GEMM_FLAGS_FP16_ALT_IMPL #ifdef __HIP_PLATFORM_HCC__
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef ROCM_BACKWARD_PASS_GUARD
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
#endif
#endif #endif
// Output Linear Dgrad // Output Linear Dgrad
......
...@@ -79,9 +79,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -79,9 +79,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
char b_layout_n{'n'}; char b_layout_n{'n'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
// Input Linear Fwd // Input Linear Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
...@@ -271,8 +269,14 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -271,8 +269,14 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_t{'t'}; char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#if USE_GEMM_FLAGS_FP16_ALT_IMPL #ifdef __HIP_PLATFORM_HCC__
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef ROCM_BACKWARD_PASS_GUARD
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
#endif
#endif #endif
// Output Linear Dgrad // Output Linear Dgrad
......
...@@ -100,11 +100,6 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -100,11 +100,6 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
1.0e-5, static_cast<const at::Half *>(lyr_nrm_gamma_weights.data_ptr()), 1.0e-5, static_cast<const at::Half *>(lyr_nrm_gamma_weights.data_ptr()),
static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr())); static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr()));
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
// Input Linear Fwd // Input Linear Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
...@@ -324,8 +319,14 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -324,8 +319,14 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_t{'t'}; char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#if USE_GEMM_FLAGS_FP16_ALT_IMPL #ifdef __HIP_PLATFORM_HCC__
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef ROCM_BACKWARD_PASS_GUARD
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
#endif
#endif #endif
// Dropout Add Backward // Dropout Add Backward
......
...@@ -4,7 +4,6 @@ import sys ...@@ -4,7 +4,6 @@ import sys
test_dirs = ["groupbn", "layer_norm", "multihead_attn", "."] # "." for test_label_smoothing.py test_dirs = ["groupbn", "layer_norm", "multihead_attn", "."] # "." for test_label_smoothing.py
ROCM_BLACKLIST = [ ROCM_BLACKLIST = [
"groupbn",
"layer_norm" "layer_norm"
] ]
......
...@@ -22,10 +22,6 @@ ...@@ -22,10 +22,6 @@
#define BIAS_RELU_BW_NTHREADS_Y 16 // backward number of thread in batch dim #define BIAS_RELU_BW_NTHREADS_Y 16 // backward number of thread in batch dim
#define BIAS_RELU_RED_PER_THREAD 16 // backward minimal reduction length per thread #define BIAS_RELU_RED_PER_THREAD 16 // backward minimal reduction length per thread
// #ifdef __HIP_PLATFORM_HCC__
// #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
// #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
// #endif
// move to a header later on // move to a header later on
#define ILP 4 #define ILP 4
...@@ -1514,7 +1510,9 @@ int mlp_bp( ...@@ -1514,7 +1510,9 @@ int mlp_bp(
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL #if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef ROCM_BACKWARD_PASS_GUARD
flag = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; flag = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
#endif #endif
#endif #endif
......
...@@ -9,6 +9,21 @@ import os ...@@ -9,6 +9,21 @@ import os
# ninja build does not work unless include_dirs are abs path # ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__)) this_dir = os.path.dirname(os.path.abspath(__file__))
torch_dir = torch.__path__[0]
# https://github.com/pytorch/pytorch/pull/71881
# For the extensions which have rocblas_gemm_flags_fp16_alt_impl we need to make sure if at::BackwardPassGuard exists.
# It helps the extensions be backward compatible with old PyTorch versions.
# The check and ROCM_BACKWARD_PASS_GUARD in nvcc/hipcc args can be retired once the PR is merged into PyTorch upstream.
context_file = os.path.join(torch_dir, "include", "ATen", "Context.h")
if os.path.exists(context_file):
lines = open(context_file, 'r').readlines()
found_Backward_Pass_Guard = False
for line in lines:
if "BackwardPassGuard" in line:
found_Backward_Pass_Guard = True
break
def get_cuda_bare_metal_version(cuda_dir): def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
...@@ -237,7 +252,9 @@ if "--cuda_ext" in sys.argv: ...@@ -237,7 +252,9 @@ if "--cuda_ext" in sys.argv:
'csrc/mlp_cuda.cu'], 'csrc/mlp_cuda.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')], include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros})) 'nvcc':['-O3'] + version_dependent_macros if not found_Backward_Pass_Guard
else ['-O3'] + version_dependent_macros + ['-DROCM_BACKWARD_PASS_GUARD']}))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fused_dense_cuda', CUDAExtension(name='fused_dense_cuda',
sources=['csrc/fused_dense.cpp', sources=['csrc/fused_dense.cpp',
...@@ -365,7 +382,6 @@ if "--deprecated_fused_lamb" in sys.argv or "--cuda_ext" in sys.argv: ...@@ -365,7 +382,6 @@ if "--deprecated_fused_lamb" in sys.argv or "--cuda_ext" in sys.argv:
# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
# See https://github.com/pytorch/pytorch/pull/70650 # See https://github.com/pytorch/pytorch/pull/70650
generator_flag = [] generator_flag = []
torch_dir = torch.__path__[0]
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "cuda", "CUDAGeneratorImpl.h")): if os.path.exists(os.path.join(torch_dir, "include", "ATen", "cuda", "CUDAGeneratorImpl.h")):
generator_flag = ["-DNEW_GENERATOR_PATH"] generator_flag = ["-DNEW_GENERATOR_PATH"]
...@@ -475,6 +491,8 @@ if "--fast_multihead_attn" in sys.argv or "--cuda_ext" in sys.argv: ...@@ -475,6 +491,8 @@ if "--fast_multihead_attn" in sys.argv or "--cuda_ext" in sys.argv:
'-I/opt/rocm/include/rocrand', '-I/opt/rocm/include/rocrand',
'-U__HIP_NO_HALF_OPERATORS__', '-U__HIP_NO_HALF_OPERATORS__',
'-U__HIP_NO_HALF_CONVERSIONS__'] + version_dependent_macros + generator_flag '-U__HIP_NO_HALF_CONVERSIONS__'] + version_dependent_macros + generator_flag
if found_Backward_Pass_Guard:
hipcc_args_mha = hipcc_args_mha + ['-DROCM_BACKWARD_PASS_GUARD']
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_additive_mask_softmax_dropout', CUDAExtension(name='fast_additive_mask_softmax_dropout',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment