Unverified Commit 063d720f authored by Hubert Lu's avatar Hubert Lu Committed by GitHub
Browse files

Add rocblas_alt_impl flag for backprop in MLP (#71)

* Add rocblas_alt_impl flag in MLP

* Refactor rocblas_alt_impl implementation and only use it for backprop
parent b6a1f48b
// New MLP with denorm mitigation only for backprop
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <assert.h>
......@@ -20,6 +22,11 @@
#define BIAS_RELU_BW_NTHREADS_Y 16 // backward number of thread in batch dim
#define BIAS_RELU_RED_PER_THREAD 16 // backward minimal reduction length per thread
// #ifdef __HIP_PLATFORM_HCC__
// #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
// #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
// #endif
// move to a header later on
#define ILP 4
template<typename T>
......@@ -70,7 +77,8 @@ cublasStatus_t mlp_gemm(
int ldb,
const float* beta,
double* C,
int ldc) {
int ldc,
int flag) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
handle,
......@@ -96,7 +104,7 @@ cublasStatus_t mlp_gemm(
rocblas_datatype_f64_r,
rocblas_gemm_algo_standard,
0,
0);
flag);
#else
return cublasGemmEx(
handle,
......@@ -136,7 +144,8 @@ cublasStatus_t mlp_gemm(
int ldb,
const float* beta,
float* C,
int ldc) {
int ldc,
int flag) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
handle,
......@@ -162,7 +171,7 @@ cublasStatus_t mlp_gemm(
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
0);
flag);
#else
return cublasGemmEx(
......@@ -203,7 +212,8 @@ cublasStatus_t mlp_gemm(
int ldb,
float* beta,
at::Half* C,
int ldc) {
int ldc,
int flag) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
handle,
......@@ -229,7 +239,7 @@ cublasStatus_t mlp_gemm(
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
0);
flag);
#else
return cublasGemmEx(
handle,
......@@ -1402,7 +1412,8 @@ int mlp_fp(
ifeat,
&zero,
output,
ofeat);
ofeat,
int(0)); // Do nothing for forward prop
if (cublas_status != CUBLAS_STATUS_SUCCESS) {
printf("GEMM fprop failed with %d\n", cublas_status);
......@@ -1498,6 +1509,14 @@ int mlp_bp(
// Get the stream from cublas handle to reuse for biasReLU kernel.
cudaStream_t stream;
cublasGetStream(handle, &stream);
int flag = 0;
#ifdef __HIP_PLATFORM_HCC__
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
flag = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
#endif
#endif
int* y_offsets = (int*)malloc(num_layers * sizeof(int));
get_y_offsets(batch_size, num_layers, output_features, y_offsets);
......@@ -1617,7 +1636,8 @@ int mlp_bp(
yfeat,
&zero,
dx,
xfeat);
xfeat,
flag); //
if (cublas_status != CUBLAS_STATUS_SUCCESS) {
printf("GEMM dgrad failed with %d\n", cublas_status);
......@@ -1640,7 +1660,8 @@ int mlp_bp(
yfeat,
&zero,
dweight,
xfeat);
xfeat,
flag); //
if (cublas_status != CUBLAS_STATUS_SUCCESS) {
printf("GEMM wgrad failed with %d\n", cublas_status);
......@@ -1760,4 +1781,3 @@ template size_t get_mlp_bp_workspace_in_bytes<double>(
int batch_size,
int num_layers,
const int* output_features);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment