Commit fedfe0d7 authored by jjsjann123's avatar jjsjann123 Committed by mcarilli
Browse files

Bnp integration pr (#275)

* Persistent group batchnorm added

Added persistent grouped batch norm for performance run on strong scaling case:
currently only supporting:

  1. nhwc layout
  2. fp16
  3. synchronization only within a node!

Environment variable is used to tune LAUNCH_MARGIN that limits the CTAs usage
by the persistent kernel.

Documentation and examples will follow.

* updating type().scalarType() to scalar_type()

* moving launch margin to be defined at layer creation, adding a knob cap max ctas per sm

* fixing the cta computation

* review comment:

set device_id through cudaGetDevice()
move cudaMemset to cudaMemsetAsync
updated __threadfence() to __threadfence_system() inter device write
parent e7beba17
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCNumerics.cuh>
#include "THC/THC.h"
#include "batch_norm.h"
#include <cuda.h>
#define cudaCheckErrors(msg) \
do { \
cudaError_t __err = cudaGetLastError(); \
if (__err != cudaSuccess) { \
fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", \
msg, cudaGetErrorString(__err), \
__FILE__, __LINE__); \
fprintf(stderr, "*** FAILED - ABORTING\n"); \
exit(1); \
} \
} while (0)
static size_t round_up_to_multiple(size_t x, int multiple) {
return ((x + multiple - 1) / multiple) * multiple;
}
// TODO: Stop manually allocating CUDA memory; allocate an ATen byte
// tensor instead.
struct Workspace {
Workspace(size_t size) : size(size), data(NULL) {
data = THCudaMalloc(at::globalContext().lazyInitCUDA(), size);
}
Workspace(const Workspace&) = delete;
Workspace(Workspace&&) = default;
Workspace& operator=(Workspace&&) = default;
~Workspace() {
if (data) {
THCudaFree(at::globalContext().lazyInitCUDA(), data);
}
}
size_t size;
void* data;
};
// Return {y}
at::Tensor nhwc_bn_fwd_train(
const at::Tensor& x,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const float momentum,
const float epsilon,
const bool fuse_relu,
void * my_data,
void * pair_data,
void * pair_data2,
const int bn_group,
const at::Tensor& magic_tensor,
const int max_cta_per_sm,
const int cta_launch_margin) {
const int N = x.size(0);
const int H = x.size(1);
const int W = x.size(2);
const int C = x.size(3);
// generating new magic number and use that for sync
int* magic = magic_tensor.data<int>();
*magic = (*magic + 1) & 0xff;
// Allocate output tensor
at::Tensor y = at::empty({N, H, W, C}, x.options());
// Create wrapper
NhwcBatchNorm *bn = new NhwcBatchNorm();
bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.data<at::Half>(),
nullptr,
y.data<at::Half>(),
nullptr);
bn->setWeightPointers({scale.data<float>(), bias.data<float>()}, {nullptr, nullptr});
bn->setParameterPointers({running_mean.data<float>(), running_inv_var.data<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t total_workspace_bytes = 0;
std::vector<size_t> workspace_offsets;
for (auto index = 3; index < workspace_bytes.size(); ++index) {
total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);
workspace_offsets.push_back(total_workspace_bytes);
auto alloc_bytes = workspace_bytes[index];
total_workspace_bytes += alloc_bytes;
}
// Allocate the workspace
Workspace ws(total_workspace_bytes);
std::vector<void *> workspace;
workspace.push_back(minibatch_mean.data<float>());
workspace.push_back(minibatch_inv_var.data<float>());
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[2];
void* retired_ctas = THCudaMalloc(at::globalContext().lazyInitCUDA(), retired_cta_bytes);
cudaMemsetAsync(retired_ctas, 0, retired_cta_bytes, stream); //FIXME: is this legit?
workspace.push_back(retired_ctas);
for (auto index = 3; index < workspace_bytes.size(); ++index) {
void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-3];
workspace.push_back(ptr);
}
bn->setWorkspacePointers(workspace, workspace_bytes);
int device_id;
cudaGetDevice(&device_id);
// Don't fuse in ReLU for now at least
bn->fwd(stream, fuse_relu, device_id, my_data, pair_data, pair_data2, bn_group, *magic, max_cta_per_sm, cta_launch_margin);
THCudaFree(at::globalContext().lazyInitCUDA(), retired_ctas);
return y;
}
at::Tensor nhwc_bn_fwd_eval(
const at::Tensor& x,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const int bn_group,
const float momentum,
const float epsilon,
const bool fuse_relu) {
const int N = x.size(0);
const int H = x.size(1);
const int W = x.size(2);
const int C = x.size(3);
// Allocate output tensor
at::Tensor y = at::empty({N, H, W, C}, x.options());
// Create wrapper
NhwcBatchNorm *bn = new NhwcBatchNorm();
bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.data<at::Half>(),
nullptr,
y.data<at::Half>(),
nullptr);
bn->setWeightPointers({scale.data<float>(), bias.data<float>()}, {nullptr, nullptr});
bn->setParameterPointers({running_mean.data<float>(), running_inv_var.data<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t total_workspace_bytes = 0;
std::vector<size_t> workspace_offsets;
for (auto index = 3; index < workspace_bytes.size(); ++index) {
total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);
workspace_offsets.push_back(total_workspace_bytes);
auto alloc_bytes = workspace_bytes[index];
total_workspace_bytes += alloc_bytes;
}
// Allocate the workspace
Workspace ws(total_workspace_bytes);
std::vector<void *> workspace;
workspace.push_back(nullptr);
workspace.push_back(nullptr);
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[2];
void* retired_ctas = THCudaMalloc(at::globalContext().lazyInitCUDA(), retired_cta_bytes);
cudaMemsetAsync(retired_ctas, 0, retired_cta_bytes, stream); //FIXME: is this legit?
workspace.push_back(retired_ctas);
for (auto index = 3; index < workspace_bytes.size(); ++index) {
void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-3];
workspace.push_back(ptr);
}
bn->setWorkspacePointers(workspace, workspace_bytes);
// Don't fuse in ReLU for now at least
bn->fwdInference(stream, fuse_relu);
THCudaFree(at::globalContext().lazyInitCUDA(), retired_ctas);
return y;
}
std::vector<at::Tensor> nhwc_bn_bwd(
const at::Tensor& x,
const at::Tensor& dy,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const float momentum,
const float epsilon,
const bool fuse_relu,
void * my_data,
void * pair_data,
void * pair_data2,
const int bn_group,
const at::Tensor& magic_tensor,
const int max_cta_per_sm,
const int cta_launch_margin) {
// shape
const int N = x.size(0);
const int H = x.size(1);
const int W = x.size(2);
const int C = x.size(3);
// generating new magic number and use that for sync
int* magic = magic_tensor.data<int>();
*magic = (*magic + 1) & 0xff;
// outputs
at::Tensor x_grad, scale_grad, bias_grad;
// Allocate outputs
x_grad = at::empty_like(x);
scale_grad = at::empty_like(scale);
bias_grad = at::empty_like(bias);
// Create wrapper
NhwcBatchNorm *bn = new NhwcBatchNorm();
bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.data<at::Half>(),
x_grad.data<at::Half>(),
nullptr,
dy.data<at::Half>());
bn->setWeightPointers({scale.data<float>(), bias.data<float>()}, {scale_grad.data<float>(), bias_grad.data<float>()});
bn->setParameterPointers({running_mean.data<float>(), running_inv_var.data<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t total_workspace_bytes = 0;
std::vector<size_t> workspace_offsets;
for (auto index = 3; index < workspace_bytes.size(); ++index) {
total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);
workspace_offsets.push_back(total_workspace_bytes);
auto alloc_bytes = workspace_bytes[index];
total_workspace_bytes += alloc_bytes;
}
// Allocate the workspace
Workspace ws(total_workspace_bytes);
std::vector<void *> workspace;
workspace.push_back(minibatch_mean.data<float>());
workspace.push_back(minibatch_inv_var.data<float>());
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[2];
void* retired_ctas = THCudaMalloc(at::globalContext().lazyInitCUDA(), retired_cta_bytes);
cudaMemsetAsync(retired_ctas, 0, retired_cta_bytes, stream); //FIXME: is this legit?
workspace.push_back(retired_ctas);
for (auto index = 3; index < workspace_bytes.size(); ++index) {
void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-3];
workspace.push_back(ptr);
}
bn->setWorkspacePointers(workspace, workspace_bytes);
int device_id;
cudaGetDevice(&device_id);
bn->dgrad(stream, fuse_relu, device_id, my_data, pair_data, pair_data2, bn_group, *magic, max_cta_per_sm, cta_launch_margin);
THCudaFree(at::globalContext().lazyInitCUDA(), retired_ctas);
return std::vector<at::Tensor>{x_grad, scale_grad, bias_grad};
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file nhwc_batch_norm.h
* \brief CUDA NHWC Batch Normalization code
* \author Shankara Rao Thejaswi Nanditale, Dick Carter, Evgeni Krimer
*/
#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_
#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_
#include <cudnn.h>
#include <algorithm>
#include <vector>
#include <string>
#include "nhwc_batch_norm_kernel.h"
#include "cuda_utils.h"
#define VERBOSE_DEFAULT false
class NhwcBatchNorm {
public:
NhwcBatchNorm() {
name_ = "nhwc_batchnorm";
createTensorDescriptor(&X_tensor_desc_);
createTensorDescriptor(&Y_tensor_desc_);
}
~NhwcBatchNorm() {
destroyTensorDescriptor(X_tensor_desc_);
destroyTensorDescriptor(Y_tensor_desc_);
}
void die() {
std::cerr << "batchnorm not initialized" << std::endl;
exit(-1);
}
void fwd(cudaStream_t stream, bool use_relu, int device_id, void* my_data, void* pair_data, void* pair_data2, const int bn_group, const int magic, const int max_cta_per_sm, const int cta_launch_margin);
void dgrad(cudaStream_t stream, bool use_relu, int device_id, void* my_data, void* pair_data, void* pair_data2, const int bn_group, const int magic, const int max_cta_per_sm, const int cta_launch_margin);
void fwdInference(cudaStream_t stream, bool use_relu);
dim3 calc_fwd_grid(int device_id, int *loop, const int max_cta_per_sm, const int cta_launch_margin);
dim3 calc_bwd_grid(int device_id, int *loop, const int max_cta_per_sm, const int cta_launch_margin);
void setInputDescriptor(const cudnnTensorFormat_t format,
const cudnnDataType_t data_type,
int n, int c, int h, int w, int bn_group) {
m_ = n * h * w;
int m_bn_adjusted = m_ * bn_group;
c_ = c;
// factor to scale sum of squared errors to get saved variance. Must be 1/nhw.
svar_inv_count_ = 1.f / m_bn_adjusted;
// factor to scale sum of squared errors to get running variance. Should be 1/(nhw-1).
int divisor = m_bn_adjusted - 1;
// nhw == 1 is unlikely, but by setting the rvar_inv_count_ == 1.f, we avoid running var infs.
rvar_inv_count_ = divisor == 0 ? 1.f : 1.f / divisor;
setTensorDescriptor(X_tensor_desc_, format, data_type, n, c, h, w);
}
void setOutputDescriptor(const cudnnTensorFormat_t format,
const cudnnDataType_t data_type,
int n, int c, int h, int w) {
setTensorDescriptor(Y_tensor_desc_, format, data_type, n, c, h, w);
}
const std::vector<size_t> numWorkspaceBytes() const;
void setWorkspacePointers(
const std::vector<void*>& workspace,
const std::vector<size_t>& num_workspace_bytes);
void setInputOutputPointers(void* X, void* dX, void* Y, void *dY) {
X_ = X;
dX_ = dX;
Y_ = Y;
dY_ = dY;
}
// Sets the pointers for the scale and weight (in that order) data and derivative buffers.
void setWeightPointers(const std::vector<void*>& weight_pointers,
const std::vector<void*>& deriv_pointers) {
assert(weight_pointers.size() == 2);
assert(deriv_pointers.size() == 2);
scale_ = static_cast<float*>(weight_pointers[0]);
bias_ = static_cast<float*>(weight_pointers[1]);
dscale_ = static_cast<float*>(deriv_pointers[0]);
dbias_ = static_cast<float*>(deriv_pointers[1]);
}
// Sets the pointers for the population mean and variance buffers, in that order.
void setParameterPointers(const std::vector<void*>& param_pointers) {
assert(param_pointers.size() == 2);
population_mean_ = static_cast<float*>(param_pointers[0]);
population_variance_ = static_cast<float*>(param_pointers[1]);
}
void setConstants(const double exp_avg_factor, const double eps) {
exp_avg_factor_ = exp_avg_factor;
eps_ = eps;
}
void processCudnnStatus(const cudnnStatus_t& status,
const std::string& string = std::string(),
bool verbose = VERBOSE_DEFAULT) {
if (status != CUDNN_STATUS_SUCCESS)
LOG(FATAL) << string << " " << cudnnGetErrorString(status);
else if (verbose)
LOG(INFO) << string << " " << cudnnGetErrorString(status);
}
void checkCudaStatus(const std::string& string = std::string(),
bool verbose = VERBOSE_DEFAULT) {
cudaError_t status = cudaGetLastError();
if (status != cudaSuccess)
LOG(FATAL) << string << " " << cudaGetErrorString(status);
else if (verbose)
LOG(INFO) << string << " " << cudaGetErrorString(status);
}
size_t size_retired_ctas(int grid_y) const {
// Note that the value of max_grid_y to handle known GPUs is about 160.
const int max_grid_y = 1024;
if (grid_y > max_grid_y)
LOG(INFO) << "GPU capabilities exceeds assumptions.";
const int retired_cta_bytes = max_grid_y * 2 * sizeof(int);
// Since the region will be initialized once and used for many kernels,
// the idea is to return an ample size that will cover all uses.
return retired_cta_bytes;
}
cudnnTensorDescriptor_t X_tensor_desc_ = nullptr;
cudnnTensorDescriptor_t Y_tensor_desc_ = nullptr;
void* X_ = nullptr;
void* dX_ = nullptr;
void* Y_ = nullptr;
void* dY_ = nullptr;
// Learned scale and bias weights.
float* scale_ = nullptr;
float* dscale_ = nullptr;
float* bias_ = nullptr;
float* dbias_ = nullptr;
// Computed population mean and variance parameters.
float* population_mean_ = nullptr;
float* population_variance_ = nullptr;
// Workspace buffers for minibatch mean and variance (computed in fwd, needed by bwd).
float* minibatch_mean_ = nullptr;
float* minibatch_variance_ = nullptr;
int m_ = 0; // Number of values per channel that BN is normalizing.
int c_ = 0; // Number of channels over which BN is normalizing.
float svar_inv_count_ = 0.f; // factor to scale sum of squared errors to get saved variance
float rvar_inv_count_ = 0.f; // factor to scale sum of squared errors to get running variance
double exp_avg_factor_ = 0.;
double eps_ = 0.;
std::string name_;
private:
void setTensorDescriptor(cudnnTensorDescriptor_t descriptor,
cudnnTensorFormat_t format,
cudnnDataType_t data_type,
int n, int c, int h, int w) {
cudnnStatus_t status = CUDNN_STATUS_SUCCESS;
status = cudnnSetTensor4dDescriptor(descriptor, format, data_type, n, c, h, w);
processCudnnStatus(status, "set tensor descriptor");
}
void createTensorDescriptor(cudnnTensorDescriptor_t *descriptor) {
cudnnStatus_t status = CUDNN_STATUS_SUCCESS;
status = cudnnCreateTensorDescriptor(descriptor);
processCudnnStatus(status, "create tensor_descriptor");
}
void destroyTensorDescriptor(cudnnTensorDescriptor_t descriptor) {
cudnnStatus_t status = CUDNN_STATUS_SUCCESS;
status = cudnnDestroyTensorDescriptor(descriptor);
processCudnnStatus(status, "destroy tensor_descriptor");
}
protected:
float *partial_sums_ = nullptr;
int *partial_counts_ = nullptr;
int *retired_ctas_ = nullptr;
void _setFwdParams(NhwcBatchNormFwdParams *params) const;
void _setFwdInferenceParams(NhwcBatchNormFwdInferenceParams *params) const;
void _setBwdParams(NhwcBatchNormBwdParams *params) const;
// @todo: ability to configure these?
// Kernel params
static const int USE_ONLINE_APPROACH = 1;
static const int THREADS_PER_CTA = 512;
static const int THREADS_PER_PIXEL = 16;
static const int C_ELEMENTS_PER_CTA = 64;
static const int ELEMENTS_PER_LDG = C_ELEMENTS_PER_CTA / THREADS_PER_PIXEL;
static const int MAX_SMEM_WITHOUT_OPT_IN = 48 * 1024;
typedef uint16_t StorageType;
//typedef float StorageType;
// increasing this to 6 causes spills in fwd kernel!
static const int PIXELS_PER_THREAD_IN_REGISTERS_FWD = 5;
static const int PIXELS_PER_THREAD_IN_REGISTERS_BWD = 3;
static const int PIXELS_PER_THREAD_IN_SMEM_FWD = 10;
static const int PIXELS_PER_THREAD_IN_SMEM_BWD = 5;
static const int PIXELS_PER_THREAD_FWD = PIXELS_PER_THREAD_IN_REGISTERS_FWD + \
PIXELS_PER_THREAD_IN_SMEM_FWD;
static const int PIXELS_PER_THREAD_BWD = PIXELS_PER_THREAD_IN_REGISTERS_BWD + \
PIXELS_PER_THREAD_IN_SMEM_BWD;
static const int PIXELS_PER_THREAD_FWD_INFERENCE = 4;
// Derived params
static const size_t SMEM_SIZE_FWD = PIXELS_PER_THREAD_IN_SMEM_FWD*THREADS_PER_CTA*\
ELEMENTS_PER_LDG*sizeof(StorageType);
static const size_t SMEM_SIZE_BWD = PIXELS_PER_THREAD_IN_SMEM_BWD*THREADS_PER_CTA*\
ELEMENTS_PER_LDG*2*sizeof(StorageType);
static const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;
static const int PIXELS_PER_CTA_FWD = THREADS_PER_CTA/THREADS_PER_PIXEL * \
PIXELS_PER_THREAD_FWD;
static const int PIXELS_PER_CTA_BWD = THREADS_PER_CTA/THREADS_PER_PIXEL * \
PIXELS_PER_THREAD_BWD;
static const int PIXELS_PER_CTA_FWD_INFERENCE = THREADS_PER_CTA/THREADS_PER_PIXEL * \
PIXELS_PER_THREAD_FWD_INFERENCE;
// max grid.y in case of group bn is limited by exchange buffer size
static const int MAX_GBN_BLOCK_Y = 256;
// Helper function to launch the forward kernel.
// We calculate (based on smem usage) the achievable occupancy and make sure we run a kernel
// version that was compiled with that occupancy in its launch bounds. This way, we avoid
// needless register spills.
void _fwdKernelLauncher(cudaStream_t stream, NhwcBatchNormFwdParams params,
dim3 grid_dim, int outer_loops, bool use_relu, int device_id, const int max_cta_per_sm) {
#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY) \
do { \
CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto fwd_func = nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
cudaFuncSetAttribute(fwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " fwd ser coop kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(&params); \
using FWD_FUNC = decltype(nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>); \
cudaLaunchCooperativeKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_FWD, \
stream); \
checkCudaStatus(name_ + " fwd ser coop kernel"); \
} while (0)
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
int occupancy = smem_driven_fwd_occupancy(device_id, max_cta_per_sm);
if (outer_loops == 1 && use_relu) {
if (occupancy >= 2)
LAUNCH_FWD_KERNEL(1, true, false, 2);
else
LAUNCH_FWD_KERNEL(1, true, false, 1);
} else if (outer_loops == 1 && !use_relu) {
if (occupancy >= 2)
LAUNCH_FWD_KERNEL(1, false, false, 2);
else
LAUNCH_FWD_KERNEL(1, false, false, 1);
} else if (use_relu) {
if (occupancy >= 2)
LAUNCH_FWD_KERNEL(0, true, false, 2);
else
LAUNCH_FWD_KERNEL(0, true, false, 1);
} else {
if (occupancy >= 2)
LAUNCH_FWD_KERNEL(0, false, false, 2);
else
LAUNCH_FWD_KERNEL(0, false, false, 1);
}
#undef LAUNCH_FWD_KERNEL
}
// Helper function to launch the backward kernel.
void _bwdKernelLauncher(cudaStream_t stream, NhwcBatchNormBwdParams params,
dim3 grid_dim, int outer_loops, bool use_relu, int device_id, const int max_cta_per_sm) {
#define LAUNCH_BWD_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto bwd_func = nhwc_batch_norm_bwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
cudaFuncSetAttribute(bwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " bwd coop serial kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(&params); \
using BWD_FUNC = decltype(nhwc_batch_norm_bwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
cudaLaunchCooperativeKernel<BWD_FUNC>(bwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
checkCudaStatus(name_ + " bwd coop serial kernel"); \
} while (0)
#define LAUNCH_BWD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto bwd_relu_func = nhwc_batch_norm_bwd_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
cudaFuncSetAttribute(bwd_relu_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " bwd-relu coop serial kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(&params); \
using BWD_RELU_FUNC = decltype(nhwc_batch_norm_bwd_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
cudaLaunchCooperativeKernel<BWD_RELU_FUNC>(bwd_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
checkCudaStatus(name_ + " bwd-relu coop serial kernel"); \
} while (0)
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
int occupancy = smem_driven_bwd_occupancy(device_id, max_cta_per_sm);
if (outer_loops == 1 && use_relu) {
if (occupancy >= 2)
LAUNCH_BWD_RELU_KERNEL(1, 2);
else
LAUNCH_BWD_RELU_KERNEL(1, 1);
} else if (outer_loops == 1 && !use_relu) {
if (occupancy >= 2)
LAUNCH_BWD_KERNEL(1, 2);
else
LAUNCH_BWD_KERNEL(1, 1);
} else if (use_relu) {
if (occupancy >= 2)
LAUNCH_BWD_RELU_KERNEL(0, 2);
else
LAUNCH_BWD_RELU_KERNEL(0, 1);
} else {
if (occupancy >= 2)
LAUNCH_BWD_KERNEL(0, 2);
else
LAUNCH_BWD_KERNEL(0, 1);
}
#undef LAUNCH_BWD_KERNEL
}
private:
// Calculate the max number of CTAs allowed in the grid for the fwd kernel.
static size_t max_fwd_grid_x(int device_id, const int max_cta_per_sm, const int cta_launch_margin) {
using namespace at::cuda::utils;
int answer = MultiprocessorCount(device_id) * smem_driven_fwd_occupancy(device_id, max_cta_per_sm);
if (SMArch(device_id) >= 70)
answer -= cta_launch_margin;
answer = std::max(1, answer); // we need at least one CTA to operate
return static_cast<size_t>(answer);
}
// Calculate the max number of CTAs allowed in the grid for the bwd kernel.
static size_t max_bwd_grid_x(int device_id, const int max_cta_per_sm, const int cta_launch_margin) {
using namespace at::cuda::utils;
int answer = MultiprocessorCount(device_id) * smem_driven_bwd_occupancy(device_id, max_cta_per_sm);
if (SMArch(device_id) >= 70)
answer -= cta_launch_margin;
answer = std::max(1, answer); // we need at least one CTA to operate
return static_cast<size_t>(answer);
}
// Calculate the expected fwd kernel occupancy, as dictated by shared memory usage.
static int smem_driven_fwd_occupancy(int device_id, const int max_cta_per_sm) {
using namespace at::cuda::utils;
int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG*sizeof(float);
int fwd_smem_bytes = SMEM_SIZE_FWD + fwd_reduction_bytes;
int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / fwd_smem_bytes;
return std::min(max_cta_per_sm, occupancy);
}
// Calculate the expected bwd kernel occupancy, as dictated by shared memory usage.
static int smem_driven_bwd_occupancy(int device_id, const int max_cta_per_sm) {
using namespace at::cuda::utils;
int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG*sizeof(float);
int bwd_smem_bytes = SMEM_SIZE_BWD + bwd_reduction_bytes;
int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / bwd_smem_bytes;
return std::min(max_cta_per_sm, occupancy);
}
};
const std::vector<size_t> NhwcBatchNorm::numWorkspaceBytes() const {
assert(c_ > 0);
// choose the max memory required between fwd/bwd passes
int grid_x_fwd = div_up(m_, PIXELS_PER_CTA_FWD);
int grid_x_bwd = div_up(m_, PIXELS_PER_CTA_BWD);
int grid_x = max(grid_x_fwd, grid_x_bwd);
int grid_y = div_up(c_, C_ELEMENTS_PER_CTA);
const size_t num_mean_bytes = c_ * sizeof(float);
const size_t num_variance_bytes = num_mean_bytes;
const size_t size_sums = grid_y*grid_x*THREADS_PER_PIXEL*\
ELEMENTS_PER_LDG*2*sizeof(float);
const size_t size_counts = grid_y*grid_x*sizeof(int);
return {num_mean_bytes, num_variance_bytes,
size_retired_ctas(grid_y), size_sums, size_counts};
}
void NhwcBatchNorm::setWorkspacePointers(
const std::vector<void*>& workspace,
const std::vector<size_t>& num_workspace_bytes) {
assert(workspace.size() == 5);
assert(num_workspace_bytes.size() == 5);
minibatch_mean_ = static_cast<float*>(workspace[0]);
minibatch_variance_ = static_cast<float*>(workspace[1]);
retired_ctas_ = static_cast<int*>(workspace[2]);
partial_sums_ = static_cast<float*>(workspace[3]);
partial_counts_ = static_cast<int*>(workspace[4]);
}
void NhwcBatchNorm::_setFwdParams(NhwcBatchNormFwdParams *params) const {
params->gmem_src = static_cast<uint16_t*>(X_);
params->gmem_dst = static_cast<uint16_t*>(Y_);
params->gmem_src1 = nullptr;
params->gmem_bias = bias_;
params->gmem_scale = scale_;
params->gmem_running_mean = population_mean_;
params->gmem_running_var = population_variance_;
params->gmem_saved_mean = minibatch_mean_;
params->gmem_saved_var = minibatch_variance_;
params->gmem_relu_bitmask = nullptr;
params->nhw = m_;
params->c = c_;
params->svar_inv_count = svar_inv_count_;
params->rvar_inv_count = rvar_inv_count_;
params->gmem_sums = partial_sums_;
params->gmem_counts = partial_counts_;
params->gmem_retired_ctas = retired_ctas_;
params->var_eps = eps_;
params->outer_loops = 0;
params->exp_avg_factor = static_cast<float>(exp_avg_factor_);
params->c_blks = div_up(c_, C_ELEMENTS_PER_CTA);
}
void NhwcBatchNorm::_setFwdInferenceParams(NhwcBatchNormFwdInferenceParams
*params) const {
params->gmem_src = static_cast<uint16_t*>(X_);
params->gmem_dst = static_cast<uint16_t*>(Y_);
params->gmem_src1 = nullptr;
params->gmem_bias = bias_;
params->gmem_scale = scale_;
params->gmem_mean = population_mean_;
params->gmem_var = population_variance_;
params->nhw = m_;
params->c = c_;
params->var_eps = eps_;
}
void NhwcBatchNorm::_setBwdParams(NhwcBatchNormBwdParams *params) const {
params->gmem_src = static_cast<uint16_t*>(X_);
params->gmem_dy = static_cast<uint16_t*>(dY_);
params->gmem_dst = static_cast<uint16_t*>(dX_);
params->gmem_dst1 = nullptr;
params->gmem_relu_bitmask = nullptr;
params->gmem_dscale = dscale_;
params->gmem_dbias = dbias_;
params->gmem_scale = scale_;
params->gmem_bias = bias_;
params->gmem_saved_mean = minibatch_mean_;
params->gmem_saved_var = minibatch_variance_;
params->nhw = m_;
params->c = c_;
params->svar_inv_count = svar_inv_count_;
params->gmem_sums = partial_sums_;
params->gmem_retired_ctas = retired_ctas_;
params->outer_loops = 0;
params->c_blks = div_up(c_, C_ELEMENTS_PER_CTA);
}
void NhwcBatchNorm::fwdInference(cudaStream_t stream, bool use_relu) {
bool ptrs_are_set =
X_tensor_desc_ != nullptr
&& Y_tensor_desc_ != nullptr
&& scale_ != nullptr
&& bias_ != nullptr
// && minibatch_mean_ != nullptr
// && minibatch_variance_ != nullptr
&& population_mean_ != nullptr
&& population_variance_ != nullptr
&& X_ != nullptr
// && dX_ != nullptr
&& Y_ != nullptr
// && dY_ != nullptr
// && dscale_ != nullptr
// && dbias_ != nullptr
&& partial_sums_ != nullptr
&& partial_counts_ != nullptr;
if (!ptrs_are_set)
die();
dim3 grid_dim;
grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD_INFERENCE);
grid_dim.y = div_up(c_, C_ELEMENTS_PER_CTA);
// @todo: maybe just move this inside initialize routine?
NhwcBatchNormFwdInferenceParams params;
_setFwdInferenceParams(&params);
if (use_relu) {
nhwc_batch_norm_fwd_inference
<StorageType, THREADS_PER_CTA, THREADS_PER_PIXEL, ELEMENTS_PER_LDG, true, false>
<<<grid_dim, THREADS_PER_CTA, 0, stream>>>(params);
checkCudaStatus(name_ + " fwd_inference-relu kernel");
} else {
nhwc_batch_norm_fwd_inference
<StorageType, THREADS_PER_CTA, THREADS_PER_PIXEL, ELEMENTS_PER_LDG, false, false>
<<<grid_dim, THREADS_PER_CTA, 0, stream>>>(params);
checkCudaStatus(name_ + " fwd_inference kernel");
}
}
dim3 NhwcBatchNorm::calc_fwd_grid(int device_id, int *loop, const int max_cta_per_sm, const int cta_launch_margin) {
dim3 grid_dim;
grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD);
int c_blks = div_up(c_, C_ELEMENTS_PER_CTA);
unsigned int max_grid_x = max_fwd_grid_x(device_id, max_cta_per_sm, cta_launch_margin);
if (grid_dim.x <= max_grid_x) {
*loop = 1;
if (max_grid_x / grid_dim.x > 1) {
grid_dim.y = std::min(c_blks, static_cast<int>(max_grid_x / grid_dim.x));
assert(grid_dim.y<MAX_GBN_BLOCK_Y); //FIXME: turn into a loop
} else {
grid_dim.y = 1;
}
} else {
grid_dim.x = max_grid_x;
grid_dim.y = 1;
int nhw_in_regs = m_ - PIXELS_PER_THREAD_IN_SMEM_FWD*PIXELS_PER_LDG*grid_dim.x;
int pixels_per_iteration = PIXELS_PER_THREAD_IN_REGISTERS_FWD*PIXELS_PER_LDG*grid_dim.x;
*loop = div_up(nhw_in_regs, pixels_per_iteration);
}
return grid_dim;
}
dim3 NhwcBatchNorm::calc_bwd_grid(int device_id, int *loop, const int max_cta_per_sm, const int cta_launch_margin) {
dim3 grid_dim;
grid_dim.x = div_up(m_, PIXELS_PER_CTA_BWD);
int c_blks = div_up(c_, C_ELEMENTS_PER_CTA);
unsigned int max_grid_x = max_bwd_grid_x(device_id, max_cta_per_sm, cta_launch_margin);
if (grid_dim.x <= max_grid_x) {
*loop = 1;
if (max_grid_x / grid_dim.x > 1) {
grid_dim.y = std::min(c_blks, static_cast<int>(max_grid_x / grid_dim.x));
assert(grid_dim.y<MAX_GBN_BLOCK_Y); //FIXME: turn into a loop
} else {
grid_dim.y = 1;
}
} else {
grid_dim.x = max_grid_x;
grid_dim.y = 1;
int nhw_in_regs = m_ - PIXELS_PER_THREAD_IN_SMEM_BWD*PIXELS_PER_LDG*grid_dim.x;
int pixels_per_iteration = PIXELS_PER_THREAD_IN_REGISTERS_BWD*PIXELS_PER_LDG*grid_dim.x;
*loop = div_up(nhw_in_regs, pixels_per_iteration);
}
return grid_dim;
}
void NhwcBatchNorm::fwd(cudaStream_t stream, bool use_relu, int device_id, void* my_data, void* pair_data, void* pair_data2, const int bn_group, const int magic, const int max_cta_per_sm, const int cta_launch_margin) {
bool ptrs_are_set =
X_tensor_desc_ != nullptr
&& Y_tensor_desc_ != nullptr
&& scale_ != nullptr
&& bias_ != nullptr
&& minibatch_mean_ != nullptr
&& minibatch_variance_ != nullptr
&& population_mean_ != nullptr
&& population_variance_ != nullptr
&& X_ != nullptr
// && dX_ != nullptr
&& Y_ != nullptr
// && dY_ != nullptr
// && dscale_ != nullptr
// && dbias_ != nullptr
&& partial_sums_ != nullptr
&& partial_counts_ != nullptr
&& retired_ctas_ != nullptr;
if (!ptrs_are_set)
die();
// reset of retired_cta_count no longer needed
NhwcBatchNormFwdParams params;
_setFwdParams(&params);
params.my_data = my_data;
params.pair_data = pair_data;
params.pair_data2 = pair_data2;
params.magic = magic;
params.sync_iters = bn_group >> 1;
dim3 grid_dim = calc_fwd_grid(device_id, &params.outer_loops, max_cta_per_sm, cta_launch_margin);
_fwdKernelLauncher(stream, params, grid_dim, params.outer_loops, use_relu, device_id, max_cta_per_sm);
}
void NhwcBatchNorm::dgrad(cudaStream_t stream, bool use_relu, int device_id, void* my_data, void* pair_data, void* pair_data2, const int bn_group, const int magic, const int max_cta_per_sm, const int cta_launch_margin) {
bool ptrs_are_set =
X_tensor_desc_ != nullptr
&& Y_tensor_desc_ != nullptr
&& scale_ != nullptr
&& (bias_ != nullptr || !use_relu)
&& minibatch_mean_ != nullptr
&& minibatch_variance_ != nullptr
// && population_mean_ != nullptr
// && population_variance_ != nullptr
&& X_ != nullptr
&& dX_ != nullptr
// && Y_ != nullptr
&& dY_ != nullptr
&& dscale_ != nullptr
&& dbias_ != nullptr;
if (!ptrs_are_set)
die();
// reset of retired_cta_count no longer needed
NhwcBatchNormBwdParams params;
_setBwdParams(&params);
params.my_data = my_data;
params.pair_data = pair_data;
params.pair_data2 = pair_data2;
params.magic = magic;
params.sync_iters = bn_group >> 1;
params.wgrad_coeff = 1.0 / bn_group;
dim3 grid_dim = calc_bwd_grid(device_id, &params.outer_loops, max_cta_per_sm, cta_launch_margin);
_bwdKernelLauncher(stream, params, grid_dim, params.outer_loops, use_relu, device_id, max_cta_per_sm);
}
#endif // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCNumerics.cuh>
#include "THC/THC.h"
#include "batch_norm_add_relu.h"
#include <cuda.h>
//FIXME move the common stuff to common h file
#define cudaCheckErrors(msg) \
do { \
cudaError_t __err = cudaGetLastError(); \
if (__err != cudaSuccess) { \
fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", \
msg, cudaGetErrorString(__err), \
__FILE__, __LINE__); \
fprintf(stderr, "*** FAILED - ABORTING\n"); \
exit(1); \
} \
} while (0)
static size_t round_up_to_multiple(size_t x, int multiple) {
return ((x + multiple - 1) / multiple) * multiple;
}
// TODO: Stop manually allocating CUDA memory; allocate an ATen byte
// tensor instead.
struct Workspace {
Workspace(size_t size) : size(size), data(NULL) {
data = THCudaMalloc(at::globalContext().lazyInitCUDA(), size);
}
Workspace(const Workspace&) = delete;
Workspace(Workspace&&) = default;
Workspace& operator=(Workspace&&) = default;
~Workspace() {
if (data) {
THCudaFree(at::globalContext().lazyInitCUDA(), data);
}
}
size_t size;
void* data;
};
// Return {y}
at::Tensor nhwc_bn_addrelu_fwd_train(
const at::Tensor& x,
const at::Tensor& z,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const at::Tensor& bitmask,
const float momentum,
const float epsilon,
void * my_data,
void * pair_data,
void * pair_data2,
const int bn_group,
const at::Tensor& magic_tensor,
const int max_cta_per_sm,
const int cta_launch_margin) {
const int N = x.size(0);
const int H = x.size(1);
const int W = x.size(2);
const int C = x.size(3);
// generating new magic number and use that for sync
int* magic = magic_tensor.data<int>();
*magic = (*magic + 1) & 0xff;
// Allocate output tensor
at::Tensor y = at::empty({N, H, W, C}, x.options());
// Create wrapper
NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu();
bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.data<at::Half>(),
nullptr,
y.data<at::Half>(),
nullptr,
z.data<at::Half>(),
nullptr);
bn->setWeightPointers({scale.data<float>(), bias.data<float>()}, {nullptr, nullptr});
bn->setParameterPointers({running_mean.data<float>(), running_inv_var.data<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t total_workspace_bytes = 0;
std::vector<size_t> workspace_offsets;
for (auto index = 4; index < workspace_bytes.size(); ++index) {
total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);
workspace_offsets.push_back(total_workspace_bytes);
auto alloc_bytes = workspace_bytes[index];
total_workspace_bytes += alloc_bytes;
}
// Allocate the workspace
Workspace ws(total_workspace_bytes);
std::vector<void *> workspace;
workspace.push_back(minibatch_mean.data<float>());
workspace.push_back(minibatch_inv_var.data<float>());
workspace.push_back(bitmask.data<int32_t>());
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[3];
void* retired_ctas = THCudaMalloc(at::globalContext().lazyInitCUDA(), retired_cta_bytes);
cudaMemsetAsync(retired_ctas, 0, retired_cta_bytes, stream); //FIXME: is this legit?
workspace.push_back(retired_ctas);
for (auto index = 4; index < workspace_bytes.size(); ++index) {
void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-4];
workspace.push_back(ptr);
}
bn->setWorkspacePointers(workspace, workspace_bytes);
int device_id;
cudaGetDevice(&device_id);
// Don't fuse in ReLU for now at least
bn->fwd(stream, device_id, my_data, pair_data, pair_data2, bn_group, *magic, max_cta_per_sm, cta_launch_margin);
THCudaFree(at::globalContext().lazyInitCUDA(), retired_ctas);
return y;
}
at::Tensor nhwc_bn_addrelu_fwd_eval(
const at::Tensor& x,
const at::Tensor& z,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const int bn_group,
const float momentum,
const float epsilon) {
const int N = x.size(0);
const int H = x.size(1);
const int W = x.size(2);
const int C = x.size(3);
// Allocate output tensor
at::Tensor y = at::empty({N, H, W, C}, x.options());
// Create wrapper
NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu();
bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.data<at::Half>(),
nullptr,
y.data<at::Half>(),
nullptr,
z.data<at::Half>(),
nullptr);
bn->setWeightPointers({scale.data<float>(), bias.data<float>()}, {nullptr, nullptr});
bn->setParameterPointers({running_mean.data<float>(), running_inv_var.data<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t total_workspace_bytes = 0;
std::vector<size_t> workspace_offsets;
for (auto index = 4; index < workspace_bytes.size(); ++index) {
total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);
workspace_offsets.push_back(total_workspace_bytes);
auto alloc_bytes = workspace_bytes[index];
total_workspace_bytes += alloc_bytes;
}
// Allocate the workspace
Workspace ws(total_workspace_bytes);
std::vector<void *> workspace;
workspace.push_back(nullptr);
workspace.push_back(nullptr);
workspace.push_back(nullptr);
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[3];
void* retired_ctas = THCudaMalloc(at::globalContext().lazyInitCUDA(), retired_cta_bytes);
cudaMemsetAsync(retired_ctas, 0, retired_cta_bytes, stream); //FIXME: is this legit?
workspace.push_back(retired_ctas);
for (auto index = 4; index < workspace_bytes.size(); ++index) {
void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-4];
workspace.push_back(ptr);
}
bn->setWorkspacePointers(workspace, workspace_bytes);
// Don't fuse in ReLU for now at least
bn->fwdInference(stream);
THCudaFree(at::globalContext().lazyInitCUDA(), retired_ctas);
return y;
}
std::vector<at::Tensor> nhwc_bn_addrelu_bwd(
const at::Tensor& x,
const at::Tensor& dy,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const at::Tensor& bitmask,
const float momentum,
const float epsilon,
void * my_data,
void * pair_data,
void * pair_data2,
const int bn_group,
const at::Tensor& magic_tensor,
const int max_cta_per_sm,
const int cta_launch_margin) {
// shape
const int N = x.size(0);
const int H = x.size(1);
const int W = x.size(2);
const int C = x.size(3);
// generating new magic number and use that for sync
int* magic = magic_tensor.data<int>();
*magic = (*magic + 1) & 0xff;
// outputs
at::Tensor x_grad, z_grad, scale_grad, bias_grad;
// Allocate outputs
x_grad = at::empty_like(x);
z_grad = at::empty_like(x);
scale_grad = at::empty_like(scale);
bias_grad = at::empty_like(bias);
// Create wrapper
NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu();
bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.data<at::Half>(),
x_grad.data<at::Half>(),
nullptr,
dy.data<at::Half>(),
nullptr,
z_grad.data<at::Half>());
bn->setWeightPointers({scale.data<float>(), bias.data<float>()}, {scale_grad.data<float>(), bias_grad.data<float>()});
bn->setParameterPointers({running_mean.data<float>(), running_inv_var.data<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t total_workspace_bytes = 0;
std::vector<size_t> workspace_offsets;
for (auto index = 4; index < workspace_bytes.size(); ++index) {
total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);
workspace_offsets.push_back(total_workspace_bytes);
auto alloc_bytes = workspace_bytes[index];
total_workspace_bytes += alloc_bytes;
}
// Allocate the workspace
Workspace ws(total_workspace_bytes);
std::vector<void *> workspace;
workspace.push_back(minibatch_mean.data<float>());
workspace.push_back(minibatch_inv_var.data<float>());
workspace.push_back(bitmask.data<int32_t>());
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[3];
void* retired_ctas = THCudaMalloc(at::globalContext().lazyInitCUDA(), retired_cta_bytes);
cudaMemsetAsync(retired_ctas, 0, retired_cta_bytes, stream); //FIXME: is this legit?
workspace.push_back(retired_ctas);
for (auto index = 4; index < workspace_bytes.size(); ++index) {
void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-4];
workspace.push_back(ptr);
}
bn->setWorkspacePointers(workspace, workspace_bytes);
int device_id;
cudaGetDevice(&device_id);
bn->dgrad(stream, device_id, my_data, pair_data, pair_data2, bn_group, *magic, max_cta_per_sm, cta_launch_margin);
THCudaFree(at::globalContext().lazyInitCUDA(), retired_ctas);
return std::vector<at::Tensor>{x_grad, z_grad, scale_grad, bias_grad};
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file nhwc_batch_norm_add_relu.h
* \brief CUDA NHWC Batch Normalization code with fused addition
* \author Shankara Rao Thejaswi Nanditale, Dick Carter, Maxim Milakov, Evgeni Krimer
*/
#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_
#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_
#include <cudnn.h>
#include <algorithm>
#include <vector>
#include <string>
#include "nhwc_batch_norm_kernel.h"
#include "cuda_utils.h"
#define VERBOSE_DEFAULT false
class NhwcBatchNormAddRelu {
public:
NhwcBatchNormAddRelu() {
name_ = "nhwc_batchnormaddrelu";
createTensorDescriptor(&X_tensor_desc_);
createTensorDescriptor(&Y_tensor_desc_);
}
~NhwcBatchNormAddRelu() {
destroyTensorDescriptor(X_tensor_desc_);
destroyTensorDescriptor(Y_tensor_desc_);
}
void die() {
std::cerr << "batchnormaddrelu not initialized" << std::endl;
exit(-1);
}
void fwd(cudaStream_t stream, int device_id, void* my_data, void* pair_data, void* pair_data2, const int bn_group, const int magic, const int max_cta_per_sm, const int cta_launch_margin);
void dgrad(cudaStream_t stream, int device_id, void* my_data, void* pair_data, void* pair_data2, const int bn_group, const int magic, const int max_cta_per_sm, const int cta_launch_margin);
void fwdInference(cudaStream_t stream);
dim3 calc_fwd_grid(int device_id, int *loop, const int max_cta_per_sm, const int cta_launch_margin);
dim3 calc_bwd_grid(int device_id, int *loop, const int max_cta_per_sm, const int cta_launch_margin);
void setInputDescriptor(const cudnnTensorFormat_t format,
const cudnnDataType_t data_type,
int n, int c, int h, int w, int bn_group) {
m_ = n * h * w;
int m_bn_adjusted = m_ * bn_group;
c_ = c;
// factor to scale sum of squared errors to get saved variance. Must be 1/nhw.
svar_inv_count_ = 1.f / m_bn_adjusted;
// factor to scale sum of squared errors to get running variance. Should be 1/(nhw-1).
int divisor = m_bn_adjusted - 1;
// nhw == 1 is unlikely, but by setting the rvar_inv_count_ == 1.f, we avoid running var infs.
rvar_inv_count_ = divisor == 0 ? 1.f : 1.f / divisor;
setTensorDescriptor(X_tensor_desc_, format, data_type, n, c, h, w);
}
void setOutputDescriptor(const cudnnTensorFormat_t format,
const cudnnDataType_t data_type,
int n, int c, int h, int w) {
setTensorDescriptor(Y_tensor_desc_, format, data_type, n, c, h, w);
}
const std::vector<size_t> numWorkspaceBytes() const;
void setWorkspacePointers(
const std::vector<void*>& workspace,
const std::vector<size_t>& num_workspace_bytes);
void setInputOutputPointers(void* X, void* dX, void* Y, void *dY, void* addend, void* dAddend) {
X_ = X;
dX_ = dX;
Y_ = Y;
dY_ = dY;
addend_ = addend;
dAddend_ = dAddend;
}
// Sets the pointers for the scale and weight (in that order) data and derivative buffers.
void setWeightPointers(const std::vector<void*>& weight_pointers,
const std::vector<void*>& deriv_pointers) {
assert(weight_pointers.size() == 2);
assert(deriv_pointers.size() == 2);
scale_ = static_cast<float*>(weight_pointers[0]);
bias_ = static_cast<float*>(weight_pointers[1]);
dscale_ = static_cast<float*>(deriv_pointers[0]);
dbias_ = static_cast<float*>(deriv_pointers[1]);
}
// Sets the pointers for the population mean and variance buffers, in that order.
void setParameterPointers(const std::vector<void*>& param_pointers) {
assert(param_pointers.size() == 2);
population_mean_ = static_cast<float*>(param_pointers[0]);
population_variance_ = static_cast<float*>(param_pointers[1]);
}
void setConstants(const double exp_avg_factor, const double eps) {
exp_avg_factor_ = exp_avg_factor;
eps_ = eps;
}
void processCudnnStatus(const cudnnStatus_t& status,
const std::string& string = std::string(),
bool verbose = VERBOSE_DEFAULT) {
if (status != CUDNN_STATUS_SUCCESS)
LOG(FATAL) << string << " " << cudnnGetErrorString(status);
else if (verbose)
LOG(INFO) << string << " " << cudnnGetErrorString(status);
}
void checkCudaStatus(const std::string& string = std::string(),
bool verbose = VERBOSE_DEFAULT) {
cudaError_t status = cudaGetLastError();
if (status != cudaSuccess)
LOG(FATAL) << string << " " << cudaGetErrorString(status);
else if (verbose)
LOG(INFO) << string << " " << cudaGetErrorString(status);
}
size_t size_retired_ctas(int grid_y) const {
// Note that the value of max_grid_y to handle known GPUs is about 160.
const int max_grid_y = 1024;
if (grid_y > max_grid_y)
LOG(INFO) << "GPU capabilities exceeds assumptions.";
const int retired_cta_bytes = max_grid_y * 2 * sizeof(int);
// Since the region will be initialized once and used for many kernels,
// the idea is to return an ample size that will cover all uses.
return retired_cta_bytes;
}
cudnnTensorDescriptor_t X_tensor_desc_ = nullptr;
cudnnTensorDescriptor_t Y_tensor_desc_ = nullptr;
void* X_ = nullptr;
void* dX_ = nullptr;
void* Y_ = nullptr;
void* dY_ = nullptr;
void* addend_ = nullptr;
void* dAddend_ = nullptr;
// Learned scale and bias weights.
float* scale_ = nullptr;
float* dscale_ = nullptr;
float* bias_ = nullptr;
float* dbias_ = nullptr;
// Computed population mean and variance parameters.
float* population_mean_ = nullptr;
float* population_variance_ = nullptr;
// Workspace buffers for minibatch mean and variance (computed in fwd, needed by bwd).
float* minibatch_mean_ = nullptr;
float* minibatch_variance_ = nullptr;
int m_ = 0; // Number of values per channel that BN is normalizing.
int c_ = 0; // Number of channels over which BN is normalizing.
float svar_inv_count_ = 0.f; // factor to scale sum of squared errors to get saved variance
float rvar_inv_count_ = 0.f; // factor to scale sum of squared errors to get running variance
double exp_avg_factor_ = 0.;
double eps_ = 0.;
std::string name_;
private:
void setTensorDescriptor(cudnnTensorDescriptor_t descriptor,
cudnnTensorFormat_t format,
cudnnDataType_t data_type,
int n, int c, int h, int w) {
cudnnStatus_t status = CUDNN_STATUS_SUCCESS;
status = cudnnSetTensor4dDescriptor(descriptor, format, data_type, n, c, h, w);
processCudnnStatus(status, "set tensor descriptor");
}
void createTensorDescriptor(cudnnTensorDescriptor_t *descriptor) {
cudnnStatus_t status = CUDNN_STATUS_SUCCESS;
status = cudnnCreateTensorDescriptor(descriptor);
processCudnnStatus(status, "create tensor_descriptor");
}
void destroyTensorDescriptor(cudnnTensorDescriptor_t descriptor) {
cudnnStatus_t status = CUDNN_STATUS_SUCCESS;
status = cudnnDestroyTensorDescriptor(descriptor);
processCudnnStatus(status, "destroy tensor_descriptor");
}
protected:
float *partial_sums_ = nullptr;
int *partial_counts_ = nullptr;
int *retired_ctas_ = nullptr;
unsigned int *relu_bitmask_ = nullptr;
void _setFwdParams(NhwcBatchNormFwdParams *params) const;
void _setFwdInferenceParams(NhwcBatchNormFwdInferenceParams *params) const;
void _setBwdParams(NhwcBatchNormBwdParams *params) const;
// @todo: ability to configure these?
// Kernel params
static const int USE_ONLINE_APPROACH = 1;
static const int THREADS_PER_CTA = 512;
static const int THREADS_PER_PIXEL = 16;
static const int C_ELEMENTS_PER_CTA = 64;
static const int ELEMENTS_PER_LDG = C_ELEMENTS_PER_CTA / THREADS_PER_PIXEL;
static const int MAX_SMEM_WITHOUT_OPT_IN = 48 * 1024;
typedef uint16_t StorageType;
// increasing this to 6 causes spills in fwd kernel!
static const int PIXELS_PER_THREAD_IN_REGISTERS_FWD = 5;
static const int PIXELS_PER_THREAD_IN_REGISTERS_BWD = 3;
static const int PIXELS_PER_THREAD_IN_SMEM_FWD = 10;
static const int PIXELS_PER_THREAD_IN_SMEM_BWD = 5;
static const int PIXELS_PER_THREAD_FWD = PIXELS_PER_THREAD_IN_REGISTERS_FWD + \
PIXELS_PER_THREAD_IN_SMEM_FWD;
static const int PIXELS_PER_THREAD_BWD = PIXELS_PER_THREAD_IN_REGISTERS_BWD + \
PIXELS_PER_THREAD_IN_SMEM_BWD;
static const int PIXELS_PER_THREAD_FWD_INFERENCE = 4;
// Derived params
static const size_t SMEM_SIZE_FWD = PIXELS_PER_THREAD_IN_SMEM_FWD*THREADS_PER_CTA*\
ELEMENTS_PER_LDG*sizeof(StorageType);
static const size_t SMEM_SIZE_BWD = PIXELS_PER_THREAD_IN_SMEM_BWD*THREADS_PER_CTA*\
ELEMENTS_PER_LDG*2*sizeof(StorageType);
static const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;
static const int PIXELS_PER_CTA_FWD = THREADS_PER_CTA/THREADS_PER_PIXEL * \
PIXELS_PER_THREAD_FWD;
static const int PIXELS_PER_CTA_BWD = THREADS_PER_CTA/THREADS_PER_PIXEL * \
PIXELS_PER_THREAD_BWD;
static const int PIXELS_PER_CTA_FWD_INFERENCE = THREADS_PER_CTA/THREADS_PER_PIXEL * \
PIXELS_PER_THREAD_FWD_INFERENCE;
// max grid.y in case of group bn is limited by exchange buffer size
static const int MAX_GBN_BLOCK_Y = 256;
// Helper function to launch the forward kernel.
// We calculate (based on smem usage) the achievable occupancy and make sure we run a kernel
// version that was compiled with that occupancy in its launch bounds. This way, we avoid
// needless register spills.
void _fwdKernelLauncher(cudaStream_t stream, NhwcBatchNormFwdParams params,
dim3 grid_dim, int outer_loops, int device_id, const int max_cta_per_sm) {
#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY) \
do { \
CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << \
"Nhwc batchnormaddrelu kernel smem too big."; \
auto fwd_func = nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
cudaFuncSetAttribute(fwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + " fwd ser coop kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(&params); \
using FWD_FUNC = decltype(nhwc_batch_norm_fwd< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_FWD, \
PIXELS_PER_THREAD_IN_SMEM_FWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>); \
cudaLaunchCooperativeKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_FWD, \
stream); \
checkCudaStatus(name_ + " fwd ser coop kernel"); \
} while (0)
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
int occupancy = smem_driven_fwd_occupancy(device_id, max_cta_per_sm);
if (outer_loops == 1) {
if (occupancy >= 2)
LAUNCH_FWD_KERNEL(1, false, true, 2);
else
LAUNCH_FWD_KERNEL(1, false, true, 1);
} else {
if (occupancy >= 2)
LAUNCH_FWD_KERNEL(0, false, true, 2);
else
LAUNCH_FWD_KERNEL(0, false, true, 1);
}
#undef LAUNCH_FWD_KERNEL
}
// Helper function to launch the backward kernel.
void _bwdKernelLauncher(cudaStream_t stream, NhwcBatchNormBwdParams params,
dim3 grid_dim, int outer_loops, int device_id, const int max_cta_per_sm) {
#define LAUNCH_BWD_ADD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << \
"Nhwc batchnormaddrelu kernel smem too big."; \
auto bwd_add_relu_func = nhwc_batch_norm_bwd_add_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>; \
if (COMPILED_FOR_OCCUPANCY > 1) { \
cudaFuncSetAttribute(bwd_add_relu_func, \
cudaFuncAttributePreferredSharedMemoryCarveout, 100); \
checkCudaStatus(name_ + \
" bwd-add-relu coop serial kernel (cudaFuncSetAttribute carveout)"); \
} \
void *params_ptr = static_cast<void*>(&params); \
using BWD_ADD_RELU_FUNC = decltype(nhwc_batch_norm_bwd_add_relu< \
StorageType, \
THREADS_PER_CTA, \
THREADS_PER_PIXEL, \
PIXELS_PER_THREAD_IN_REGISTERS_BWD, \
PIXELS_PER_THREAD_IN_SMEM_BWD, \
ELEMENTS_PER_LDG, \
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
cudaLaunchCooperativeKernel<BWD_ADD_RELU_FUNC>(bwd_add_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
checkCudaStatus(name_ + " bwd-add-relu coop serial kernel"); \
} while (0)
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
int occupancy = smem_driven_bwd_occupancy(device_id, max_cta_per_sm);
if (outer_loops == 1) {
if (occupancy >= 2)
LAUNCH_BWD_ADD_RELU_KERNEL(1, 2);
else
LAUNCH_BWD_ADD_RELU_KERNEL(1, 1);
} else {
if (occupancy >= 2)
LAUNCH_BWD_ADD_RELU_KERNEL(0, 2);
else
LAUNCH_BWD_ADD_RELU_KERNEL(0, 1);
}
#undef LAUNCH_BWD_KERNEL
}
private:
// Calculate the max number of CTAs allowed in the grid for the fwd kernel.
static size_t max_fwd_grid_x(int device_id, const int max_cta_per_sm, const int cta_launch_margin) {
using namespace at::cuda::utils;
int answer = MultiprocessorCount(device_id) * smem_driven_fwd_occupancy(device_id, max_cta_per_sm);
if (SMArch(device_id) >= 70)
answer -= cta_launch_margin;
answer = std::max(1, answer); // we need at least one CTA to operate
return static_cast<size_t>(answer);
}
// Calculate the max number of CTAs allowed in the grid for the bwd kernel.
static size_t max_bwd_grid_x(int device_id, const int max_cta_per_sm, const int cta_launch_margin) {
using namespace at::cuda::utils;
int answer = MultiprocessorCount(device_id) * smem_driven_bwd_occupancy(device_id, max_cta_per_sm);
if (SMArch(device_id) >= 70)
answer -= cta_launch_margin;
answer = std::max(1, answer); // we need at least one CTA to operate
return static_cast<size_t>(answer);
}
// Calculate the expected fwd kernel occupancy, as dictated by shared memory usage.
static int smem_driven_fwd_occupancy(int device_id, const int max_cta_per_sm) {
using namespace at::cuda::utils;
int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG*sizeof(float);
int fwd_smem_bytes = SMEM_SIZE_FWD + fwd_reduction_bytes;
int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / fwd_smem_bytes;
return std::min(max_cta_per_sm, occupancy);
}
// Calculate the expected bwd kernel occupancy, as dictated by shared memory usage.
static int smem_driven_bwd_occupancy(int device_id, const int max_cta_per_sm) {
using namespace at::cuda::utils;
int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG*sizeof(float);
int bwd_smem_bytes = SMEM_SIZE_BWD + bwd_reduction_bytes;
int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / bwd_smem_bytes;
return std::min(max_cta_per_sm, occupancy);
}
};
const std::vector<size_t> NhwcBatchNormAddRelu::numWorkspaceBytes() const {
assert(c_ > 0);
// choose the max memory required between fwd/bwd passes
int grid_x_fwd = div_up(m_, PIXELS_PER_CTA_FWD);
int grid_x_bwd = div_up(m_, PIXELS_PER_CTA_BWD);
int grid_x = max(grid_x_fwd, grid_x_bwd);
int grid_y = div_up(c_, C_ELEMENTS_PER_CTA);
const size_t num_mean_bytes = c_ * sizeof(float);
const size_t num_variance_bytes = num_mean_bytes;
int elems_per_group = ((m_ + 31) & ~31) * 2;
int group_count = div_up(c_, C_ELEMENTS_PER_CTA);
const size_t bitmask_bytes = elems_per_group * group_count * sizeof(unsigned int);
const size_t size_sums = grid_y*grid_x*THREADS_PER_PIXEL*\
ELEMENTS_PER_LDG*2*sizeof(float);
const size_t size_counts = grid_y*grid_x*sizeof(int);
return {num_mean_bytes, num_variance_bytes, bitmask_bytes,
size_retired_ctas(grid_y), size_sums, size_counts};
}
void NhwcBatchNormAddRelu::setWorkspacePointers(
const std::vector<void*>& workspace,
const std::vector<size_t>& num_workspace_bytes) {
assert(workspace.size() == 6);
assert(num_workspace_bytes.size() == 6);
minibatch_mean_ = static_cast<float*>(workspace[0]);
minibatch_variance_ = static_cast<float*>(workspace[1]);
relu_bitmask_ = static_cast<unsigned int*>(workspace[2]);
retired_ctas_ = static_cast<int*>(workspace[3]);
partial_sums_ = static_cast<float*>(workspace[4]);
partial_counts_ = static_cast<int*>(workspace[5]);
}
void NhwcBatchNormAddRelu::_setFwdParams(NhwcBatchNormFwdParams *params) const {
params->gmem_src = static_cast<uint16_t*>(X_);
params->gmem_dst = static_cast<uint16_t*>(Y_);
params->gmem_src1 = static_cast<uint16_t*>(addend_);
params->gmem_bias = bias_;
params->gmem_scale = scale_;
params->gmem_running_mean = population_mean_;
params->gmem_running_var = population_variance_;
params->gmem_saved_mean = minibatch_mean_;
params->gmem_saved_var = minibatch_variance_;
params->gmem_relu_bitmask = relu_bitmask_;
params->nhw = m_;
params->c = c_;
params->svar_inv_count = svar_inv_count_;
params->rvar_inv_count = rvar_inv_count_;
params->gmem_sums = partial_sums_;
params->gmem_counts = partial_counts_;
params->gmem_retired_ctas = retired_ctas_;
params->var_eps = eps_;
params->outer_loops = 0;
params->exp_avg_factor = static_cast<float>(exp_avg_factor_);
params->c_blks = div_up(c_, C_ELEMENTS_PER_CTA);
}
void NhwcBatchNormAddRelu::_setFwdInferenceParams(NhwcBatchNormFwdInferenceParams
*params) const {
params->gmem_src = static_cast<uint16_t*>(X_);
params->gmem_dst = static_cast<uint16_t*>(Y_);
params->gmem_src1 = static_cast<uint16_t*>(addend_);
params->gmem_bias = bias_;
params->gmem_scale = scale_;
params->gmem_mean = population_mean_;
params->gmem_var = population_variance_;
params->nhw = m_;
params->c = c_;
params->var_eps = eps_;
}
void NhwcBatchNormAddRelu::_setBwdParams(NhwcBatchNormBwdParams *params) const {
params->gmem_src = static_cast<uint16_t*>(X_);
params->gmem_dy = static_cast<uint16_t*>(dY_);
params->gmem_dst = static_cast<uint16_t*>(dX_);
params->gmem_dst1 = static_cast<uint16_t*>(dAddend_);
params->gmem_relu_bitmask = relu_bitmask_;
params->gmem_dscale = dscale_;
params->gmem_dbias = dbias_;
params->gmem_scale = scale_;
params->gmem_bias = bias_;
params->gmem_saved_mean = minibatch_mean_;
params->gmem_saved_var = minibatch_variance_;
params->nhw = m_;
params->c = c_;
params->svar_inv_count = svar_inv_count_;
params->gmem_sums = partial_sums_;
params->gmem_retired_ctas = retired_ctas_;
params->outer_loops = 0;
params->c_blks = div_up(c_, C_ELEMENTS_PER_CTA);
}
void NhwcBatchNormAddRelu::fwdInference(cudaStream_t stream) {
bool ptrs_are_set =
X_tensor_desc_ != nullptr
&& Y_tensor_desc_ != nullptr
&& scale_ != nullptr
&& bias_ != nullptr
// && minibatch_mean_ != nullptr
// && minibatch_variance_ != nullptr
&& population_mean_ != nullptr
&& population_variance_ != nullptr
&& X_ != nullptr
// && dX_ != nullptr
&& Y_ != nullptr
&& addend_ != nullptr
// && dY_ != nullptr
// && dscale_ != nullptr
// && dbias_ != nullptr
&& partial_sums_ != nullptr
&& partial_counts_ != nullptr;
if (!ptrs_are_set)
die();
dim3 grid_dim;
grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD_INFERENCE);
grid_dim.y = div_up(c_, C_ELEMENTS_PER_CTA);
// @todo: maybe just move this inside initialize routine?
NhwcBatchNormFwdInferenceParams params;
_setFwdInferenceParams(&params);
nhwc_batch_norm_fwd_inference
<StorageType, THREADS_PER_CTA, THREADS_PER_PIXEL, ELEMENTS_PER_LDG, false, true>
<<<grid_dim, THREADS_PER_CTA, 0, stream>>>(params);
checkCudaStatus(name_ + " fwd_inference-relu kernel");
}
dim3 NhwcBatchNormAddRelu::calc_fwd_grid(int device_id, int *loop, const int max_cta_per_sm, const int cta_launch_margin) {
dim3 grid_dim;
grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD);
int c_blks = div_up(c_, C_ELEMENTS_PER_CTA);
unsigned int max_grid_x = max_fwd_grid_x(device_id, max_cta_per_sm, cta_launch_margin);
if (grid_dim.x <= max_grid_x) {
*loop = 1;
if (max_grid_x / grid_dim.x > 1) {
grid_dim.y = std::min(c_blks, static_cast<int>(max_grid_x / grid_dim.x));
assert(grid_dim.y<MAX_GBN_BLOCK_Y); //FIXME: turn into a loop
} else {
grid_dim.y = 1;
}
} else {
grid_dim.x = max_grid_x;
grid_dim.y = 1;
int nhw_in_regs = m_ - PIXELS_PER_THREAD_IN_SMEM_FWD*PIXELS_PER_LDG*grid_dim.x;
int pixels_per_iteration = PIXELS_PER_THREAD_IN_REGISTERS_FWD*PIXELS_PER_LDG*grid_dim.x;
*loop = div_up(nhw_in_regs, pixels_per_iteration);
}
return grid_dim;
}
dim3 NhwcBatchNormAddRelu::calc_bwd_grid(int device_id, int *loop, const int max_cta_per_sm, const int cta_launch_margin) {
dim3 grid_dim;
grid_dim.x = div_up(m_, PIXELS_PER_CTA_BWD);
int c_blks = div_up(c_, C_ELEMENTS_PER_CTA);
unsigned int max_grid_x = max_bwd_grid_x(device_id, max_cta_per_sm, cta_launch_margin);
if (grid_dim.x <= max_grid_x) {
*loop = 1;
if (max_grid_x / grid_dim.x > 1) {
grid_dim.y = std::min(c_blks, static_cast<int>(max_grid_x / grid_dim.x));
assert(grid_dim.y<MAX_GBN_BLOCK_Y); //FIXME: turn into a loop
} else {
grid_dim.y = 1;
}
} else {
grid_dim.x = max_grid_x;
grid_dim.y = 1;
int nhw_in_regs = m_ - PIXELS_PER_THREAD_IN_SMEM_BWD*PIXELS_PER_LDG*grid_dim.x;
int pixels_per_iteration = PIXELS_PER_THREAD_IN_REGISTERS_BWD*PIXELS_PER_LDG*grid_dim.x;
*loop = div_up(nhw_in_regs, pixels_per_iteration);
}
return grid_dim;
}
void NhwcBatchNormAddRelu::fwd(cudaStream_t stream, int device_id, void* my_data, void* pair_data, void* pair_data2, const int bn_group, const int magic, const int max_cta_per_sm, const int cta_launch_margin) {
bool ptrs_are_set =
X_tensor_desc_ != nullptr
&& Y_tensor_desc_ != nullptr
&& scale_ != nullptr
&& bias_ != nullptr
&& minibatch_mean_ != nullptr
&& minibatch_variance_ != nullptr
&& relu_bitmask_ != nullptr
&& population_mean_ != nullptr
&& population_variance_ != nullptr
&& X_ != nullptr
// && dX_ != nullptr
&& Y_ != nullptr
&& addend_ != nullptr
// && dY_ != nullptr
// && dscale_ != nullptr
// && dbias_ != nullptr
&& partial_sums_ != nullptr
&& partial_counts_ != nullptr
&& retired_ctas_ != nullptr;
if (!ptrs_are_set)
die();
// reset of retired_cta_count no longer needed
NhwcBatchNormFwdParams params;
_setFwdParams(&params);
params.my_data = my_data;
params.pair_data = pair_data;
params.pair_data2 = pair_data2;
params.magic = magic;
params.sync_iters = bn_group >> 1;
dim3 grid_dim = calc_fwd_grid(device_id, &params.outer_loops, max_cta_per_sm, cta_launch_margin);
_fwdKernelLauncher(stream, params, grid_dim, params.outer_loops, device_id, max_cta_per_sm);
}
void NhwcBatchNormAddRelu::dgrad(cudaStream_t stream, int device_id, void* my_data, void* pair_data, void* pair_data2, const int bn_group, const int magic, const int max_cta_per_sm, const int cta_launch_margin) {
bool ptrs_are_set =
X_tensor_desc_ != nullptr
&& Y_tensor_desc_ != nullptr
&& scale_ != nullptr
&& bias_ != nullptr
&& minibatch_mean_ != nullptr
&& minibatch_variance_ != nullptr
&& relu_bitmask_ != nullptr
// && population_mean_ != nullptr
// && population_variance_ != nullptr
&& X_ != nullptr
&& dX_ != nullptr
// && Y_ != nullptr
&& dY_ != nullptr
&& dAddend_ != nullptr
&& dscale_ != nullptr
&& dbias_ != nullptr
&& retired_ctas_ != nullptr;
if (!ptrs_are_set)
die();
// reset of retired_cta_count no longer needed
NhwcBatchNormBwdParams params;
_setBwdParams(&params);
params.my_data = my_data;
params.pair_data = pair_data;
params.pair_data2 = pair_data2;
params.magic = magic;
params.sync_iters = bn_group >> 1;
params.wgrad_coeff = 1.0 / bn_group;
dim3 grid_dim = calc_bwd_grid(device_id, &params.outer_loops, max_cta_per_sm, cta_launch_margin);
_bwdKernelLauncher(stream, params, grid_dim, params.outer_loops, device_id, max_cta_per_sm);
}
#endif // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_
#include <ATen/cuda/CUDAContext.h>
#ifndef CUDA_UTILS_H
#define CUDA_UTILS_H
namespace at {
namespace cuda {
namespace utils {
//eventually should be replaced by real query functions
static inline int MultiprocessorCount(int device_id) {
return getDeviceProperties(device_id)->multiProcessorCount;
}
static inline int SMArch(int device_id) {
auto device_property = getDeviceProperties(device_id);
int cc = device_property->major * 10 + device_property->minor;
return cc;
}
static inline int MaxSharedMemoryPerMultiprocessor(int device_id) {
return getDeviceProperties(device_id)->sharedMemPerMultiprocessor;
}
}
}
}
#endif
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/ArrayRef.h>
#include <ATen/ScalarType.h>
#include "ATen/Scalar.h"
#include "ATen/Type.h"
#include "ATen/Tensor.h"
#include "ATen/Storage.h"
#include "ATen/Generator.h"
namespace py = pybind11;
int64_t get_buffer_size(
const int bn_sync_steps);
void* get_data_ptr(
const at::Tensor& data);
void* get_remote_data_ptr(
const at::Tensor& handle,
const int64_t offset);
void close_remote_data(
const at::Tensor& handle);
at::Tensor nhwc_bn_fwd_train(
const at::Tensor& x,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const float momentum,
const float epsilon,
const bool fuse_relu,
void* my_data,
void* pair_data,
void* pair_data2,
const int bn_group,
const at::Tensor& magic_tensor,
const int max_cta_per_sm,
const int cta_launch_margin);
at::Tensor nhwc_bn_fwd_eval(
const at::Tensor& x,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const int bn_group,
const float momentum,
const float epsilon,
const bool fuse_relu);
std::vector<at::Tensor> nhwc_bn_bwd(
const at::Tensor& x,
const at::Tensor& dy,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const float momentum,
const float epsilon,
const bool fuse_relu,
void* my_data,
void* pair_data,
void* pair_data2,
const int bn_group,
const at::Tensor& magic_tensor,
const int max_cta_per_sm,
const int cta_launch_margin);
at::Tensor nhwc_bn_addrelu_fwd_train(
const at::Tensor& x,
const at::Tensor& z,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const at::Tensor& bitmask,
const float momentum,
const float epsilon,
void* my_data,
void* pair_data,
void* pair_data2,
const int bn_group,
const at::Tensor& magic_tensor,
const int max_cta_per_sm,
const int cta_launch_margin);
at::Tensor nhwc_bn_addrelu_fwd_eval(
const at::Tensor& x,
const at::Tensor& z,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const int bn_group,
const float momentum,
const float epsilon);
std::vector<at::Tensor> nhwc_bn_addrelu_bwd(
const at::Tensor& x,
const at::Tensor& dy,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const at::Tensor& bitmask,
const float momentum,
const float epsilon,
void* my_data,
void* pair_data,
void* pair_data2,
const int bn_group,
const at::Tensor& magic_tensor,
const int max_cta_per_sm,
const int cta_launch_margin);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("get_buffer_size", &get_buffer_size, "get_buffer_size");
m.def("get_data_ptr", &get_data_ptr, "get_data_ptr");
m.def("get_remote_data_ptr", &get_remote_data_ptr, "get_remote_data_ptr");
m.def("close_remote_data", &close_remote_data, "close_remote_data");
m.def("bn_fwd_nhwc", &nhwc_bn_fwd_train, "bn_fwd_nhwc");
m.def("bn_fwd_eval_nhwc", &nhwc_bn_fwd_eval, "bn_fwd_eval_nhwc");
m.def("bn_bwd_nhwc", &nhwc_bn_bwd, "bn_bwd_nhwc");
m.def("bn_addrelu_fwd_nhwc", &nhwc_bn_addrelu_fwd_train, "bn_addrelu_fwd_nhwc");
m.def("bn_addrelu_fwd_eval_nhwc", &nhwc_bn_addrelu_fwd_eval, "bn_addrelu_fwd_eval_nhwc");
m.def("bn_addrelu_bwd_nhwc", &nhwc_bn_addrelu_bwd, "bn_addrelu_bwd_nhwc");
}
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCNumerics.cuh>
#include "THC/THC.h"
#include <cuda.h>
#define cudaCheckErrors(msg) \
do { \
cudaError_t __err = cudaGetLastError(); \
if (__err != cudaSuccess) { \
fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", \
msg, cudaGetErrorString(__err), \
__FILE__, __LINE__); \
fprintf(stderr, "*** FAILED - ABORTING\n"); \
exit(1); \
} \
} while (0)
template<>
struct std::hash<cudaIpcMemHandle_t> {
size_t operator() (const cudaIpcMemHandle_t& handle) const {
size_t hash = 0;
uint8_t* ptr = (uint8_t*)&handle;
assert(sizeof(uint8_t) == 1);
for (int i=0; i<sizeof(cudaIpcMemHandle_t); i++) {
hash += *ptr;
ptr++;
}
return hash;
}
};
template<>
struct std::equal_to<cudaIpcMemHandle_t> {
bool operator() (const cudaIpcMemHandle_t &lhs,
const cudaIpcMemHandle_t &rhs) const {
return (std::memcmp((void*) &lhs,
(void*) &rhs,
sizeof(cudaIpcMemHandle_t)) == 0);
}
};
namespace {
namespace gpuipc {
//from: src/operator/nn/cudnn/nhwc_batch_norm_kernel.h
// The number of threads per pixel.
const int THREADS_PER_PIXEL = 16;
// The number of elements per ldg.
const int ELEMENTS_PER_LDG = 4;
// The number of reducing ops, each uses its own space : mean, var, dscale, dbias
const int REDUCE_OPS = 4;
// Maximum block.y supported - limited due to buffer allocation
const int MAX_BLOCK_Y = 256;
const int MAX_OFFSET = REDUCE_OPS*MAX_BLOCK_Y;
const int BYTES_PER_ELEM = 4;
// Buffer size per sync step
const int SINGLE_SYNC_BUFFER_BYTES = MAX_OFFSET*THREADS_PER_PIXEL*(1+ELEMENTS_PER_LDG)*BYTES_PER_ELEM;
};
class IpcMemHandleRegistry {
public:
void* getPtr(const cudaIpcMemHandle_t& handle, int64_t offset) {
if (registry_.count(handle) == 0) {
registry_.insert(std::make_pair(handle, RegistryEntry()));
registry_[handle].dev_ptr = ipcOpenMem(handle);
}
registry_[handle].ref_count++;
return (((uint8_t*)registry_[handle].dev_ptr) + offset);
}
void releasePtr(const cudaIpcMemHandle_t& handle) {
if (registry_.count(handle) == 0) {
}
if (--registry_[handle].ref_count == 0) {
ipcCloseMem(registry_[handle].dev_ptr);
registry_.erase(handle);
}
}
struct RegistryEntry {
void* dev_ptr;
int ref_count;
RegistryEntry() : dev_ptr(NULL) , ref_count(0) {}
};
protected:
std::unordered_map<cudaIpcMemHandle_t, RegistryEntry> registry_;
void* ipcOpenMem(const cudaIpcMemHandle_t& handle) {
void *data;
cudaIpcOpenMemHandle(&data, handle, cudaIpcMemLazyEnablePeerAccess);
cudaCheckErrors("ipc init");
return data;
}
void ipcCloseMem(void* dev_ptr) {
cudaIpcCloseMemHandle(dev_ptr);
cudaCheckErrors("ipc close");
}
};
}
static IpcMemHandleRegistry ipc_mem_registry;
int64_t get_buffer_size(const int bn_sync_steps) {
return bn_sync_steps * gpuipc::SINGLE_SYNC_BUFFER_BYTES;
}
void* get_remote_data_ptr(const at::Tensor& handle, const int64_t offset) {
cudaIpcMemHandle_t my_handle;
memcpy((unsigned char *)(&my_handle), handle.data<uint8_t>(), sizeof(my_handle));
return ipc_mem_registry.getPtr(my_handle, offset);
}
void close_remote_data(const at::Tensor& handle) {
cudaIpcMemHandle_t my_handle;
memcpy((unsigned char *)(&my_handle), handle.data<uint8_t>(), sizeof(my_handle));
ipc_mem_registry.releasePtr(my_handle);
}
void* get_data_ptr(
const at::Tensor& data) {
return data.data<uint8_t>();
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file nhwc_batch_norm_kernel.h
* \brief CUDA NHWC Batch Normalization code
* \author Shankara Rao Thejaswi Nanditale, Dick Carter, Maxim Milakov, Evgeni Krimer
*/
#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_
#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_
#include <stdint.h>
#include <algorithm>
#define DEVICE_FUNCTION static inline __device__
// CTA margin used by cooperative launch. Can be overridden by env var NHWC_BATCHNORM_LAUNCH_MARGIN.
#define NHWC_BATCHNORM_LAUNCH_MARGIN_MIN 3
#define NHWC_BATCHNORM_LAUNCH_MARGIN_DEFAULT NHWC_BATCHNORM_LAUNCH_MARGIN_MIN
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename T, int ELEMENTS_PER_LDG >
struct PackedStorage {
enum { PACKED_ELEMENTS_PER_LDG = ELEMENTS_PER_LDG };
typedef T Type;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int ELEMENTS_PER_LDG >
struct PackedStorage<uint16_t, ELEMENTS_PER_LDG> {
enum { PACKED_ELEMENTS_PER_LDG = ELEMENTS_PER_LDG/2 };
typedef int Type;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
DEVICE_FUNCTION void from_float(int (&dst)[N], const float (&src)[2*N]) {
#pragma unroll
for (int i = 0; i < N; ++i) {
uint16_t lo, hi;
asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(lo) : "f"(src[2*i+0]));
asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(hi) : "f"(src[2*i+1]));
asm volatile("mov.b32 %0, {%1, %2};" : "=r"(dst[i]) : "h"(lo), "h"(hi));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
DEVICE_FUNCTION void from_float(float (&dst)[N], const float (&src)[N]) {
#pragma unroll
for (int i = 0; i < N; ++i) {
dst[i] = src[i];
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
DEVICE_FUNCTION void to_float(float (&dst)[2*N], int (&src)[N]) {
#pragma unroll
for (int i = 0; i < N; ++i) {
uint16_t lo, hi;
asm volatile("mov.b32 {%0, %1}, %2;" : "=h"(lo), "=h"(hi) : "r"(src[i]));
asm volatile("cvt.f32.f16 %0, %1;" : "=f"(dst[2*i+0]) : "h"(lo));
asm volatile("cvt.f32.f16 %0, %1;" : "=f"(dst[2*i+1]) : "h"(hi));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
DEVICE_FUNCTION void to_float(float (&dst)[N], float (&src)[N]) {
#pragma unroll
for (int i = 0; i < N; ++i) {
dst[i] = src[i];
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void ldg(int (&dst)[1], const uint16_t *gmem) {
dst[0] = __ldg((const int*) gmem);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void ldg_stream(int (&dst)[1], const uint16_t *gmem) {
unsigned int tmp;
asm volatile ("ld.global.cs.nc.s32 %0, [%1];" : "=r"(tmp) : "l" ((const uint *)gmem));
dst[0] = tmp;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void ldg(int (&dst)[2], const uint16_t *gmem) {
int2 tmp = __ldg((const int2*) gmem);
dst[0] = tmp.x;
dst[1] = tmp.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void ldg_stream(int (&dst)[2], const uint16_t *gmem) {
int2 tmp;
asm volatile ("ld.global.cs.nc.v2.s32 {%0,%1}, [%2];"
: "=r"(tmp.x), "=r"(tmp.y) : "l"((const int2 *)gmem));
dst[0] = tmp.x;
dst[1] = tmp.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
DEVICE_FUNCTION void ldg(float (&dst)[N], const uint16_t *gmem) {
int tmp[N/2];
ldg(tmp, gmem);
to_float(dst, tmp);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
DEVICE_FUNCTION void ldg_stream(float (&dst)[N], const uint16_t *gmem) {
int tmp[N/2];
ldg_stream(tmp, gmem);
to_float(dst, tmp);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void stg(uint16_t *gmem, int (&src)[1]) {
reinterpret_cast<int*>(gmem)[0] = src[0];
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void stg_stream(uint16_t *gmem, int (&src)[1]) {
unsigned int tmp = src[0];
asm volatile ("st.global.cs.s32 [%0], %1;"
:: "l"((uint *)gmem) , "r"(tmp));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void stg(uint16_t *gmem, int (&src)[2]) {
reinterpret_cast<int2*>(gmem)[0] = make_int2(src[0], src[1]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void stg_stream(uint16_t *gmem, int (&src)[2]) {
asm volatile ("st.global.cs.v2.s32 [%0], {%1,%2};"
:: "l"((uint *)gmem) , "r"(src[0]), "r"( src[1]));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
DEVICE_FUNCTION void stg(uint16_t *gmem, float (&src)[N]) {
int tmp[N/2];
from_float(tmp, src);
stg(gmem, tmp);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
DEVICE_FUNCTION void stg_stream(uint16_t *gmem, float (&src)[N]) {
int tmp[N/2];
from_float(tmp, src);
stg_stream(gmem, tmp);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void read_from_gmem(float (&dst)[2], const float *gmem, int idx) {
float2 tmp = __ldg(reinterpret_cast<const float2*>(&gmem[2*idx]));
dst[0] = tmp.x;
dst[1] = tmp.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void read_from_gmem(float (&dst)[4], const float *gmem, int idx) {
float4 tmp = __ldg(reinterpret_cast<const float4*>(&gmem[4*idx]));
dst[0] = tmp.x;
dst[1] = tmp.y;
dst[2] = tmp.z;
dst[3] = tmp.w;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void read_from_smem(float (&x)[2], const float *smem, int idx) {
float2 tmp = *(const float2*) &smem[2*idx];
x[0] = tmp.x;
x[1] = tmp.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void read_from_smem(int (&x)[1], const int *smem, int idx) {
x[0] = smem[idx];
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void read_from_smem(float (&x)[4], const float *smem, int idx) {
float4 tmp = *(const float4*) &smem[4*idx];
x[0] = tmp.x;
x[1] = tmp.y;
x[2] = tmp.z;
x[3] = tmp.w;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void read_from_smem(int (&x)[2], const int *smem, int idx) {
int2 tmp = *(const int2*) &smem[2*idx];
x[0] = tmp.x;
x[1] = tmp.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void write_to_gmem(float *gmem, int idx, const float (&src)[2]) {
reinterpret_cast<float2*>(&gmem[2*idx])[0] = make_float2(src[0], src[1]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void write_to_gmem(float *gmem, int idx, const float (&src)[4]) {
reinterpret_cast<float4*>(&gmem[4*idx])[0] = make_float4(src[0], src[1], src[2], src[3]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void scaled_write_to_gmem(float *gmem, int idx, const float (&src)[4], const float coeff) {
reinterpret_cast<float4*>(&gmem[4*idx])[0] = make_float4(src[0]*coeff, src[1]*coeff, src[2]*coeff, src[3]*coeff);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void write_to_smem(float *smem, int idx, const float (&x)[2]) {
reinterpret_cast<float2*>(&smem[2*idx])[0] = make_float2(x[0], x[1]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void write_to_smem(int *smem, int idx, const int (&x)[1]) {
smem[idx] = x[0];
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void write_to_smem(float *smem, int idx, const float (&x)[4]) {
reinterpret_cast<float4*>(&smem[4*idx])[0] = make_float4(x[0], x[1], x[2], x[3]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
DEVICE_FUNCTION void write_to_smem(int *smem, int idx, const int (&x)[2]) {
reinterpret_cast<int2*>(&smem[2*idx])[0] = make_int2(x[0], x[1]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
DEVICE_FUNCTION void zero_array(int (&dst)[N]) {
#pragma unroll
for (int i = 0; i < N; ++i) {
dst[i] = 0;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
DEVICE_FUNCTION void zero_array(float (&dst)[N]) {
#pragma unroll
for (int i = 0; i < N; ++i) {
dst[i] = 0.f;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
DEVICE_FUNCTION void add(float (&x)[N], const float (&y)[N]) {
#pragma unroll
for (int i = 0; i < N; ++i) {
x[i] += y[i];
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
DEVICE_FUNCTION void multiply(float (&x)[N], const float (&y)[N]) {
#pragma unroll
for (int i = 0; i < N; ++i) {
x[i] *= y[i];
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
DEVICE_FUNCTION void scale_(float (&x)[N], float scalar) {
#pragma unroll
for (int i = 0; i < N; ++i) {
x[i] *= scalar;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
DEVICE_FUNCTION void normalize(float (&x)[N], const float (&bias)[N],
const float (&scale)[N], const float (&m1)[N]) {
#pragma unroll
for (int i = 0; i < N; ++i) {
x[i] = bias[i] + scale[i] * (x[i] - m1[i]);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Storage>
DEVICE_FUNCTION Storage relu(Storage in) {
Storage zero = (Storage)0.f;
return (in < zero)? zero : in;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
DEVICE_FUNCTION void relu_activation(float (&x)[N]) {
#pragma unroll
for (int i = 0; i < N; ++i) {
x[i] = relu(x[i]);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int THREADS_PER_CTA >
DEVICE_FUNCTION void parallel_sums_16x2(float *smem, float (&x)[4], int nhw, void* params_my_data, void* params_pair_data, int off, const int magic, void* params_pair_data2, const unsigned int& sync_iters) {
// The size of a warp.
const int THREADS_PER_WARP = 32;
// The number of warps in a CTA.
const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP;
// The number of threads per pixel.
const int THREADS_PER_PIXEL = 16;
// The number of elements per ldg.
const int ELEMENTS_PER_LDG = 4;
// The number of reducing ops, each uses its own space : mean, var, dscale, dbias
const int REDUCE_OPS = 4;
// Maximum block.y supported - limited due to buffer allocation
const int MAX_BLOCK_Y = 256;
const int MAX_OFFSET = REDUCE_OPS*MAX_BLOCK_Y;
// The warp decomposition.
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int lane_id = threadIdx.x % THREADS_PER_WARP;
#ifdef BNDEBUGX
if (threadIdx.x==0)
printf("start parallel_sums_16x2 off=%d magic=%d sync_iters=%d thread%d block %d , %d\n", off, magic, sync_iters, threadIdx.x, blockIdx.x, blockIdx.y);
#endif
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL+lane_id);
}
// The warp leaders, write to SMEM.
if (lane_id < THREADS_PER_PIXEL) {
write_to_smem(smem, warp_id*THREADS_PER_PIXEL + lane_id, x);
}
// The data is in SMEM. Do the final reduction.
__syncthreads();
// The 1st warp does all the work.
// We do the final reduction each half-warp sequentially reduces the final values.
if (warp_id == 0) {
read_from_smem(x, smem, threadIdx.x);
#pragma unroll
for (int offset = 1;
offset < WARPS_PER_CTA/(THREADS_PER_WARP / THREADS_PER_PIXEL); ++offset) {
float y[ELEMENTS_PER_LDG];
// Read the mean and variance from the other pixel.
read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_WARP);
// Compute the updated sum.
add(x, y);
}
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL+lane_id);
}
// Make sure the data was read from SMEM.
__syncwarp();
// Store the final values.
if (threadIdx.x < THREADS_PER_PIXEL) {
//probably could do it earlier, before sync
for (int sync_iter=0; sync_iter<sync_iters; ++sync_iter)
{
// total size of flags per sync iter, to be skiped for data
const int flags_total = MAX_OFFSET*THREADS_PER_PIXEL;
// total size of data per sync iter
const int data_total = MAX_OFFSET*THREADS_PER_PIXEL*ELEMENTS_PER_LDG;
//skip the space consumed by previous sync iterations
const int xbuf_offset = sync_iter*(flags_total+data_total);
// flags are at the begining of the buffer, one per thread
const int flags_offset = xbuf_offset + off*THREADS_PER_PIXEL;
// data starts after flags, but have to skip previous
const int data_offset = xbuf_offset + flags_total + off*ELEMENTS_PER_LDG*THREADS_PER_PIXEL + ELEMENTS_PER_LDG*threadIdx.x;
//after sums for this GPU were computed, let CTA0 broadcast the sum to over GPU
if (blockIdx.x==0)
{
volatile float * write_data = &(((float*)params_pair_data)[data_offset]);
volatile int32_t * write_flag = &(((int32_t*)((params_pair_data)))[flags_offset]);
//write the data to memory region to be reflected to other GPU
asm volatile ("st.global.wt.v4.f32 [%0], {%1,%2,%3,%4};"
:: "l"((float4 *)write_data) , "f"(x[0]), "f"( x[1]), "f"(x[2]), "f"( x[3]));
__threadfence_system();
//write the magic value to indicate data readiness
write_flag[threadIdx.x] = magic; //or can sync and set only one flag
#ifdef BNDEBUG
printf("writing buddy flag, thread %d myvalue %d data offset %d flag offset %d\n", threadIdx.x, magic, 4*THREADS_PER_PIXEL+off*ELEMENTS_PER_LDG*THREADS_PER_PIXEL + ELEMENTS_PER_LDG*threadIdx.x, off*THREADS_PER_PIXEL);
#endif
}
//now each CTA (on each GPU) reads the data written by CTA 0 of the other GPU
volatile float * read_data_ = &(((float*)params_my_data)[data_offset]);
volatile int32_t * read_flag = &(((int32_t*)((params_my_data)))[flags_offset]);
//check if other side has written
#ifdef BNDEBUG
unsigned int safety=0;
while ((read_flag[threadIdx.x] % 1000000) != (magic % 1000000) )
{
++safety;
if (safety>99999) {
printf("stuck waiting for my buddy, thread %d myvalue %d data offset %d flag offset %d read value %d\n", threadIdx.x, magic, 4*THREADS_PER_PIXEL+off*ELEMENTS_PER_LDG*THREADS_PER_PIXEL + ELEMENTS_PER_LDG*threadIdx.x, off*THREADS_PER_PIXEL, read_flag[threadIdx.x]);
safety=0;
}
}
#else
while ((read_flag[threadIdx.x] ) != (magic ) ) ;
#endif
float other[4];
asm volatile ("ld.global.cv.v4.f32 {%0, %1, %2, %3}, [%4];"
: "=f"(other[0]), "=f"(other[1]), "=f"(other[2]), "=f"(other[3]) : "l"(read_data_));
add(x, other);
params_pair_data = params_pair_data2; //FIXME use an array
}
// finally, after syncing up and accounting for partial sums from other GPUs as required, write the result
write_to_smem(smem, threadIdx.x, x);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int THREADS_PER_CTA >
DEVICE_FUNCTION void parallel_sums_8x4(float *smem, float (&x)[4], int nhw) {
// The size of a warp.
const int THREADS_PER_WARP = 32;
// The number of warps in a CTA.
const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP;
// The number of threads per pixel.
const int THREADS_PER_PIXEL = 8;
// The number of elements per ldg.
const int ELEMENTS_PER_LDG = 4;
// The warp decomposition.
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int lane_id = threadIdx.x % THREADS_PER_WARP;
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL+lane_id);
x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL*2+lane_id);
}
// The warp leaders, write to SMEM.
if (lane_id < THREADS_PER_PIXEL) {
write_to_smem(smem, warp_id*THREADS_PER_PIXEL + lane_id, x);
}
// The data is in SMEM. Do the final reduction.
__syncthreads();
// The 1st warp does all the work.
// We do the final reduction each half-warp sequentially reduces the final values.
if (warp_id == 0) {
read_from_smem(x, smem, threadIdx.x);
#pragma unroll
for (int offset = 1;
offset < WARPS_PER_CTA/(THREADS_PER_WARP / THREADS_PER_PIXEL); ++offset) {
float y[ELEMENTS_PER_LDG];
// Read the mean and variance from the other pixel.
read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_WARP);
// Compute the updated sum.
add(x, y);
}
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL+lane_id);
x[i] += __shfl_sync(0xffffffffU, x[i], THREADS_PER_PIXEL*2+lane_id);
}
// Make sure the data was read from SMEM.
__syncwarp();
// Store the final values.
if (threadIdx.x < THREADS_PER_PIXEL) {
write_to_smem(smem, threadIdx.x, x);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int THREADS_PER_CTA, int THREADS_PER_PIXEL, int ELEMENTS_PER_LDG >
DEVICE_FUNCTION void parallel_sums(float *smem, float (&x)[ELEMENTS_PER_LDG], int nhw) {
// The size of a warp.
const int THREADS_PER_WARP = 32;
// The number of warps in a CTA.
const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP;
// The number of pixels computed by a single warp.
const int PIXELS_PER_WARP = THREADS_PER_WARP / THREADS_PER_PIXEL;
// The position in the warp.
const int nhw_in_warp = nhw % PIXELS_PER_WARP;
// The C in the warp.
const int c_in_warp = threadIdx.x % THREADS_PER_PIXEL;
// Store the values to shared memory.
write_to_smem(smem, threadIdx.x, x);
// Compute the parallel sums.
for (int offset = PIXELS_PER_WARP/2; offset > 0; offset /= 2) {
// NOP.
__syncwarp();
// Read the running sum from the other thread.
float y[ELEMENTS_PER_LDG];
if (nhw_in_warp < offset) {
read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_PIXEL);
}
// Compute the updated sum.
add(x, y);
// NOP.
__syncwarp();
// Update the sum in SMEM.
if (offset > 1 && nhw_in_warp < offset) {
write_to_smem(smem, threadIdx.x, x);
}
}
// The warps are done. Do the final reduction at the CTA level.
__syncthreads();
// The warp leaders, write to SMEM.
const int idx = (threadIdx.x/THREADS_PER_WARP)*THREADS_PER_PIXEL + c_in_warp;
if (nhw_in_warp == 0) {
write_to_smem(smem, idx, x);
}
// The data is in SMEM. Do the final reduction.
__syncthreads();
// Read the 1st element to prepare the work.
if (nhw < WARPS_PER_CTA/2) {
read_from_smem(x, smem, threadIdx.x);
}
// We have the running mean and running m2. Let's build the mean/var of the CTA.
for (int offset = WARPS_PER_CTA/2; offset > 0; offset /= 2) {
// NOP.
__syncwarp();
// Read the mean and variance from the other pixel.
float y[ELEMENTS_PER_LDG];
if (nhw < offset) {
read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_PIXEL);
}
// Compute the updated sum.
add(x, y);
// NOP.
__syncwarp();
// Store the mean/var for the different pixels.
if (nhw < offset) {
write_to_smem(smem, threadIdx.x, x);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int THREADS_PER_PIXEL, int ELEMENTS_PER_LDG >
struct ParallelSums {
template< int THREADS_PER_CTA >
DEVICE_FUNCTION void dispatch(float *smem, float (&x)[ELEMENTS_PER_LDG], int nhw) {
parallel_sums<THREADS_PER_CTA, THREADS_PER_PIXEL, ELEMENTS_PER_LDG>(smem, x, nhw);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct ParallelSums<16, 4> {
template< int THREADS_PER_CTA >
DEVICE_FUNCTION void dispatch(float *smem, float (&x)[4], int nhw) {
parallel_sums_16x2<THREADS_PER_CTA>(smem, x, nhw, 0, 0, 0, 0, 0, 0);
}
template< int THREADS_PER_CTA >
DEVICE_FUNCTION void dispatchX(float *smem, float (&x)[4], int nhw, void* params_my_data, void* params_pair_data, int off, const int magic, void* params_pair_data2, const unsigned int& sync_iters) {
parallel_sums_16x2<THREADS_PER_CTA>(smem, x, nhw, params_my_data, params_pair_data, off, magic, params_pair_data2, sync_iters);
}
};
template<>
struct ParallelSums<8, 4> {
template< int THREADS_PER_CTA >
DEVICE_FUNCTION void dispatch(float *smem, float (&x)[4], int nhw) {
#ifdef BNDEBUGX
assert(0);
#endif
parallel_sums_8x4<THREADS_PER_CTA>(smem, x, nhw);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline int div_up(int m, int n) {
return (m + n - 1) / n;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// It is expected that all threads in the CTA enter this function!
DEVICE_FUNCTION void inter_block_sync(int* gmem_retired_ctas, int expected_count, bool master) {
#ifdef BNDEBUGX
if (threadIdx.x==0)
printf("start inter_block_sync thread%d block %d , %d grid.X %d\n", threadIdx.x, blockIdx.x, blockIdx.y, gridDim.x);
#endif
// Register the CTA.
if (threadIdx.x == 0) {
// Issue the membar.
__threadfence();
// Notify that the CTA is done.
int val_to_add = 1;
if (master) {
val_to_add = -(expected_count - 1);
}
atomicAdd(gmem_retired_ctas, val_to_add);
}
// Are all CTAs done?
if (threadIdx.x == 0) {
int retired_ctas = -1;
do {
__threadfence();
asm volatile ("ld.global.cg.b32 %0, [%1];"
: "=r"(retired_ctas) : "l"(gmem_retired_ctas));
} while (retired_ctas != 0);
}
__syncthreads();
#ifdef BNDEBUGX
if (threadIdx.x==0)
printf("finish inter_block_sync thread%d block %d , %d\n", threadIdx.x, blockIdx.x, blockIdx.y);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
struct NhwcBatchNormFwdInferenceParams {
// The input/output tensors.
uint16_t *gmem_src, *gmem_dst, *gmem_src1;
// the final mean and variance as calculated during the training process
float *gmem_mean, *gmem_var;
// The bias/scale.
float *gmem_bias, *gmem_scale;
// The dimensions.
int nhw, c;
// epsilon
float var_eps;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// No DESIRED_OCCUPANCY launch bounds needed, as this is not launched cooperatively
template<
typename Storage,
int THREADS_PER_CTA,
int THREADS_PER_PIXEL,
int ELEMENTS_PER_LDG,
bool USE_RELU,
bool USE_ADD_RELU
>
__global__ __launch_bounds__(THREADS_PER_CTA)
void nhwc_batch_norm_fwd_inference(NhwcBatchNormFwdInferenceParams params) {
// The number of pixels loaded in a single LDG.
const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;
// The number of C elements per CTA.
const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG;
// The start position in the NHW dimension where the CTA starts.
const int cta_nhw_stride = gridDim.x * PIXELS_PER_LDG;
// Compute the NHW coordinate of the thread in the CTA.
const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL;
// thread's starting point in NHW
const int thread_nhw = thread_in_cta_nhw + blockIdx.x * PIXELS_PER_LDG;
// The position in the C dimension where the CTA starts.
const int cta_c = blockIdx.y * C_ELEMENTS_PER_CTA;
// Compute the C coordinate of the thread in the CTA.
const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL;
// Compute the C coordinate of the thread.
const int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG;
// Is the thread working on a valid C dimension?
const int is_valid_c = thread_c < params.c;
float mean[ELEMENTS_PER_LDG], var[ELEMENTS_PER_LDG];
float scale[ELEMENTS_PER_LDG], bias[ELEMENTS_PER_LDG];
zero_array(mean);
zero_array(var);
zero_array(scale);
zero_array(bias);
if (is_valid_c) {
read_from_gmem(var, &params.gmem_var[cta_c], thread_in_cta_c);
read_from_gmem(scale, &params.gmem_scale[cta_c], thread_in_cta_c);
read_from_gmem(mean, &params.gmem_mean[cta_c], thread_in_cta_c);
read_from_gmem(bias, &params.gmem_bias[cta_c], thread_in_cta_c);
}
// Update the scale with the stddev and eps.
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
scale[i] *= rsqrtf(var[i] + params.var_eps);
}
// The base pointers for reading/writing
uint16_t *const gmem_src = &params.gmem_src[thread_c];
uint16_t *const gmem_dst = &params.gmem_dst[thread_c];
const uint16_t *gmem_src1 = nullptr;
if (USE_ADD_RELU) {
gmem_src1 = &params.gmem_src1[thread_c];
}
// apply BN
for (int nhw = thread_nhw; nhw < params.nhw; nhw += cta_nhw_stride) {
float x_math[ELEMENTS_PER_LDG];
zero_array(x_math);
if (is_valid_c) {
ldg(x_math, &gmem_src[nhw*params.c]);
}
// Normalize and apply activation function
normalize(x_math, bias, scale, mean);
if (USE_ADD_RELU) {
float x1_math[ELEMENTS_PER_LDG];
ldg(x1_math, &gmem_src1[nhw*params.c]);
add(x_math, x1_math);
relu_activation(x_math);
} else if (USE_RELU) {
relu_activation(x_math);
}
if (is_valid_c) {
stg(&gmem_dst[nhw*params.c], x_math);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
struct NhwcBatchNormFwdParams {
// The input/output tensors.
uint16_t *gmem_src, *gmem_dst, *gmem_src1;
// The bias/scale.
float *gmem_bias, *gmem_scale;
// running mean/var (refer BN API from cudnn doc)
float *gmem_running_mean, *gmem_running_var;
// saved mean/var (refer BN API from cudnn doc)
float *gmem_saved_mean, *gmem_saved_var;
// ReLU bitmask
unsigned int *gmem_relu_bitmask;
// The dimensions.
int nhw, c;
// factor to scale sum of squared errors to get saved variance. Must be 1/nhw.
float svar_inv_count;
// factor to scale sum of squared errors to get running variance. Should be 1/nhw or 1/(nhw-1).
float rvar_inv_count;
// The buffer to do the reduction for mean, stddev and count.
float *gmem_sums;
// The buffer to count items in the different CTAs.
int *gmem_counts;
// The counters of retired CTAs.
int *gmem_retired_ctas;
// The epsilon to apply to the computation of the variance.
float var_eps;
// outer loop count
int outer_loops;
// exponential average factor
float exp_avg_factor;
// number of CTAs along .x dimension
int c_blks;
void* my_data;
void* pair_data;
void* pair_data2;
int magic;
int sync_iters;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
typename Storage,
int THREADS_PER_CTA,
int THREADS_PER_PIXEL,
int PIXELS_PER_THREAD_IN_REGISTERS,
int PIXELS_PER_THREAD_IN_SMEM,
int ELEMENTS_PER_LDG,
int USE_ONLINE_APPROACH,
int OUTER_LOOPS_,
bool USE_RELU,
bool USE_ADD_RELU,
int DESIRED_OCCUPANCY
>
__global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
void nhwc_batch_norm_fwd(NhwcBatchNormFwdParams params) {
// The number of pixels loaded in a single LDG.
const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;
// The number of pixels computed per CTA stored in registers.
const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG;
// The number of pixels computed per CTA stored in SMEM.
const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM*PIXELS_PER_LDG;
// The number of C elements per CTA.
const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG;
// Shared memory to do CTA-wide parallel sums.
__shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG];
// Compute the NHW coordinate of the thread in the CTA.
const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL;
// The adapter for the storage.
typedef PackedStorage<Storage, ELEMENTS_PER_LDG> PackedStorage_;
// The data type for packed storage in SMEM.
typedef typename PackedStorage_::Type PackedStorageType;
// The number of elements in the packed storage.
const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG;
// Registers to keep the data live for the persistent approach.
PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];
// Shared memory buffer to store the extra pixels.
extern __shared__ PackedStorageType smem_storage_packed[];
for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) {
// The position in the NHW dimension where the CTA starts.
int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS;
// The position in the NHW dimension where the CTA starts for the portion in SMEM.
int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM;
// The position in the C dimension where the CTA starts.
const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA;
// Compute the C coordinate of the thread in the CTA.
const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL;
// Compute the C coordinate of the thread.
int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG;
// Is the thread working on a valid C dimension?
const int is_valid_c = thread_c < params.c;
// Clamp thread_c so that we load from valid locations even if we don't use the value
if (!is_valid_c)
thread_c = params.c - 4;
// Single pass numerically stable algorithm, see:
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm
//
// n = 0, mean = 0.0, M2 = 0.0
//
// for x in data:
// n += 1
// delta = x - mean
// mean += delta/n
// delta2 = x - mean
// M2 += delta*delta2
//
// if n < 2:
// return float('nan')
// else:
// return M2 / (n - 1)
// Register to store the number of elements read so far.
float count = 0.f, mean[ELEMENTS_PER_LDG], m2[ELEMENTS_PER_LDG];
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
mean[i] = 0.f;
m2[i] = 0.f;
}
// The number of elements loaded by this CTA.
int cta_count = 0;
// The base pointer to load from.
const uint16_t *gmem_src = &params.gmem_src[thread_c];
// outer loops
int OUTER_LOOPS = OUTER_LOOPS_ == 1? 1 : params.outer_loops;
// Load the batch of elements. Compute the mean/var across those elements.
const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS*gridDim.x;
if (OUTER_LOOPS_ != 1) {
// We cannot load everything to store persistently, so let's makes sure registers and
// smem are fully utilized, offset is evenly divisible by 32
int offset = (pixels_per_iteration * OUTER_LOOPS +
PIXELS_PER_CTA_IN_SMEM * gridDim.x - params.nhw) & ~31;
cta_nhw_regs -= offset;
cta_nhw_smem -= offset;
}
#pragma unroll 1
for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) {
// The nhw position.
int nhw_regs = cta_nhw_regs + loop_i*pixels_per_iteration;
// Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!
cta_count += max(min(nhw_regs + PIXELS_PER_CTA_IN_REGISTERS, params.nhw) -
max(nhw_regs, 0), 0);
// Load the data and compute the local mean/sum and the variance.
if (USE_ONLINE_APPROACH) {
// Read the elements from memory.
float is_valid[PIXELS_PER_THREAD_IN_REGISTERS];
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG;
zero_array(x_storage[i]);
is_valid[i] = 0.f;
if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {
if (loop_i == OUTER_LOOPS - 1) {
ldg_stream(x_storage[i], &gmem_src[idx*params.c]);
} else {
ldg(x_storage[i], &gmem_src[idx*params.c]);
}
is_valid[i] = 1.f;
}
}
// Do the math.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
// Convert to float.
float x_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage[i]);
// Update the count.
count += is_valid[i];
// Invert the count.
float inv_count = is_valid[i] ? 1.f / count : 0.f;
// Update the mean and m2 using deltas.
#pragma unroll
for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {
float delta0 = x_math[j] - mean[j];
mean[j] += delta0 * inv_count;
float delta1 = x_math[j] - mean[j];
m2[j] += delta0 * delta1 * is_valid[i];
}
}
} else {
// Read the elements from memory.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG;
zero_array(x_storage[i]);
if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {
if (loop_i == OUTER_LOOPS - 1) {
ldg_stream(x_storage[i], &gmem_src[idx*params.c]);
} else {
ldg(x_storage[i], &gmem_src[idx*params.c]);
}
count += 1.f;
}
}
// Sum the elements in registers.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
// Convert to float.
float x_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage[i]);
// Update the mean and m2 using deltas.
#pragma unroll
for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {
mean[j] += x_math[j];
}
}
// Compute the mean.
float inv_count = 1.f / count;
#pragma unroll
for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {
mean[j] *= inv_count;
}
// Compute the variance.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
// Convert to float.
float x_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage[i]);
// Is it a valid pixel?
float is_valid = i < static_cast<int>(count) ? 1.f : 0.f;
// Update the mean and m2 using deltas.
#pragma unroll
for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {
m2[j] += (x_math[j] - mean[j]) * (x_math[j] - mean[j]) * is_valid;
}
}
}
}
// The elements to load and store in SMEM.
int smem_nhw = OUTER_LOOPS*pixels_per_iteration + cta_nhw_smem;
// Load elements from SMEM, update the CTA count.
int pixels_in_smem = min(smem_nhw + PIXELS_PER_CTA_IN_SMEM, params.nhw) - max(smem_nhw, 0);
if (pixels_in_smem > 0) {
cta_count += pixels_in_smem;
for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {
const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
float is_pixel_valid = (((unsigned int)idx <
(unsigned int)params.nhw) && is_valid_c) ? 1.f : 0.f;
PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG];
ldg_stream(x_storage_local, &gmem_src[(is_pixel_valid ? idx : 0)*params.c]);
// The offset to store in SMEM.
const int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
// Store in SMEM.
write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local);
// Update the count.
count += is_pixel_valid;
// Invert the count.
float inv_count = is_pixel_valid ? 1.f / count : 0.f;
float x_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage_local);
// Update the mean and m2 using deltas.
#pragma unroll
for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {
float delta0 = x_math[j] - mean[j];
mean[j] += delta0 * inv_count;
float delta1 = x_math[j] - mean[j];
m2[j] += delta0 * delta1 * is_pixel_valid;
}
}
}
// We scale the mean by the number of elements. It brings more stability.
float m1[ELEMENTS_PER_LDG];
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
m1[i] = mean[i] * count;
}
// Run the parallel sum accross the CTA to get the local sum.
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, m1, thread_in_cta_nhw);
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(m1, smem, thread_in_cta_c);
__syncthreads();
// Adjust the variance.
float inv_cta_count = 1.f / static_cast<float>(cta_count);
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
float mean_diff = m1[i]*inv_cta_count - mean[i];
m2[i] = m2[i] + mean_diff * mean_diff * count;
}
// Run the parallel sum accross the CTA to get the local adjusted variance.
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, m2, thread_in_cta_nhw);
// The workspace in global memory is distributed across the different CTA.
int gmem_sums_offset = c_blk_index*gridDim.x*C_ELEMENTS_PER_CTA*2;
// Write the data for the CTA to global memory.
float *gmem_sums = &params.gmem_sums[gmem_sums_offset];
if (threadIdx.x < THREADS_PER_PIXEL) {
const int idx = blockIdx.x*THREADS_PER_PIXEL + threadIdx.x;
write_to_gmem(&gmem_sums[ 0], idx, m1);
write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx, m2);
}
// The memory location to store the number of pixels per CTA.
int *gmem_counts = &params.gmem_counts[c_blk_index*gridDim.x];
if (threadIdx.x == 0) {
gmem_counts[blockIdx.x] = cta_count;
}
// Read the bias and scale.
float bias[ELEMENTS_PER_LDG], scale[ELEMENTS_PER_LDG];
if (is_valid_c) {
read_from_gmem(bias, &params.gmem_bias[cta_c], thread_in_cta_c);
read_from_gmem(scale, &params.gmem_scale[cta_c], thread_in_cta_c);
}
// The counters to count how many CTAs have retired at this point.
// A given cta uses the same counter every other time through the outer loop.
int *gmem_retired_ctas = &params.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)];
inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0);
// Reset the mean to compute the global mean.
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
m1[i] = 0.f;
}
// Build the global mean.
#pragma unroll 1
for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) {
float tmp[ELEMENTS_PER_LDG];
read_from_gmem(tmp, gmem_sums, idx);
add(m1, tmp);
}
if (params.sync_iters>0)
{
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, m1, thread_in_cta_nhw, params.my_data, params.pair_data, 4*c_blk_index+3, params.magic, params.pair_data2, params.sync_iters);
} else {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, m1, thread_in_cta_nhw);
}
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(m1, smem, thread_in_cta_c);
__syncthreads();
// Normalize the mean.
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
m1[i] = m1[i] * params.svar_inv_count;
}
// Reset the variance.
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
m2[i] = 0.f;
}
// for add+relu fusion
const uint16_t *gmem_src1 = nullptr;
if (USE_ADD_RELU) {
gmem_src1 = &params.gmem_src1[thread_c];
}
// Build the global variance.
#pragma unroll 1
for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) {
// Read the means computed by different CTAs (again). Reuse tmp if we have 1 iteration.
float tmp_mean[ELEMENTS_PER_LDG], tmp_var[ELEMENTS_PER_LDG];
read_from_gmem(tmp_mean, &gmem_sums[ 0], idx);
read_from_gmem(tmp_var, &gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx);
// Read the number of pixels visited by a given CTA.
cta_count = __ldg(&gmem_counts[idx / THREADS_PER_PIXEL]);
// Compute the diff to update the variance.
float mean_diff[ELEMENTS_PER_LDG], inv_cta_count = 1.f / static_cast<float>(cta_count);
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
mean_diff[i] = m1[i] - tmp_mean[i]*inv_cta_count;
}
// Update the variance.
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
m2[i] += tmp_var[i] + mean_diff[i]*mean_diff[i]*static_cast<float>(cta_count);
}
}
if (params.sync_iters>0)
{
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, m2, thread_in_cta_nhw, params.my_data, params.pair_data, 4*c_blk_index+2, params.magic, params.pair_data2, params.sync_iters);
} else {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, m2, thread_in_cta_nhw);
}
__syncthreads();
read_from_smem(m2, smem, thread_in_cta_c);
// Finalize the stddev.
// becasue saved var and running var may have different denominator, we don't do it here
// scale_(m2, inv_count);
// store the saved mean/var
float svarinv[ELEMENTS_PER_LDG];
bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0;
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
svarinv[i] = rsqrtf(m2[i] * params.svar_inv_count + params.var_eps);
}
if (is_valid_for_saving) {
write_to_gmem(params.gmem_saved_mean, thread_c/ELEMENTS_PER_LDG, m1);
write_to_gmem(params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG, svarinv);
}
// store the running mean/var
float rmean[ELEMENTS_PER_LDG], rvar[ELEMENTS_PER_LDG];
zero_array(rmean);
zero_array(rvar);
if (params.exp_avg_factor != 1.f && is_valid_for_saving) {
read_from_gmem(rmean, params.gmem_running_mean, thread_c/ELEMENTS_PER_LDG);
read_from_gmem(rvar, params.gmem_running_var, thread_c/ELEMENTS_PER_LDG);
}
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
rmean[i] = (1.f - params.exp_avg_factor) * rmean[i] + \
params.exp_avg_factor * m1[i];
rvar[i] = (1.f - params.exp_avg_factor) * rvar[i] + \
params.exp_avg_factor * (m2[i] * params.rvar_inv_count);
}
if (is_valid_for_saving) {
write_to_gmem(params.gmem_running_mean, thread_c/ELEMENTS_PER_LDG, rmean);
write_to_gmem(params.gmem_running_var, thread_c/ELEMENTS_PER_LDG, rvar);
}
// Update the scale with the stddev and eps.
multiply(scale, svarinv);
// The base pointer to write to.
uint16_t *const gmem_dst = &params.gmem_dst[thread_c];
unsigned int *const gmem_relu_bitmask = params.gmem_relu_bitmask +
((params.nhw + 31) & ~31) * 2 * c_blk_index;
// Store the elements in registers.
#pragma unroll 1
for (int loop_i = OUTER_LOOPS-1; loop_i >= 0; --loop_i) {
// The value for nhw.
int out_nhw = cta_nhw_regs + loop_i*pixels_per_iteration;
// Normalize the elements and write to memory.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
const bool is_valid_nhw =
static_cast<unsigned int>(idx) < static_cast<unsigned int>(params.nhw);
const bool is_valid = is_valid_nhw && is_valid_c;
// Convert to float.
float x_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage[i]);
// Normalize and apply activation function
normalize(x_math, bias, scale, m1);
if (USE_ADD_RELU) {
float x1_math[ELEMENTS_PER_LDG];
ldg_stream(x1_math, &gmem_src1[(is_valid ? idx : 0)*params.c]);
add(x_math, x1_math);
unsigned int relu_mask;
int lane_id = threadIdx.x & 31;
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
bool rectified = x_math[i] < 0.0F;
unsigned int local_relu_mask = __ballot_sync(0xFFFFFFFFU, rectified);
if (lane_id == i) {
// Thread 0 remembers the relu_mask from the first time through this
// loop, Thread 1 the next, Thread 2 the next, and Thread 3 the last.
relu_mask = local_relu_mask;
}
if (rectified) {
x_math[i] = 0.0F;
}
}
if (is_valid_nhw && (lane_id < ELEMENTS_PER_LDG)) {
gmem_relu_bitmask[idx * 2 + lane_id] = relu_mask;
}
} else if (USE_RELU) {
relu_activation(x_math);
}
// Write back.
if (is_valid) {
stg_stream(&gmem_dst[idx*params.c], x_math);
}
}
// The next value of nhw.
out_nhw -= pixels_per_iteration;
// Read the next elements from memory.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {
ldg_stream(x_storage[i], &gmem_src[idx*params.c]);
}
}
}
// Normalize the elements from SMEM and write them out.
if (pixels_in_smem > 0) {
#pragma unroll 2
for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {
const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
const bool is_valid_nhw =
static_cast<unsigned int>(idx) < static_cast<unsigned int>(params.nhw);
const bool is_valid = is_valid_nhw && is_valid_c;
// Read from SMEM.
const int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG];
read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x);
float x_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage_local);
// Normalize and apply activation function
normalize(x_math, bias, scale, m1);
if (USE_ADD_RELU) {
float x1_math[ELEMENTS_PER_LDG];
ldg_stream(x1_math, &gmem_src1[(is_valid ? idx : 0)*params.c]);
add(x_math, x1_math);
unsigned int relu_mask;
int lane_id = threadIdx.x & 31;
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
bool rectified = x_math[i] < 0.0F;
unsigned int local_relu_mask = __ballot_sync(0xFFFFFFFFU, rectified);
if (lane_id == i) {
relu_mask = local_relu_mask;
}
if (rectified) {
x_math[i] = 0.0F;
}
}
if (is_valid_nhw && (lane_id < ELEMENTS_PER_LDG)) {
gmem_relu_bitmask[idx * 2 + lane_id] = relu_mask;
}
} else if (USE_RELU) {
relu_activation(x_math);
}
// Write back.
if (is_valid) {
stg_stream(&gmem_dst[idx*params.c], x_math);
}
}
}
// We're about to start on the next c-blk. Needed?
__syncthreads();
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
struct NhwcBatchNormBwdParams {
// The input/output tensors.
uint16_t *gmem_src, *gmem_dy, *gmem_dst, *gmem_dst1;
// dscale/dbias
float *gmem_dscale, *gmem_dbias;
// The scale and bias.
float *gmem_scale, *gmem_bias;
// The mean/inv-var saved from fwd pass
float *gmem_saved_mean, *gmem_saved_var;
// ReLU bitmask
unsigned int *gmem_relu_bitmask;
// The dimensions.
int nhw, c;
// factor to scale sum of squared errors to get saved variance. Must be 1/nhw.
float svar_inv_count;
// The buffer to do the reduction for dscale and dbias
float *gmem_sums;
// The counters of retired CTAs.
int *gmem_retired_ctas;
// outer loop count
int outer_loops;
// number of CTAs along .x dimension
int c_blks;
void* my_data;
void* pair_data;
void* pair_data2;
int magic;
int sync_iters;
float wgrad_coeff;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
DEVICE_FUNCTION void relu_bwd(float (&dy)[N], const float (&x)[N],
const float (&mean_var_scale_bias)[N],
const float (&var_scale)[N], bool valid_data) {
#pragma unroll
for (int j = 0; j < N; ++j) {
float y = (x[j] * var_scale[j]) + mean_var_scale_bias[j];
if ((y <= 0.f) && valid_data) {
dy[j] = 0.f;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
DEVICE_FUNCTION void relu_bwd(float (&dy)[N], const float (&y)[N], bool valid_data) {
#pragma unroll
for (int j = 0; j < N; ++j) {
if ((y[j] <= 0.f) && valid_data) {
dy[j] = 0.f;
}
}
}
template <int N>
DEVICE_FUNCTION void relu_bwd(float (&dy)[N], const bool (&rectified)[N], bool valid_data) {
#pragma unroll
for (int j = 0; j < N; ++j) {
if (rectified[j] && valid_data) {
dy[j] = 0.f;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
DEVICE_FUNCTION void relu_bwd_for_dx(float (&dy)[N],
const float (&x)[N],
const float (&mean_var_scale_bias)[N],
const float (&var_scale)[N]) {
#pragma unroll
for (int j = 0; j < N; ++j) {
float y = (x[j] * var_scale[j]) + mean_var_scale_bias[j];
if (y <= 0.f) {
dy[j] = 0.f;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
DEVICE_FUNCTION void relu_bwd_for_dx(float (&dy)[N], const float (&y)[N]) {
#pragma unroll
for (int j = 0; j < N; ++j) {
if (y[j] <= 0.f) {
dy[j] = 0.f;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
DEVICE_FUNCTION void bwd_update(float (&dscale)[N], float (&dbias)[N],
const float (&dy)[N], const float (&x)[N],
const float (&mean)[N], float inv_count) {
#pragma unroll
for (int j = 0; j < N; ++j) {
float delta0 = dy[j] - dbias[j];
dbias[j] += delta0 * inv_count;
delta0 = (dy[j] * (x[j] - mean[j])) - dscale[j];
dscale[j] += delta0 * inv_count;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N>
DEVICE_FUNCTION void bwd_dx(float (&dx)[N], const float (&dy)[N],
const float (&var)[N], const float (&x)[N], const float (&mean)[N],
const float (&dscale)[N], const float (&dbias)[N], float inv_count) {
#pragma unroll
for (int j = 0; j < N; ++j) {
float tmp1 = dy[j] - (dbias[j]* inv_count);
float tmp2 = dscale[j] * inv_count;
float tmp3 = x[j] - mean[j];
dx[j] = var[j] * (tmp1 - (tmp2 * tmp3));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
typename Storage,
int THREADS_PER_CTA,
int THREADS_PER_PIXEL,
int PIXELS_PER_THREAD_IN_REGISTERS,
int PIXELS_PER_THREAD_IN_SMEM,
int ELEMENTS_PER_LDG,
int USE_ONLINE_APPROACH,
int OUTER_LOOPS_,
int DESIRED_OCCUPANCY
>
__global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
void nhwc_batch_norm_bwd(NhwcBatchNormBwdParams params) {
// The number of pixels loaded in a single LDG.
const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;
// The number of pixels computed per CTA stored in registers.
const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG;
// The number of pixels computed per CTA stored in SMEM.
const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM*PIXELS_PER_LDG;
// The number of C elements per CTA.
const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG;
// Shared memory to do CTA-wide parallel sums.
__shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG];
// The adapter for the storage.
typedef PackedStorage<Storage, ELEMENTS_PER_LDG> PackedStorage_;
// The data type for packed storage in SMEM.
typedef typename PackedStorage_::Type PackedStorageType;
// The number of elements in the packed storage.
const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG;
// Registers to keep the data live for the persistent approach.
PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];
PackedStorageType dy_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];
// Shared memory buffer to store the extra pixels.
extern __shared__ PackedStorageType smem_storage_packed[];
#ifdef BNDEBUGX
if (threadIdx.x==0)
printf("starting nhwc_batch_norm_bwd\n");
#endif
for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) {
// The position in the NHW dimension where the CTA starts.
int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS;
// The position in the NHW dimension where the CTA starts for the portion in SMEM.
int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM;
// Compute the NHW coordinate of the thread in the CTA.
const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL;
// The position in the C dimension where the CTA starts.
const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA;
// Compute the C coordinate of the thread in the CTA.
const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL;
// Compute the C coordinate of the thread.
const int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG;
// Is the thread working on a valid C dimension?
const int is_valid_c = thread_c < params.c;
// Registers to store the mean used for entire duration
float mean[ELEMENTS_PER_LDG];
zero_array(mean);
if (is_valid_c) {
read_from_gmem(mean, params.gmem_saved_mean, thread_c/ELEMENTS_PER_LDG);
}
// accumulation related registers
float count = 0.f, dscale[ELEMENTS_PER_LDG], dbias[ELEMENTS_PER_LDG];
zero_array(dscale);
zero_array(dbias);
// The number of elements loaded by this CTA.
int cta_count = 0;
// The base pointers to load from.
const uint16_t *gmem_src = &params.gmem_src[thread_c];
const uint16_t *gmem_dy = &params.gmem_dy[thread_c];
// outer loops
int OUTER_LOOPS = OUTER_LOOPS_ == 1? 1 : params.outer_loops;
// Load the batch of elements. Compute sum across them
const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS*gridDim.x;
if (OUTER_LOOPS_ != 1) {
// We cannot load everything to store persistently, so let's makes sure registers and
// smem are fully utilized
int offset = params.nhw - pixels_per_iteration * OUTER_LOOPS -
PIXELS_PER_CTA_IN_SMEM * gridDim.x;
cta_nhw_regs += offset;
cta_nhw_smem += offset;
}
#pragma unroll 1
for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) {
// The nhw position.
int nhw_regs = cta_nhw_regs + loop_i*pixels_per_iteration;
// Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!
cta_count += max(0, min(PIXELS_PER_CTA_IN_REGISTERS, params.nhw-nhw_regs));
// Read the elements from memory.
float is_valid[PIXELS_PER_THREAD_IN_REGISTERS];
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG;
zero_array(x_storage[i]);
zero_array(dy_storage[i]);
is_valid[i] = 0.f;
if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {
if (loop_i == OUTER_LOOPS - 1) {
ldg_stream(x_storage[i], &gmem_src[idx*params.c]);
ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]);
} else {
ldg(x_storage[i], &gmem_src[idx*params.c]);
ldg(dy_storage[i], &gmem_dy[idx*params.c]);
}
is_valid[i] = 1.f;
}
}
// Do the math.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
// Convert to float and update
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage[i]);
to_float(dy_math, dy_storage[i]);
// Update the count.
count += is_valid[i];
// Invert the count.
float inv_count = is_valid[i] ? 1.f / count : 0.f;
bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);
}
}
// The elements to load and store in SMEM.
int smem_nhw = OUTER_LOOPS*pixels_per_iteration + cta_nhw_smem;
// Load elements from SMEM, update the CTA count.
int pixels_in_smem = min(PIXELS_PER_CTA_IN_SMEM, params.nhw-smem_nhw);
if (pixels_in_smem > 0) {
cta_count += pixels_in_smem;
for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {
const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
bool is_pixel_valid = (((unsigned int)idx <
(unsigned int)params.nhw) && is_valid_c);
PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG],
dy_storage_local[PACKED_ELEMENTS_PER_LDG];
zero_array(x_storage_local);
zero_array(dy_storage_local);
if (is_pixel_valid) {
ldg_stream(x_storage_local, &gmem_src[idx*params.c]);
ldg_stream(dy_storage_local, &gmem_dy[idx*params.c]);
}
// The offset to store in SMEM.
int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
// Store in SMEM.
write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local);
offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
write_to_smem(&smem_storage_packed[offset], threadIdx.x, dy_storage_local);
// Update the count.
count += is_pixel_valid;
// Invert the count.
float inv_count = is_pixel_valid ? 1.f / count : 0.f;
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage_local);
to_float(dy_math, dy_storage_local);
bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);
}
}
// We scale the mean by the number of elements. It brings more stability.
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
dbias[i] *= count;
dscale[i] *= count;
}
// dscale parallel sum
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, dscale, thread_in_cta_nhw);
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dscale, smem, thread_in_cta_c);
__syncthreads();
// dbias parallel sum
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, dbias, thread_in_cta_nhw);
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dbias, smem, thread_in_cta_c);
__syncthreads();
// The workspace in global memory is distributed across the different CTA.
int gmem_sums_offset = c_blk_index*gridDim.x*C_ELEMENTS_PER_CTA*2;
// Write the data for the CTA to global memory.
float *gmem_sums = &params.gmem_sums[gmem_sums_offset];
if (threadIdx.x < THREADS_PER_PIXEL) {
const int idx = blockIdx.x*THREADS_PER_PIXEL + threadIdx.x;
write_to_gmem(&gmem_sums[ 0], idx, dscale);
write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx, dbias);
}
// The counters to count how many CTAs have retired at this point.
// A given cta uses the same counter every other time through the outer loop.
int *gmem_retired_ctas = &params.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)];
inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0);
// Reset the accumulators for global summation
zero_array(dscale);
zero_array(dbias);
// Build the global accumulation
#pragma unroll 1
for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) {
float tmp1[ELEMENTS_PER_LDG], tmp2[ELEMENTS_PER_LDG];
read_from_gmem(tmp1, gmem_sums, idx);
read_from_gmem(tmp2, gmem_sums+C_ELEMENTS_PER_CTA*gridDim.x, idx);
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
dscale[i] += tmp1[i];
dbias[i] += tmp2[i];
}
}
// dscale parallel sum
if (params.sync_iters>0) {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_data, 4*c_blk_index+1, params.magic, params.pair_data2, params.sync_iters);
} else {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, dscale, thread_in_cta_nhw);
}
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dscale, smem, thread_in_cta_c);
__syncthreads();
// dbias parallel sum
if (params.sync_iters>0) {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_data, 4*c_blk_index+0, params.magic, params.pair_data2, params.sync_iters);
} else {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, dbias, thread_in_cta_nhw);
}
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dbias, smem, thread_in_cta_c);
// inv-var
float var[ELEMENTS_PER_LDG];
zero_array(var);
if (is_valid_c) {
read_from_gmem(var, params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG);
}
// Normalize the dscale.
multiply(dscale, var);
// store dscale/dbias
bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0;
if (is_valid_for_saving) {
if (params.sync_iters>0)
{
scaled_write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale, params.wgrad_coeff);
scaled_write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias, params.wgrad_coeff);
} else {
write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale);
write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias);
}
}
// scale
float scale[ELEMENTS_PER_LDG];
zero_array(scale);
if (is_valid_c) {
read_from_gmem(scale, params.gmem_scale, thread_c/ELEMENTS_PER_LDG);
}
// Further normalize the dscale to be used in dx calculation
multiply(dscale, var);
// scale the inv-var as well, afterwards
multiply(var, scale);
// inverse count
float inv_count = params.svar_inv_count;
// The base pointer to write to.
uint16_t *const gmem_dst = &params.gmem_dst[thread_c];
// Store the elements in registers.
#pragma unroll 1
for (int loop_i = OUTER_LOOPS-1; loop_i >= 0; --loop_i) {
// The value for nhw.
int out_nhw = cta_nhw_regs + loop_i*pixels_per_iteration;
// Normalize the elements and write to memory.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
// Convert to float.
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage[i]);
to_float(dy_math, dy_storage[i]);
float dx[ELEMENTS_PER_LDG];
bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);
// Write back.
const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {
stg_stream(&gmem_dst[idx*params.c], dx);
}
}
// The next value of nhw.
out_nhw -= pixels_per_iteration;
// Read the next elements from memory.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {
ldg_stream(x_storage[i], &gmem_src[idx*params.c]);
ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]);
}
}
}
// Normalize the elements from SMEM and write them out.
if (pixels_in_smem > 0) {
for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {
const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c;
if (is_valid) {
// Read from SMEM.
int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG],
dy_storage_local[PACKED_ELEMENTS_PER_LDG];
read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x);
offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
read_from_smem(dy_storage_local, &smem_storage_packed[offset], threadIdx.x);
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage_local);
to_float(dy_math, dy_storage_local);
float dx[ELEMENTS_PER_LDG];
bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);
// Write back.
stg_stream(&gmem_dst[idx*params.c], dx);
}
}
}
// We're about to start on the next c-blk. Needed?
__syncthreads();
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
typename Storage,
int THREADS_PER_CTA,
int THREADS_PER_PIXEL,
int PIXELS_PER_THREAD_IN_REGISTERS,
int PIXELS_PER_THREAD_IN_SMEM,
int ELEMENTS_PER_LDG,
int USE_ONLINE_APPROACH,
int OUTER_LOOPS_,
int DESIRED_OCCUPANCY
>
__global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
void nhwc_batch_norm_bwd_relu(NhwcBatchNormBwdParams params) {
// The number of pixels loaded in a single LDG.
const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;
// The number of pixels computed per CTA stored in registers.
const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG;
// The number of pixels computed per CTA stored in SMEM.
const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM*PIXELS_PER_LDG;
// The number of C elements per CTA.
const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG;
// Shared memory to do CTA-wide parallel sums.
__shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG];
// The adapter for the storage.
typedef PackedStorage<Storage, ELEMENTS_PER_LDG> PackedStorage_;
// The data type for packed storage in SMEM.
typedef typename PackedStorage_::Type PackedStorageType;
// The number of elements in the packed storage.
const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG;
// Registers to keep the data live for the persistent approach.
PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];
PackedStorageType dy_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];
// Shared memory buffer to store the extra pixels.
extern __shared__ PackedStorageType smem_storage_packed[];
#ifdef BNDEBUGX
if (threadIdx.x==0)
printf("starting nhwc_batch_norm_bwd_relu\n");
#endif
for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) {
// The position in the NHW dimension where the CTA starts.
int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS;
// The position in the NHW dimension where the CTA starts for the portion in SMEM.
int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM;
// Compute the NHW coordinate of the thread in the CTA.
const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL;
// The position in the C dimension where the CTA starts.
const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA;
// Compute the C coordinate of the thread in the CTA.
const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL;
// Compute the C coordinate of the thread.
const int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG;
// Is the thread working on a valid C dimension?
const int is_valid_c = thread_c < params.c;
// Registers to store the mean/var/scale/bias used for the entire duration
// Register usage optimizations:
// 1. Can combine bias - (mean * var * scale) into a single register
// 2. Can combine var * scale into a single register
float varscale[ELEMENTS_PER_LDG];
zero_array(varscale);
if (is_valid_c) {
read_from_gmem(varscale, params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG);
}
float tmp[ELEMENTS_PER_LDG];
zero_array(tmp);
if (is_valid_c) {
read_from_gmem(tmp, params.gmem_scale, thread_c/ELEMENTS_PER_LDG);
}
multiply(varscale, tmp);
float mean[ELEMENTS_PER_LDG];
zero_array(mean);
if (is_valid_c) {
read_from_gmem(mean, params.gmem_saved_mean, thread_c/ELEMENTS_PER_LDG);
}
zero_array(tmp);
if (is_valid_c) {
read_from_gmem(tmp, params.gmem_bias, thread_c/ELEMENTS_PER_LDG);
}
float mean_var_scale_bias[ELEMENTS_PER_LDG];
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
mean_var_scale_bias[i] = tmp[i] - (mean[i] * varscale[i]);
}
// accumulation related registers
float count = 0.f, dscale[ELEMENTS_PER_LDG], dbias[ELEMENTS_PER_LDG];
zero_array(dscale);
zero_array(dbias);
// The number of elements loaded by this CTA.
int cta_count = 0;
// The base pointers to load from.
const uint16_t *gmem_src = &params.gmem_src[thread_c];
const uint16_t *gmem_dy = &params.gmem_dy[thread_c];
// outer loops
int OUTER_LOOPS = OUTER_LOOPS_ == 1? 1 : params.outer_loops;
// Load the batch of elements. Compute sum across them
const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS*gridDim.x;
if (OUTER_LOOPS_ != 1) {
// We cannot load everything to store persistently, so let's makes sure registers and
// smem are fully utilized
int offset = params.nhw - pixels_per_iteration * OUTER_LOOPS -
PIXELS_PER_CTA_IN_SMEM * gridDim.x;
cta_nhw_regs += offset;
cta_nhw_smem += offset;
}
#pragma unroll 1
for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) {
// The nhw position.
int nhw_regs = cta_nhw_regs + loop_i*pixels_per_iteration;
// Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!
cta_count += max(0, min(PIXELS_PER_CTA_IN_REGISTERS, params.nhw-nhw_regs));
// Read the elements from memory.
float is_valid[PIXELS_PER_THREAD_IN_REGISTERS];
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG;
zero_array(x_storage[i]);
zero_array(dy_storage[i]);
is_valid[i] = 0.f;
if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {
if (loop_i == OUTER_LOOPS - 1) {
ldg_stream(x_storage[i], &gmem_src[idx*params.c]);
ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]);
} else {
ldg(x_storage[i], &gmem_src[idx*params.c]);
ldg(dy_storage[i], &gmem_dy[idx*params.c]);
}
is_valid[i] = 1.f;
}
}
// Do the math.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
// Convert to float and update
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage[i]);
to_float(dy_math, dy_storage[i]);
// Update the count.
count += is_valid[i];
// Invert the count.
float inv_count = is_valid[i] ? 1.f / count : 0.f;
relu_bwd(dy_math, x_math, mean_var_scale_bias, varscale, is_valid[i]);
bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);
}
}
// The elements to load and store in SMEM.
int smem_nhw = OUTER_LOOPS*pixels_per_iteration + cta_nhw_smem;
// Load elements from SMEM, update the CTA count.
int pixels_in_smem = min(PIXELS_PER_CTA_IN_SMEM, params.nhw-smem_nhw);
if (pixels_in_smem > 0) {
cta_count += pixels_in_smem;
for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {
const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
bool is_pixel_valid = (((unsigned int)idx <
(unsigned int)params.nhw) && is_valid_c);
PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG],
dy_storage_local[PACKED_ELEMENTS_PER_LDG];
zero_array(x_storage_local);
zero_array(dy_storage_local);
if (is_pixel_valid) {
ldg_stream(x_storage_local, &gmem_src[idx*params.c]);
ldg_stream(dy_storage_local, &gmem_dy[idx*params.c]);
}
// The offset to store in SMEM.
int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
// Store in SMEM.
write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local);
offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
write_to_smem(&smem_storage_packed[offset], threadIdx.x, dy_storage_local);
// Update the count.
count += is_pixel_valid;
// Invert the count.
float inv_count = is_pixel_valid ? 1.f / count : 0.f;
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage_local);
to_float(dy_math, dy_storage_local);
relu_bwd(dy_math, x_math, mean_var_scale_bias, varscale, is_pixel_valid);
bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);
}
}
// We scale the mean by the number of elements. It brings more stability.
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
dbias[i] *= count;
dscale[i] *= count;
}
// dscale parallel sum
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, dscale, thread_in_cta_nhw);
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dscale, smem, thread_in_cta_c);
__syncthreads();
// dbias parallel sum
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, dbias, thread_in_cta_nhw);
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dbias, smem, thread_in_cta_c);
__syncthreads();
// The workspace in global memory is distributed across the different CTA.
int gmem_sums_offset = c_blk_index*gridDim.x*C_ELEMENTS_PER_CTA*2;
// Write the data for the CTA to global memory.
float *gmem_sums = &params.gmem_sums[gmem_sums_offset];
if (threadIdx.x < THREADS_PER_PIXEL) {
const int idx = blockIdx.x*THREADS_PER_PIXEL + threadIdx.x;
write_to_gmem(&gmem_sums[ 0], idx, dscale);
write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx, dbias);
}
// The counters to count how many CTAs have retired at this point.
// A given cta uses the same counter every other time through the outer loop.
int *gmem_retired_ctas = &params.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)];
inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0);
// Reset the accumulators for global summation
zero_array(dscale);
zero_array(dbias);
// Build the global accumulation
#pragma unroll 1
for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) {
float tmp1[ELEMENTS_PER_LDG], tmp2[ELEMENTS_PER_LDG];
read_from_gmem(tmp1, gmem_sums, idx);
read_from_gmem(tmp2, gmem_sums+C_ELEMENTS_PER_CTA*gridDim.x, idx);
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
dscale[i] += tmp1[i];
dbias[i] += tmp2[i];
}
}
// dscale parallel sum
if (params.sync_iters>0) {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_data, 4*c_blk_index+1, params.magic, params.pair_data2, params.sync_iters);
} else {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, dscale, thread_in_cta_nhw);
}
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dscale, smem, thread_in_cta_c);
__syncthreads();
// dbias parallel sum
if (params.sync_iters>0) {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_data, 4*c_blk_index+0, params.magic, params.pair_data2, params.sync_iters);
} else {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, dbias, thread_in_cta_nhw);
}
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dbias, smem, thread_in_cta_c);
// Normalize the dscale.
float var[ELEMENTS_PER_LDG];
zero_array(var);
if (is_valid_c) {
read_from_gmem(var, params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG);
}
multiply(dscale, var);
// store dscale/dbias
bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0;
if (is_valid_for_saving) {
if (params.sync_iters>0)
{
scaled_write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale, params.wgrad_coeff);
scaled_write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias, params.wgrad_coeff);
} else {
write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale);
write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias);
}
}
// Further normalize the dscale to be used in dx calculation
float scale[ELEMENTS_PER_LDG];
zero_array(scale);
if (is_valid_c) {
read_from_gmem(scale, params.gmem_scale, thread_c/ELEMENTS_PER_LDG);
}
multiply(dscale, var);
// scale the inv-var as well, afterwards
multiply(var, scale);
// inverse count
float inv_count = params.svar_inv_count;
// The base pointer to write to.
uint16_t *const gmem_dst = &params.gmem_dst[thread_c];
// Store the elements in registers.
#pragma unroll 1
for (int loop_i = OUTER_LOOPS-1; loop_i >= 0; --loop_i) {
// The value for nhw.
int out_nhw = cta_nhw_regs + loop_i*pixels_per_iteration;
// Normalize the elements and write to memory.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
// Convert to float.
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage[i]);
to_float(dy_math, dy_storage[i]);
relu_bwd_for_dx(dy_math, x_math, mean_var_scale_bias, var);
float dx[ELEMENTS_PER_LDG];
bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);
// Write back.
const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {
stg_stream(&gmem_dst[idx*params.c], dx);
}
}
// The next value of nhw.
out_nhw -= pixels_per_iteration;
// Read the next elements from memory.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {
ldg_stream(x_storage[i], &gmem_src[idx*params.c]);
ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]);
}
}
}
// Normalize the elements from SMEM and write them out.
if (pixels_in_smem > 0) {
for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {
const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c;
if (is_valid) {
// Read from SMEM.
int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG],
dy_storage_local[PACKED_ELEMENTS_PER_LDG];
read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x);
offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
read_from_smem(dy_storage_local, &smem_storage_packed[offset], threadIdx.x);
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage_local);
to_float(dy_math, dy_storage_local);
relu_bwd_for_dx(dy_math, x_math, mean_var_scale_bias, var);
float dx[ELEMENTS_PER_LDG];
bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);
// Write back.
stg_stream(&gmem_dst[idx*params.c], dx);
}
}
}
// We're about to start on the next c-blk. Needed?
__syncthreads();
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
typename Storage,
int THREADS_PER_CTA,
int THREADS_PER_PIXEL,
int PIXELS_PER_THREAD_IN_REGISTERS,
int PIXELS_PER_THREAD_IN_SMEM,
int ELEMENTS_PER_LDG,
int USE_ONLINE_APPROACH,
int OUTER_LOOPS_,
int DESIRED_OCCUPANCY
>
__global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
void nhwc_batch_norm_bwd_add_relu(NhwcBatchNormBwdParams params) {
// The number of pixels loaded in a single LDG.
const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL;
// The number of pixels computed per CTA stored in registers.
const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG;
// The number of pixels computed per CTA stored in SMEM.
const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM*PIXELS_PER_LDG;
// The number of C elements per CTA.
const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG;
// Shared memory to do CTA-wide parallel sums.
__shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/32)*ELEMENTS_PER_LDG];
// The adapter for the storage.
typedef PackedStorage<Storage, ELEMENTS_PER_LDG> PackedStorage_;
// The data type for packed storage in SMEM.
typedef typename PackedStorage_::Type PackedStorageType;
// The number of elements in the packed storage.
const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG;
// Registers to keep the data live for the persistent approach.
PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];
PackedStorageType dy_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG];
// Shared memory buffer to store the extra pixels.
extern __shared__ PackedStorageType smem_storage_packed[];
#ifdef BNDEBUGX
if (threadIdx.x==0)
printf("starting nhwc_batch_norm_bwd_add_relu\n");
#endif
for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) {
// The position in the NHW dimension where the CTA starts.
int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS;
// The position in the NHW dimension where the CTA starts for the portion in SMEM.
int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM;
// Compute the NHW coordinate of the thread in the CTA.
const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL;
// The position in the C dimension where the CTA starts.
const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA;
// Compute the C coordinate of the thread in the CTA.
const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL;
// Compute the C coordinate of the thread.
const int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG;
// Is the thread working on a valid C dimension?
const int is_valid_c = thread_c < params.c;
float mean[ELEMENTS_PER_LDG];
zero_array(mean);
if (is_valid_c) {
read_from_gmem(mean, params.gmem_saved_mean, thread_c/ELEMENTS_PER_LDG);
}
// accumulation related registers
float count = 0.f, dscale[ELEMENTS_PER_LDG], dbias[ELEMENTS_PER_LDG];
zero_array(dscale);
zero_array(dbias);
// The number of elements loaded by this CTA.
int cta_count = 0;
// The base pointers to load from.
const uint16_t *gmem_src = &params.gmem_src[thread_c];
const uint16_t *gmem_dy = &params.gmem_dy[thread_c];
uint16_t *gmem_dst1 = &params.gmem_dst1[thread_c];
// outer loops
int OUTER_LOOPS = OUTER_LOOPS_ == 1? 1 : params.outer_loops;
// Load the batch of elements. Compute sum across them
const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS*gridDim.x;
if (OUTER_LOOPS_ != 1) {
// We cannot load everything to store persistently, so let's makes sure registers and
// smem are fully utilized, offset is evenly divisible by 32
int offset = (pixels_per_iteration * OUTER_LOOPS + PIXELS_PER_CTA_IN_SMEM * gridDim.x -
params.nhw) & ~31;
cta_nhw_regs -= offset;
cta_nhw_smem -= offset;
}
const unsigned int *const gmem_relu_bitmask = params.gmem_relu_bitmask +
((params.nhw + 31) & ~31) * 2 * c_blk_index;
#pragma unroll 1
for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) {
// The nhw position.
int nhw_regs = cta_nhw_regs + loop_i*pixels_per_iteration;
// Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!!
cta_count += max(0, min(PIXELS_PER_CTA_IN_REGISTERS, params.nhw-nhw_regs));
int lane_id = threadIdx.x & 31;
// Read the elements from memory.
float is_valid[PIXELS_PER_THREAD_IN_REGISTERS];
unsigned int relu_mask[PIXELS_PER_THREAD_IN_REGISTERS];
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG;
zero_array(x_storage[i]);
zero_array(dy_storage[i]);
is_valid[i] = 0.f;
const bool is_valid_nhw =
static_cast<unsigned int>(idx) < static_cast<unsigned int>(params.nhw);
if (is_valid_nhw) {
if (is_valid_c) {
if (loop_i == OUTER_LOOPS - 1) {
ldg_stream(x_storage[i], &gmem_src[idx*params.c]);
ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]);
} else {
ldg(x_storage[i], &gmem_src[idx*params.c]);
ldg(dy_storage[i], &gmem_dy[idx*params.c]);
}
is_valid[i] = 1.f;
}
if (lane_id < ELEMENTS_PER_LDG) {
relu_mask[i] = gmem_relu_bitmask[idx * 2 + lane_id];
}
}
}
// Do the math.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG;
// Convert to float and update
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
bool rectified[ELEMENTS_PER_LDG];
#pragma unroll
for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {
rectified[j] = ((__shfl_sync(0xFFFFFFFFU, relu_mask[i], j) &
(1U << lane_id)) != 0);
}
to_float(x_math, x_storage[i]);
to_float(dy_math, dy_storage[i]);
// Update the count.
count += is_valid[i];
// Invert the count.
float inv_count = is_valid[i] ? 1.f / count : 0.f;
relu_bwd(dy_math, rectified, is_valid[i]);
bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);
// Lastly we need 'dy' only for BN, so store the 'relu-dgrad'ed version
from_float(dy_storage[i], dy_math);
// dZ for elementwise add
if (is_valid[i]) {
if (loop_i == OUTER_LOOPS - 1) {
stg_stream(&gmem_dst1[idx*params.c], dy_storage[i]);
} else {
stg(&gmem_dst1[idx*params.c], dy_storage[i]);
}
}
}
}
// The elements to load and store in SMEM.
int smem_nhw = OUTER_LOOPS*pixels_per_iteration + cta_nhw_smem;
// Load elements from SMEM, update the CTA count.
int pixels_in_smem = min(PIXELS_PER_CTA_IN_SMEM, params.nhw-smem_nhw);
if (pixels_in_smem > 0) {
cta_count += pixels_in_smem;
for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {
const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
const bool is_pixel_valid_nhw =
static_cast<unsigned int>(idx) < static_cast<unsigned int>(params.nhw);
const bool is_pixel_valid = is_pixel_valid_nhw && is_valid_c;
PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG],
dy_storage_local[PACKED_ELEMENTS_PER_LDG];
unsigned int relu_mask;
int lane_id = threadIdx.x & 31;
zero_array(x_storage_local);
zero_array(dy_storage_local);
if (is_pixel_valid_nhw) {
if (is_valid_c) {
ldg_stream(x_storage_local, &gmem_src[idx*params.c]);
ldg_stream(dy_storage_local, &gmem_dy[idx*params.c]);
}
if (lane_id < ELEMENTS_PER_LDG) {
relu_mask = gmem_relu_bitmask[idx * 2 + lane_id];
}
}
bool rectified[ELEMENTS_PER_LDG];
#pragma unroll
for (int j = 0; j < ELEMENTS_PER_LDG; ++j) {
rectified[j] = ((__shfl_sync(0xFFFFFFFFU, relu_mask, j) &
(1U << lane_id)) != 0);
}
// The offset to store in SMEM.
int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
// Store in SMEM.
write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local);
offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
// Update the count.
count += is_pixel_valid;
// Invert the count.
float inv_count = is_pixel_valid ? 1.f / count : 0.f;
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage_local);
to_float(dy_math, dy_storage_local);
relu_bwd(dy_math, rectified, is_pixel_valid);
bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count);
from_float(dy_storage_local, dy_math);
// dZ for elementwise add
if (is_pixel_valid) {
stg_stream(&gmem_dst1[idx*params.c], dy_storage_local);
}
// only store the 'relu-dgrad'ed version!
write_to_smem(&smem_storage_packed[offset], threadIdx.x, dy_storage_local);
}
}
// We scale the mean by the number of elements. It brings more stability.
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
dbias[i] *= count;
dscale[i] *= count;
}
// dscale parallel sum
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, dscale, thread_in_cta_nhw);
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dscale, smem, thread_in_cta_c);
__syncthreads();
// dbias parallel sum
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, dbias, thread_in_cta_nhw);
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dbias, smem, thread_in_cta_c);
__syncthreads();
// The workspace in global memory is distributed across the different CTA.
int gmem_sums_offset = c_blk_index*gridDim.x*C_ELEMENTS_PER_CTA*2;
// Write the data for the CTA to global memory.
float *gmem_sums = &params.gmem_sums[gmem_sums_offset];
if (threadIdx.x < THREADS_PER_PIXEL) {
const int idx = blockIdx.x*THREADS_PER_PIXEL + threadIdx.x;
write_to_gmem(&gmem_sums[ 0], idx, dscale);
write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx, dbias);
}
// The counters to count how many CTAs have retired at this point.
// A given cta uses the same counter every other time through the outer loop.
int *gmem_retired_ctas = &params.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)];
inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0);
// Reset the accumulators for global summation
zero_array(dscale);
zero_array(dbias);
// Build the global accumulation
#pragma unroll 1
for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) {
float tmp1[ELEMENTS_PER_LDG], tmp2[ELEMENTS_PER_LDG];
read_from_gmem(tmp1, gmem_sums, idx);
read_from_gmem(tmp2, gmem_sums+C_ELEMENTS_PER_CTA*gridDim.x, idx);
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
dscale[i] += tmp1[i];
dbias[i] += tmp2[i];
}
}
// dscale parallel sum
if (params.sync_iters>0) {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_data, 4*c_blk_index+1, params.magic, params.pair_data2, params.sync_iters);
} else {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, dscale, thread_in_cta_nhw);
}
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dscale, smem, thread_in_cta_c);
__syncthreads();
// dbias parallel sum
if (params.sync_iters>0) {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_data, 4*c_blk_index+0, params.magic, params.pair_data2, params.sync_iters);
} else {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, dbias, thread_in_cta_nhw);
}
__syncthreads();
// The values in shared memory correspond to the CTA-wide sums.
read_from_smem(dbias, smem, thread_in_cta_c);
// Normalize the dscale.
float var[ELEMENTS_PER_LDG];
zero_array(var);
if (is_valid_c) {
read_from_gmem(var, params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG);
}
multiply(dscale, var);
// store dscale/dbias
bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0;
if (is_valid_for_saving) {
if (params.sync_iters>0)
{
scaled_write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale, params.wgrad_coeff);
scaled_write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias, params.wgrad_coeff);
} else {
write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale);
write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias);
}
}
// Further normalize the dscale to be used in dx calculation
float scale[ELEMENTS_PER_LDG];
zero_array(scale);
if (is_valid_c) {
read_from_gmem(scale, params.gmem_scale, thread_c/ELEMENTS_PER_LDG);
}
multiply(dscale, var);
// scale the inv-var as well, afterwards
multiply(var, scale);
// inverse count
float inv_count = params.svar_inv_count;
// The base pointer to write to.
uint16_t *const gmem_dst = &params.gmem_dst[thread_c];
// Store the elements in registers.
#pragma unroll 1
for (int loop_i = OUTER_LOOPS-1; loop_i >= 0; --loop_i) {
// The value for nhw.
int out_nhw = cta_nhw_regs + loop_i*pixels_per_iteration;
// Normalize the elements and write to memory.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c;
// Convert to float.
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage[i]);
to_float(dy_math, dy_storage[i]);
float dx[ELEMENTS_PER_LDG];
bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);
// Write back.
if (is_valid) {
stg_stream(&gmem_dst[idx*params.c], dx);
}
}
// The next value of nhw.
out_nhw -= pixels_per_iteration;
// Read the next elements from memory.
#pragma unroll
for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) {
const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
float y[ELEMENTS_PER_LDG];
zero_array(y);
if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) {
ldg_stream(x_storage[i], &gmem_src[idx*params.c]);
ldg_stream(dy_storage[i], &gmem_dst1[idx*params.c]);
}
}
}
// Normalize the elements from SMEM and write them out.
if (pixels_in_smem > 0) {
for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) {
const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG;
const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c;
if (is_valid) {
// Read from SMEM.
int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG],
dy_storage_local[PACKED_ELEMENTS_PER_LDG];
read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x);
offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG;
read_from_smem(dy_storage_local, &smem_storage_packed[offset], threadIdx.x);
float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG];
to_float(x_math, x_storage_local);
to_float(dy_math, dy_storage_local);
float dx[ELEMENTS_PER_LDG];
bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count);
// Write back.
stg_stream(&gmem_dst[idx*params.c], dx);
}
}
}
// We're about to start on the next c-blk. Needed?
__syncthreads();
}
}
#endif // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_
try:
import torch
import bnp
from .batch_norm import BatchNorm2d_NHWC
del torch
del bnp
del batch_norm
except ImportError as err:
print("apex was installed without --bnp flag, contrib.groupbn is not available")
import torch
import numpy as np
from torch.nn.modules.batchnorm import _BatchNorm
import bnp
class bn_NHWC_impl(torch.autograd.Function):
@staticmethod
def forward(ctx, x, s, b, rm, riv, mini_m, mini_riv, mom, epsilon, fuse_relu=False, is_train=True, bn_group=1, my_data=None, pair_data=None, magic=1, pair_data2=None, max_cta_per_sm=2, cta_launch_margin=12):
if is_train:
ctx.save_for_backward(x, s, b, rm, riv, mini_m, mini_riv)
ctx.epsilon = epsilon
ctx.momentum = mom
ctx.fuse_relu = fuse_relu
ctx.my_data = my_data
ctx.pair_data = pair_data
ctx.magic = magic
ctx.pair_data2 = pair_data2
ctx.bn_group = bn_group
ctx.max_cta_per_sm = max_cta_per_sm
ctx.cta_launch_margin = cta_launch_margin
res = bnp.bn_fwd_nhwc(x, s, b, rm, riv, mini_m, mini_riv, mom, epsilon, fuse_relu, my_data, pair_data, pair_data2, bn_group, magic, max_cta_per_sm, cta_launch_margin)
return res
else:
return bnp.bn_fwd_eval_nhwc(x, s, b, rm, riv, bn_group, mom, epsilon, fuse_relu)
@staticmethod
def backward(ctx, grad_y):
x, s, b, rm, riv, mini_m, mini_riv = ctx.saved_variables
epsilon = ctx.epsilon
mom = ctx.momentum
fuse_relu = ctx.fuse_relu
my_data = ctx.my_data
pair_data = ctx.pair_data
magic = ctx.magic
pair_data2 = ctx.pair_data2
bn_group = ctx.bn_group
max_cta_per_sm = ctx.max_cta_per_sm
cta_launch_margin = ctx.cta_launch_margin
dx, dscale, dbias = bnp.bn_bwd_nhwc(x, grad_y, s, b, rm, riv, mini_m, mini_riv, mom, epsilon, fuse_relu, my_data, pair_data, pair_data2, bn_group, magic, max_cta_per_sm, cta_launch_margin)
return dx, dscale, dbias, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
class bn_addrelu_NHWC_impl(torch.autograd.Function):
@staticmethod
def forward(ctx, x, z, s, b, rm, riv, mini_m, mini_riv, mom, epsilon, is_train=True, bn_group=1, my_data=None, pair_data=None, magic=1, pair_data2=None, max_cta_per_sm=2, cta_launch_margin=12):
if is_train:
bitmask = torch.cuda.IntTensor(x.numel()//32)
ctx.save_for_backward(x, s, b, rm, riv, mini_m, mini_riv, bitmask)
ctx.epsilon = epsilon
ctx.momentum = mom
ctx.my_data = my_data
ctx.pair_data = pair_data
ctx.magic = magic
ctx.pair_data2 = pair_data2
ctx.bn_group = bn_group
ctx.max_cta_per_sm = max_cta_per_sm
ctx.cta_launch_margin = cta_launch_margin
res = bnp.bn_addrelu_fwd_nhwc(x, z, s, b, rm, riv, mini_m, mini_riv, bitmask, mom, epsilon, my_data, pair_data, pair_data2, bn_group, magic, max_cta_per_sm, cta_launch_margin)
return res
else:
return bnp.bn_addrelu_fwd_eval_nhwc(x, z, s, b, rm, riv, bn_group, mom, epsilon)
@staticmethod
def backward(ctx, grad_y):
x, s, b, rm, riv, mini_m, mini_riv, bitmask = ctx.saved_variables
epsilon = ctx.epsilon
mom = ctx.momentum
my_data = ctx.my_data
pair_data = ctx.pair_data
magic = ctx.magic
pair_data2 = ctx.pair_data2
bn_group = ctx.bn_group
max_cta_per_sm = ctx.max_cta_per_sm
cta_launch_margin = ctx.cta_launch_margin
dx, dz, dscale, dbias = bnp.bn_addrelu_bwd_nhwc(x, grad_y, s, b, rm, riv, mini_m, mini_riv, bitmask, mom, epsilon, my_data, pair_data, pair_data2, bn_group, magic, max_cta_per_sm, cta_launch_margin)
return dx, dz, dscale, dbias, None, None, None, None, None, None, None, None, None, None, None, None, None, None
class BatchNorm2d_NHWC(_BatchNorm):
def __init__(self, num_features, fuse_relu=False, bn_group=1, max_cta_per_sm=2, cta_launch_margin=12):
super(BatchNorm2d_NHWC, self).__init__(num_features)
self.fuse_relu = fuse_relu
self.minibatch_mean = torch.cuda.FloatTensor(num_features)
self.minibatch_riv = torch.cuda.FloatTensor(num_features)
#defaut to distributed bn disabled
self.bn_group = bn_group
self.max_cta_per_sm = max_cta_per_sm #used only in training fwd and bwd
self.cta_launch_margin = cta_launch_margin #used only in training fwd and bwd
self.my_data = None
self.pair_data = None
self.pair_data2 = None
self.local_rank = 0
self.magic = torch.IntTensor([0])
assert(max_cta_per_sm>0) # won't be able to do much with 0 CTAs :)
#FIXME: turn pair handles into an array
if bn_group>1:
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
assert(world_size >= bn_group)
assert(world_size % bn_group == 0)
bn_sync_steps = 1
if (bn_group==4):
bn_sync_steps = 2
self.ipc_buffer = torch.cuda.ByteTensor(bnp.get_buffer_size(bn_sync_steps))
self.my_data = bnp.get_data_ptr(self.ipc_buffer)
# we are walking on very thin ice here by utilizing internal `_share_cuda_()`
self.storage = self.ipc_buffer.storage()
self.share_cuda = self.storage._share_cuda_()
internal_cuda_mem = self.share_cuda
# internal_cuda_mem[1]: ipc_mem_handle
my_handle = torch.cuda.ByteTensor(np.frombuffer(internal_cuda_mem[1], dtype=np.uint8))
# internal_cuda_mem[3]: offset
my_offset = torch.cuda.IntTensor([internal_cuda_mem[3]])
handles_all = torch.empty(world_size, my_handle.size(0), dtype=my_handle.dtype, device=my_handle.device)
handles_l = list(handles_all.unbind(0))
torch.distributed.all_gather(handles_l, my_handle)
offsets_all = torch.empty(world_size, my_offset.size(0), dtype=my_offset.dtype, device=my_offset.device)
offsets_l = list(offsets_all.unbind(0))
torch.distributed.all_gather(offsets_l, my_offset)
#whom do I actually care about? that would be local_rank XOR 1
self.pair_handle = handles_l[local_rank ^ 1].cpu().contiguous()
pair_offset = offsets_l[local_rank ^ 1].cpu()
self.pair_data = bnp.get_remote_data_ptr(self.pair_handle, pair_offset)
if bn_group>2:
self.pair_handle2 = handles_l[local_rank ^ 2].cpu().contiguous()
pair_offset2 = offsets_l[local_rank ^ 2].cpu()
self.pair_data2 = bnp.get_remote_data_ptr(self.pair_handle2, pair_offset2)
#FIXME: get magic value into C code and eliminate from here
self.magic = torch.IntTensor([2])
self.local_rank = local_rank
def forward(self, x, z=None):
if z is not None:
assert(self.fuse_relu==True)
return bn_addrelu_NHWC_impl.apply(x, z,
self.weight, self.bias,
self.running_mean, self.running_var,
self.minibatch_mean, self.minibatch_riv,
self.momentum,
self.eps, self.training, self.bn_group, self.my_data, self.pair_data, (self.magic), self.pair_data2,
self.max_cta_per_sm, self.cta_launch_margin)
else:
return bn_NHWC_impl.apply(x,
self.weight, self.bias,
self.running_mean, self.running_var,
self.minibatch_mean, self.minibatch_riv,
self.momentum,
self.eps, self.fuse_relu, self.training, self.bn_group, self.my_data, self.pair_data, (self.magic), self.pair_data2,
self.max_cta_per_sm, self.cta_launch_margin)
def __del__(self):
if self.bn_group>1:
bnp.close_remote_data(self.pair_handle)
if self.bn_group>2:
bnp.close_remote_data(self.pair_handle2)
......@@ -55,10 +55,11 @@ class SyncBatchNorm(_BatchNorm):
>>> inp = torch.randn(10, 14, 14, 100).cuda()
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last=False):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last=False, fuse_relu=False):
super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
self.process_group = process_group
self.channel_last = channel_last
self.fuse_relu = fuse_relu
def _specify_process_group(self, process_group):
self.process_group = process_group
......@@ -66,11 +67,11 @@ class SyncBatchNorm(_BatchNorm):
def _specify_channel_last(self, channel_last):
self.channel_last = channel_last
def forward(self, input):
def forward(self, input, z = None):
# if input.dim() == 2, we switch to channel_last for efficient memory accessing
channel_last = self.channel_last if input.dim() != 2 else True
if not self.training and self.track_running_stats and not channel_last:
if not self.training and self.track_running_stats and not self.channel_last and not self.fuse_relu and z == None:
# fall back to pytorch implementation for inference
return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps)
else:
......@@ -81,4 +82,4 @@ class SyncBatchNorm(_BatchNorm):
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else:
exponential_average_factor = self.momentum
return SyncBatchnormFunction.apply(input, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.training or not self.track_running_stats, exponential_average_factor, self.process_group, channel_last)
return SyncBatchnormFunction.apply(input, z, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.training or not self.track_running_stats, exponential_average_factor, self.process_group, self.channel_last, self.fuse_relu)
......@@ -7,7 +7,7 @@ from apex.parallel import ReduceOp
class SyncBatchnormFunction(Function):
@staticmethod
def forward(ctx, input, weight, bias, running_mean, running_variance, eps, track_running_stats = True, momentum = 1.0, process_group = None, channel_last = False):
def forward(ctx, input, z, weight, bias, running_mean, running_variance, eps, track_running_stats = True, momentum = 1.0, process_group = None, channel_last = False, fuse_relu = False):
torch.cuda.nvtx.range_push("sync_BN_fw")
input = input.contiguous()
world_size = 0
......@@ -53,13 +53,14 @@ class SyncBatchnormFunction(Function):
mean = running_mean.data
inv_std = 1.0 / torch.sqrt(running_variance.data + eps)
ctx.save_for_backward(input, weight, mean, inv_std)
ctx.save_for_backward(input, weight, mean, inv_std, z, bias)
ctx.process_group = process_group
ctx.channel_last = channel_last
ctx.world_size = world_size
ctx.fuse_relu = fuse_relu
if channel_last:
out = syncbn.batchnorm_forward_c_last(input, mean, inv_std, weight, bias)
out = syncbn.batchnorm_forward_c_last(input, z, mean, inv_std, weight, bias, fuse_relu)
else:
out = syncbn.batchnorm_forward(input, mean, inv_std, weight, bias)
......@@ -73,11 +74,17 @@ class SyncBatchnormFunction(Function):
# mini batch mean & var are calculated by forward path.
# mu = 1./N*np.sum(h, axis = 0)
# var = 1./N*np.sum((h-mu)**2, axis = 0)
saved_input, weight, mean, inv_std = ctx.saved_tensors
saved_input, weight, mean, inv_std, z, bias = ctx.saved_tensors
process_group = ctx.process_group
channel_last = ctx.channel_last
world_size = ctx.world_size
grad_input = grad_weight = grad_bias = None
fuse_relu = ctx.fuse_relu
grad_input = grad_z = grad_weight = grad_bias = None
if fuse_relu:
grad_output = syncbn.relu_bw_c_last(grad_output, saved_input, z, mean, inv_std, weight, bias)
if isinstance(z, torch.Tensor) and ctx.needs_input_grad[1]:
grad_z = grad_output.clone()
# TODO(jie): why do I have to clone here? life time of grad_output?
if channel_last:
......@@ -100,11 +107,11 @@ class SyncBatchnormFunction(Function):
else:
grad_input = syncbn.batchnorm_backward(grad_output, saved_input, mean, inv_std, weight, mean_dy, mean_dy_xmu)
if weight is None or not ctx.needs_input_grad[1]:
if weight is None or not ctx.needs_input_grad[2]:
grad_weight = None
if weight is None or not ctx.needs_input_grad[2]:
if weight is None or not ctx.needs_input_grad[3]:
grad_bias = None
torch.cuda.nvtx.range_pop()
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None, None
return grad_input, grad_z, grad_weight, grad_bias, None, None, None, None, None, None, None, None
......@@ -55,10 +55,12 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input);
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
at::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input,
const at::optional<at::Tensor> z,
const at::Tensor mean,
const at::Tensor inv_std,
const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift);
const at::optional<at::Tensor> shift,
const bool fuse_relu);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type;
......@@ -82,6 +84,15 @@ at::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu);
at::Tensor relu_backward_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::optional<at::Tensor> z,
const at::Tensor mean,
const at::Tensor inv_std,
const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("welford_mean_var", &welford_mean_var_CUDA, "welford mean variance");
m.def("welford_parallel", &welford_parallel_CUDA, "welford parallel reduce mean variance");
......@@ -92,4 +103,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("batchnorm_forward_c_last", &batchnorm_forward_c_last_CUDA, "batchnorm forward nhwc");
m.def("reduce_bn_c_last", &reduce_bn_c_last_CUDA, "batchnorm backwards reduce grad sum and bias/weight grad nhwc");
m.def("batchnorm_backward_c_last", &batchnorm_backward_c_last_CUDA, "batchnorm backward dgrad nhwc");
m.def("relu_bw_c_last", &relu_backward_c_last_CUDA, "relu_bw_c_last");
}
......@@ -590,6 +590,58 @@ template <
int PARALLEL_LOADS>
__global__ void batchnorm_forward_c_last_kernel(
const scalar_t* __restrict__ input,
const scalar_t* __restrict__ z,
const accscalar_t* __restrict__ mean,
const accscalar_t* __restrict__ inv_std,
const layerscalar_t* __restrict__ weight,
const layerscalar_t* __restrict__ shift,
scalar_t* __restrict__ out,
const int reduction_size,
const int stride,
const bool fuse_relu) {
// tensor dimension (m,c)
// loop along m dimension
int inner_loop_stride = blockDim.y * gridDim.y;
// offset along m dimension
int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
auto m_c = mean[c_offset];
auto inv_std_c = static_cast<accscalar_t>(inv_std[c_offset]);
auto w_c = weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset]);
auto s_c = shift == NULL ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[c_offset]);
int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
int address_base = m_offset * stride + c_offset;
int address_increment = inner_loop_stride * stride;
for (int i = 0; i < loop_count; i++) {
#pragma unroll
for (int j = 0; j < PARALLEL_LOADS; j++) {
if (c_offset < stride && m_offset < reduction_size) {
auto tmp = w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c;
if (z != NULL) {
tmp += z[address_base];
}
out[address_base] = (fuse_relu && tmp <= accscalar_t(0.0) ? scalar_t(0.0) : static_cast<scalar_t>(tmp));
}
m_offset += inner_loop_stride;
address_base += address_increment;
}
}
}
// elementwise BN kernel
template <
typename scalar_t,
typename accscalar_t,
typename layerscalar_t,
int PARALLEL_LOADS>
__global__ void relu_backward_c_last_kernel(
const scalar_t* __restrict__ grad_output,
const scalar_t* __restrict__ input,
const scalar_t* __restrict__ z,
const accscalar_t* __restrict__ mean,
const accscalar_t* __restrict__ inv_std,
const layerscalar_t* __restrict__ weight,
......@@ -618,9 +670,11 @@ __global__ void batchnorm_forward_c_last_kernel(
#pragma unroll
for (int j = 0; j < PARALLEL_LOADS; j++) {
if (c_offset < stride && m_offset < reduction_size) {
out[address_base] = static_cast<scalar_t>(
w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c
);
auto tmp = w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c;
if (z != NULL) {
tmp += z[address_base];
}
out[address_base] = (tmp <= accscalar_t(0.0) ? scalar_t(0.0) : grad_output[address_base]);
}
m_offset += inner_loop_stride;
address_base += address_increment;
......@@ -1146,10 +1200,12 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input) {
at::Tensor batchnorm_forward_c_last_CUDA(
const at::Tensor input,
const at::optional<at::Tensor> z,
const at::Tensor mean,
const at::Tensor inv_std,
const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift) {
const at::optional<at::Tensor> shift,
const bool fuse_relu) {
const auto stride = input.size(input.ndimension()-1);
const auto reduction_size = input.numel() / stride;
......@@ -1169,13 +1225,15 @@ at::Tensor batchnorm_forward_c_last_CUDA(
batchnorm_forward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
input.data<scalar_t_0>(),
z.has_value() ? z.value().data<scalar_t_0>() : NULL,
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
shift.has_value() ? shift.value().data<accscalar_t>(): NULL,
out.data<scalar_t_0>(),
reduction_size,
stride);
stride,
fuse_relu);
);
} else {
if (weight.has_value()) {
......@@ -1188,13 +1246,15 @@ at::Tensor batchnorm_forward_c_last_CUDA(
batchnorm_forward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
input.data<scalar_t_0>(),
z.has_value() ? z.value().data<scalar_t_0>() : NULL,
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<scalar_t_0>() : NULL,
shift.has_value() ? shift.value().data<scalar_t_0>(): NULL,
out.data<scalar_t_0>(),
reduction_size,
stride);
stride,
fuse_relu);
);
}
return out;
......@@ -1350,3 +1410,66 @@ at::Tensor batchnorm_backward_c_last_CUDA(
return grad_input;
}
at::Tensor relu_backward_c_last_CUDA(
const at::Tensor grad_output,
const at::Tensor input,
const at::optional<at::Tensor> z,
const at::Tensor mean,
const at::Tensor inv_std,
const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift) {
const auto stride = input.size(input.ndimension()-1);
const auto reduction_size = input.numel() / stride;
at::Tensor out = at::empty_like(input);
dim3 block;
dim3 grid;
flexible_launch_configs(reduction_size, stride, block, grid);
auto stream = at::cuda::getCurrentCUDAStream();
if (input.scalar_type() == at::ScalarType::Half
&& weight.has_value() && weight.value().scalar_type() == at::ScalarType::Float) {
using namespace at;
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
using accscalar_t = at::acc_type<scalar_t_0, true>;
relu_backward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
grad_output.data<scalar_t_0>(),
input.data<scalar_t_0>(),
z.has_value() ? z.value().data<scalar_t_0>() : NULL,
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
shift.has_value() ? shift.value().data<accscalar_t>(): NULL,
out.data<scalar_t_0>(),
reduction_size,
stride);
);
} else {
if (weight.has_value()) {
AT_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.scalar_type() is not supported with weight.scalar_type()");
}
using namespace at;
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
using accscalar_t = at::acc_type<scalar_t_0, true>;
relu_backward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
grad_output.data<scalar_t_0>(),
input.data<scalar_t_0>(),
z.has_value() ? z.value().data<scalar_t_0>() : NULL,
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<scalar_t_0>() : NULL,
shift.has_value() ? shift.value().data<scalar_t_0>(): NULL,
out.data<scalar_t_0>(),
reduction_size,
stride);
);
}
return out;
}
......@@ -99,6 +99,35 @@ if "--cuda_ext" in sys.argv:
'-O3',
'--use_fast_math'] + version_ge_1_1}))
if "--bnp" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--bnp")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--cuda_ext was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
# Set up macros for forward/backward compatibility hack around
# https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e
version_ge_1_1 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):
version_ge_1_1 = ['-DVERSION_GE_1_1']
ext_modules.append(
CUDAExtension(name='bnp',
sources=['apex/contrib/csrc/groupbn/batch_norm.cu',
'apex/contrib/csrc/groupbn/ipc.cu',
'apex/contrib/csrc/groupbn/interface.cpp',
'apex/contrib/csrc/groupbn/batch_norm_add_relu.cu'],
extra_compile_args={'cxx': [] + version_ge_1_1,
'nvcc':['-DCUDA_HAS_FP16=1',
'-D__CUDA_NO_HALF_OPERATORS__',
'-D__CUDA_NO_HALF_CONVERSIONS__',
'-D__CUDA_NO_HALF2_OPERATORS__',
'-gencode',
'arch=compute_70,code=sm_70'] + version_ge_1_1}))
setup(
name='apex',
version='0.1',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment