Unverified Commit 063d720f authored by Hubert Lu's avatar Hubert Lu Committed by GitHub
Browse files

Add rocblas_alt_impl flag for backprop in MLP (#71)

* Add rocblas_alt_impl flag in MLP

* Refactor rocblas_alt_impl implementation and only use it for backprop
parent b6a1f48b
// New MLP with denorm mitigation only for backprop
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <assert.h> #include <assert.h>
...@@ -20,6 +22,11 @@ ...@@ -20,6 +22,11 @@
#define BIAS_RELU_BW_NTHREADS_Y 16 // backward number of thread in batch dim #define BIAS_RELU_BW_NTHREADS_Y 16 // backward number of thread in batch dim
#define BIAS_RELU_RED_PER_THREAD 16 // backward minimal reduction length per thread #define BIAS_RELU_RED_PER_THREAD 16 // backward minimal reduction length per thread
// #ifdef __HIP_PLATFORM_HCC__
// #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
// #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
// #endif
// move to a header later on // move to a header later on
#define ILP 4 #define ILP 4
template<typename T> template<typename T>
...@@ -70,7 +77,8 @@ cublasStatus_t mlp_gemm( ...@@ -70,7 +77,8 @@ cublasStatus_t mlp_gemm(
int ldb, int ldb,
const float* beta, const float* beta,
double* C, double* C,
int ldc) { int ldc,
int flag) {
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex( return rocblas_gemm_ex(
handle, handle,
...@@ -96,7 +104,7 @@ cublasStatus_t mlp_gemm( ...@@ -96,7 +104,7 @@ cublasStatus_t mlp_gemm(
rocblas_datatype_f64_r, rocblas_datatype_f64_r,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0,
0); flag);
#else #else
return cublasGemmEx( return cublasGemmEx(
handle, handle,
...@@ -136,7 +144,8 @@ cublasStatus_t mlp_gemm( ...@@ -136,7 +144,8 @@ cublasStatus_t mlp_gemm(
int ldb, int ldb,
const float* beta, const float* beta,
float* C, float* C,
int ldc) { int ldc,
int flag) {
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex( return rocblas_gemm_ex(
handle, handle,
...@@ -162,7 +171,7 @@ cublasStatus_t mlp_gemm( ...@@ -162,7 +171,7 @@ cublasStatus_t mlp_gemm(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0,
0); flag);
#else #else
return cublasGemmEx( return cublasGemmEx(
...@@ -203,7 +212,8 @@ cublasStatus_t mlp_gemm( ...@@ -203,7 +212,8 @@ cublasStatus_t mlp_gemm(
int ldb, int ldb,
float* beta, float* beta,
at::Half* C, at::Half* C,
int ldc) { int ldc,
int flag) {
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex( return rocblas_gemm_ex(
handle, handle,
...@@ -229,7 +239,7 @@ cublasStatus_t mlp_gemm( ...@@ -229,7 +239,7 @@ cublasStatus_t mlp_gemm(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0,
0); flag);
#else #else
return cublasGemmEx( return cublasGemmEx(
handle, handle,
...@@ -1402,7 +1412,8 @@ int mlp_fp( ...@@ -1402,7 +1412,8 @@ int mlp_fp(
ifeat, ifeat,
&zero, &zero,
output, output,
ofeat); ofeat,
int(0)); // Do nothing for forward prop
if (cublas_status != CUBLAS_STATUS_SUCCESS) { if (cublas_status != CUBLAS_STATUS_SUCCESS) {
printf("GEMM fprop failed with %d\n", cublas_status); printf("GEMM fprop failed with %d\n", cublas_status);
...@@ -1498,7 +1509,15 @@ int mlp_bp( ...@@ -1498,7 +1509,15 @@ int mlp_bp(
// Get the stream from cublas handle to reuse for biasReLU kernel. // Get the stream from cublas handle to reuse for biasReLU kernel.
cudaStream_t stream; cudaStream_t stream;
cublasGetStream(handle, &stream); cublasGetStream(handle, &stream);
int flag = 0;
#ifdef __HIP_PLATFORM_HCC__
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
flag = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
#endif
int* y_offsets = (int*)malloc(num_layers * sizeof(int)); int* y_offsets = (int*)malloc(num_layers * sizeof(int));
get_y_offsets(batch_size, num_layers, output_features, y_offsets); get_y_offsets(batch_size, num_layers, output_features, y_offsets);
...@@ -1617,7 +1636,8 @@ int mlp_bp( ...@@ -1617,7 +1636,8 @@ int mlp_bp(
yfeat, yfeat,
&zero, &zero,
dx, dx,
xfeat); xfeat,
flag); //
if (cublas_status != CUBLAS_STATUS_SUCCESS) { if (cublas_status != CUBLAS_STATUS_SUCCESS) {
printf("GEMM dgrad failed with %d\n", cublas_status); printf("GEMM dgrad failed with %d\n", cublas_status);
...@@ -1640,7 +1660,8 @@ int mlp_bp( ...@@ -1640,7 +1660,8 @@ int mlp_bp(
yfeat, yfeat,
&zero, &zero,
dweight, dweight,
xfeat); xfeat,
flag); //
if (cublas_status != CUBLAS_STATUS_SUCCESS) { if (cublas_status != CUBLAS_STATUS_SUCCESS) {
printf("GEMM wgrad failed with %d\n", cublas_status); printf("GEMM wgrad failed with %d\n", cublas_status);
...@@ -1760,4 +1781,3 @@ template size_t get_mlp_bp_workspace_in_bytes<double>( ...@@ -1760,4 +1781,3 @@ template size_t get_mlp_bp_workspace_in_bytes<double>(
int batch_size, int batch_size,
int num_layers, int num_layers,
const int* output_features); const int* output_features);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment