Unverified Commit e57c84e0 authored by sarunyap's avatar sarunyap Committed by GitHub
Browse files

Enable group batch norm (--bnp) on ROCm (only bn_group = 1) (#51)

* Enable group batch norm (--bnp) on ROCm (only bn_group = 1)

Enable NHWC group batch norm on a single GPU on ROCm (bn_group = 1).
The multi-GPU case (bn_group > 1) will be revisited in the future.

The following are the main changes:

1) Use MIOpen data structures/functions in HIP instead of CUDNN
2) For the warp-level primitive code, we ensure that the code operates
   on 64-thread wide warp instead of 32-thread wide
3) Disable all the bn_group > 1 paths

Notes:

1) Multi-stream is not tested.
2) We have not optimized for performance

* Fix bnp hipification

Avoid calling hipify-perl in setup.py and rely on PyTorch's internal
hipification mechanism.

* Make bnp data pointers contiguous

The contrib group batch norm implementation assumes that all input
tensors are contiguous.  When non-contiguous tensors are passed to the
function, it gives a wrong result.  This commit explicitly calls
.contiguous() to make all input tensors contiguous before accessing
them.

* Fix HIP lane id in bnp

Fix typo

* Fix ReLU bitmask for HIP in bnp

The ReLU bitmask is derived by using the __ballot function which returns
a 64-bit value in HIP.  This commit fixes the ReLU bitmask storage size
and offsets on ROCm.

This patch also fixes the kernel to set ReLU bitmask to 1 when the data
is less than or equal to zero (not only less than).  Not doing so can
cause a stability issue.

* Remove multiple of 64 offset for HIP in bnp

The multiple of 64 offset is not necessary.

* Use FP16 intermediate output to determine whether to rectify in bnp

Group batch norm takes FP16 tensors and produces the FP16 output,
however, all arithmetic operations are done in FP32, thus intermediate
outputs are in FP32.  For the fusion kernels, ReLU determines the FP32
intermediate output to decide whether to rectify it.  ReLU must rectify
the intermediate output if it is less than or "equal" to zero.  There is
a chance that the intermediate FP32 output is very close to zero, and
when it is converted to FP16, it becomes zero.  In this case, this
output is not rectified when it should be.  Since the output is not
rectified in the forward pass, the gradient is not rectified in the
backward pass.  This can cause a stability issue.

This patch can have a negative impact on the performance of group batch
norm as we perform FP32-FP16 conversion multiple times.

* Disable dispatchX ParallelSums in HIP in bnp

dispatchX is not required for the bn_group = 1 case.

* Use traditional load/store for HIP in bnp

The built-in function has a high floating point rounding error.  Thus,
we replace it with the traditional load/store.  Doing so breaks the
aligned pointer property in the load/store functions.  We conservatively
use traditional load/store for all memory access.

* Replace shfl_down with shfl_sync in parallel sums for HIP in bnp

This commit separates the HIP code from the CUDA code in parallel sums

* Remove -U__HIP_NO_HALF_CONVERSIONS__ for HIP in bnp

Since the built-in function is removed, -U__HIP_NO_HALF_CONVERSIONS__ is
no longer needed.

* Preserve CUDA's ReLU condition path for USE_ADD_RELU in bnp

* Add test for bnp

The test evaluates correctness of batch norm, batch norm + ReLU, and
batch norm + add + ReLU against the reference implementation.

For the forward activation output, we validate it against the PyTorch's
implementation.  The group batch norm activation output must be allclose
with the PyTorch activation output for the test to pass.

For the backward gradient output, we validate it against the Python
implementation.  Due to the floating point rounding error in the batch
norm implementation, the group batch norm gradient output might not be
allclose with the Python implementation output when ReLU is being used
although the majority of the elements are very close to each other.
Thus, we use the norm difference threshold to determine whether the test
is passed or failed instead of allclose.

* Use the warp size variable than hard coding the warp size in bnp

Use C10_WARP_SIZE from c10/macros/Macros.h in the host functions and use
warpSize in the device kernels instead of hard coding the warp size.
parent 37d8410c
......@@ -83,19 +83,21 @@ at::Tensor nhwc_bn_fwd_train(
// Create wrapper
NhwcBatchNorm *bn = new NhwcBatchNorm();
bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);
bn->setInputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.DATA_PTR<at::Half>(),
bn->setInputOutputPointers(x.contiguous().DATA_PTR<at::Half>(),
nullptr,
y.DATA_PTR<at::Half>(),
nullptr);
bn->setWeightPointers({scale.DATA_PTR<float>(), bias.DATA_PTR<float>()}, {nullptr, nullptr});
bn->setParameterPointers({running_mean.DATA_PTR<float>(), running_inv_var.DATA_PTR<float>()});
bn->setWeightPointers({scale.contiguous().DATA_PTR<float>(),
bias.contiguous().DATA_PTR<float>()}, {nullptr, nullptr});
bn->setParameterPointers({running_mean.contiguous().DATA_PTR<float>(),
running_inv_var.DATA_PTR<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
......@@ -116,12 +118,12 @@ at::Tensor nhwc_bn_fwd_train(
Workspace ws(total_workspace_bytes);
std::vector<void *> workspace;
workspace.push_back(minibatch_mean.DATA_PTR<float>());
workspace.push_back(minibatch_inv_var.DATA_PTR<float>());
workspace.push_back(minibatch_mean.contiguous().DATA_PTR<float>());
workspace.push_back(minibatch_inv_var.contiguous().DATA_PTR<float>());
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[2];
void* retired_ctas = ret_cta.DATA_PTR<uint8_t>();
void* retired_ctas = ret_cta.contiguous().DATA_PTR<uint8_t>();
assert(ret_cta.size(0)>=retired_cta_bytes);
workspace.push_back(retired_ctas);
......@@ -161,19 +163,21 @@ at::Tensor nhwc_bn_fwd_eval(
// Create wrapper
NhwcBatchNorm *bn = new NhwcBatchNorm();
bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);
bn->setInputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.DATA_PTR<at::Half>(),
bn->setInputOutputPointers(x.contiguous().DATA_PTR<at::Half>(),
nullptr,
y.DATA_PTR<at::Half>(),
nullptr);
bn->setWeightPointers({scale.DATA_PTR<float>(), bias.DATA_PTR<float>()}, {nullptr, nullptr});
bn->setParameterPointers({running_mean.DATA_PTR<float>(), running_inv_var.DATA_PTR<float>()});
bn->setWeightPointers({scale.contiguous().DATA_PTR<float>(),
bias.contiguous().DATA_PTR<float>()}, {nullptr, nullptr});
bn->setParameterPointers({running_mean.contiguous().DATA_PTR<float>(),
running_inv_var.contiguous().DATA_PTR<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
......@@ -199,7 +203,7 @@ at::Tensor nhwc_bn_fwd_eval(
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[2];
void* retired_ctas = ret_cta.DATA_PTR<uint8_t>();
void* retired_ctas = ret_cta.contiguous().DATA_PTR<uint8_t>();
assert(ret_cta.size(0)>=retired_cta_bytes);
workspace.push_back(retired_ctas);
......@@ -260,19 +264,23 @@ std::vector<at::Tensor> nhwc_bn_bwd(
// Create wrapper
NhwcBatchNorm *bn = new NhwcBatchNorm();
bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);
bn->setInputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.DATA_PTR<at::Half>(),
bn->setInputOutputPointers(x.contiguous().DATA_PTR<at::Half>(),
x_grad.DATA_PTR<at::Half>(),
nullptr,
dy.DATA_PTR<at::Half>());
dy.contiguous().DATA_PTR<at::Half>());
bn->setWeightPointers({scale.DATA_PTR<float>(), bias.DATA_PTR<float>()}, {scale_grad.DATA_PTR<float>(), bias_grad.DATA_PTR<float>()});
bn->setParameterPointers({running_mean.DATA_PTR<float>(), running_inv_var.DATA_PTR<float>()});
bn->setWeightPointers({scale.contiguous().DATA_PTR<float>(),
bias.contiguous().DATA_PTR<float>()},
{scale_grad.DATA_PTR<float>(),
bias_grad.DATA_PTR<float>()});
bn->setParameterPointers({running_mean.contiguous().DATA_PTR<float>(),
running_inv_var.contiguous().DATA_PTR<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
......@@ -293,12 +301,12 @@ std::vector<at::Tensor> nhwc_bn_bwd(
Workspace ws(total_workspace_bytes);
std::vector<void *> workspace;
workspace.push_back(minibatch_mean.DATA_PTR<float>());
workspace.push_back(minibatch_inv_var.DATA_PTR<float>());
workspace.push_back(minibatch_mean.contiguous().DATA_PTR<float>());
workspace.push_back(minibatch_inv_var.contiguous().DATA_PTR<float>());
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[2];
void* retired_ctas = ret_cta.DATA_PTR<uint8_t>();
void* retired_ctas = ret_cta.contiguous().DATA_PTR<uint8_t>();
assert(ret_cta.size(0)>=retired_cta_bytes);
workspace.push_back(retired_ctas);
......
......@@ -26,7 +26,7 @@
#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_
#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_
#include <cudnn.h>
#include "dnn.h"
#include <algorithm>
#include <vector>
......@@ -34,6 +34,7 @@
#include "nhwc_batch_norm_kernel.h"
#include "cuda_utils.h"
#include "c10/macros/Macros.h"
#define VERBOSE_DEFAULT false
......@@ -62,8 +63,8 @@ class NhwcBatchNorm {
dim3 calc_fwd_grid(int *loop, const int grid_dim_x);
dim3 calc_bwd_grid(int *loop, const int grid_dim_x);
void setInputDescriptor(const cudnnTensorFormat_t format,
const cudnnDataType_t data_type,
void setInputDescriptor(const dnnTensorFormat_t format,
const dnnDataType_t data_type,
int n, int c, int h, int w, int bn_group) {
m_ = n * h * w;
int m_bn_adjusted = m_ * bn_group;
......@@ -77,8 +78,8 @@ class NhwcBatchNorm {
setTensorDescriptor(X_tensor_desc_, format, data_type, n, c, h, w);
}
void setOutputDescriptor(const cudnnTensorFormat_t format,
const cudnnDataType_t data_type,
void setOutputDescriptor(const dnnTensorFormat_t format,
const dnnDataType_t data_type,
int n, int c, int h, int w) {
setTensorDescriptor(Y_tensor_desc_, format, data_type, n, c, h, w);
}
......@@ -119,13 +120,20 @@ class NhwcBatchNorm {
eps_ = eps;
}
void processCudnnStatus(const cudnnStatus_t& status,
void processCudnnStatus(const dnnStatus_t& status,
const std::string& string = std::string(),
bool verbose = VERBOSE_DEFAULT) {
if (status != CUDNN_STATUS_SUCCESS)
#ifdef __HIP_PLATFORM_HCC__
if (status != DNN_STATUS_SUCCESS)
LOG(FATAL) << string << " " << miopenGetErrorString(status);
else if (verbose)
LOG(INFO) << string << " " << miopenGetErrorString(status);
#else
if (status != DNN_STATUS_SUCCESS)
LOG(FATAL) << string << " " << cudnnGetErrorString(status);
else if (verbose)
LOG(INFO) << string << " " << cudnnGetErrorString(status);
#endif
}
void checkCudaStatus(const std::string& string = std::string(),
......@@ -148,8 +156,8 @@ class NhwcBatchNorm {
return retired_cta_bytes;
}
cudnnTensorDescriptor_t X_tensor_desc_ = nullptr;
cudnnTensorDescriptor_t Y_tensor_desc_ = nullptr;
dnnTensorDescriptor_t X_tensor_desc_ = nullptr;
dnnTensorDescriptor_t Y_tensor_desc_ = nullptr;
void* X_ = nullptr;
void* dX_ = nullptr;
......@@ -181,24 +189,36 @@ class NhwcBatchNorm {
std::string name_;
private:
void setTensorDescriptor(cudnnTensorDescriptor_t descriptor,
cudnnTensorFormat_t format,
cudnnDataType_t data_type,
void setTensorDescriptor(dnnTensorDescriptor_t descriptor,
dnnTensorFormat_t format,
dnnDataType_t data_type,
int n, int c, int h, int w) {
cudnnStatus_t status = CUDNN_STATUS_SUCCESS;
dnnStatus_t status = DNN_STATUS_SUCCESS;
#ifdef __HIP_PLATFORM_HCC__
status = miopenSet4dTensorDescriptor(descriptor, data_type, n, c, h, w);
#else
status = cudnnSetTensor4dDescriptor(descriptor, format, data_type, n, c, h, w);
#endif
processCudnnStatus(status, "set tensor descriptor");
}
void createTensorDescriptor(cudnnTensorDescriptor_t *descriptor) {
cudnnStatus_t status = CUDNN_STATUS_SUCCESS;
void createTensorDescriptor(dnnTensorDescriptor_t *descriptor) {
dnnStatus_t status = DNN_STATUS_SUCCESS;
#ifdef __HIP_PLATFORM_HCC__
status = miopenCreateTensorDescriptor(descriptor);
#else
status = cudnnCreateTensorDescriptor(descriptor);
#endif
processCudnnStatus(status, "create tensor_descriptor");
}
void destroyTensorDescriptor(cudnnTensorDescriptor_t descriptor) {
cudnnStatus_t status = CUDNN_STATUS_SUCCESS;
void destroyTensorDescriptor(dnnTensorDescriptor_t descriptor) {
dnnStatus_t status = DNN_STATUS_SUCCESS;
#ifdef __HIP_PLATFORM_HCC__
status = miopenDestroyTensorDescriptor(descriptor);
#else
status = cudnnDestroyTensorDescriptor(descriptor);
#endif
processCudnnStatus(status, "destroy tensor_descriptor");
}
......@@ -258,6 +278,57 @@ class NhwcBatchNorm {
void _fwdKernelLauncher(cudaStream_t stream, NhwcBatchNormFwdParams params,
dim3 grid_dim, int outer_loops, bool use_relu, const int occupancy, const bool coop) {
#ifdef __HIP_PLATFORM_HCC__
#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto fwd_func = nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
hipFuncSetAttribute((void *) fwd_func, hipFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " fwd ser coop kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(&params); \
using FWD_FUNC = decltype(nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
hipLaunchCooperativeKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_FWD, \
stream); \
} else { \
hipLaunchKernel((void *) fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_FWD, \
stream); \
} \
checkCudaStatus(name_ + " fwd ser coop kernel"); \
} while (0)
#else
#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
......@@ -307,6 +378,7 @@ class NhwcBatchNorm {
} \
checkCudaStatus(name_ + " fwd ser coop kernel"); \
} while (0)
#endif
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
if (outer_loops == 1 && use_relu) {
......@@ -337,6 +409,99 @@ class NhwcBatchNorm {
void _bwdKernelLauncher(cudaStream_t stream, NhwcBatchNormBwdParams params,
dim3 grid_dim, int outer_loops, bool use_relu, const int occupancy, const bool coop) {
#ifdef __HIP_PLATFORM_HCC__
#define LAUNCH_BWD_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto bwd_func = nhwc_batch_norm_bwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
hipFuncSetAttribute((void *) bwd_func, hipFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " bwd coop serial kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(&params); \
using BWD_FUNC = decltype(nhwc_batch_norm_bwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
hipLaunchCooperativeKernel<BWD_FUNC>(bwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} else { \
hipLaunchKernel((void *) bwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} \
checkCudaStatus(name_ + " bwd coop serial kernel"); \
} while (0)
#define LAUNCH_BWD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto bwd_relu_func = nhwc_batch_norm_bwd_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
hipFuncSetAttribute((void *) bwd_relu_func, hipFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " bwd-relu coop serial kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(&params); \
using BWD_RELU_FUNC = decltype(nhwc_batch_norm_bwd_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
hipLaunchCooperativeKernel<BWD_RELU_FUNC>(bwd_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} else { \
hipLaunchKernel((void *) bwd_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} \
checkCudaStatus(name_ + " bwd-relu coop serial kernel"); \
} while (0)
#else
#define LAUNCH_BWD_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
......@@ -428,6 +593,7 @@ class NhwcBatchNorm {
} \
checkCudaStatus(name_ + " bwd-relu coop serial kernel"); \
} while (0)
#endif
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
if (outer_loops == 1 && use_relu) {
......@@ -459,7 +625,7 @@ class NhwcBatchNorm {
// Calculate the expected fwd kernel occupancy, as dictated by shared memory usage.
static int smem_driven_fwd_occupancy(int device_id, const int max_cta_per_sm) {
using namespace at::cuda::utils;
int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG*sizeof(float);
int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG*sizeof(float);
int fwd_smem_bytes = SMEM_SIZE_FWD + fwd_reduction_bytes;
int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / fwd_smem_bytes;
return std::min(max_cta_per_sm, occupancy);
......@@ -468,7 +634,7 @@ class NhwcBatchNorm {
// Calculate the expected bwd kernel occupancy, as dictated by shared memory usage.
static int smem_driven_bwd_occupancy(int device_id, const int max_cta_per_sm) {
using namespace at::cuda::utils;
int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG*sizeof(float);
int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG*sizeof(float);
int bwd_smem_bytes = SMEM_SIZE_BWD + bwd_reduction_bytes;
int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / bwd_smem_bytes;
return std::min(max_cta_per_sm, occupancy);
......
......@@ -85,21 +85,23 @@ at::Tensor nhwc_bn_addrelu_fwd_train(
// Create wrapper
NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu();
bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);
bn->setInputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.DATA_PTR<at::Half>(),
bn->setInputOutputPointers(x.contiguous().DATA_PTR<at::Half>(),
nullptr,
y.DATA_PTR<at::Half>(),
nullptr,
z.DATA_PTR<at::Half>(),
z.contiguous().DATA_PTR<at::Half>(),
nullptr);
bn->setWeightPointers({scale.DATA_PTR<float>(), bias.DATA_PTR<float>()}, {nullptr, nullptr});
bn->setParameterPointers({running_mean.DATA_PTR<float>(), running_inv_var.DATA_PTR<float>()});
bn->setWeightPointers({scale.contiguous().DATA_PTR<float>(),
bias.contiguous().DATA_PTR<float>()}, {nullptr, nullptr});
bn->setParameterPointers({running_mean.contiguous().DATA_PTR<float>(),
running_inv_var.contiguous().DATA_PTR<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
......@@ -120,13 +122,13 @@ at::Tensor nhwc_bn_addrelu_fwd_train(
Workspace ws(total_workspace_bytes);
std::vector<void *> workspace;
workspace.push_back(minibatch_mean.DATA_PTR<float>());
workspace.push_back(minibatch_inv_var.DATA_PTR<float>());
workspace.push_back(bitmask.DATA_PTR<int32_t>());
workspace.push_back(minibatch_mean.contiguous().DATA_PTR<float>());
workspace.push_back(minibatch_inv_var.contiguous().DATA_PTR<float>());
workspace.push_back(bitmask.contiguous().DATA_PTR<bitmask_pyt_t>());
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[3];
void* retired_ctas = ret_cta.DATA_PTR<uint8_t>();
void* retired_ctas = ret_cta.contiguous().DATA_PTR<uint8_t>();
assert(ret_cta.size(0)>=retired_cta_bytes);
workspace.push_back(retired_ctas);
......@@ -167,21 +169,23 @@ at::Tensor nhwc_bn_addrelu_fwd_eval(
// Create wrapper
NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu();
bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);
bn->setInputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.DATA_PTR<at::Half>(),
bn->setInputOutputPointers(x.contiguous().DATA_PTR<at::Half>(),
nullptr,
y.DATA_PTR<at::Half>(),
nullptr,
z.DATA_PTR<at::Half>(),
z.contiguous().DATA_PTR<at::Half>(),
nullptr);
bn->setWeightPointers({scale.DATA_PTR<float>(), bias.DATA_PTR<float>()}, {nullptr, nullptr});
bn->setParameterPointers({running_mean.DATA_PTR<float>(), running_inv_var.DATA_PTR<float>()});
bn->setWeightPointers({scale.contiguous().DATA_PTR<float>(),
bias.contiguous().DATA_PTR<float>()}, {nullptr, nullptr});
bn->setParameterPointers({running_mean.contiguous().DATA_PTR<float>(),
running_inv_var.contiguous().DATA_PTR<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
......@@ -208,7 +212,7 @@ at::Tensor nhwc_bn_addrelu_fwd_eval(
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[3];
void* retired_ctas = ret_cta.DATA_PTR<uint8_t>();
void* retired_ctas = ret_cta.contiguous().DATA_PTR<uint8_t>();
assert(ret_cta.size(0)>=retired_cta_bytes);
workspace.push_back(retired_ctas);
......@@ -270,21 +274,24 @@ std::vector<at::Tensor> nhwc_bn_addrelu_bwd(
// Create wrapper
NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu();
bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);
bn->setInputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.DATA_PTR<at::Half>(),
bn->setInputOutputPointers(x.contiguous().DATA_PTR<at::Half>(),
x_grad.DATA_PTR<at::Half>(),
nullptr,
dy.DATA_PTR<at::Half>(),
dy.contiguous().DATA_PTR<at::Half>(),
nullptr,
z_grad.DATA_PTR<at::Half>());
bn->setWeightPointers({scale.DATA_PTR<float>(), bias.DATA_PTR<float>()}, {scale_grad.DATA_PTR<float>(), bias_grad.DATA_PTR<float>()});
bn->setParameterPointers({running_mean.DATA_PTR<float>(), running_inv_var.DATA_PTR<float>()});
bn->setWeightPointers({scale.contiguous().DATA_PTR<float>(),
bias.contiguous().DATA_PTR<float>()},
{scale_grad.DATA_PTR<float>(), bias_grad.DATA_PTR<float>()});
bn->setParameterPointers({running_mean.contiguous().DATA_PTR<float>(),
running_inv_var.contiguous().DATA_PTR<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
......@@ -305,13 +312,13 @@ std::vector<at::Tensor> nhwc_bn_addrelu_bwd(
Workspace ws(total_workspace_bytes);
std::vector<void *> workspace;
workspace.push_back(minibatch_mean.DATA_PTR<float>());
workspace.push_back(minibatch_inv_var.DATA_PTR<float>());
workspace.push_back(bitmask.DATA_PTR<int32_t>());
workspace.push_back(minibatch_mean.contiguous().DATA_PTR<float>());
workspace.push_back(minibatch_inv_var.contiguous().DATA_PTR<float>());
workspace.push_back(bitmask.contiguous().DATA_PTR<bitmask_pyt_t>());
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[3];
void* retired_ctas = ret_cta.DATA_PTR<uint8_t>();
void* retired_ctas = ret_cta.contiguous().DATA_PTR<uint8_t>();
assert(ret_cta.size(0)>=retired_cta_bytes);
workspace.push_back(retired_ctas);
......
......@@ -26,7 +26,7 @@
#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_
#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_
#include <cudnn.h>
#include "dnn.h"
#include <algorithm>
#include <vector>
......@@ -34,7 +34,15 @@
#include "nhwc_batch_norm_kernel.h"
#include "cuda_utils.h"
#include "c10/macros/Macros.h"
#ifdef __HIP_PLATFORM_HCC__
using bitmask_t = uint64_t;
using bitmask_pyt_t = int64_t;
#else
using bitmask_t = unsigned int;
using bitmask_pyt_t = int32_t;
#endif
#define VERBOSE_DEFAULT false
......@@ -62,8 +70,8 @@ class NhwcBatchNormAddRelu {
dim3 calc_fwd_grid(int *loop, const int grid_dim_x);
dim3 calc_bwd_grid(int *loop, const int grid_dim_x);
void setInputDescriptor(const cudnnTensorFormat_t format,
const cudnnDataType_t data_type,
void setInputDescriptor(const dnnTensorFormat_t format,
const dnnDataType_t data_type,
int n, int c, int h, int w, int bn_group) {
m_ = n * h * w;
int m_bn_adjusted = m_ * bn_group;
......@@ -77,8 +85,8 @@ class NhwcBatchNormAddRelu {
setTensorDescriptor(X_tensor_desc_, format, data_type, n, c, h, w);
}
void setOutputDescriptor(const cudnnTensorFormat_t format,
const cudnnDataType_t data_type,
void setOutputDescriptor(const dnnTensorFormat_t format,
const dnnDataType_t data_type,
int n, int c, int h, int w) {
setTensorDescriptor(Y_tensor_desc_, format, data_type, n, c, h, w);
}
......@@ -121,13 +129,20 @@ class NhwcBatchNormAddRelu {
eps_ = eps;
}
void processCudnnStatus(const cudnnStatus_t& status,
void processCudnnStatus(const dnnStatus_t& status,
const std::string& string = std::string(),
bool verbose = VERBOSE_DEFAULT) {
if (status != CUDNN_STATUS_SUCCESS)
#ifdef __HIP_PLATFORM_HCC__
if (status != DNN_STATUS_SUCCESS)
LOG(FATAL) << string << " " << miopenGetErrorString(status);
else if (verbose)
LOG(INFO) << string << " " << miopenGetErrorString(status);
#else
if (status != DNN_STATUS_SUCCESS)
LOG(FATAL) << string << " " << cudnnGetErrorString(status);
else if (verbose)
LOG(INFO) << string << " " << cudnnGetErrorString(status);
#endif
}
void checkCudaStatus(const std::string& string = std::string(),
......@@ -150,8 +165,8 @@ class NhwcBatchNormAddRelu {
return retired_cta_bytes;
}
cudnnTensorDescriptor_t X_tensor_desc_ = nullptr;
cudnnTensorDescriptor_t Y_tensor_desc_ = nullptr;
dnnTensorDescriptor_t X_tensor_desc_ = nullptr;
dnnTensorDescriptor_t Y_tensor_desc_ = nullptr;
void* X_ = nullptr;
void* dX_ = nullptr;
......@@ -185,24 +200,36 @@ class NhwcBatchNormAddRelu {
std::string name_;
private:
void setTensorDescriptor(cudnnTensorDescriptor_t descriptor,
cudnnTensorFormat_t format,
cudnnDataType_t data_type,
void setTensorDescriptor(dnnTensorDescriptor_t descriptor,
dnnTensorFormat_t format,
dnnDataType_t data_type,
int n, int c, int h, int w) {
cudnnStatus_t status = CUDNN_STATUS_SUCCESS;
dnnStatus_t status = DNN_STATUS_SUCCESS;
#ifdef __HIP_PLATFORM_HCC__
status = miopenSet4dTensorDescriptor(descriptor, data_type, n, c, h, w);
#else
status = cudnnSetTensor4dDescriptor(descriptor, format, data_type, n, c, h, w);
#endif
processCudnnStatus(status, "set tensor descriptor");
}
void createTensorDescriptor(cudnnTensorDescriptor_t *descriptor) {
cudnnStatus_t status = CUDNN_STATUS_SUCCESS;
void createTensorDescriptor(dnnTensorDescriptor_t *descriptor) {
dnnStatus_t status = DNN_STATUS_SUCCESS;
#ifdef __HIP_PLATFORM_HCC__
status = miopenCreateTensorDescriptor(descriptor);
#else
status = cudnnCreateTensorDescriptor(descriptor);
#endif
processCudnnStatus(status, "create tensor_descriptor");
}
void destroyTensorDescriptor(cudnnTensorDescriptor_t descriptor) {
cudnnStatus_t status = CUDNN_STATUS_SUCCESS;
void destroyTensorDescriptor(dnnTensorDescriptor_t descriptor) {
dnnStatus_t status = DNN_STATUS_SUCCESS;
#ifdef __HIP_PLATFORM_HCC__
status = miopenDestroyTensorDescriptor(descriptor);
#else
status = cudnnDestroyTensorDescriptor(descriptor);
#endif
processCudnnStatus(status, "destroy tensor_descriptor");
}
......@@ -210,7 +237,7 @@ class NhwcBatchNormAddRelu {
float *partial_sums_ = nullptr;
int *partial_counts_ = nullptr;
int *retired_ctas_ = nullptr;
unsigned int *relu_bitmask_ = nullptr;
bitmask_t *relu_bitmask_ = nullptr;
void _setFwdParams(NhwcBatchNormFwdParams *params) const;
void _setFwdInferenceParams(NhwcBatchNormFwdInferenceParams *params) const;
......@@ -261,6 +288,58 @@ class NhwcBatchNormAddRelu {
// needless register spills.
void _fwdKernelLauncher(cudaStream_t stream, NhwcBatchNormFwdParams params,
dim3 grid_dim, int outer_loops, const int occupancy, const bool coop) {
#ifdef __HIP_PLATFORM_HCC__
#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << \
"Nhwc batchnormaddrelu kernel smem too big."; \
auto fwd_func = nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
hipFuncSetAttribute((void *) fwd_func, hipFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " fwd ser coop kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(&params); \
using FWD_FUNC = decltype(nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
hipLaunchCooperativeKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_FWD, \
stream); \
} else { \
hipLaunchKernel((void *) fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_FWD, \
stream); \
} \
checkCudaStatus(name_ + " fwd ser coop kernel"); \
} while (0)
#else
#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << \
......@@ -311,6 +390,7 @@ class NhwcBatchNormAddRelu {
} \
checkCudaStatus(name_ + " fwd ser coop kernel"); \
} while (0)
#endif
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
if (outer_loops == 1) {
......@@ -331,7 +411,56 @@ class NhwcBatchNormAddRelu {
void _bwdKernelLauncher(cudaStream_t stream, NhwcBatchNormBwdParams params,
dim3 grid_dim, int outer_loops, const int occupancy, const bool coop) {
#ifdef __HIP_PLATFORM_HCC__
#define LAUNCH_BWD_ADD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << \
"Nhwc batchnormaddrelu kernel smem too big."; \
auto bwd_add_relu_func = nhwc_batch_norm_bwd_add_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
hipFuncSetAttribute((void *) bwd_add_relu_func, \
hipFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + \
" bwd-add-relu coop serial kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(&params); \
using BWD_ADD_RELU_FUNC = decltype(nhwc_batch_norm_bwd_add_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
hipLaunchCooperativeKernel<BWD_ADD_RELU_FUNC>(bwd_add_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} else { \
hipLaunchKernel((void *) bwd_add_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} \
checkCudaStatus(name_ + " bwd-add-relu coop serial kernel"); \
} while (0)
#else
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << \
"Nhwc batchnormaddrelu kernel smem too big."; \
......@@ -379,6 +508,7 @@ class NhwcBatchNormAddRelu {
} \
checkCudaStatus(name_ + " bwd-add-relu coop serial kernel"); \
} while (0)
#endif
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
if (outer_loops == 1) {
......@@ -399,7 +529,7 @@ class NhwcBatchNormAddRelu {
// Calculate the expected fwd kernel occupancy, as dictated by shared memory usage.
static int smem_driven_fwd_occupancy(int device_id, const int max_cta_per_sm) {
using namespace at::cuda::utils;
int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG*sizeof(float);
int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG*sizeof(float);
int fwd_smem_bytes = SMEM_SIZE_FWD + fwd_reduction_bytes;
int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / fwd_smem_bytes;
return std::min(max_cta_per_sm, occupancy);
......@@ -408,7 +538,7 @@ class NhwcBatchNormAddRelu {
// Calculate the expected bwd kernel occupancy, as dictated by shared memory usage.
static int smem_driven_bwd_occupancy(int device_id, const int max_cta_per_sm) {
using namespace at::cuda::utils;
int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG*sizeof(float);
int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG*sizeof(float);
int bwd_smem_bytes = SMEM_SIZE_BWD + bwd_reduction_bytes;
int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / bwd_smem_bytes;
return std::min(max_cta_per_sm, occupancy);
......@@ -427,9 +557,13 @@ const std::vector<size_t> NhwcBatchNormAddRelu::numWorkspaceBytes() const {
const size_t num_mean_bytes = c_ * sizeof(float);
const size_t num_variance_bytes = num_mean_bytes;
#ifdef __HIP_PLATFORM_HCC__
int elems_per_group = ((m_ + 3) & ~3);
#else
int elems_per_group = ((m_ + 31) & ~31) * 2;
#endif
int group_count = div_up(c_, C_ELEMENTS_PER_CTA);
const size_t bitmask_bytes = elems_per_group * group_count * sizeof(unsigned int);
const size_t bitmask_bytes = elems_per_group * group_count * sizeof(bitmask_t);
const size_t size_sums = grid_y*grid_x*THREADS_PER_PIXEL*\
ELEMENTS_PER_LDG*2*sizeof(float);
......@@ -447,7 +581,7 @@ void NhwcBatchNormAddRelu::setWorkspacePointers(
minibatch_mean_ = static_cast<float*>(workspace[0]);
minibatch_variance_ = static_cast<float*>(workspace[1]);
relu_bitmask_ = static_cast<unsigned int*>(workspace[2]);
relu_bitmask_ = static_cast<bitmask_t*>(workspace[2]);
retired_ctas_ = static_cast<int*>(workspace[3]);
partial_sums_ = static_cast<float*>(workspace[4]);
partial_counts_ = static_cast<int*>(workspace[5]);
......
#ifdef __HIP_PLATFORM_HCC__
#include <ATen/hip/HIPContext.h>
#else
#include <ATen/cuda/CUDAContext.h>
#endif
#ifndef CUDA_UTILS_H
#define CUDA_UTILS_H
......@@ -8,7 +12,11 @@ namespace cuda {
namespace utils {
static inline int MaxSharedMemoryPerMultiprocessor(int device_id) {
#ifdef __HIP_PLATFORM_HCC__
return getDeviceProperties(device_id)->maxSharedMemoryPerMultiProcessor;
#else
return getDeviceProperties(device_id)->sharedMemPerMultiprocessor;
#endif
}
......
#ifndef DNN_H
#define DNN_H
#ifdef __HIP_PLATFORM_HCC__
#include <miopen/miopen.h>
#define DNN_STATUS_SUCCESS miopenStatusSuccess
#define DNN_DATA_HALF miopenHalf
#define DNN_TENSOR_FORMAT 0
using dnnTensorFormat_t = int;
using dnnDataType_t = miopenDataType_t;
using dnnStatus_t = miopenStatus_t;
using dnnTensorDescriptor_t = miopenTensorDescriptor_t;
#else
#include <cudnn.h>
#define DNN_STATUS_SUCCESS CUDNN_STATUS_SUCCESS
#define DNN_DATA_HALF CUDNN_DATA_HALF
#define DNN_TENSOR_FORMAT CUDNN_TENSOR_NHWC
using dnnTensorFormat_t = cudnnTensorFormat_t;
using dnnDataType_t = cudnnDataType_t;
using dnnStatus_t = cudnnStatus_t;
using dnnTensorDescriptor_t = cudnnTensorDescriptor_t;
#endif
#endif // DNN_H
......@@ -4,6 +4,16 @@ from torch.nn.modules.batchnorm import _BatchNorm
import bnp
def check_if_rocm_pytorch():
is_rocm_pytorch = False
if torch.__version__ >= '1.5':
from torch.utils.cpp_extension import ROCM_HOME
is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False
return is_rocm_pytorch
IS_ROCM_PYTORCH = check_if_rocm_pytorch()
class bn_NHWC_impl(torch.autograd.Function):
@staticmethod
def forward(ctx, x, s, b, rm, riv, mini_m, mini_riv, ret_cta, mom, epsilon, fuse_relu, is_train, bn_group, my_data, pair_data, magic, pair_data2, pair_data3, fwd_occup, fwd_grid_x, bwd_occup, bwd_grid_x, multi_stream):
......@@ -54,7 +64,11 @@ class bn_addrelu_NHWC_impl(torch.autograd.Function):
@staticmethod
def forward(ctx, x, z, s, b, rm, riv, mini_m, mini_riv, grid_dim_y, ret_cta, mom, epsilon, is_train, bn_group, my_data, pair_data, magic, pair_data2, pair_data3, fwd_occup, fwd_grid_x, bwd_occup, bwd_grid_x, multi_stream):
if is_train:
bitmask = torch.cuda.IntTensor(((x.numel()+31)//32) * 2 * grid_dim_y)
if IS_ROCM_PYTORCH:
nhw = x.shape[0] * x.shape[1] * x.shape[2]
bitmask = torch.cuda.LongTensor(((nhw + 3) & ~3) * grid_dim_y)
else:
bitmask = torch.cuda.IntTensor(((x.numel()+31)//32) * 2 * grid_dim_y)
ctx.save_for_backward(x, s, b, rm, riv, mini_m, mini_riv, bitmask)
ctx.epsilon = epsilon
ctx.momentum = mom
......
import torch
import unittest
import numpy as np
import random
from apex.contrib.groupbn.batch_norm import BatchNorm2d_NHWC
def generate_uniform_tensor(size, np_dtype, pyt_dtype, device):
array = None
while array is None or np.isnan(array).any():
array = np.random.uniform(low=-1.0, high=1.0, size=size).astype(np_dtype)
return torch.from_numpy(array).to(device).to(pyt_dtype)
def to_channels_last(tensor):
return tensor.permute(0, 2, 3, 1).contiguous()
def to_channels_first(tensor):
return tensor.permute(0, 3, 1, 2).contiguous()
class Bn(torch.nn.BatchNorm2d):
def __init__(self, planes, mode):
super(Bn, self).__init__(planes, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
self.mode = mode
def forward(self, x, z=None):
out = super().forward(x)
if self.mode == 'bn_add_relu':
out = out.add_(z)
if self.mode != 'bn':
out = out.relu_()
return out
def bn_nhwc_bwd_ref(grad_y, x, mu, ivar, gamma):
sum_dim_c = (0, 1, 2)
grad_y_f32 = grad_y.float()
x_f32 = x.float()
N = x.shape[0] * x.shape[1] * x.shape[2] # nhw
ones = torch.ones(x.shape, dtype=torch.float32, device='cuda')
xmu = x_f32 - mu
xhat = xmu * ivar
dbias = torch.sum(grad_y_f32, dim=sum_dim_c)
dscale = torch.sum(grad_y_f32 * xhat, dim=sum_dim_c)
dx1 = (gamma * ivar) / N
dx2 = (N * grad_y_f32) - (dbias * ones)
dx3 = -xhat * dscale
dx = dx1 * (dx2 + dx3)
dx = dx.half()
return dx, dscale, dbias
class TestGroupBN(unittest.TestCase):
def setUp(self, seed=5, verbose=False):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
self.verbose = verbose
def test_bn(self):
self.run_group_bn('bn')
def test_bn_relu(self):
self.run_group_bn('bn_relu')
def test_bn_add_relu(self):
self.run_group_bn('bn_add_relu')
def run_group_bn(self, mode):
if self.verbose:
print('Running {}'.format(mode))
tensor_sizes = [
(120, 64, 150, 150),
(120, 64, 75, 75),
(120, 128, 38, 38),
(120, 256, 38, 38)]
for i in range(len(tensor_sizes)):
tensor_size = tensor_sizes[i]
num_channels = tensor_size[1]
# Create input data
input_data = generate_uniform_tensor(tensor_size, np.float16, torch.half, 'cuda')
np.save('input.npy', input_data.detach().cpu().numpy())
input_data.requires_grad = True
gbn_input = torch.from_numpy(np.load('input.npy')).cuda().half()
gbn_input.requires_grad = True
residual_data = None
gbn_residual_data = None
if mode == 'bn':
fuse_relu = False
else:
fuse_relu = True
if mode == 'bn_add_relu':
residual_data = generate_uniform_tensor(tensor_size, np.float16, torch.half, 'cuda')
gbn_residual_data = to_channels_last(residual_data)
bn_grad = generate_uniform_tensor(input_data.shape, np.float16, torch.half, 'cuda')
# Create models
batchnorm_model = Bn(num_channels, mode).cuda()
group_batchnorm = BatchNorm2d_NHWC(num_channels, fuse_relu=fuse_relu, bn_group=1).cuda()
# Run reference forward
bn_output = batchnorm_model(input_data, residual_data)
# Run GBN forward
gbn_input_data = to_channels_last(gbn_input)
gbn_output = group_batchnorm(gbn_input_data, gbn_residual_data)
torch.cuda.synchronize()
# Run reference backward
# (Use the same input and parameters as GBN)
gbn_grad = to_channels_last(bn_grad)
grad = gbn_grad.clone().detach()
input_data = torch.from_numpy(np.load('input.npy')).cuda().half()
input_data = to_channels_last(input_data)
if mode != 'bn':
grad[gbn_output <= 0] = 0
bn_output_grad, _, _ = bn_nhwc_bwd_ref( \
grad,
input_data,
group_batchnorm.minibatch_mean,
group_batchnorm.minibatch_riv,
group_batchnorm.weight)
bn_output_grad = to_channels_first(bn_output_grad)
# Run GBN backward
gbn_output.backward(gbn_grad)
torch.cuda.synchronize()
gbn_output = to_channels_first(gbn_output)
gbn_output_grad = gbn_input.grad.detach().clone().cpu()
########################## Validate results ##########################
if self.verbose:
print('Validate activation')
self.validate(bn_output.shape, bn_output, gbn_output)
if self.verbose:
print('Validate grad')
self.validate(bn_output_grad.shape, bn_output_grad, gbn_output_grad, is_grad=True)
def validate(self, tensors, output_ref, output_test, is_grad=False):
output_ref = output_ref.detach().cpu().numpy()
output_test = output_test.detach().cpu().numpy()
if self.verbose:
print('>>> tensor_size\t{}'.format(tensors))
print("sum_output_ref {}, isnan {}, max {}, min {}".format(
np.sum(output_ref, dtype=float), np.isnan(output_ref).any(), np.max(output_ref), np.min(output_ref)))
print("sum_output_test {}, isnan {}, max {}, min {}".format(
np.sum(output_test, dtype=float), np.isnan(output_test).any(), np.max(output_test), np.min(output_test)))
ret = np.array_equal(output_ref, output_test)
if not ret:
ret_allclose = np.allclose(
output_ref, output_test, rtol=1e-3, atol=1e-3, equal_nan=True)
if self.verbose:
print('{}\tshape {}\tidentical {}\tclose {}'.format('cpu/gpu', tensors, ret, ret_allclose))
output_ref = output_ref.flatten()
output_test = output_test.flatten()
if not ret:
sub = np.absolute(output_ref - output_test)
norm_diff = np.average(sub)
rel = np.divide(sub, np.absolute(output_ref))
rel[rel == np.inf] = 0
max_abs_idx = np.argmax(sub)
max_rel_idx = np.argmax(rel)
if self.verbose:
print('max_diff {}, max_rel_diff {}, norm_diff {}'.format(np.max(sub), np.max(rel), np.average(sub)))
print('max_abs pair [{}] {} {}'.format(max_abs_idx, output_ref[max_abs_idx], output_test[max_abs_idx]))
print('max_rel pair [{}] {} {}'.format(max_rel_idx, output_ref[max_rel_idx], output_test[max_rel_idx]))
result = ret or ret_allclose or (is_grad and norm_diff < 1e-4)
if self.verbose:
print("Result {}".format("PASS" if result else "FAIL"))
self.assertTrue(result)
if __name__ == '__main__':
unittest.main()
......@@ -243,7 +243,7 @@ if "--bnp" in sys.argv:
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None:
if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--bnp was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
ext_modules.append(
......@@ -252,7 +252,8 @@ if "--bnp" in sys.argv:
'apex/contrib/csrc/groupbn/ipc.cu',
'apex/contrib/csrc/groupbn/interface.cpp',
'apex/contrib/csrc/groupbn/batch_norm_add_relu.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
include_dirs=[os.path.join(this_dir, 'csrc'),
os.path.join(this_dir, 'apex/contrib/csrc/groupbn')],
extra_compile_args={'cxx': [] + version_dependent_macros,
'nvcc':['-DCUDA_HAS_FP16=1',
'-D__CUDA_NO_HALF_OPERATORS__',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment