"vscode:/vscode.git/clone" did not exist on "648c970ecc19df3a51e52e7c7eb53d980d849da1"
Unverified Commit cf77e9b5 authored by Hubert Lu's avatar Hubert Lu Committed by GitHub
Browse files

Make rocblas_gemm_flags_fp16_alt_impl backward-compat for new naming (#79)

* Make rocblas_gemm_flags_fp16_alt_impl backward-compat for new naming

* Use BACKWARD_PASS_GUARD_CLASS to prevent lengthy if-statement
parent 27a47345
......@@ -325,8 +325,8 @@ std::vector<torch::Tensor> bwd_cuda(
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef ROCM_BACKWARD_PASS_GUARD
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#ifdef BACKWARD_PASS_GUARD
flags = at::BACKWARD_PASS_GUARD_CLASS::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
#endif
#endif
......
......@@ -381,12 +381,12 @@ std::vector<torch::Tensor> bwd_cuda(
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef ROCM_BACKWARD_PASS_GUARD
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#ifdef BACKWARD_PASS_GUARD
flags = at::BACKWARD_PASS_GUARD_CLASS::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
#endif
#endif
// Dropout Add Backward
apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(output_grads.data_ptr()),
......
......@@ -280,8 +280,8 @@ std::vector<torch::Tensor> bwd_cuda(
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef ROCM_BACKWARD_PASS_GUARD
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#ifdef BACKWARD_PASS_GUARD
flags = at::BACKWARD_PASS_GUARD_CLASS::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
#endif
#endif
......
......@@ -280,8 +280,8 @@ std::vector<torch::Tensor> bwd_cuda(
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef ROCM_BACKWARD_PASS_GUARD
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#ifdef BACKWARD_PASS_GUARD
flags = at::BACKWARD_PASS_GUARD_CLASS::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
#endif
#endif
......
......@@ -276,8 +276,8 @@ std::vector<torch::Tensor> bwd_cuda(
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef ROCM_BACKWARD_PASS_GUARD
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#ifdef BACKWARD_PASS_GUARD
flags = at::BACKWARD_PASS_GUARD_CLASS::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
#endif
#endif
......
......@@ -327,8 +327,8 @@ std::vector<torch::Tensor> bwd_cuda(
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef ROCM_BACKWARD_PASS_GUARD
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#ifdef BACKWARD_PASS_GUARD
flags = at::BACKWARD_PASS_GUARD_CLASS::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
#endif
#endif
......
......@@ -1510,8 +1510,8 @@ int mlp_bp(
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef ROCM_BACKWARD_PASS_GUARD
flag = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#ifdef BACKWARD_PASS_GUARD
flag = at::BACKWARD_PASS_GUARD_CLASS::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
#endif
#endif
......
......@@ -20,9 +20,15 @@ context_file = os.path.join(torch_dir, "include", "ATen", "Context.h")
if os.path.exists(context_file):
lines = open(context_file, 'r').readlines()
found_Backward_Pass_Guard = False
found_ROCmBackward_Pass_Guard = False
for line in lines:
if "BackwardPassGuard" in line:
found_Backward_Pass_Guard = True
# BackwardPassGuard has been renamed to ROCmBackwardPassGuard
# https://github.com/pytorch/pytorch/pull/71881/commits/4b82f5a67a35406ffb5691c69e6b4c9086316a43
if "ROCmBackwardPassGuard" in line:
found_ROCmBackward_Pass_Guard = True
else:
found_Backward_Pass_Guard = True
break
def get_cuda_bare_metal_version(cuda_dir):
......@@ -245,6 +251,12 @@ if "--cuda_ext" in sys.argv:
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc': nvcc_args_layer_norm if not IS_ROCM_PYTORCH else hipcc_args_layer_norm}))
hipcc_args_mlp = ['-O3'] + version_dependent_macros
if found_Backward_Pass_Guard:
hipcc_args_mlp = hipcc_args_mlp + ['-DBACKWARD_PASS_GUARD'] + ['-DBACKWARD_PASS_GUARD_CLASS=BackwardPassGuard']
if found_ROCmBackward_Pass_Guard:
hipcc_args_mlp = hipcc_args_mlp + ['-DBACKWARD_PASS_GUARD'] + ['-DBACKWARD_PASS_GUARD_CLASS=ROCmBackwardPassGuard']
print ("INFO: Building the MLP Extension.")
ext_modules.append(
CUDAExtension(name='mlp_cuda',
......@@ -252,8 +264,8 @@ if "--cuda_ext" in sys.argv:
'csrc/mlp_cuda.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros if not found_Backward_Pass_Guard
else ['-O3'] + version_dependent_macros + ['-DROCM_BACKWARD_PASS_GUARD']}))
'nvcc':['-O3'] + version_dependent_macros
if not IS_ROCM_PYTORCH else hipcc_args_mlp}))
ext_modules.append(
CUDAExtension(name='fused_dense_cuda',
......@@ -493,7 +505,9 @@ if "--fast_multihead_attn" in sys.argv or "--cuda_ext" in sys.argv:
'-U__HIP_NO_HALF_OPERATORS__',
'-U__HIP_NO_HALF_CONVERSIONS__'] + version_dependent_macros + generator_flag
if found_Backward_Pass_Guard:
hipcc_args_mha = hipcc_args_mha + ['-DROCM_BACKWARD_PASS_GUARD']
hipcc_args_mha = hipcc_args_mha + ['-DBACKWARD_PASS_GUARD'] + ['-DBACKWARD_PASS_GUARD_CLASS=BackwardPassGuard']
if found_ROCmBackward_Pass_Guard:
hipcc_args_mha = hipcc_args_mha + ['-DBACKWARD_PASS_GUARD'] + ['-DBACKWARD_PASS_GUARD_CLASS=ROCmBackwardPassGuard']
ext_modules.append(
CUDAExtension(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment