Unverified Commit 5ecad142 authored by Hubert Lu's avatar Hubert Lu Committed by GitHub
Browse files

Make rocblas_gemm_flags_fp16_alt_impl in MHA and MLP backward compatible with...

Make rocblas_gemm_flags_fp16_alt_impl in MHA and MLP backward compatible with old PyTorch versions (#74)

* First attempt to make rocblas flag backward compatible

* Fix some bugs

* Fix some bugs

* Make rocblas_gemm_flags_fp16_alt_impl in MHA backward compatible with old PyTorch versions

* Add groupbn extension unit tests for ROCm

* Fix some bugs
parent 063d720f
......@@ -87,9 +87,6 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
char b_layout_n{'n'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
// Input Linear Q Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
......@@ -321,9 +318,16 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef ROCM_BACKWARD_PASS_GUARD
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
#endif
#endif
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
......
......@@ -113,10 +113,6 @@ std::vector<torch::Tensor> fwd_cuda(
1.0e-5, static_cast<const at::Half *>(lyr_nrm_gamma_weights.data_ptr()),
static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr()));
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
// Input Linear Q Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
......@@ -377,10 +373,15 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef ROCM_BACKWARD_PASS_GUARD
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
#endif
#endif
// Dropout Add Backward
apex_masked_scale_cuda<at::Half,float,uint32_t>(
......
......@@ -88,9 +88,7 @@ std::vector<torch::Tensor> fwd_cuda(
char b_layout_n{'n'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
// Input Linear Fwd
input_lin_results.copy_(input_biases);
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
......@@ -275,9 +273,15 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef ROCM_BACKWARD_PASS_GUARD
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
#endif
#endif
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
......
......@@ -80,9 +80,6 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
char b_layout_n{'n'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
// Input Linear Fwd
input_lin_results.copy_(input_biases);
......@@ -276,9 +273,15 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef ROCM_BACKWARD_PASS_GUARD
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
#endif
#endif
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
......
......@@ -79,9 +79,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
char b_layout_n{'n'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
// Input Linear Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
......@@ -271,9 +269,15 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef ROCM_BACKWARD_PASS_GUARD
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
#endif
#endif
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
......
......@@ -100,11 +100,6 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
1.0e-5, static_cast<const at::Half *>(lyr_nrm_gamma_weights.data_ptr()),
static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr()));
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
// Input Linear Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
......@@ -324,9 +319,15 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef ROCM_BACKWARD_PASS_GUARD
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
#endif
#endif
// Dropout Add Backward
apex_masked_scale_cuda<at::Half, float, uint32_t>(
......
......@@ -4,7 +4,6 @@ import sys
test_dirs = ["groupbn", "layer_norm", "multihead_attn", "."] # "." for test_label_smoothing.py
ROCM_BLACKLIST = [
"groupbn",
"layer_norm"
]
......
......@@ -22,10 +22,6 @@
#define BIAS_RELU_BW_NTHREADS_Y 16 // backward number of thread in batch dim
#define BIAS_RELU_RED_PER_THREAD 16 // backward minimal reduction length per thread
// #ifdef __HIP_PLATFORM_HCC__
// #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
// #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
// #endif
// move to a header later on
#define ILP 4
......@@ -1514,9 +1510,11 @@ int mlp_bp(
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
#ifdef ROCM_BACKWARD_PASS_GUARD
flag = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
#endif
#endif
int* y_offsets = (int*)malloc(num_layers * sizeof(int));
get_y_offsets(batch_size, num_layers, output_features, y_offsets);
......
......@@ -9,6 +9,21 @@ import os
# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))
torch_dir = torch.__path__[0]
# https://github.com/pytorch/pytorch/pull/71881
# For the extensions which have rocblas_gemm_flags_fp16_alt_impl we need to make sure if at::BackwardPassGuard exists.
# It helps the extensions be backward compatible with old PyTorch versions.
# The check and ROCM_BACKWARD_PASS_GUARD in nvcc/hipcc args can be retired once the PR is merged into PyTorch upstream.
context_file = os.path.join(torch_dir, "include", "ATen", "Context.h")
if os.path.exists(context_file):
lines = open(context_file, 'r').readlines()
found_Backward_Pass_Guard = False
for line in lines:
if "BackwardPassGuard" in line:
found_Backward_Pass_Guard = True
break
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
......@@ -237,7 +252,9 @@ if "--cuda_ext" in sys.argv:
'csrc/mlp_cuda.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros}))
'nvcc':['-O3'] + version_dependent_macros if not found_Backward_Pass_Guard
else ['-O3'] + version_dependent_macros + ['-DROCM_BACKWARD_PASS_GUARD']}))
ext_modules.append(
CUDAExtension(name='fused_dense_cuda',
sources=['csrc/fused_dense.cpp',
......@@ -365,7 +382,6 @@ if "--deprecated_fused_lamb" in sys.argv or "--cuda_ext" in sys.argv:
# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
# See https://github.com/pytorch/pytorch/pull/70650
generator_flag = []
torch_dir = torch.__path__[0]
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "cuda", "CUDAGeneratorImpl.h")):
generator_flag = ["-DNEW_GENERATOR_PATH"]
......@@ -475,6 +491,8 @@ if "--fast_multihead_attn" in sys.argv or "--cuda_ext" in sys.argv:
'-I/opt/rocm/include/rocrand',
'-U__HIP_NO_HALF_OPERATORS__',
'-U__HIP_NO_HALF_CONVERSIONS__'] + version_dependent_macros + generator_flag
if found_Backward_Pass_Guard:
hipcc_args_mha = hipcc_args_mha + ['-DROCM_BACKWARD_PASS_GUARD']
ext_modules.append(
CUDAExtension(name='fast_additive_mask_softmax_dropout',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment