Unverified Commit cf77e9b5 authored by Hubert Lu's avatar Hubert Lu Committed by GitHub
Browse files

Make rocblas_gemm_flags_fp16_alt_impl backward-compat for new naming (#79)

* Make rocblas_gemm_flags_fp16_alt_impl backward-compat for new naming

* Use BACKWARD_PASS_GUARD_CLASS to prevent lengthy if-statement
parent 27a47345
...@@ -325,8 +325,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -325,8 +325,8 @@ std::vector<torch::Tensor> bwd_cuda(
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL #if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef ROCM_BACKWARD_PASS_GUARD #ifdef BACKWARD_PASS_GUARD
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; flags = at::BACKWARD_PASS_GUARD_CLASS::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif #endif
#endif #endif
#endif #endif
......
...@@ -381,12 +381,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -381,12 +381,12 @@ std::vector<torch::Tensor> bwd_cuda(
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL #if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef ROCM_BACKWARD_PASS_GUARD #ifdef BACKWARD_PASS_GUARD
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; flags = at::BACKWARD_PASS_GUARD_CLASS::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif #endif
#endif #endif
#endif #endif
// Dropout Add Backward // Dropout Add Backward
apex_masked_scale_cuda<at::Half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(output_grads.data_ptr()), static_cast<at::Half const*>(output_grads.data_ptr()),
......
...@@ -280,8 +280,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -280,8 +280,8 @@ std::vector<torch::Tensor> bwd_cuda(
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL #if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef ROCM_BACKWARD_PASS_GUARD #ifdef BACKWARD_PASS_GUARD
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; flags = at::BACKWARD_PASS_GUARD_CLASS::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif #endif
#endif #endif
#endif #endif
......
...@@ -280,8 +280,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -280,8 +280,8 @@ std::vector<torch::Tensor> bwd_cuda(
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL #if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef ROCM_BACKWARD_PASS_GUARD #ifdef BACKWARD_PASS_GUARD
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; flags = at::BACKWARD_PASS_GUARD_CLASS::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif #endif
#endif #endif
#endif #endif
......
...@@ -276,8 +276,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -276,8 +276,8 @@ std::vector<torch::Tensor> bwd_cuda(
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL #if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef ROCM_BACKWARD_PASS_GUARD #ifdef BACKWARD_PASS_GUARD
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; flags = at::BACKWARD_PASS_GUARD_CLASS::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif #endif
#endif #endif
#endif #endif
......
...@@ -327,8 +327,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -327,8 +327,8 @@ std::vector<torch::Tensor> bwd_cuda(
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL #if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef ROCM_BACKWARD_PASS_GUARD #ifdef BACKWARD_PASS_GUARD
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; flags = at::BACKWARD_PASS_GUARD_CLASS::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif #endif
#endif #endif
#endif #endif
......
...@@ -1510,8 +1510,8 @@ int mlp_bp( ...@@ -1510,8 +1510,8 @@ int mlp_bp(
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL #if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef ROCM_BACKWARD_PASS_GUARD #ifdef BACKWARD_PASS_GUARD
flag = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; flag = at::BACKWARD_PASS_GUARD_CLASS::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif #endif
#endif #endif
#endif #endif
......
...@@ -20,9 +20,15 @@ context_file = os.path.join(torch_dir, "include", "ATen", "Context.h") ...@@ -20,9 +20,15 @@ context_file = os.path.join(torch_dir, "include", "ATen", "Context.h")
if os.path.exists(context_file): if os.path.exists(context_file):
lines = open(context_file, 'r').readlines() lines = open(context_file, 'r').readlines()
found_Backward_Pass_Guard = False found_Backward_Pass_Guard = False
found_ROCmBackward_Pass_Guard = False
for line in lines: for line in lines:
if "BackwardPassGuard" in line: if "BackwardPassGuard" in line:
found_Backward_Pass_Guard = True # BackwardPassGuard has been renamed to ROCmBackwardPassGuard
# https://github.com/pytorch/pytorch/pull/71881/commits/4b82f5a67a35406ffb5691c69e6b4c9086316a43
if "ROCmBackwardPassGuard" in line:
found_ROCmBackward_Pass_Guard = True
else:
found_Backward_Pass_Guard = True
break break
def get_cuda_bare_metal_version(cuda_dir): def get_cuda_bare_metal_version(cuda_dir):
...@@ -245,6 +251,12 @@ if "--cuda_ext" in sys.argv: ...@@ -245,6 +251,12 @@ if "--cuda_ext" in sys.argv:
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc': nvcc_args_layer_norm if not IS_ROCM_PYTORCH else hipcc_args_layer_norm})) 'nvcc': nvcc_args_layer_norm if not IS_ROCM_PYTORCH else hipcc_args_layer_norm}))
hipcc_args_mlp = ['-O3'] + version_dependent_macros
if found_Backward_Pass_Guard:
hipcc_args_mlp = hipcc_args_mlp + ['-DBACKWARD_PASS_GUARD'] + ['-DBACKWARD_PASS_GUARD_CLASS=BackwardPassGuard']
if found_ROCmBackward_Pass_Guard:
hipcc_args_mlp = hipcc_args_mlp + ['-DBACKWARD_PASS_GUARD'] + ['-DBACKWARD_PASS_GUARD_CLASS=ROCmBackwardPassGuard']
print ("INFO: Building the MLP Extension.") print ("INFO: Building the MLP Extension.")
ext_modules.append( ext_modules.append(
CUDAExtension(name='mlp_cuda', CUDAExtension(name='mlp_cuda',
...@@ -252,8 +264,8 @@ if "--cuda_ext" in sys.argv: ...@@ -252,8 +264,8 @@ if "--cuda_ext" in sys.argv:
'csrc/mlp_cuda.cu'], 'csrc/mlp_cuda.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')], include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros if not found_Backward_Pass_Guard 'nvcc':['-O3'] + version_dependent_macros
else ['-O3'] + version_dependent_macros + ['-DROCM_BACKWARD_PASS_GUARD']})) if not IS_ROCM_PYTORCH else hipcc_args_mlp}))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fused_dense_cuda', CUDAExtension(name='fused_dense_cuda',
...@@ -493,7 +505,9 @@ if "--fast_multihead_attn" in sys.argv or "--cuda_ext" in sys.argv: ...@@ -493,7 +505,9 @@ if "--fast_multihead_attn" in sys.argv or "--cuda_ext" in sys.argv:
'-U__HIP_NO_HALF_OPERATORS__', '-U__HIP_NO_HALF_OPERATORS__',
'-U__HIP_NO_HALF_CONVERSIONS__'] + version_dependent_macros + generator_flag '-U__HIP_NO_HALF_CONVERSIONS__'] + version_dependent_macros + generator_flag
if found_Backward_Pass_Guard: if found_Backward_Pass_Guard:
hipcc_args_mha = hipcc_args_mha + ['-DROCM_BACKWARD_PASS_GUARD'] hipcc_args_mha = hipcc_args_mha + ['-DBACKWARD_PASS_GUARD'] + ['-DBACKWARD_PASS_GUARD_CLASS=BackwardPassGuard']
if found_ROCmBackward_Pass_Guard:
hipcc_args_mha = hipcc_args_mha + ['-DBACKWARD_PASS_GUARD'] + ['-DBACKWARD_PASS_GUARD_CLASS=ROCmBackwardPassGuard']
ext_modules.append( ext_modules.append(
CUDAExtension( CUDAExtension(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment