Unverified Commit e57c84e0 authored by sarunyap's avatar sarunyap Committed by GitHub
Browse files

Enable group batch norm (--bnp) on ROCm (only bn_group = 1) (#51)

* Enable group batch norm (--bnp) on ROCm (only bn_group = 1)

Enable NHWC group batch norm on a single GPU on ROCm (bn_group = 1).
The multi-GPU case (bn_group > 1) will be revisited in the future.

The following are the main changes:

1) Use MIOpen data structures/functions in HIP instead of CUDNN
2) For the warp-level primitive code, we ensure that the code operates
   on 64-thread wide warp instead of 32-thread wide
3) Disable all the bn_group > 1 paths

Notes:

1) Multi-stream is not tested.
2) We have not optimized for performance

* Fix bnp hipification

Avoid calling hipify-perl in setup.py and rely on PyTorch's internal
hipification mechanism.

* Make bnp data pointers contiguous

The contrib group batch norm implementation assumes that all input
tensors are contiguous.  When non-contiguous tensors are passed to the
function, it gives a wrong result.  This commit explicitly calls
.contiguous() to make all input tensors contiguous before accessing
them.

* Fix HIP lane id in bnp

Fix typo

* Fix ReLU bitmask for HIP in bnp

The ReLU bitmask is derived by using the __ballot function which returns
a 64-bit value in HIP.  This commit fixes the ReLU bitmask storage size
and offsets on ROCm.

This patch also fixes the kernel to set ReLU bitmask to 1 when the data
is less than or equal to zero (not only less than).  Not doing so can
cause a stability issue.

* Remove multiple of 64 offset for HIP in bnp

The multiple of 64 offset is not necessary.

* Use FP16 intermediate output to determine whether to rectify in bnp

Group batch norm takes FP16 tensors and produces the FP16 output,
however, all arithmetic operations are done in FP32, thus intermediate
outputs are in FP32.  For the fusion kernels, ReLU determines the FP32
intermediate output to decide whether to rectify it.  ReLU must rectify
the intermediate output if it is less than or "equal" to zero.  There is
a chance that the intermediate FP32 output is very close to zero, and
when it is converted to FP16, it becomes zero.  In this case, this
output is not rectified when it should be.  Since the output is not
rectified in the forward pass, the gradient is not rectified in the
backward pass.  This can cause a stability issue.

This patch can have a negative impact on the performance of group batch
norm as we perform FP32-FP16 conversion multiple times.

* Disable dispatchX ParallelSums in HIP in bnp

dispatchX is not required for the bn_group = 1 case.

* Use traditional load/store for HIP in bnp

The built-in function has a high floating point rounding error.  Thus,
we replace it with the traditional load/store.  Doing so breaks the
aligned pointer property in the load/store functions.  We conservatively
use traditional load/store for all memory access.

* Replace shfl_down with shfl_sync in parallel sums for HIP in bnp

This commit separates the HIP code from the CUDA code in parallel sums

* Remove -U__HIP_NO_HALF_CONVERSIONS__ for HIP in bnp

Since the built-in function is removed, -U__HIP_NO_HALF_CONVERSIONS__ is
no longer needed.

* Preserve CUDA's ReLU condition path for USE_ADD_RELU in bnp

* Add test for bnp

The test evaluates correctness of batch norm, batch norm + ReLU, and
batch norm + add + ReLU against the reference implementation.

For the forward activation output, we validate it against the PyTorch's
implementation.  The group batch norm activation output must be allclose
with the PyTorch activation output for the test to pass.

For the backward gradient output, we validate it against the Python
implementation.  Due to the floating point rounding error in the batch
norm implementation, the group batch norm gradient output might not be
allclose with the Python implementation output when ReLU is being used
although the majority of the elements are very close to each other.
Thus, we use the norm difference threshold to determine whether the test
is passed or failed instead of allclose.

* Use the warp size variable than hard coding the warp size in bnp

Use C10_WARP_SIZE from c10/macros/Macros.h in the host functions and use
warpSize in the device kernels instead of hard coding the warp size.
parent 37d8410c
......@@ -83,19 +83,21 @@ at::Tensor nhwc_bn_fwd_train(
// Create wrapper
NhwcBatchNorm *bn = new NhwcBatchNorm();
bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);
bn->setInputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.DATA_PTR<at::Half>(),
bn->setInputOutputPointers(x.contiguous().DATA_PTR<at::Half>(),
nullptr,
y.DATA_PTR<at::Half>(),
nullptr);
bn->setWeightPointers({scale.DATA_PTR<float>(), bias.DATA_PTR<float>()}, {nullptr, nullptr});
bn->setParameterPointers({running_mean.DATA_PTR<float>(), running_inv_var.DATA_PTR<float>()});
bn->setWeightPointers({scale.contiguous().DATA_PTR<float>(),
bias.contiguous().DATA_PTR<float>()}, {nullptr, nullptr});
bn->setParameterPointers({running_mean.contiguous().DATA_PTR<float>(),
running_inv_var.DATA_PTR<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
......@@ -116,12 +118,12 @@ at::Tensor nhwc_bn_fwd_train(
Workspace ws(total_workspace_bytes);
std::vector<void *> workspace;
workspace.push_back(minibatch_mean.DATA_PTR<float>());
workspace.push_back(minibatch_inv_var.DATA_PTR<float>());
workspace.push_back(minibatch_mean.contiguous().DATA_PTR<float>());
workspace.push_back(minibatch_inv_var.contiguous().DATA_PTR<float>());
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[2];
void* retired_ctas = ret_cta.DATA_PTR<uint8_t>();
void* retired_ctas = ret_cta.contiguous().DATA_PTR<uint8_t>();
assert(ret_cta.size(0)>=retired_cta_bytes);
workspace.push_back(retired_ctas);
......@@ -161,19 +163,21 @@ at::Tensor nhwc_bn_fwd_eval(
// Create wrapper
NhwcBatchNorm *bn = new NhwcBatchNorm();
bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);
bn->setInputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.DATA_PTR<at::Half>(),
bn->setInputOutputPointers(x.contiguous().DATA_PTR<at::Half>(),
nullptr,
y.DATA_PTR<at::Half>(),
nullptr);
bn->setWeightPointers({scale.DATA_PTR<float>(), bias.DATA_PTR<float>()}, {nullptr, nullptr});
bn->setParameterPointers({running_mean.DATA_PTR<float>(), running_inv_var.DATA_PTR<float>()});
bn->setWeightPointers({scale.contiguous().DATA_PTR<float>(),
bias.contiguous().DATA_PTR<float>()}, {nullptr, nullptr});
bn->setParameterPointers({running_mean.contiguous().DATA_PTR<float>(),
running_inv_var.contiguous().DATA_PTR<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
......@@ -199,7 +203,7 @@ at::Tensor nhwc_bn_fwd_eval(
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[2];
void* retired_ctas = ret_cta.DATA_PTR<uint8_t>();
void* retired_ctas = ret_cta.contiguous().DATA_PTR<uint8_t>();
assert(ret_cta.size(0)>=retired_cta_bytes);
workspace.push_back(retired_ctas);
......@@ -260,19 +264,23 @@ std::vector<at::Tensor> nhwc_bn_bwd(
// Create wrapper
NhwcBatchNorm *bn = new NhwcBatchNorm();
bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);
bn->setInputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.DATA_PTR<at::Half>(),
bn->setInputOutputPointers(x.contiguous().DATA_PTR<at::Half>(),
x_grad.DATA_PTR<at::Half>(),
nullptr,
dy.DATA_PTR<at::Half>());
dy.contiguous().DATA_PTR<at::Half>());
bn->setWeightPointers({scale.DATA_PTR<float>(), bias.DATA_PTR<float>()}, {scale_grad.DATA_PTR<float>(), bias_grad.DATA_PTR<float>()});
bn->setParameterPointers({running_mean.DATA_PTR<float>(), running_inv_var.DATA_PTR<float>()});
bn->setWeightPointers({scale.contiguous().DATA_PTR<float>(),
bias.contiguous().DATA_PTR<float>()},
{scale_grad.DATA_PTR<float>(),
bias_grad.DATA_PTR<float>()});
bn->setParameterPointers({running_mean.contiguous().DATA_PTR<float>(),
running_inv_var.contiguous().DATA_PTR<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
......@@ -293,12 +301,12 @@ std::vector<at::Tensor> nhwc_bn_bwd(
Workspace ws(total_workspace_bytes);
std::vector<void *> workspace;
workspace.push_back(minibatch_mean.DATA_PTR<float>());
workspace.push_back(minibatch_inv_var.DATA_PTR<float>());
workspace.push_back(minibatch_mean.contiguous().DATA_PTR<float>());
workspace.push_back(minibatch_inv_var.contiguous().DATA_PTR<float>());
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[2];
void* retired_ctas = ret_cta.DATA_PTR<uint8_t>();
void* retired_ctas = ret_cta.contiguous().DATA_PTR<uint8_t>();
assert(ret_cta.size(0)>=retired_cta_bytes);
workspace.push_back(retired_ctas);
......
......@@ -26,7 +26,7 @@
#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_
#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_
#include <cudnn.h>
#include "dnn.h"
#include <algorithm>
#include <vector>
......@@ -34,6 +34,7 @@
#include "nhwc_batch_norm_kernel.h"
#include "cuda_utils.h"
#include "c10/macros/Macros.h"
#define VERBOSE_DEFAULT false
......@@ -62,8 +63,8 @@ class NhwcBatchNorm {
dim3 calc_fwd_grid(int *loop, const int grid_dim_x);
dim3 calc_bwd_grid(int *loop, const int grid_dim_x);
void setInputDescriptor(const cudnnTensorFormat_t format,
const cudnnDataType_t data_type,
void setInputDescriptor(const dnnTensorFormat_t format,
const dnnDataType_t data_type,
int n, int c, int h, int w, int bn_group) {
m_ = n * h * w;
int m_bn_adjusted = m_ * bn_group;
......@@ -77,8 +78,8 @@ class NhwcBatchNorm {
setTensorDescriptor(X_tensor_desc_, format, data_type, n, c, h, w);
}
void setOutputDescriptor(const cudnnTensorFormat_t format,
const cudnnDataType_t data_type,
void setOutputDescriptor(const dnnTensorFormat_t format,
const dnnDataType_t data_type,
int n, int c, int h, int w) {
setTensorDescriptor(Y_tensor_desc_, format, data_type, n, c, h, w);
}
......@@ -119,13 +120,20 @@ class NhwcBatchNorm {
eps_ = eps;
}
void processCudnnStatus(const cudnnStatus_t& status,
void processCudnnStatus(const dnnStatus_t& status,
const std::string& string = std::string(),
bool verbose = VERBOSE_DEFAULT) {
if (status != CUDNN_STATUS_SUCCESS)
#ifdef __HIP_PLATFORM_HCC__
if (status != DNN_STATUS_SUCCESS)
LOG(FATAL) << string << " " << miopenGetErrorString(status);
else if (verbose)
LOG(INFO) << string << " " << miopenGetErrorString(status);
#else
if (status != DNN_STATUS_SUCCESS)
LOG(FATAL) << string << " " << cudnnGetErrorString(status);
else if (verbose)
LOG(INFO) << string << " " << cudnnGetErrorString(status);
#endif
}
void checkCudaStatus(const std::string& string = std::string(),
......@@ -148,8 +156,8 @@ class NhwcBatchNorm {
return retired_cta_bytes;
}
cudnnTensorDescriptor_t X_tensor_desc_ = nullptr;
cudnnTensorDescriptor_t Y_tensor_desc_ = nullptr;
dnnTensorDescriptor_t X_tensor_desc_ = nullptr;
dnnTensorDescriptor_t Y_tensor_desc_ = nullptr;
void* X_ = nullptr;
void* dX_ = nullptr;
......@@ -181,24 +189,36 @@ class NhwcBatchNorm {
std::string name_;
private:
void setTensorDescriptor(cudnnTensorDescriptor_t descriptor,
cudnnTensorFormat_t format,
cudnnDataType_t data_type,
void setTensorDescriptor(dnnTensorDescriptor_t descriptor,
dnnTensorFormat_t format,
dnnDataType_t data_type,
int n, int c, int h, int w) {
cudnnStatus_t status = CUDNN_STATUS_SUCCESS;
dnnStatus_t status = DNN_STATUS_SUCCESS;
#ifdef __HIP_PLATFORM_HCC__
status = miopenSet4dTensorDescriptor(descriptor, data_type, n, c, h, w);
#else
status = cudnnSetTensor4dDescriptor(descriptor, format, data_type, n, c, h, w);
#endif
processCudnnStatus(status, "set tensor descriptor");
}
void createTensorDescriptor(cudnnTensorDescriptor_t *descriptor) {
cudnnStatus_t status = CUDNN_STATUS_SUCCESS;
void createTensorDescriptor(dnnTensorDescriptor_t *descriptor) {
dnnStatus_t status = DNN_STATUS_SUCCESS;
#ifdef __HIP_PLATFORM_HCC__
status = miopenCreateTensorDescriptor(descriptor);
#else
status = cudnnCreateTensorDescriptor(descriptor);
#endif
processCudnnStatus(status, "create tensor_descriptor");
}
void destroyTensorDescriptor(cudnnTensorDescriptor_t descriptor) {
cudnnStatus_t status = CUDNN_STATUS_SUCCESS;
void destroyTensorDescriptor(dnnTensorDescriptor_t descriptor) {
dnnStatus_t status = DNN_STATUS_SUCCESS;
#ifdef __HIP_PLATFORM_HCC__
status = miopenDestroyTensorDescriptor(descriptor);
#else
status = cudnnDestroyTensorDescriptor(descriptor);
#endif
processCudnnStatus(status, "destroy tensor_descriptor");
}
......@@ -258,6 +278,57 @@ class NhwcBatchNorm {
void _fwdKernelLauncher(cudaStream_t stream, NhwcBatchNormFwdParams params,
dim3 grid_dim, int outer_loops, bool use_relu, const int occupancy, const bool coop) {
#ifdef __HIP_PLATFORM_HCC__
#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto fwd_func = nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
hipFuncSetAttribute((void *) fwd_func, hipFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " fwd ser coop kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(&params); \
using FWD_FUNC = decltype(nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
hipLaunchCooperativeKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_FWD, \
stream); \
} else { \
hipLaunchKernel((void *) fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_FWD, \
stream); \
} \
checkCudaStatus(name_ + " fwd ser coop kernel"); \
} while (0)
#else
#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
......@@ -307,6 +378,7 @@ class NhwcBatchNorm {
} \
checkCudaStatus(name_ + " fwd ser coop kernel"); \
} while (0)
#endif
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
if (outer_loops == 1 && use_relu) {
......@@ -337,6 +409,99 @@ class NhwcBatchNorm {
void _bwdKernelLauncher(cudaStream_t stream, NhwcBatchNormBwdParams params,
dim3 grid_dim, int outer_loops, bool use_relu, const int occupancy, const bool coop) {
#ifdef __HIP_PLATFORM_HCC__
#define LAUNCH_BWD_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto bwd_func = nhwc_batch_norm_bwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
hipFuncSetAttribute((void *) bwd_func, hipFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " bwd coop serial kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(&params); \
using BWD_FUNC = decltype(nhwc_batch_norm_bwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
hipLaunchCooperativeKernel<BWD_FUNC>(bwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} else { \
hipLaunchKernel((void *) bwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} \
checkCudaStatus(name_ + " bwd coop serial kernel"); \
} while (0)
#define LAUNCH_BWD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto bwd_relu_func = nhwc_batch_norm_bwd_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
hipFuncSetAttribute((void *) bwd_relu_func, hipFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " bwd-relu coop serial kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(&params); \
using BWD_RELU_FUNC = decltype(nhwc_batch_norm_bwd_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
hipLaunchCooperativeKernel<BWD_RELU_FUNC>(bwd_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} else { \
hipLaunchKernel((void *) bwd_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} \
checkCudaStatus(name_ + " bwd-relu coop serial kernel"); \
} while (0)
#else
#define LAUNCH_BWD_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
......@@ -428,6 +593,7 @@ class NhwcBatchNorm {
} \
checkCudaStatus(name_ + " bwd-relu coop serial kernel"); \
} while (0)
#endif
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
if (outer_loops == 1 && use_relu) {
......@@ -459,7 +625,7 @@ class NhwcBatchNorm {
// Calculate the expected fwd kernel occupancy, as dictated by shared memory usage.
static int smem_driven_fwd_occupancy(int device_id, const int max_cta_per_sm) {
using namespace at::cuda::utils;
int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG*sizeof(float);
int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG*sizeof(float);
int fwd_smem_bytes = SMEM_SIZE_FWD + fwd_reduction_bytes;
int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / fwd_smem_bytes;
return std::min(max_cta_per_sm, occupancy);
......@@ -468,7 +634,7 @@ class NhwcBatchNorm {
// Calculate the expected bwd kernel occupancy, as dictated by shared memory usage.
static int smem_driven_bwd_occupancy(int device_id, const int max_cta_per_sm) {
using namespace at::cuda::utils;
int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG*sizeof(float);
int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG*sizeof(float);
int bwd_smem_bytes = SMEM_SIZE_BWD + bwd_reduction_bytes;
int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / bwd_smem_bytes;
return std::min(max_cta_per_sm, occupancy);
......
......@@ -85,21 +85,23 @@ at::Tensor nhwc_bn_addrelu_fwd_train(
// Create wrapper
NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu();
bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);
bn->setInputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.DATA_PTR<at::Half>(),
bn->setInputOutputPointers(x.contiguous().DATA_PTR<at::Half>(),
nullptr,
y.DATA_PTR<at::Half>(),
nullptr,
z.DATA_PTR<at::Half>(),
z.contiguous().DATA_PTR<at::Half>(),
nullptr);
bn->setWeightPointers({scale.DATA_PTR<float>(), bias.DATA_PTR<float>()}, {nullptr, nullptr});
bn->setParameterPointers({running_mean.DATA_PTR<float>(), running_inv_var.DATA_PTR<float>()});
bn->setWeightPointers({scale.contiguous().DATA_PTR<float>(),
bias.contiguous().DATA_PTR<float>()}, {nullptr, nullptr});
bn->setParameterPointers({running_mean.contiguous().DATA_PTR<float>(),
running_inv_var.contiguous().DATA_PTR<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
......@@ -120,13 +122,13 @@ at::Tensor nhwc_bn_addrelu_fwd_train(
Workspace ws(total_workspace_bytes);
std::vector<void *> workspace;
workspace.push_back(minibatch_mean.DATA_PTR<float>());
workspace.push_back(minibatch_inv_var.DATA_PTR<float>());
workspace.push_back(bitmask.DATA_PTR<int32_t>());
workspace.push_back(minibatch_mean.contiguous().DATA_PTR<float>());
workspace.push_back(minibatch_inv_var.contiguous().DATA_PTR<float>());
workspace.push_back(bitmask.contiguous().DATA_PTR<bitmask_pyt_t>());
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[3];
void* retired_ctas = ret_cta.DATA_PTR<uint8_t>();
void* retired_ctas = ret_cta.contiguous().DATA_PTR<uint8_t>();
assert(ret_cta.size(0)>=retired_cta_bytes);
workspace.push_back(retired_ctas);
......@@ -167,21 +169,23 @@ at::Tensor nhwc_bn_addrelu_fwd_eval(
// Create wrapper
NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu();
bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);
bn->setInputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.DATA_PTR<at::Half>(),
bn->setInputOutputPointers(x.contiguous().DATA_PTR<at::Half>(),
nullptr,
y.DATA_PTR<at::Half>(),
nullptr,
z.DATA_PTR<at::Half>(),
z.contiguous().DATA_PTR<at::Half>(),
nullptr);
bn->setWeightPointers({scale.DATA_PTR<float>(), bias.DATA_PTR<float>()}, {nullptr, nullptr});
bn->setParameterPointers({running_mean.DATA_PTR<float>(), running_inv_var.DATA_PTR<float>()});
bn->setWeightPointers({scale.contiguous().DATA_PTR<float>(),
bias.contiguous().DATA_PTR<float>()}, {nullptr, nullptr});
bn->setParameterPointers({running_mean.contiguous().DATA_PTR<float>(),
running_inv_var.contiguous().DATA_PTR<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
......@@ -208,7 +212,7 @@ at::Tensor nhwc_bn_addrelu_fwd_eval(
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[3];
void* retired_ctas = ret_cta.DATA_PTR<uint8_t>();
void* retired_ctas = ret_cta.contiguous().DATA_PTR<uint8_t>();
assert(ret_cta.size(0)>=retired_cta_bytes);
workspace.push_back(retired_ctas);
......@@ -270,21 +274,24 @@ std::vector<at::Tensor> nhwc_bn_addrelu_bwd(
// Create wrapper
NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu();
bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);
bn->setInputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.DATA_PTR<at::Half>(),
bn->setInputOutputPointers(x.contiguous().DATA_PTR<at::Half>(),
x_grad.DATA_PTR<at::Half>(),
nullptr,
dy.DATA_PTR<at::Half>(),
dy.contiguous().DATA_PTR<at::Half>(),
nullptr,
z_grad.DATA_PTR<at::Half>());
bn->setWeightPointers({scale.DATA_PTR<float>(), bias.DATA_PTR<float>()}, {scale_grad.DATA_PTR<float>(), bias_grad.DATA_PTR<float>()});
bn->setParameterPointers({running_mean.DATA_PTR<float>(), running_inv_var.DATA_PTR<float>()});
bn->setWeightPointers({scale.contiguous().DATA_PTR<float>(),
bias.contiguous().DATA_PTR<float>()},
{scale_grad.DATA_PTR<float>(), bias_grad.DATA_PTR<float>()});
bn->setParameterPointers({running_mean.contiguous().DATA_PTR<float>(),
running_inv_var.contiguous().DATA_PTR<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
......@@ -305,13 +312,13 @@ std::vector<at::Tensor> nhwc_bn_addrelu_bwd(
Workspace ws(total_workspace_bytes);
std::vector<void *> workspace;
workspace.push_back(minibatch_mean.DATA_PTR<float>());
workspace.push_back(minibatch_inv_var.DATA_PTR<float>());
workspace.push_back(bitmask.DATA_PTR<int32_t>());
workspace.push_back(minibatch_mean.contiguous().DATA_PTR<float>());
workspace.push_back(minibatch_inv_var.contiguous().DATA_PTR<float>());
workspace.push_back(bitmask.contiguous().DATA_PTR<bitmask_pyt_t>());
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[3];
void* retired_ctas = ret_cta.DATA_PTR<uint8_t>();
void* retired_ctas = ret_cta.contiguous().DATA_PTR<uint8_t>();
assert(ret_cta.size(0)>=retired_cta_bytes);
workspace.push_back(retired_ctas);
......
......@@ -26,7 +26,7 @@
#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_
#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_
#include <cudnn.h>
#include "dnn.h"
#include <algorithm>
#include <vector>
......@@ -34,7 +34,15 @@
#include "nhwc_batch_norm_kernel.h"
#include "cuda_utils.h"
#include "c10/macros/Macros.h"
#ifdef __HIP_PLATFORM_HCC__
using bitmask_t = uint64_t;
using bitmask_pyt_t = int64_t;
#else
using bitmask_t = unsigned int;
using bitmask_pyt_t = int32_t;
#endif
#define VERBOSE_DEFAULT false
......@@ -62,8 +70,8 @@ class NhwcBatchNormAddRelu {
dim3 calc_fwd_grid(int *loop, const int grid_dim_x);
dim3 calc_bwd_grid(int *loop, const int grid_dim_x);
void setInputDescriptor(const cudnnTensorFormat_t format,
const cudnnDataType_t data_type,
void setInputDescriptor(const dnnTensorFormat_t format,
const dnnDataType_t data_type,
int n, int c, int h, int w, int bn_group) {
m_ = n * h * w;
int m_bn_adjusted = m_ * bn_group;
......@@ -77,8 +85,8 @@ class NhwcBatchNormAddRelu {
setTensorDescriptor(X_tensor_desc_, format, data_type, n, c, h, w);
}
void setOutputDescriptor(const cudnnTensorFormat_t format,
const cudnnDataType_t data_type,
void setOutputDescriptor(const dnnTensorFormat_t format,
const dnnDataType_t data_type,
int n, int c, int h, int w) {
setTensorDescriptor(Y_tensor_desc_, format, data_type, n, c, h, w);
}
......@@ -121,13 +129,20 @@ class NhwcBatchNormAddRelu {
eps_ = eps;
}
void processCudnnStatus(const cudnnStatus_t& status,
void processCudnnStatus(const dnnStatus_t& status,
const std::string& string = std::string(),
bool verbose = VERBOSE_DEFAULT) {
if (status != CUDNN_STATUS_SUCCESS)
#ifdef __HIP_PLATFORM_HCC__
if (status != DNN_STATUS_SUCCESS)
LOG(FATAL) << string << " " << miopenGetErrorString(status);
else if (verbose)
LOG(INFO) << string << " " << miopenGetErrorString(status);
#else
if (status != DNN_STATUS_SUCCESS)
LOG(FATAL) << string << " " << cudnnGetErrorString(status);
else if (verbose)
LOG(INFO) << string << " " << cudnnGetErrorString(status);
#endif
}
void checkCudaStatus(const std::string& string = std::string(),
......@@ -150,8 +165,8 @@ class NhwcBatchNormAddRelu {
return retired_cta_bytes;
}
cudnnTensorDescriptor_t X_tensor_desc_ = nullptr;
cudnnTensorDescriptor_t Y_tensor_desc_ = nullptr;
dnnTensorDescriptor_t X_tensor_desc_ = nullptr;
dnnTensorDescriptor_t Y_tensor_desc_ = nullptr;
void* X_ = nullptr;
void* dX_ = nullptr;
......@@ -185,24 +200,36 @@ class NhwcBatchNormAddRelu {
std::string name_;
private:
void setTensorDescriptor(cudnnTensorDescriptor_t descriptor,
cudnnTensorFormat_t format,
cudnnDataType_t data_type,
void setTensorDescriptor(dnnTensorDescriptor_t descriptor,
dnnTensorFormat_t format,
dnnDataType_t data_type,
int n, int c, int h, int w) {
cudnnStatus_t status = CUDNN_STATUS_SUCCESS;
dnnStatus_t status = DNN_STATUS_SUCCESS;
#ifdef __HIP_PLATFORM_HCC__
status = miopenSet4dTensorDescriptor(descriptor, data_type, n, c, h, w);
#else
status = cudnnSetTensor4dDescriptor(descriptor, format, data_type, n, c, h, w);
#endif
processCudnnStatus(status, "set tensor descriptor");
}
void createTensorDescriptor(cudnnTensorDescriptor_t *descriptor) {
cudnnStatus_t status = CUDNN_STATUS_SUCCESS;
void createTensorDescriptor(dnnTensorDescriptor_t *descriptor) {
dnnStatus_t status = DNN_STATUS_SUCCESS;
#ifdef __HIP_PLATFORM_HCC__
status = miopenCreateTensorDescriptor(descriptor);
#else
status = cudnnCreateTensorDescriptor(descriptor);
#endif
processCudnnStatus(status, "create tensor_descriptor");
}
void destroyTensorDescriptor(cudnnTensorDescriptor_t descriptor) {
cudnnStatus_t status = CUDNN_STATUS_SUCCESS;
void destroyTensorDescriptor(dnnTensorDescriptor_t descriptor) {
dnnStatus_t status = DNN_STATUS_SUCCESS;
#ifdef __HIP_PLATFORM_HCC__
status = miopenDestroyTensorDescriptor(descriptor);
#else
status = cudnnDestroyTensorDescriptor(descriptor);
#endif
processCudnnStatus(status, "destroy tensor_descriptor");
}
......@@ -210,7 +237,7 @@ class NhwcBatchNormAddRelu {
float *partial_sums_ = nullptr;
int *partial_counts_ = nullptr;
int *retired_ctas_ = nullptr;
unsigned int *relu_bitmask_ = nullptr;
bitmask_t *relu_bitmask_ = nullptr;
void _setFwdParams(NhwcBatchNormFwdParams *params) const;
void _setFwdInferenceParams(NhwcBatchNormFwdInferenceParams *params) const;
......@@ -261,6 +288,58 @@ class NhwcBatchNormAddRelu {
// needless register spills.
void _fwdKernelLauncher(cudaStream_t stream, NhwcBatchNormFwdParams params,
dim3 grid_dim, int outer_loops, const int occupancy, const bool coop) {
#ifdef __HIP_PLATFORM_HCC__
#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << \
"Nhwc batchnormaddrelu kernel smem too big."; \
auto fwd_func = nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
hipFuncSetAttribute((void *) fwd_func, hipFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " fwd ser coop kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(&params); \
using FWD_FUNC = decltype(nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
hipLaunchCooperativeKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_FWD, \
stream); \
} else { \
hipLaunchKernel((void *) fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_FWD, \
stream); \
} \
checkCudaStatus(name_ + " fwd ser coop kernel"); \
} while (0)
#else
#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << \
......@@ -311,6 +390,7 @@ class NhwcBatchNormAddRelu {
} \
checkCudaStatus(name_ + " fwd ser coop kernel"); \
} while (0)
#endif
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
if (outer_loops == 1) {
......@@ -331,7 +411,56 @@ class NhwcBatchNormAddRelu {
void _bwdKernelLauncher(cudaStream_t stream, NhwcBatchNormBwdParams params,
dim3 grid_dim, int outer_loops, const int occupancy, const bool coop) {
#ifdef __HIP_PLATFORM_HCC__
#define LAUNCH_BWD_ADD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << \
"Nhwc batchnormaddrelu kernel smem too big."; \
auto bwd_add_relu_func = nhwc_batch_norm_bwd_add_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
hipFuncSetAttribute((void *) bwd_add_relu_func, \
hipFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + \
" bwd-add-relu coop serial kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(&params); \
using BWD_ADD_RELU_FUNC = decltype(nhwc_batch_norm_bwd_add_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
if (COOP) { \
hipLaunchCooperativeKernel<BWD_ADD_RELU_FUNC>(bwd_add_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} else { \
hipLaunchKernel((void *) bwd_add_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} \
checkCudaStatus(name_ + " bwd-add-relu coop serial kernel"); \
} while (0)
#else
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << \
"Nhwc batchnormaddrelu kernel smem too big."; \
......@@ -379,6 +508,7 @@ class NhwcBatchNormAddRelu {
} \
checkCudaStatus(name_ + " bwd-add-relu coop serial kernel"); \
} while (0)
#endif
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
if (outer_loops == 1) {
......@@ -399,7 +529,7 @@ class NhwcBatchNormAddRelu {
// Calculate the expected fwd kernel occupancy, as dictated by shared memory usage.
static int smem_driven_fwd_occupancy(int device_id, const int max_cta_per_sm) {
using namespace at::cuda::utils;
int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG*sizeof(float);
int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG*sizeof(float);
int fwd_smem_bytes = SMEM_SIZE_FWD + fwd_reduction_bytes;
int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / fwd_smem_bytes;
return std::min(max_cta_per_sm, occupancy);
......@@ -408,7 +538,7 @@ class NhwcBatchNormAddRelu {
// Calculate the expected bwd kernel occupancy, as dictated by shared memory usage.
static int smem_driven_bwd_occupancy(int device_id, const int max_cta_per_sm) {
using namespace at::cuda::utils;
int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG*sizeof(float);
int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG*sizeof(float);
int bwd_smem_bytes = SMEM_SIZE_BWD + bwd_reduction_bytes;
int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / bwd_smem_bytes;
return std::min(max_cta_per_sm, occupancy);
......@@ -427,9 +557,13 @@ const std::vector<size_t> NhwcBatchNormAddRelu::numWorkspaceBytes() const {
const size_t num_mean_bytes = c_ * sizeof(float);
const size_t num_variance_bytes = num_mean_bytes;
#ifdef __HIP_PLATFORM_HCC__
int elems_per_group = ((m_ + 3) & ~3);
#else
int elems_per_group = ((m_ + 31) & ~31) * 2;
#endif
int group_count = div_up(c_, C_ELEMENTS_PER_CTA);
const size_t bitmask_bytes = elems_per_group * group_count * sizeof(unsigned int);
const size_t bitmask_bytes = elems_per_group * group_count * sizeof(bitmask_t);
const size_t size_sums = grid_y*grid_x*THREADS_PER_PIXEL*\
ELEMENTS_PER_LDG*2*sizeof(float);
......@@ -447,7 +581,7 @@ void NhwcBatchNormAddRelu::setWorkspacePointers(
minibatch_mean_ = static_cast<float*>(workspace[0]);
minibatch_variance_ = static_cast<float*>(workspace[1]);
relu_bitmask_ = static_cast<unsigned int*>(workspace[2]);
relu_bitmask_ = static_cast<bitmask_t*>(workspace[2]);
retired_ctas_ = static_cast<int*>(workspace[3]);
partial_sums_ = static_cast<float*>(workspace[4]);
partial_counts_ = static_cast<int*>(workspace[5]);
......
#ifdef __HIP_PLATFORM_HCC__
#include <ATen/hip/HIPContext.h>
#else
#include <ATen/cuda/CUDAContext.h>
#endif
#ifndef CUDA_UTILS_H
#define CUDA_UTILS_H
......@@ -8,7 +12,11 @@ namespace cuda {
namespace utils {
static inline int MaxSharedMemoryPerMultiprocessor(int device_id) {
#ifdef __HIP_PLATFORM_HCC__
return getDeviceProperties(device_id)->maxSharedMemoryPerMultiProcessor;
#else
return getDeviceProperties(device_id)->sharedMemPerMultiprocessor;
#endif
}
......
#ifndef DNN_H
#define DNN_H
#ifdef __HIP_PLATFORM_HCC__
#include <miopen/miopen.h>
#define DNN_STATUS_SUCCESS miopenStatusSuccess
#define DNN_DATA_HALF miopenHalf
#define DNN_TENSOR_FORMAT 0
using dnnTensorFormat_t = int;
using dnnDataType_t = miopenDataType_t;
using dnnStatus_t = miopenStatus_t;
using dnnTensorDescriptor_t = miopenTensorDescriptor_t;
#else
#include <cudnn.h>
#define DNN_STATUS_SUCCESS CUDNN_STATUS_SUCCESS
#define DNN_DATA_HALF CUDNN_DATA_HALF
#define DNN_TENSOR_FORMAT CUDNN_TENSOR_NHWC
using dnnTensorFormat_t = cudnnTensorFormat_t;
using dnnDataType_t = cudnnDataType_t;
using dnnStatus_t = cudnnStatus_t;
using dnnTensorDescriptor_t = cudnnTensorDescriptor_t;
#endif
#endif // DNN_H
......@@ -26,9 +26,24 @@
#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_
#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_
#ifdef __HIP_PLATFORM_HCC__
#include <hip/hip_runtime.h>
#include <hip/hip_runtime_api.h>
#include <hip/hip_fp16.h>
#endif
#include <stdint.h>
#include <algorithm>
#ifdef __HIP_PLATFORM_HCC__
using bitmask_t = uint64_t;
#define BITMASK_OFFSET 1
#define ONE_BITMASK 1UL
#else
using bitmask_t = unsigned int;
#define BITMASK_OFFSET 2
#define ONE_BITMASK 1U
#endif
#define DEVICE_FUNCTION static inline __device__
// CTA margin used by cooperative launch. Can be overridden by env var NHWC_BATCHNORM_LAUNCH_MARGIN.
......@@ -37,6 +52,37 @@
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void syncwarp() {
#ifdef __HIP_PLATFORM_HCC__
__builtin_amdgcn_wave_barrier();
#else
__syncwarp();
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
DEVICE_FUNCTION T shfl_sync(T var, int src_lane) {
#ifdef __HIP_PLATFORM_HCC__
return __shfl(var, src_lane);
#else
return __shfl_sync(0xFFFFFFFFU, var, src_lane);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION bitmask_t ballot(int predicate) {
#ifdef __HIP_PLATFORM_HCC__
return __ballot(predicate);
#else
return __ballot_sync(0xFFFFFFFFU, predicate);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename T, int ELEMENTS_PER_LDG >
struct PackedStorage {
enum { PACKED_ELEMENTS_PER_LDG = ELEMENTS_PER_LDG };
......@@ -55,12 +101,20 @@ struct PackedStorage<uint16_t, ELEMENTS_PER_LDG> {
template< int N >
DEVICE_FUNCTION void from_float(int (&dst)[N], const float (&src)[2*N]) {
// Convert from two f32s to two f16s (mantissa LSB rounds to nearest even)
// (From 64-bit to 32-bit)
half *dst_ = (half *) dst;
#pragma unroll
for (int i = 0; i < N; ++i) {
#ifdef __HIP_PLATFORM_HCC__
dst_[2*i] = __float2half(src[2*i]);
dst_[2*i+1] = __float2half(src[2*i+1]);
#else
uint16_t lo, hi;
asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(lo) : "f"(src[2*i+0]));
asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(hi) : "f"(src[2*i+1]));
asm volatile("mov.b32 %0, {%1, %2};" : "=r"(dst[i]) : "h"(lo), "h"(hi));
#endif
}
}
......@@ -78,12 +132,19 @@ DEVICE_FUNCTION void from_float(float (&dst)[N], const float (&src)[N]) {
template< int N >
DEVICE_FUNCTION void to_float(float (&dst)[2*N], int (&src)[N]) {
// Convert from two f16s to two f32s (From 32-bit to 64-bit)
#pragma unroll
for (int i = 0; i < N; ++i) {
#ifdef __HIP_PLATFORM_HCC__
half *src_ = (half *) src;
dst[2*i] = __half2float(src_[2*i]);
dst[2*i+1] = __half2float(src_[2*i+1]);
#else
uint16_t lo, hi;
asm volatile("mov.b32 {%0, %1}, %2;" : "=h"(lo), "=h"(hi) : "r"(src[i]));
asm volatile("cvt.f32.f16 %0, %1;" : "=f"(dst[2*i+0]) : "h"(lo));
asm volatile("cvt.f32.f16 %0, %1;" : "=f"(dst[2*i+1]) : "h"(hi));
#endif
}
}
......@@ -106,9 +167,13 @@ DEVICE_FUNCTION void ldg(int (&dst)[1], const uint16_t *gmem) {
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void ldg_stream(int (&dst)[1], const uint16_t *gmem) {
#ifdef __HIP_PLATFORM_HCC__
dst[0] = __ldg((const int*) gmem);
#else
unsigned int tmp;
asm volatile ("ld.global.cs.nc.s32 %0, [%1];" : "=r"(tmp) : "l" ((const uint *)gmem));
dst[0] = tmp;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -122,11 +187,17 @@ DEVICE_FUNCTION void ldg(int (&dst)[2], const uint16_t *gmem) {
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void ldg_stream(int (&dst)[2], const uint16_t *gmem) {
#ifdef __HIP_PLATFORM_HCC__
int2 tmp = __ldg((const int2*) gmem);
dst[0] = tmp.x;
dst[1] = tmp.y;
#else
int2 tmp;
asm volatile ("ld.global.cs.nc.v2.s32 {%0,%1}, [%2];"
: "=r"(tmp.x), "=r"(tmp.y) : "l"((const int2 *)gmem));
dst[0] = tmp.x;
dst[1] = tmp.y;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -156,22 +227,42 @@ DEVICE_FUNCTION void stg(uint16_t *gmem, int (&src)[1]) {
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void stg_stream(uint16_t *gmem, int (&src)[1]) {
#ifdef __HIP_PLATFORM_HCC__
reinterpret_cast<int*>(gmem)[0] = src[0];
#else
unsigned int tmp = src[0];
asm volatile ("st.global.cs.s32 [%0], %1;"
:: "l"((uint *)gmem) , "r"(tmp));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void stg(uint16_t *gmem, int (&src)[2]) {
#ifdef __HIP_PLATFORM_HCC__
half *gmem_ = (half *) gmem;
half *src_ = (half *) src;
for (int i = 0; i < 4; i++) {
gmem_[i] = src_[i];
}
#else
reinterpret_cast<int2*>(gmem)[0] = make_int2(src[0], src[1]);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void stg_stream(uint16_t *gmem, int (&src)[2]) {
#ifdef __HIP_PLATFORM_HCC__
half *gmem_ = (half *) gmem;
half *src_ = (half *) src;
for (int i = 0; i < 4; i++) {
gmem_[i] = src_[i];
}
#else
asm volatile ("st.global.cs.v2.s32 [%0], {%1,%2};"
:: "l"((uint *)gmem) , "r"(src[0]), "r"( src[1]));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -194,28 +285,65 @@ DEVICE_FUNCTION void stg_stream(uint16_t *gmem, float (&src)[N]) {
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef __HIP_PLATFORM_HCC__
DEVICE_FUNCTION void stg(uint16_t *gmem, float (&src)[4]) {
half *gmem_ = (half *) gmem;
gmem_[0] = __float2half(src[0]);
gmem_[1] = __float2half(src[1]);
gmem_[2] = __float2half(src[2]);
gmem_[3] = __float2half(src[3]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void stg_stream(uint16_t *gmem, float (&src)[4]) {
half *gmem_ = (half *) gmem;
gmem_[0] = __float2half(src[0]);
gmem_[1] = __float2half(src[1]);
gmem_[2] = __float2half(src[2]);
gmem_[3] = __float2half(src[3]);
}
#endif
DEVICE_FUNCTION void read_from_gmem(float (&dst)[2], const float *gmem, int idx) {
#ifdef __HIP_PLATFORM_HCC__
dst[0] = gmem[2*idx];
dst[1] = gmem[2*idx+1];
#else
float2 tmp = __ldg(reinterpret_cast<const float2*>(&gmem[2*idx]));
dst[0] = tmp.x;
dst[1] = tmp.y;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void read_from_gmem(float (&dst)[4], const float *gmem, int idx) {
#ifdef __HIP_PLATFORM_HCC__
dst[0] = gmem[4*idx];
dst[1] = gmem[4*idx+1];
dst[2] = gmem[4*idx+2];
dst[3] = gmem[4*idx+3];
#else
float4 tmp = __ldg(reinterpret_cast<const float4*>(&gmem[4*idx]));
dst[0] = tmp.x;
dst[1] = tmp.y;
dst[2] = tmp.z;
dst[3] = tmp.w;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void read_from_smem(float (&x)[2], const float *smem, int idx) {
#ifdef __HIP_PLATFORM_HCC__
x[0] = smem[2*idx];
x[1] = smem[2*idx+1];
#else
float2 tmp = *(const float2*) &smem[2*idx];
x[0] = tmp.x;
x[1] = tmp.y;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -227,43 +355,79 @@ DEVICE_FUNCTION void read_from_smem(int (&x)[1], const int *smem, int idx) {
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void read_from_smem(float (&x)[4], const float *smem, int idx) {
#ifdef __HIP_PLATFORM_HCC__
x[0] = smem[4*idx];
x[1] = smem[4*idx+1];
x[2] = smem[4*idx+2];
x[3] = smem[4*idx+3];
#else
float4 tmp = *(const float4*) &smem[4*idx];
x[0] = tmp.x;
x[1] = tmp.y;
x[2] = tmp.z;
x[3] = tmp.w;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void read_from_smem(int (&x)[2], const int *smem, int idx) {
#ifdef __HIP_PLATFORM_HCC__
x[0] = smem[2*idx];
x[1] = smem[2*idx+1];
#else
int2 tmp = *(const int2*) &smem[2*idx];
x[0] = tmp.x;
x[1] = tmp.y;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void write_to_gmem(float *gmem, int idx, const float (&src)[2]) {
#ifdef __HIP_PLATFORM_HCC__
gmem[2*idx] = src[0];
gmem[2*idx+1] = src[1];
#else
reinterpret_cast<float2*>(&gmem[2*idx])[0] = make_float2(src[0], src[1]);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void write_to_gmem(float *gmem, int idx, const float (&src)[4]) {
#ifdef __HIP_PLATFORM_HCC__
gmem[4*idx] = src[0];
gmem[4*idx+1] = src[1];
gmem[4*idx+2] = src[2];
gmem[4*idx+3] = src[3];
#else
reinterpret_cast<float4*>(&gmem[4*idx])[0] = make_float4(src[0], src[1], src[2], src[3]);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void scaled_write_to_gmem(float *gmem, int idx, const float (&src)[4], const float coeff) {
#ifdef __HIP_PLATFORM_HCC__
gmem[4*idx] = src[0]*coeff;
gmem[4*idx+1] = src[1]*coeff;
gmem[4*idx+2] = src[2]*coeff;
gmem[4*idx+3] = src[3]*coeff;
#else
reinterpret_cast<float4*>(&gmem[4*idx])[0] = make_float4(src[0]*coeff, src[1]*coeff, src[2]*coeff, src[3]*coeff);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void write_to_smem(float *smem, int idx, const float (&x)[2]) {
#ifdef __HIP_PLATFORM_HCC__
smem[2*idx] = x[0];
smem[2*idx+1] = x[1];
#else
reinterpret_cast<float2*>(&smem[2*idx])[0] = make_float2(x[0], x[1]);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -275,13 +439,25 @@ DEVICE_FUNCTION void write_to_smem(int *smem, int idx, const int (&x)[1]) {
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void write_to_smem(float *smem, int idx, const float (&x)[4]) {
#ifdef __HIP_PLATFORM_HCC__
smem[4*idx] = x[0];
smem[4*idx+1] = x[1];
smem[4*idx+2] = x[2];
smem[4*idx+3] = x[3];
#else
reinterpret_cast<float4*>(&smem[4*idx])[0] = make_float4(x[0], x[1], x[2], x[3]);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void write_to_smem(int *smem, int idx, const int (&x)[2]) {
#ifdef __HIP_PLATFORM_HCC__
smem[2*idx] = x[0];
smem[2*idx+1] = x[1];
#else
reinterpret_cast<int2*>(&smem[2*idx])[0] = make_int2(x[0], x[1]);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -370,7 +546,11 @@ DEVICE_FUNCTION void parallel_sums_16x2(float *smem, float (&x)[4], int nhw,
const int magic,
const int sync_iters) {
// The size of a warp.
#ifdef __HIP_PLATFORM_HCC__
const int THREADS_PER_WARP = 64;
#else
const int THREADS_PER_WARP = 32;
#endif
// The number of warps in a CTA.
const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP;
// The number of threads per pixel.
......@@ -388,10 +568,19 @@ DEVICE_FUNCTION void parallel_sums_16x2(float *smem, float (&x)[4], int nhw,
// total size of data per sync iter
const int data_total = MAX_OFFSET*THREADS_PER_PIXEL*ELEMENTS_PER_LDG*2;
#ifdef __HIP_PLATFORM_HCC__
for (int offset = THREADS_PER_PIXEL; offset <= THREADS_PER_WARP >> 1; offset <<= 1) {
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += shfl_sync(x[i], offset + lane_id);
}
}
#else
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL+lane_id);
x[i] += shfl_sync(x[i], THREADS_PER_PIXEL+lane_id);
}
#endif
// The warp leaders, write to SMEM.
if (lane_id < THREADS_PER_PIXEL) {
......@@ -416,17 +605,25 @@ DEVICE_FUNCTION void parallel_sums_16x2(float *smem, float (&x)[4], int nhw,
add(x, y);
}
#ifdef __HIP_PLATFORM_HCC__
for (int offset = THREADS_PER_WARP >> 1; offset >= THREADS_PER_PIXEL; offset >>= 1) { for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += shfl_sync(x[i], offset + lane_id);
}
}
#else
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL+lane_id);
x[i] += shfl_sync(x[i], THREADS_PER_PIXEL+lane_id);
}
#endif
// Make sure the data was read from SMEM.
__syncwarp();
syncwarp();
// Store the final values.
if (threadIdx.x < THREADS_PER_PIXEL) {
// probably could do it earlier, before sync
#ifndef __HIP_PLATFORM_HCC__ // bn_group > 1 is not enabled on HIP
for (int sync_iter=0; sync_iter < sync_iters; ++sync_iter) {
//float* params_pair_data = (reinterpret_cast<float**>(params_pair_datas))[sync_iter];
void* params_pair_data = params_pair_datas[sync_iter];
......@@ -469,6 +666,7 @@ DEVICE_FUNCTION void parallel_sums_16x2(float *smem, float (&x)[4], int nhw,
add(x, other);
}
#endif
// finally, after syncing up and accounting for partial sums from
// other GPUs as required, write the result
......@@ -483,7 +681,11 @@ DEVICE_FUNCTION void parallel_sums_16x2(float *smem, float (&x)[4], int nhw,
template< int THREADS_PER_CTA >
DEVICE_FUNCTION void parallel_sums_8x4(float *smem, float (&x)[4], int nhw) {
// The size of a warp.
#ifdef __HIP_PLATFORM_HCC__
const int THREADS_PER_WARP = 64;
#else
const int THREADS_PER_WARP = 32;
#endif
// The number of warps in a CTA.
const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP;
// The number of threads per pixel.
......@@ -496,8 +698,8 @@ DEVICE_FUNCTION void parallel_sums_8x4(float *smem, float (&x)[4], int nhw) {
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL+lane_id);
x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL*2+lane_id);
x[i] += shfl_sync(x[i], THREADS_PER_PIXEL+lane_id);
x[i] += shfl_sync(x[i], THREADS_PER_PIXEL*2+lane_id);
}
// The warp leaders, write to SMEM.
......@@ -524,12 +726,12 @@ DEVICE_FUNCTION void parallel_sums_8x4(float *smem, float (&x)[4], int nhw) {
}
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL+lane_id);
x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL*2+lane_id);
x[i] += shfl_sync(x[i], THREADS_PER_PIXEL+lane_id);
x[i] += shfl_sync(x[i], THREADS_PER_PIXEL*2+lane_id);
}
// Make sure the data was read from SMEM.
__syncwarp();
syncwarp();
// Store the final values.
if (threadIdx.x < THREADS_PER_PIXEL) {
......@@ -543,7 +745,7 @@ DEVICE_FUNCTION void parallel_sums_8x4(float *smem, float (&x)[4], int nhw) {
template< int THREADS_PER_CTA, int THREADS_PER_PIXEL, int ELEMENTS_PER_LDG >
DEVICE_FUNCTION void parallel_sums(float *smem, float (&x)[ELEMENTS_PER_LDG], int nhw) {
// The size of a warp.
const int THREADS_PER_WARP = 32;
const int THREADS_PER_WARP = warpSize;
// The number of warps in a CTA.
const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP;
// The number of pixels computed by a single warp.
......@@ -560,7 +762,7 @@ DEVICE_FUNCTION void parallel_sums(float *smem, float (&x)[ELEMENTS_PER_LDG], in
// Compute the parallel sums.
for (int offset = PIXELS_PER_WARP/2; offset > 0; offset /= 2) {
// NOP.
__syncwarp();
syncwarp();
// Read the running sum from the other thread.
float y[ELEMENTS_PER_LDG];
......@@ -572,7 +774,7 @@ DEVICE_FUNCTION void parallel_sums(float *smem, float (&x)[ELEMENTS_PER_LDG], in
add(x, y);
// NOP.
__syncwarp();
syncwarp();
// Update the sum in SMEM.
if (offset > 1 && nhw_in_warp < offset) {
......@@ -600,7 +802,7 @@ DEVICE_FUNCTION void parallel_sums(float *smem, float (&x)[ELEMENTS_PER_LDG], in
// We have the running mean and running m2. Let's build the mean/var of the CTA.
for (int offset = WARPS_PER_CTA/2; offset > 0; offset /= 2) {
// NOP.
__syncwarp();
syncwarp();
// Read the mean and variance from the other pixel.
float y[ELEMENTS_PER_LDG];
......@@ -612,7 +814,7 @@ DEVICE_FUNCTION void parallel_sums(float *smem, float (&x)[ELEMENTS_PER_LDG], in
add(x, y);
// NOP.
__syncwarp();
syncwarp();
// Store the mean/var for the different pixels.
if (nhw < offset) {
......@@ -684,8 +886,12 @@ DEVICE_FUNCTION void inter_block_sync(int* gmem_retired_ctas, int expected_count
int retired_ctas = -1;
do {
__threadfence();
#ifdef __HIP_PLATFORM_HCC__
retired_ctas = __ldg((const int*) gmem_retired_ctas);
#else
asm volatile ("ld.global.cg.b32 %0, [%1];"
: "=r"(retired_ctas) : "l"(gmem_retired_ctas));
#endif
} while (retired_ctas != 0);
}
__syncthreads();
......@@ -806,7 +1012,7 @@ struct NhwcBatchNormFwdParams {
// saved mean/var (refer BN API from cudnn doc)
float *gmem_saved_mean, *gmem_saved_var;
// ReLU bitmask
unsigned int *gmem_relu_bitmask;
bitmask_t *gmem_relu_bitmask;
// The dimensions.
int nhw, c;
// factor to scale sum of squared errors to get saved variance. Must be 1/nhw.
......@@ -861,7 +1067,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG;
// Shared memory to do CTA-wide parallel sums.
__shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG];
__shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/warpSize)*ELEMENTS_PER_LDG];
// Compute the NHW coordinate of the thread in the CTA.
const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL;
......@@ -878,6 +1084,10 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
// Shared memory buffer to store the extra pixels.
extern __shared__ PackedStorageType smem_storage_packed[];
#ifdef __HIP_PLATFORM_HCC__
const half zero_h = __float2half(0.0F);
#endif
for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) {
// The position in the NHW dimension where the CTA starts.
int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS;
......@@ -960,11 +1170,15 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
zero_array(x_storage[i]);
is_valid[i] = 0.f;
if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {
#ifndef __HIP_PLATFORM_HCC__
if (loop_i == OUTER_LOOPS - 1) {
ldg_stream(x_storage[i], &gmem_src[idx*params.c]);
} else {
#endif
ldg(x_storage[i], &gmem_src[idx*params.c]);
#ifndef __HIP_PLATFORM_HCC__
}
#endif
is_valid[i] = 1.f;
}
}
......@@ -1089,7 +1303,11 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
}
// Run the parallel sum accross the CTA to get the local sum.
#ifdef __HIP_PLATFORM_HCC__
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::template dispatch<THREADS_PER_CTA>(
#else
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
#endif
smem, m1, thread_in_cta_nhw);
__syncthreads();
......@@ -1106,7 +1324,11 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
}
// Run the parallel sum accross the CTA to get the local adjusted variance.
#ifdef __HIP_PLATFORM_HCC__
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::template dispatch<THREADS_PER_CTA>(
#else
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
#endif
smem, m2, thread_in_cta_nhw);
// The workspace in global memory is distributed across the different CTA.
......@@ -1152,14 +1374,22 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
add(m1, tmp);
}
#ifndef __HIP_PLATFORM_HCC__
if (params.sync_iters>0)
{
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, m1, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+3, params.magic, params.sync_iters);
} else {
#endif
#ifdef __HIP_PLATFORM_HCC__
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::template dispatch<THREADS_PER_CTA>(
#else
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
#endif
smem, m1, thread_in_cta_nhw);
#ifndef __HIP_PLATFORM_HCC__
}
#endif
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
......@@ -1209,14 +1439,22 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
}
}
#ifndef __HIP_PLATFORM_HCC__
if (params.sync_iters>0)
{
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, m2, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+2, params.magic, params.sync_iters);
} else {
#endif
#ifdef __HIP_PLATFORM_HCC__
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::template dispatch<THREADS_PER_CTA>(
#else
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
#endif
smem, m2, thread_in_cta_nhw);
#ifndef __HIP_PLATFORM_HCC__
}
#endif
__syncthreads();
read_from_smem(m2, smem, thread_in_cta_c);
......@@ -1263,8 +1501,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
// The base pointer to write to.
uint16_t *const gmem_dst = &params.gmem_dst[thread_c];
unsigned int *const gmem_relu_bitmask = params.gmem_relu_bitmask +
bitmask_t *const gmem_relu_bitmask = params.gmem_relu_bitmask +
#ifdef __HIP_PLATFORM_HCC__
((params.nhw + 3) & ~3) * c_blk_index;
#else
((params.nhw + 31) & ~31) * 2 * c_blk_index;
#endif
// Store the elements in registers.
#pragma unroll 1
......@@ -1289,23 +1531,31 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
float x1_math[ELEMENTS_PER_LDG];
ldg_stream(x1_math, &gmem_src1[(is_valid ? idx : 0)*params.c]);
add(x_math, x1_math);
unsigned int relu_mask;
bitmask_t relu_mask;
#ifdef __HIP_PLATFORM_HCC__
int lane_id = threadIdx.x & 63;
#else
int lane_id = threadIdx.x & 31;
#endif
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
bool rectified = x_math[i] < 0.0F;
unsigned int local_relu_mask = __ballot_sync(0xFFFFFFFFU, rectified);
if (lane_id == i) {
for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {
#ifdef __HIP_PLATFORM_HCC__
bool rectified = __hle(__float2half(x_math[j]), zero_h);
#else
bool rectified = x_math[j] < 0;
#endif
bitmask_t local_relu_mask = ballot(rectified);
if (lane_id == j) {
// Thread 0 remembers the relu_mask from the first time through this
// loop, Thread 1 the next, Thread 2 the next, and Thread 3 the last.
relu_mask = local_relu_mask;
}
if (rectified) {
x_math[i] = 0.0F;
x_math[j] = 0.0F;
}
}
if (is_valid_nhw && (lane_id < ELEMENTS_PER_LDG)) {
gmem_relu_bitmask[idx * 2 + lane_id] = relu_mask;
gmem_relu_bitmask[idx * BITMASK_OFFSET + lane_id] = relu_mask;
}
} else if (USE_RELU) {
relu_activation(x_math);
......@@ -1352,21 +1602,29 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
float x1_math[ELEMENTS_PER_LDG];
ldg_stream(x1_math, &gmem_src1[(is_valid ? idx : 0)*params.c]);
add(x_math, x1_math);
unsigned int relu_mask;
bitmask_t relu_mask;
#ifdef __HIP_PLATFORM_HCC__
int lane_id = threadIdx.x & 63;
#else
int lane_id = threadIdx.x & 31;
#endif
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
bool rectified = x_math[i] < 0.0F;
unsigned int local_relu_mask = __ballot_sync(0xFFFFFFFFU, rectified);
if (lane_id == i) {
for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {
#ifdef __HIP_PLATFORM_HCC__
bool rectified = __hle(__float2half(x_math[j]), zero_h);
#else
bool rectified = x_math[j] < 0;
#endif
bitmask_t local_relu_mask = ballot(rectified);
if (lane_id == j) {
relu_mask = local_relu_mask;
}
if (rectified) {
x_math[i] = 0.0F;
x_math[j] = 0.0F;
}
}
if (is_valid_nhw && (lane_id < ELEMENTS_PER_LDG)) {
gmem_relu_bitmask[idx * 2 + lane_id] = relu_mask;
gmem_relu_bitmask[idx * BITMASK_OFFSET + lane_id] = relu_mask;
}
} else if (USE_RELU) {
relu_activation(x_math);
......@@ -1395,7 +1653,7 @@ struct NhwcBatchNormBwdParams {
// The mean/inv-var saved from fwd pass
float *gmem_saved_mean, *gmem_saved_var;
// ReLU bitmask
unsigned int *gmem_relu_bitmask;
bitmask_t *gmem_relu_bitmask;
// The dimensions.
int nhw, c;
// factor to scale sum of squared errors to get saved variance. Must be 1/nhw.
......@@ -1536,7 +1794,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG;
// Shared memory to do CTA-wide parallel sums.
__shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG];
__shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/warpSize)*ELEMENTS_PER_LDG];
// The adapter for the storage.
typedef PackedStorage<Storage, ELEMENTS_PER_LDG> PackedStorage_;
......@@ -1691,7 +1949,11 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
}
// dscale parallel sum
#ifdef __HIP_PLATFORM_HCC__
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::template dispatch<THREADS_PER_CTA>(
#else
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
#endif
smem, dscale, thread_in_cta_nhw);
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
......@@ -1699,7 +1961,11 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
__syncthreads();
// dbias parallel sum
#ifdef __HIP_PLATFORM_HCC__
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::template dispatch<THREADS_PER_CTA>(
#else
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
#endif
smem, dbias, thread_in_cta_nhw);
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
......@@ -1740,13 +2006,21 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
}
// dscale parallel sum
#ifndef __HIP_PLATFORM_HCC__
if (params.sync_iters>0) {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+1, params.magic, params.sync_iters);
} else {
#endif
#ifdef __HIP_PLATFORM_HCC__
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::template dispatch<THREADS_PER_CTA>(
#else
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
#endif
smem, dscale, thread_in_cta_nhw);
#ifndef __HIP_PLATFORM_HCC__
}
#endif
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
......@@ -1754,13 +2028,21 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
__syncthreads();
// dbias parallel sum
#ifndef __HIP_PLATFORM_HCC__
if (params.sync_iters>0) {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+0, params.magic, params.sync_iters);
} else {
#endif
#ifdef __HIP_PLATFORM_HCC__
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::template dispatch<THREADS_PER_CTA>(
#else
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
#endif
smem, dbias, thread_in_cta_nhw);
#ifndef __HIP_PLATFORM_HCC__
}
#endif
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
......@@ -1900,7 +2182,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG;
// Shared memory to do CTA-wide parallel sums.
__shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG];
__shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/warpSize)*ELEMENTS_PER_LDG];
// The adapter for the storage.
typedef PackedStorage<Storage, ELEMENTS_PER_LDG> PackedStorage_;
......@@ -2081,7 +2363,11 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
}
// dscale parallel sum
#ifdef __HIP_PLATFORM_HCC__
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::template dispatch<THREADS_PER_CTA>(
#else
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
#endif
smem, dscale, thread_in_cta_nhw);
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
......@@ -2089,7 +2375,11 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
__syncthreads();
// dbias parallel sum
#ifdef __HIP_PLATFORM_HCC__
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::template dispatch<THREADS_PER_CTA>(
#else
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
#endif
smem, dbias, thread_in_cta_nhw);
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
......@@ -2130,13 +2420,21 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
}
// dscale parallel sum
#ifndef __HIP_PLATFORM_HCC__
if (params.sync_iters>0) {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+1, params.magic, params.sync_iters);
} else {
#endif
#ifdef __HIP_PLATFORM_HCC__
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::template dispatch<THREADS_PER_CTA>(
#else
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
#endif
smem, dscale, thread_in_cta_nhw);
#ifndef __HIP_PLATFORM_HCC__
}
#endif
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
......@@ -2144,13 +2442,21 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
__syncthreads();
// dbias parallel sum
#ifndef __HIP_PLATFORM_HCC__
if (params.sync_iters>0) {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+0, params.magic, params.sync_iters);
} else {
#endif
#ifdef __HIP_PLATFORM_HCC__
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::template dispatch<THREADS_PER_CTA>(
#else
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
#endif
smem, dbias, thread_in_cta_nhw);
#ifndef __HIP_PLATFORM_HCC__
}
#endif
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
......@@ -2288,7 +2594,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG;
// Shared memory to do CTA-wide parallel sums.
__shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG];
__shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/warpSize)*ELEMENTS_PER_LDG];
// The adapter for the storage.
typedef PackedStorage<Storage, ELEMENTS_PER_LDG> PackedStorage_;
......@@ -2353,8 +2659,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
cta_nhw_smem -= offset;
}
const unsigned int *const gmem_relu_bitmask = params.gmem_relu_bitmask +
const bitmask_t *const gmem_relu_bitmask = params.gmem_relu_bitmask +
#ifdef __HIP_PLATFORM_HCC__
((params.nhw + 3) & ~3) * c_blk_index;
#else
((params.nhw + 31) & ~31) * 2 * c_blk_index;
#endif
#pragma unroll 1
for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) {
......@@ -2363,11 +2673,15 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
// Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!
cta_count += max(0, min(PIXELS_PER_CTA_IN_REGISTERS, params.nhw-nhw_regs));
#ifdef __HIP_PLATFORM_HCC__
int lane_id = threadIdx.x & 63;
#else
int lane_id = threadIdx.x & 31;
#endif
// Read the elements from memory.
float is_valid[PIXELS_PER_THREAD_IN_REGISTERS];
unsigned int relu_mask[PIXELS_PER_THREAD_IN_REGISTERS];
bitmask_t relu_mask[PIXELS_PER_THREAD_IN_REGISTERS];
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG;
......@@ -2389,7 +2703,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
}
if (lane_id < ELEMENTS_PER_LDG) {
relu_mask[i] = gmem_relu_bitmask[idx * 2 + lane_id];
relu_mask[i] = gmem_relu_bitmask[idx * BITMASK_OFFSET + lane_id];
}
}
}
......@@ -2403,8 +2717,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
bool rectified[ELEMENTS_PER_LDG];
#pragma unroll
for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {
rectified[j] = ((__shfl_sync(0xFFFFFFFFU, relu_mask[i], j) &
(1U << lane_id)) != 0);
rectified[j] = ((shfl_sync(relu_mask[i], j) &
(ONE_BITMASK << lane_id)) != 0);
}
to_float(x_math, x_storage[i]);
to_float(dy_math, dy_storage[i]);
......@@ -2444,8 +2758,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
const bool is_pixel_valid = is_pixel_valid_nhw && is_valid_c;
PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG],
dy_storage_local[PACKED_ELEMENTS_PER_LDG];
unsigned int relu_mask;
bitmask_t relu_mask;
#ifdef __HIP_PLATFORM_HCC__
int lane_id = threadIdx.x & 63;
#else
int lane_id = threadIdx.x & 31;
#endif
zero_array(x_storage_local);
zero_array(dy_storage_local);
if (is_pixel_valid_nhw) {
......@@ -2454,14 +2772,14 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
ldg_stream(dy_storage_local, &gmem_dy[idx*params.c]);
}
if (lane_id < ELEMENTS_PER_LDG) {
relu_mask = gmem_relu_bitmask[idx * 2 + lane_id];
relu_mask = gmem_relu_bitmask[idx * BITMASK_OFFSET + lane_id];
}
}
bool rectified[ELEMENTS_PER_LDG];
#pragma unroll
for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {
rectified[j] = ((__shfl_sync(0xFFFFFFFFU, relu_mask, j) &
(1U << lane_id)) != 0);
rectified[j] = ((shfl_sync(relu_mask, j) &
(ONE_BITMASK << lane_id)) != 0);
}
// The offset to store in SMEM.
......@@ -2499,7 +2817,11 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
}
// dscale parallel sum
#ifdef __HIP_PLATFORM_HCC__
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::template dispatch<THREADS_PER_CTA>(
#else
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
#endif
smem, dscale, thread_in_cta_nhw);
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
......@@ -2507,7 +2829,11 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
__syncthreads();
// dbias parallel sum
#ifdef __HIP_PLATFORM_HCC__
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::template dispatch<THREADS_PER_CTA>(
#else
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
#endif
smem, dbias, thread_in_cta_nhw);
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
......@@ -2548,13 +2874,21 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
}
// dscale parallel sum
#ifndef __HIP_PLATFORM_HCC__
if (params.sync_iters>0) {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+1, params.magic, params.sync_iters);
} else {
#endif
#ifdef __HIP_PLATFORM_HCC__
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::template dispatch<THREADS_PER_CTA>(
#else
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
#endif
smem, dscale, thread_in_cta_nhw);
#ifndef __HIP_PLATFORM_HCC__
}
#endif
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
......@@ -2562,13 +2896,21 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
__syncthreads();
// dbias parallel sum
#ifndef __HIP_PLATFORM_HCC__
if (params.sync_iters>0) {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+0, params.magic, params.sync_iters);
} else {
#endif
#ifdef __HIP_PLATFORM_HCC__
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::template dispatch<THREADS_PER_CTA>(
#else
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
#endif
smem, dbias, thread_in_cta_nhw);
#ifndef __HIP_PLATFORM_HCC__
}
#endif
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
......
......@@ -4,6 +4,16 @@ from torch.nn.modules.batchnorm import _BatchNorm
import bnp
def check_if_rocm_pytorch():
is_rocm_pytorch = False
if torch.__version__ >= '1.5':
from torch.utils.cpp_extension import ROCM_HOME
is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False
return is_rocm_pytorch
IS_ROCM_PYTORCH = check_if_rocm_pytorch()
class bn_NHWC_impl(torch.autograd.Function):
@staticmethod
def forward(ctx, x, s, b, rm, riv, mini_m, mini_riv, ret_cta, mom, epsilon, fuse_relu, is_train, bn_group, my_data, pair_data, magic, pair_data2, pair_data3, fwd_occup, fwd_grid_x, bwd_occup, bwd_grid_x, multi_stream):
......@@ -54,7 +64,11 @@ class bn_addrelu_NHWC_impl(torch.autograd.Function):
@staticmethod
def forward(ctx, x, z, s, b, rm, riv, mini_m, mini_riv, grid_dim_y, ret_cta, mom, epsilon, is_train, bn_group, my_data, pair_data, magic, pair_data2, pair_data3, fwd_occup, fwd_grid_x, bwd_occup, bwd_grid_x, multi_stream):
if is_train:
bitmask = torch.cuda.IntTensor(((x.numel()+31)//32) * 2 * grid_dim_y)
if IS_ROCM_PYTORCH:
nhw = x.shape[0] * x.shape[1] * x.shape[2]
bitmask = torch.cuda.LongTensor(((nhw + 3) & ~3) * grid_dim_y)
else:
bitmask = torch.cuda.IntTensor(((x.numel()+31)//32) * 2 * grid_dim_y)
ctx.save_for_backward(x, s, b, rm, riv, mini_m, mini_riv, bitmask)
ctx.epsilon = epsilon
ctx.momentum = mom
......
import torch
import unittest
import numpy as np
import random
from apex.contrib.groupbn.batch_norm import BatchNorm2d_NHWC
def generate_uniform_tensor(size, np_dtype, pyt_dtype, device):
array = None
while array is None or np.isnan(array).any():
array = np.random.uniform(low=-1.0, high=1.0, size=size).astype(np_dtype)
return torch.from_numpy(array).to(device).to(pyt_dtype)
def to_channels_last(tensor):
return tensor.permute(0, 2, 3, 1).contiguous()
def to_channels_first(tensor):
return tensor.permute(0, 3, 1, 2).contiguous()
class Bn(torch.nn.BatchNorm2d):
def __init__(self, planes, mode):
super(Bn, self).__init__(planes, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
self.mode = mode
def forward(self, x, z=None):
out = super().forward(x)
if self.mode == 'bn_add_relu':
out = out.add_(z)
if self.mode != 'bn':
out = out.relu_()
return out
def bn_nhwc_bwd_ref(grad_y, x, mu, ivar, gamma):
sum_dim_c = (0, 1, 2)
grad_y_f32 = grad_y.float()
x_f32 = x.float()
N = x.shape[0] * x.shape[1] * x.shape[2] # nhw
ones = torch.ones(x.shape, dtype=torch.float32, device='cuda')
xmu = x_f32 - mu
xhat = xmu * ivar
dbias = torch.sum(grad_y_f32, dim=sum_dim_c)
dscale = torch.sum(grad_y_f32 * xhat, dim=sum_dim_c)
dx1 = (gamma * ivar) / N
dx2 = (N * grad_y_f32) - (dbias * ones)
dx3 = -xhat * dscale
dx = dx1 * (dx2 + dx3)
dx = dx.half()
return dx, dscale, dbias
class TestGroupBN(unittest.TestCase):
def setUp(self, seed=5, verbose=False):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
self.verbose = verbose
def test_bn(self):
self.run_group_bn('bn')
def test_bn_relu(self):
self.run_group_bn('bn_relu')
def test_bn_add_relu(self):
self.run_group_bn('bn_add_relu')
def run_group_bn(self, mode):
if self.verbose:
print('Running {}'.format(mode))
tensor_sizes = [
(120, 64, 150, 150),
(120, 64, 75, 75),
(120, 128, 38, 38),
(120, 256, 38, 38)]
for i in range(len(tensor_sizes)):
tensor_size = tensor_sizes[i]
num_channels = tensor_size[1]
# Create input data
input_data = generate_uniform_tensor(tensor_size, np.float16, torch.half, 'cuda')
np.save('input.npy', input_data.detach().cpu().numpy())
input_data.requires_grad = True
gbn_input = torch.from_numpy(np.load('input.npy')).cuda().half()
gbn_input.requires_grad = True
residual_data = None
gbn_residual_data = None
if mode == 'bn':
fuse_relu = False
else:
fuse_relu = True
if mode == 'bn_add_relu':
residual_data = generate_uniform_tensor(tensor_size, np.float16, torch.half, 'cuda')
gbn_residual_data = to_channels_last(residual_data)
bn_grad = generate_uniform_tensor(input_data.shape, np.float16, torch.half, 'cuda')
# Create models
batchnorm_model = Bn(num_channels, mode).cuda()
group_batchnorm = BatchNorm2d_NHWC(num_channels, fuse_relu=fuse_relu, bn_group=1).cuda()
# Run reference forward
bn_output = batchnorm_model(input_data, residual_data)
# Run GBN forward
gbn_input_data = to_channels_last(gbn_input)
gbn_output = group_batchnorm(gbn_input_data, gbn_residual_data)
torch.cuda.synchronize()
# Run reference backward
# (Use the same input and parameters as GBN)
gbn_grad = to_channels_last(bn_grad)
grad = gbn_grad.clone().detach()
input_data = torch.from_numpy(np.load('input.npy')).cuda().half()
input_data = to_channels_last(input_data)
if mode != 'bn':
grad[gbn_output <= 0] = 0
bn_output_grad, _, _ = bn_nhwc_bwd_ref( \
grad,
input_data,
group_batchnorm.minibatch_mean,
group_batchnorm.minibatch_riv,
group_batchnorm.weight)
bn_output_grad = to_channels_first(bn_output_grad)
# Run GBN backward
gbn_output.backward(gbn_grad)
torch.cuda.synchronize()
gbn_output = to_channels_first(gbn_output)
gbn_output_grad = gbn_input.grad.detach().clone().cpu()
########################## Validate results ##########################
if self.verbose:
print('Validate activation')
self.validate(bn_output.shape, bn_output, gbn_output)
if self.verbose:
print('Validate grad')
self.validate(bn_output_grad.shape, bn_output_grad, gbn_output_grad, is_grad=True)
def validate(self, tensors, output_ref, output_test, is_grad=False):
output_ref = output_ref.detach().cpu().numpy()
output_test = output_test.detach().cpu().numpy()
if self.verbose:
print('>>> tensor_size\t{}'.format(tensors))
print("sum_output_ref {}, isnan {}, max {}, min {}".format(
np.sum(output_ref, dtype=float), np.isnan(output_ref).any(), np.max(output_ref), np.min(output_ref)))
print("sum_output_test {}, isnan {}, max {}, min {}".format(
np.sum(output_test, dtype=float), np.isnan(output_test).any(), np.max(output_test), np.min(output_test)))
ret = np.array_equal(output_ref, output_test)
if not ret:
ret_allclose = np.allclose(
output_ref, output_test, rtol=1e-3, atol=1e-3, equal_nan=True)
if self.verbose:
print('{}\tshape {}\tidentical {}\tclose {}'.format('cpu/gpu', tensors, ret, ret_allclose))
output_ref = output_ref.flatten()
output_test = output_test.flatten()
if not ret:
sub = np.absolute(output_ref - output_test)
norm_diff = np.average(sub)
rel = np.divide(sub, np.absolute(output_ref))
rel[rel == np.inf] = 0
max_abs_idx = np.argmax(sub)
max_rel_idx = np.argmax(rel)
if self.verbose:
print('max_diff {}, max_rel_diff {}, norm_diff {}'.format(np.max(sub), np.max(rel), np.average(sub)))
print('max_abs pair [{}] {} {}'.format(max_abs_idx, output_ref[max_abs_idx], output_test[max_abs_idx]))
print('max_rel pair [{}] {} {}'.format(max_rel_idx, output_ref[max_rel_idx], output_test[max_rel_idx]))
result = ret or ret_allclose or (is_grad and norm_diff < 1e-4)
if self.verbose:
print("Result {}".format("PASS" if result else "FAIL"))
self.assertTrue(result)
if __name__ == '__main__':
unittest.main()
......@@ -243,7 +243,7 @@ if "--bnp" in sys.argv:
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None:
if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--bnp was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
ext_modules.append(
......@@ -252,7 +252,8 @@ if "--bnp" in sys.argv:
'apex/contrib/csrc/groupbn/ipc.cu',
'apex/contrib/csrc/groupbn/interface.cpp',
'apex/contrib/csrc/groupbn/batch_norm_add_relu.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
include_dirs=[os.path.join(this_dir, 'csrc'),
os.path.join(this_dir, 'apex/contrib/csrc/groupbn')],
extra_compile_args={'cxx': [] + version_dependent_macros,
'nvcc':['-DCUDA_HAS_FP16=1',
'-D__CUDA_NO_HALF_OPERATORS__',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment