Commit fedfe0d7 authored by jjsjann123's avatar jjsjann123 Committed by mcarilli
Browse files

Bnp integration pr (#275)

* Persistent group batchnorm added

Added persistent grouped batch norm for performance run on strong scaling case:
currently only supporting:

  1. nhwc layout
  2. fp16
  3. synchronization only within a node!

Environment variable is used to tune LAUNCH_MARGIN that limits the CTAs usage
by the persistent kernel.

Documentation and examples will follow.

* updating type().scalarType() to scalar_type()

* moving launch margin to be defined at layer creation, adding a knob cap max ctas per sm

* fixing the cta computation

* review comment:

set device_id through cudaGetDevice()
move cudaMemset to cudaMemsetAsync
updated __threadfence() to __threadfence_system() inter device write
parent e7beba17
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCNumerics.cuh>
#include "THC/THC.h"
#include "batch_norm.h"
#include <cuda.h>
#define cudaCheckErrors(msg) \
do { \
cudaError_t __err = cudaGetLastError(); \
if (__err != cudaSuccess) { \
fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", \
msg, cudaGetErrorString(__err), \
__FILE__, __LINE__); \
fprintf(stderr, "*** FAILED - ABORTING\n"); \
exit(1); \
} \
} while (0)
static size_t round_up_to_multiple(size_t x, int multiple) {
return ((x + multiple - 1) / multiple) * multiple;
}
// TODO: Stop manually allocating CUDA memory; allocate an ATen byte
// tensor instead.
struct Workspace {
Workspace(size_t size) : size(size), data(NULL) {
data = THCudaMalloc(at::globalContext().lazyInitCUDA(), size);
}
Workspace(const Workspace&) = delete;
Workspace(Workspace&&) = default;
Workspace& operator=(Workspace&&) = default;
~Workspace() {
if (data) {
THCudaFree(at::globalContext().lazyInitCUDA(), data);
}
}
size_t size;
void* data;
};
// Return {y}
at::Tensor nhwc_bn_fwd_train(
const at::Tensor& x,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const float momentum,
const float epsilon,
const bool fuse_relu,
void * my_data,
void * pair_data,
void * pair_data2,
const int bn_group,
const at::Tensor& magic_tensor,
const int max_cta_per_sm,
const int cta_launch_margin) {
const int N = x.size(0);
const int H = x.size(1);
const int W = x.size(2);
const int C = x.size(3);
// generating new magic number and use that for sync
int* magic = magic_tensor.data<int>();
*magic = (*magic + 1) & 0xff;
// Allocate output tensor
at::Tensor y = at::empty({N, H, W, C}, x.options());
// Create wrapper
NhwcBatchNorm *bn = new NhwcBatchNorm();
bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.data<at::Half>(),
nullptr,
y.data<at::Half>(),
nullptr);
bn->setWeightPointers({scale.data<float>(), bias.data<float>()}, {nullptr, nullptr});
bn->setParameterPointers({running_mean.data<float>(), running_inv_var.data<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t total_workspace_bytes = 0;
std::vector<size_t> workspace_offsets;
for (auto index = 3; index < workspace_bytes.size(); ++index) {
total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);
workspace_offsets.push_back(total_workspace_bytes);
auto alloc_bytes = workspace_bytes[index];
total_workspace_bytes += alloc_bytes;
}
// Allocate the workspace
Workspace ws(total_workspace_bytes);
std::vector<void *> workspace;
workspace.push_back(minibatch_mean.data<float>());
workspace.push_back(minibatch_inv_var.data<float>());
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[2];
void* retired_ctas = THCudaMalloc(at::globalContext().lazyInitCUDA(), retired_cta_bytes);
cudaMemsetAsync(retired_ctas, 0, retired_cta_bytes, stream); //FIXME: is this legit?
workspace.push_back(retired_ctas);
for (auto index = 3; index < workspace_bytes.size(); ++index) {
void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-3];
workspace.push_back(ptr);
}
bn->setWorkspacePointers(workspace, workspace_bytes);
int device_id;
cudaGetDevice(&device_id);
// Don't fuse in ReLU for now at least
bn->fwd(stream, fuse_relu, device_id, my_data, pair_data, pair_data2, bn_group, *magic, max_cta_per_sm, cta_launch_margin);
THCudaFree(at::globalContext().lazyInitCUDA(), retired_ctas);
return y;
}
at::Tensor nhwc_bn_fwd_eval(
const at::Tensor& x,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const int bn_group,
const float momentum,
const float epsilon,
const bool fuse_relu) {
const int N = x.size(0);
const int H = x.size(1);
const int W = x.size(2);
const int C = x.size(3);
// Allocate output tensor
at::Tensor y = at::empty({N, H, W, C}, x.options());
// Create wrapper
NhwcBatchNorm *bn = new NhwcBatchNorm();
bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.data<at::Half>(),
nullptr,
y.data<at::Half>(),
nullptr);
bn->setWeightPointers({scale.data<float>(), bias.data<float>()}, {nullptr, nullptr});
bn->setParameterPointers({running_mean.data<float>(), running_inv_var.data<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t total_workspace_bytes = 0;
std::vector<size_t> workspace_offsets;
for (auto index = 3; index < workspace_bytes.size(); ++index) {
total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);
workspace_offsets.push_back(total_workspace_bytes);
auto alloc_bytes = workspace_bytes[index];
total_workspace_bytes += alloc_bytes;
}
// Allocate the workspace
Workspace ws(total_workspace_bytes);
std::vector<void *> workspace;
workspace.push_back(nullptr);
workspace.push_back(nullptr);
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[2];
void* retired_ctas = THCudaMalloc(at::globalContext().lazyInitCUDA(), retired_cta_bytes);
cudaMemsetAsync(retired_ctas, 0, retired_cta_bytes, stream); //FIXME: is this legit?
workspace.push_back(retired_ctas);
for (auto index = 3; index < workspace_bytes.size(); ++index) {
void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-3];
workspace.push_back(ptr);
}
bn->setWorkspacePointers(workspace, workspace_bytes);
// Don't fuse in ReLU for now at least
bn->fwdInference(stream, fuse_relu);
THCudaFree(at::globalContext().lazyInitCUDA(), retired_ctas);
return y;
}
std::vector<at::Tensor> nhwc_bn_bwd(
const at::Tensor& x,
const at::Tensor& dy,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const float momentum,
const float epsilon,
const bool fuse_relu,
void * my_data,
void * pair_data,
void * pair_data2,
const int bn_group,
const at::Tensor& magic_tensor,
const int max_cta_per_sm,
const int cta_launch_margin) {
// shape
const int N = x.size(0);
const int H = x.size(1);
const int W = x.size(2);
const int C = x.size(3);
// generating new magic number and use that for sync
int* magic = magic_tensor.data<int>();
*magic = (*magic + 1) & 0xff;
// outputs
at::Tensor x_grad, scale_grad, bias_grad;
// Allocate outputs
x_grad = at::empty_like(x);
scale_grad = at::empty_like(scale);
bias_grad = at::empty_like(bias);
// Create wrapper
NhwcBatchNorm *bn = new NhwcBatchNorm();
bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.data<at::Half>(),
x_grad.data<at::Half>(),
nullptr,
dy.data<at::Half>());
bn->setWeightPointers({scale.data<float>(), bias.data<float>()}, {scale_grad.data<float>(), bias_grad.data<float>()});
bn->setParameterPointers({running_mean.data<float>(), running_inv_var.data<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t total_workspace_bytes = 0;
std::vector<size_t> workspace_offsets;
for (auto index = 3; index < workspace_bytes.size(); ++index) {
total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);
workspace_offsets.push_back(total_workspace_bytes);
auto alloc_bytes = workspace_bytes[index];
total_workspace_bytes += alloc_bytes;
}
// Allocate the workspace
Workspace ws(total_workspace_bytes);
std::vector<void *> workspace;
workspace.push_back(minibatch_mean.data<float>());
workspace.push_back(minibatch_inv_var.data<float>());
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[2];
void* retired_ctas = THCudaMalloc(at::globalContext().lazyInitCUDA(), retired_cta_bytes);
cudaMemsetAsync(retired_ctas, 0, retired_cta_bytes, stream); //FIXME: is this legit?
workspace.push_back(retired_ctas);
for (auto index = 3; index < workspace_bytes.size(); ++index) {
void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-3];
workspace.push_back(ptr);
}
bn->setWorkspacePointers(workspace, workspace_bytes);
int device_id;
cudaGetDevice(&device_id);
bn->dgrad(stream, fuse_relu, device_id, my_data, pair_data, pair_data2, bn_group, *magic, max_cta_per_sm, cta_launch_margin);
THCudaFree(at::globalContext().lazyInitCUDA(), retired_ctas);
return std::vector<at::Tensor>{x_grad, scale_grad, bias_grad};
}
This diff is collapsed.
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCNumerics.cuh>
#include "THC/THC.h"
#include "batch_norm_add_relu.h"
#include <cuda.h>
//FIXME move the common stuff to common h file
#define cudaCheckErrors(msg) \
do { \
cudaError_t __err = cudaGetLastError(); \
if (__err != cudaSuccess) { \
fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", \
msg, cudaGetErrorString(__err), \
__FILE__, __LINE__); \
fprintf(stderr, "*** FAILED - ABORTING\n"); \
exit(1); \
} \
} while (0)
static size_t round_up_to_multiple(size_t x, int multiple) {
return ((x + multiple - 1) / multiple) * multiple;
}
// TODO: Stop manually allocating CUDA memory; allocate an ATen byte
// tensor instead.
struct Workspace {
Workspace(size_t size) : size(size), data(NULL) {
data = THCudaMalloc(at::globalContext().lazyInitCUDA(), size);
}
Workspace(const Workspace&) = delete;
Workspace(Workspace&&) = default;
Workspace& operator=(Workspace&&) = default;
~Workspace() {
if (data) {
THCudaFree(at::globalContext().lazyInitCUDA(), data);
}
}
size_t size;
void* data;
};
// Return {y}
at::Tensor nhwc_bn_addrelu_fwd_train(
const at::Tensor& x,
const at::Tensor& z,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const at::Tensor& bitmask,
const float momentum,
const float epsilon,
void * my_data,
void * pair_data,
void * pair_data2,
const int bn_group,
const at::Tensor& magic_tensor,
const int max_cta_per_sm,
const int cta_launch_margin) {
const int N = x.size(0);
const int H = x.size(1);
const int W = x.size(2);
const int C = x.size(3);
// generating new magic number and use that for sync
int* magic = magic_tensor.data<int>();
*magic = (*magic + 1) & 0xff;
// Allocate output tensor
at::Tensor y = at::empty({N, H, W, C}, x.options());
// Create wrapper
NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu();
bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.data<at::Half>(),
nullptr,
y.data<at::Half>(),
nullptr,
z.data<at::Half>(),
nullptr);
bn->setWeightPointers({scale.data<float>(), bias.data<float>()}, {nullptr, nullptr});
bn->setParameterPointers({running_mean.data<float>(), running_inv_var.data<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t total_workspace_bytes = 0;
std::vector<size_t> workspace_offsets;
for (auto index = 4; index < workspace_bytes.size(); ++index) {
total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);
workspace_offsets.push_back(total_workspace_bytes);
auto alloc_bytes = workspace_bytes[index];
total_workspace_bytes += alloc_bytes;
}
// Allocate the workspace
Workspace ws(total_workspace_bytes);
std::vector<void *> workspace;
workspace.push_back(minibatch_mean.data<float>());
workspace.push_back(minibatch_inv_var.data<float>());
workspace.push_back(bitmask.data<int32_t>());
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[3];
void* retired_ctas = THCudaMalloc(at::globalContext().lazyInitCUDA(), retired_cta_bytes);
cudaMemsetAsync(retired_ctas, 0, retired_cta_bytes, stream); //FIXME: is this legit?
workspace.push_back(retired_ctas);
for (auto index = 4; index < workspace_bytes.size(); ++index) {
void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-4];
workspace.push_back(ptr);
}
bn->setWorkspacePointers(workspace, workspace_bytes);
int device_id;
cudaGetDevice(&device_id);
// Don't fuse in ReLU for now at least
bn->fwd(stream, device_id, my_data, pair_data, pair_data2, bn_group, *magic, max_cta_per_sm, cta_launch_margin);
THCudaFree(at::globalContext().lazyInitCUDA(), retired_ctas);
return y;
}
at::Tensor nhwc_bn_addrelu_fwd_eval(
const at::Tensor& x,
const at::Tensor& z,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const int bn_group,
const float momentum,
const float epsilon) {
const int N = x.size(0);
const int H = x.size(1);
const int W = x.size(2);
const int C = x.size(3);
// Allocate output tensor
at::Tensor y = at::empty({N, H, W, C}, x.options());
// Create wrapper
NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu();
bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.data<at::Half>(),
nullptr,
y.data<at::Half>(),
nullptr,
z.data<at::Half>(),
nullptr);
bn->setWeightPointers({scale.data<float>(), bias.data<float>()}, {nullptr, nullptr});
bn->setParameterPointers({running_mean.data<float>(), running_inv_var.data<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t total_workspace_bytes = 0;
std::vector<size_t> workspace_offsets;
for (auto index = 4; index < workspace_bytes.size(); ++index) {
total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);
workspace_offsets.push_back(total_workspace_bytes);
auto alloc_bytes = workspace_bytes[index];
total_workspace_bytes += alloc_bytes;
}
// Allocate the workspace
Workspace ws(total_workspace_bytes);
std::vector<void *> workspace;
workspace.push_back(nullptr);
workspace.push_back(nullptr);
workspace.push_back(nullptr);
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[3];
void* retired_ctas = THCudaMalloc(at::globalContext().lazyInitCUDA(), retired_cta_bytes);
cudaMemsetAsync(retired_ctas, 0, retired_cta_bytes, stream); //FIXME: is this legit?
workspace.push_back(retired_ctas);
for (auto index = 4; index < workspace_bytes.size(); ++index) {
void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-4];
workspace.push_back(ptr);
}
bn->setWorkspacePointers(workspace, workspace_bytes);
// Don't fuse in ReLU for now at least
bn->fwdInference(stream);
THCudaFree(at::globalContext().lazyInitCUDA(), retired_ctas);
return y;
}
std::vector<at::Tensor> nhwc_bn_addrelu_bwd(
const at::Tensor& x,
const at::Tensor& dy,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const at::Tensor& bitmask,
const float momentum,
const float epsilon,
void * my_data,
void * pair_data,
void * pair_data2,
const int bn_group,
const at::Tensor& magic_tensor,
const int max_cta_per_sm,
const int cta_launch_margin) {
// shape
const int N = x.size(0);
const int H = x.size(1);
const int W = x.size(2);
const int C = x.size(3);
// generating new magic number and use that for sync
int* magic = magic_tensor.data<int>();
*magic = (*magic + 1) & 0xff;
// outputs
at::Tensor x_grad, z_grad, scale_grad, bias_grad;
// Allocate outputs
x_grad = at::empty_like(x);
z_grad = at::empty_like(x);
scale_grad = at::empty_like(scale);
bias_grad = at::empty_like(bias);
// Create wrapper
NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu();
bn->setInputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W, bn_group);
bn->setOutputDescriptor(CUDNN_TENSOR_NHWC, CUDNN_DATA_HALF, N, C, H, W);
bn->setConstants(momentum, epsilon);
// set pointers within the wrapper
bn->setInputOutputPointers(x.data<at::Half>(),
x_grad.data<at::Half>(),
nullptr,
dy.data<at::Half>(),
nullptr,
z_grad.data<at::Half>());
bn->setWeightPointers({scale.data<float>(), bias.data<float>()}, {scale_grad.data<float>(), bias_grad.data<float>()});
bn->setParameterPointers({running_mean.data<float>(), running_inv_var.data<float>()});
// deal with workspace(s)
auto workspace_bytes = bn->numWorkspaceBytes();
// We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset
// an allocated workspace for the others
size_t total_workspace_bytes = 0;
std::vector<size_t> workspace_offsets;
for (auto index = 4; index < workspace_bytes.size(); ++index) {
total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512);
workspace_offsets.push_back(total_workspace_bytes);
auto alloc_bytes = workspace_bytes[index];
total_workspace_bytes += alloc_bytes;
}
// Allocate the workspace
Workspace ws(total_workspace_bytes);
std::vector<void *> workspace;
workspace.push_back(minibatch_mean.data<float>());
workspace.push_back(minibatch_inv_var.data<float>());
workspace.push_back(bitmask.data<int32_t>());
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[3];
void* retired_ctas = THCudaMalloc(at::globalContext().lazyInitCUDA(), retired_cta_bytes);
cudaMemsetAsync(retired_ctas, 0, retired_cta_bytes, stream); //FIXME: is this legit?
workspace.push_back(retired_ctas);
for (auto index = 4; index < workspace_bytes.size(); ++index) {
void *ptr = reinterpret_cast<uint8_t*>(ws.data) + workspace_offsets[index-4];
workspace.push_back(ptr);
}
bn->setWorkspacePointers(workspace, workspace_bytes);
int device_id;
cudaGetDevice(&device_id);
bn->dgrad(stream, device_id, my_data, pair_data, pair_data2, bn_group, *magic, max_cta_per_sm, cta_launch_margin);
THCudaFree(at::globalContext().lazyInitCUDA(), retired_ctas);
return std::vector<at::Tensor>{x_grad, z_grad, scale_grad, bias_grad};
}
This diff is collapsed.
#include <ATen/cuda/CUDAContext.h>
#ifndef CUDA_UTILS_H
#define CUDA_UTILS_H
namespace at {
namespace cuda {
namespace utils {
//eventually should be replaced by real query functions
static inline int MultiprocessorCount(int device_id) {
return getDeviceProperties(device_id)->multiProcessorCount;
}
static inline int SMArch(int device_id) {
auto device_property = getDeviceProperties(device_id);
int cc = device_property->major * 10 + device_property->minor;
return cc;
}
static inline int MaxSharedMemoryPerMultiprocessor(int device_id) {
return getDeviceProperties(device_id)->sharedMemPerMultiprocessor;
}
}
}
}
#endif
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/ArrayRef.h>
#include <ATen/ScalarType.h>
#include "ATen/Scalar.h"
#include "ATen/Type.h"
#include "ATen/Tensor.h"
#include "ATen/Storage.h"
#include "ATen/Generator.h"
namespace py = pybind11;
int64_t get_buffer_size(
const int bn_sync_steps);
void* get_data_ptr(
const at::Tensor& data);
void* get_remote_data_ptr(
const at::Tensor& handle,
const int64_t offset);
void close_remote_data(
const at::Tensor& handle);
at::Tensor nhwc_bn_fwd_train(
const at::Tensor& x,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const float momentum,
const float epsilon,
const bool fuse_relu,
void* my_data,
void* pair_data,
void* pair_data2,
const int bn_group,
const at::Tensor& magic_tensor,
const int max_cta_per_sm,
const int cta_launch_margin);
at::Tensor nhwc_bn_fwd_eval(
const at::Tensor& x,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const int bn_group,
const float momentum,
const float epsilon,
const bool fuse_relu);
std::vector<at::Tensor> nhwc_bn_bwd(
const at::Tensor& x,
const at::Tensor& dy,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const float momentum,
const float epsilon,
const bool fuse_relu,
void* my_data,
void* pair_data,
void* pair_data2,
const int bn_group,
const at::Tensor& magic_tensor,
const int max_cta_per_sm,
const int cta_launch_margin);
at::Tensor nhwc_bn_addrelu_fwd_train(
const at::Tensor& x,
const at::Tensor& z,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const at::Tensor& bitmask,
const float momentum,
const float epsilon,
void* my_data,
void* pair_data,
void* pair_data2,
const int bn_group,
const at::Tensor& magic_tensor,
const int max_cta_per_sm,
const int cta_launch_margin);
at::Tensor nhwc_bn_addrelu_fwd_eval(
const at::Tensor& x,
const at::Tensor& z,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const int bn_group,
const float momentum,
const float epsilon);
std::vector<at::Tensor> nhwc_bn_addrelu_bwd(
const at::Tensor& x,
const at::Tensor& dy,
const at::Tensor& scale,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const at::Tensor& bitmask,
const float momentum,
const float epsilon,
void* my_data,
void* pair_data,
void* pair_data2,
const int bn_group,
const at::Tensor& magic_tensor,
const int max_cta_per_sm,
const int cta_launch_margin);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("get_buffer_size", &get_buffer_size, "get_buffer_size");
m.def("get_data_ptr", &get_data_ptr, "get_data_ptr");
m.def("get_remote_data_ptr", &get_remote_data_ptr, "get_remote_data_ptr");
m.def("close_remote_data", &close_remote_data, "close_remote_data");
m.def("bn_fwd_nhwc", &nhwc_bn_fwd_train, "bn_fwd_nhwc");
m.def("bn_fwd_eval_nhwc", &nhwc_bn_fwd_eval, "bn_fwd_eval_nhwc");
m.def("bn_bwd_nhwc", &nhwc_bn_bwd, "bn_bwd_nhwc");
m.def("bn_addrelu_fwd_nhwc", &nhwc_bn_addrelu_fwd_train, "bn_addrelu_fwd_nhwc");
m.def("bn_addrelu_fwd_eval_nhwc", &nhwc_bn_addrelu_fwd_eval, "bn_addrelu_fwd_eval_nhwc");
m.def("bn_addrelu_bwd_nhwc", &nhwc_bn_addrelu_bwd, "bn_addrelu_bwd_nhwc");
}
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCNumerics.cuh>
#include "THC/THC.h"
#include <cuda.h>
#define cudaCheckErrors(msg) \
do { \
cudaError_t __err = cudaGetLastError(); \
if (__err != cudaSuccess) { \
fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", \
msg, cudaGetErrorString(__err), \
__FILE__, __LINE__); \
fprintf(stderr, "*** FAILED - ABORTING\n"); \
exit(1); \
} \
} while (0)
template<>
struct std::hash<cudaIpcMemHandle_t> {
size_t operator() (const cudaIpcMemHandle_t& handle) const {
size_t hash = 0;
uint8_t* ptr = (uint8_t*)&handle;
assert(sizeof(uint8_t) == 1);
for (int i=0; i<sizeof(cudaIpcMemHandle_t); i++) {
hash += *ptr;
ptr++;
}
return hash;
}
};
template<>
struct std::equal_to<cudaIpcMemHandle_t> {
bool operator() (const cudaIpcMemHandle_t &lhs,
const cudaIpcMemHandle_t &rhs) const {
return (std::memcmp((void*) &lhs,
(void*) &rhs,
sizeof(cudaIpcMemHandle_t)) == 0);
}
};
namespace {
namespace gpuipc {
//from: src/operator/nn/cudnn/nhwc_batch_norm_kernel.h
// The number of threads per pixel.
const int THREADS_PER_PIXEL = 16;
// The number of elements per ldg.
const int ELEMENTS_PER_LDG = 4;
// The number of reducing ops, each uses its own space : mean, var, dscale, dbias
const int REDUCE_OPS = 4;
// Maximum block.y supported - limited due to buffer allocation
const int MAX_BLOCK_Y = 256;
const int MAX_OFFSET = REDUCE_OPS*MAX_BLOCK_Y;
const int BYTES_PER_ELEM = 4;
// Buffer size per sync step
const int SINGLE_SYNC_BUFFER_BYTES = MAX_OFFSET*THREADS_PER_PIXEL*(1+ELEMENTS_PER_LDG)*BYTES_PER_ELEM;
};
class IpcMemHandleRegistry {
public:
void* getPtr(const cudaIpcMemHandle_t& handle, int64_t offset) {
if (registry_.count(handle) == 0) {
registry_.insert(std::make_pair(handle, RegistryEntry()));
registry_[handle].dev_ptr = ipcOpenMem(handle);
}
registry_[handle].ref_count++;
return (((uint8_t*)registry_[handle].dev_ptr) + offset);
}
void releasePtr(const cudaIpcMemHandle_t& handle) {
if (registry_.count(handle) == 0) {
}
if (--registry_[handle].ref_count == 0) {
ipcCloseMem(registry_[handle].dev_ptr);
registry_.erase(handle);
}
}
struct RegistryEntry {
void* dev_ptr;
int ref_count;
RegistryEntry() : dev_ptr(NULL) , ref_count(0) {}
};
protected:
std::unordered_map<cudaIpcMemHandle_t, RegistryEntry> registry_;
void* ipcOpenMem(const cudaIpcMemHandle_t& handle) {
void *data;
cudaIpcOpenMemHandle(&data, handle, cudaIpcMemLazyEnablePeerAccess);
cudaCheckErrors("ipc init");
return data;
}
void ipcCloseMem(void* dev_ptr) {
cudaIpcCloseMemHandle(dev_ptr);
cudaCheckErrors("ipc close");
}
};
}
static IpcMemHandleRegistry ipc_mem_registry;
int64_t get_buffer_size(const int bn_sync_steps) {
return bn_sync_steps * gpuipc::SINGLE_SYNC_BUFFER_BYTES;
}
void* get_remote_data_ptr(const at::Tensor& handle, const int64_t offset) {
cudaIpcMemHandle_t my_handle;
memcpy((unsigned char *)(&my_handle), handle.data<uint8_t>(), sizeof(my_handle));
return ipc_mem_registry.getPtr(my_handle, offset);
}
void close_remote_data(const at::Tensor& handle) {
cudaIpcMemHandle_t my_handle;
memcpy((unsigned char *)(&my_handle), handle.data<uint8_t>(), sizeof(my_handle));
ipc_mem_registry.releasePtr(my_handle);
}
void* get_data_ptr(
const at::Tensor& data) {
return data.data<uint8_t>();
}
This diff is collapsed.
try:
import torch
import bnp
from .batch_norm import BatchNorm2d_NHWC
del torch
del bnp
del batch_norm
except ImportError as err:
print("apex was installed without --bnp flag, contrib.groupbn is not available")
import torch
import numpy as np
from torch.nn.modules.batchnorm import _BatchNorm
import bnp
class bn_NHWC_impl(torch.autograd.Function):
@staticmethod
def forward(ctx, x, s, b, rm, riv, mini_m, mini_riv, mom, epsilon, fuse_relu=False, is_train=True, bn_group=1, my_data=None, pair_data=None, magic=1, pair_data2=None, max_cta_per_sm=2, cta_launch_margin=12):
if is_train:
ctx.save_for_backward(x, s, b, rm, riv, mini_m, mini_riv)
ctx.epsilon = epsilon
ctx.momentum = mom
ctx.fuse_relu = fuse_relu
ctx.my_data = my_data
ctx.pair_data = pair_data
ctx.magic = magic
ctx.pair_data2 = pair_data2
ctx.bn_group = bn_group
ctx.max_cta_per_sm = max_cta_per_sm
ctx.cta_launch_margin = cta_launch_margin
res = bnp.bn_fwd_nhwc(x, s, b, rm, riv, mini_m, mini_riv, mom, epsilon, fuse_relu, my_data, pair_data, pair_data2, bn_group, magic, max_cta_per_sm, cta_launch_margin)
return res
else:
return bnp.bn_fwd_eval_nhwc(x, s, b, rm, riv, bn_group, mom, epsilon, fuse_relu)
@staticmethod
def backward(ctx, grad_y):
x, s, b, rm, riv, mini_m, mini_riv = ctx.saved_variables
epsilon = ctx.epsilon
mom = ctx.momentum
fuse_relu = ctx.fuse_relu
my_data = ctx.my_data
pair_data = ctx.pair_data
magic = ctx.magic
pair_data2 = ctx.pair_data2
bn_group = ctx.bn_group
max_cta_per_sm = ctx.max_cta_per_sm
cta_launch_margin = ctx.cta_launch_margin
dx, dscale, dbias = bnp.bn_bwd_nhwc(x, grad_y, s, b, rm, riv, mini_m, mini_riv, mom, epsilon, fuse_relu, my_data, pair_data, pair_data2, bn_group, magic, max_cta_per_sm, cta_launch_margin)
return dx, dscale, dbias, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
class bn_addrelu_NHWC_impl(torch.autograd.Function):
@staticmethod
def forward(ctx, x, z, s, b, rm, riv, mini_m, mini_riv, mom, epsilon, is_train=True, bn_group=1, my_data=None, pair_data=None, magic=1, pair_data2=None, max_cta_per_sm=2, cta_launch_margin=12):
if is_train:
bitmask = torch.cuda.IntTensor(x.numel()//32)
ctx.save_for_backward(x, s, b, rm, riv, mini_m, mini_riv, bitmask)
ctx.epsilon = epsilon
ctx.momentum = mom
ctx.my_data = my_data
ctx.pair_data = pair_data
ctx.magic = magic
ctx.pair_data2 = pair_data2
ctx.bn_group = bn_group
ctx.max_cta_per_sm = max_cta_per_sm
ctx.cta_launch_margin = cta_launch_margin
res = bnp.bn_addrelu_fwd_nhwc(x, z, s, b, rm, riv, mini_m, mini_riv, bitmask, mom, epsilon, my_data, pair_data, pair_data2, bn_group, magic, max_cta_per_sm, cta_launch_margin)
return res
else:
return bnp.bn_addrelu_fwd_eval_nhwc(x, z, s, b, rm, riv, bn_group, mom, epsilon)
@staticmethod
def backward(ctx, grad_y):
x, s, b, rm, riv, mini_m, mini_riv, bitmask = ctx.saved_variables
epsilon = ctx.epsilon
mom = ctx.momentum
my_data = ctx.my_data
pair_data = ctx.pair_data
magic = ctx.magic
pair_data2 = ctx.pair_data2
bn_group = ctx.bn_group
max_cta_per_sm = ctx.max_cta_per_sm
cta_launch_margin = ctx.cta_launch_margin
dx, dz, dscale, dbias = bnp.bn_addrelu_bwd_nhwc(x, grad_y, s, b, rm, riv, mini_m, mini_riv, bitmask, mom, epsilon, my_data, pair_data, pair_data2, bn_group, magic, max_cta_per_sm, cta_launch_margin)
return dx, dz, dscale, dbias, None, None, None, None, None, None, None, None, None, None, None, None, None, None
class BatchNorm2d_NHWC(_BatchNorm):
def __init__(self, num_features, fuse_relu=False, bn_group=1, max_cta_per_sm=2, cta_launch_margin=12):
super(BatchNorm2d_NHWC, self).__init__(num_features)
self.fuse_relu = fuse_relu
self.minibatch_mean = torch.cuda.FloatTensor(num_features)
self.minibatch_riv = torch.cuda.FloatTensor(num_features)
#defaut to distributed bn disabled
self.bn_group = bn_group
self.max_cta_per_sm = max_cta_per_sm #used only in training fwd and bwd
self.cta_launch_margin = cta_launch_margin #used only in training fwd and bwd
self.my_data = None
self.pair_data = None
self.pair_data2 = None
self.local_rank = 0
self.magic = torch.IntTensor([0])
assert(max_cta_per_sm>0) # won't be able to do much with 0 CTAs :)
#FIXME: turn pair handles into an array
if bn_group>1:
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
assert(world_size >= bn_group)
assert(world_size % bn_group == 0)
bn_sync_steps = 1
if (bn_group==4):
bn_sync_steps = 2
self.ipc_buffer = torch.cuda.ByteTensor(bnp.get_buffer_size(bn_sync_steps))
self.my_data = bnp.get_data_ptr(self.ipc_buffer)
# we are walking on very thin ice here by utilizing internal `_share_cuda_()`
self.storage = self.ipc_buffer.storage()
self.share_cuda = self.storage._share_cuda_()
internal_cuda_mem = self.share_cuda
# internal_cuda_mem[1]: ipc_mem_handle
my_handle = torch.cuda.ByteTensor(np.frombuffer(internal_cuda_mem[1], dtype=np.uint8))
# internal_cuda_mem[3]: offset
my_offset = torch.cuda.IntTensor([internal_cuda_mem[3]])
handles_all = torch.empty(world_size, my_handle.size(0), dtype=my_handle.dtype, device=my_handle.device)
handles_l = list(handles_all.unbind(0))
torch.distributed.all_gather(handles_l, my_handle)
offsets_all = torch.empty(world_size, my_offset.size(0), dtype=my_offset.dtype, device=my_offset.device)
offsets_l = list(offsets_all.unbind(0))
torch.distributed.all_gather(offsets_l, my_offset)
#whom do I actually care about? that would be local_rank XOR 1
self.pair_handle = handles_l[local_rank ^ 1].cpu().contiguous()
pair_offset = offsets_l[local_rank ^ 1].cpu()
self.pair_data = bnp.get_remote_data_ptr(self.pair_handle, pair_offset)
if bn_group>2:
self.pair_handle2 = handles_l[local_rank ^ 2].cpu().contiguous()
pair_offset2 = offsets_l[local_rank ^ 2].cpu()
self.pair_data2 = bnp.get_remote_data_ptr(self.pair_handle2, pair_offset2)
#FIXME: get magic value into C code and eliminate from here
self.magic = torch.IntTensor([2])
self.local_rank = local_rank
def forward(self, x, z=None):
if z is not None:
assert(self.fuse_relu==True)
return bn_addrelu_NHWC_impl.apply(x, z,
self.weight, self.bias,
self.running_mean, self.running_var,
self.minibatch_mean, self.minibatch_riv,
self.momentum,
self.eps, self.training, self.bn_group, self.my_data, self.pair_data, (self.magic), self.pair_data2,
self.max_cta_per_sm, self.cta_launch_margin)
else:
return bn_NHWC_impl.apply(x,
self.weight, self.bias,
self.running_mean, self.running_var,
self.minibatch_mean, self.minibatch_riv,
self.momentum,
self.eps, self.fuse_relu, self.training, self.bn_group, self.my_data, self.pair_data, (self.magic), self.pair_data2,
self.max_cta_per_sm, self.cta_launch_margin)
def __del__(self):
if self.bn_group>1:
bnp.close_remote_data(self.pair_handle)
if self.bn_group>2:
bnp.close_remote_data(self.pair_handle2)
......@@ -55,10 +55,11 @@ class SyncBatchNorm(_BatchNorm):
>>> inp = torch.randn(10, 14, 14, 100).cuda()
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last=False):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last=False, fuse_relu=False):
super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
self.process_group = process_group
self.channel_last = channel_last
self.fuse_relu = fuse_relu
def _specify_process_group(self, process_group):
self.process_group = process_group
......@@ -66,11 +67,11 @@ class SyncBatchNorm(_BatchNorm):
def _specify_channel_last(self, channel_last):
self.channel_last = channel_last
def forward(self, input):
def forward(self, input, z = None):
# if input.dim() == 2, we switch to channel_last for efficient memory accessing
channel_last = self.channel_last if input.dim() != 2 else True
if not self.training and self.track_running_stats and not channel_last:
if not self.training and self.track_running_stats and not self.channel_last and not self.fuse_relu and z == None:
# fall back to pytorch implementation for inference
return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps)
else:
......@@ -81,4 +82,4 @@ class SyncBatchNorm(_BatchNorm):
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else:
exponential_average_factor = self.momentum
return SyncBatchnormFunction.apply(input, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.training or not self.track_running_stats, exponential_average_factor, self.process_group, channel_last)
return SyncBatchnormFunction.apply(input, z, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.training or not self.track_running_stats, exponential_average_factor, self.process_group, self.channel_last, self.fuse_relu)
......@@ -7,7 +7,7 @@ from apex.parallel import ReduceOp
class SyncBatchnormFunction(Function):
@staticmethod
def forward(ctx, input, weight, bias, running_mean, running_variance, eps, track_running_stats = True, momentum = 1.0, process_group = None, channel_last = False):
def forward(ctx, input, z, weight, bias, running_mean, running_variance, eps, track_running_stats = True, momentum = 1.0, process_group = None, channel_last = False, fuse_relu = False):
torch.cuda.nvtx.range_push("sync_BN_fw")
input = input.contiguous()
world_size = 0
......@@ -53,13 +53,14 @@ class SyncBatchnormFunction(Function):
mean = running_mean.data
inv_std = 1.0 / torch.sqrt(running_variance.data + eps)
ctx.save_for_backward(input, weight, mean, inv_std)
ctx.save_for_backward(input, weight, mean, inv_std, z, bias)
ctx.process_group = process_group
ctx.channel_last = channel_last
ctx.world_size = world_size
ctx.fuse_relu = fuse_relu
if channel_last:
out = syncbn.batchnorm_forward_c_last(input, mean, inv_std, weight, bias)
out = syncbn.batchnorm_forward_c_last(input, z, mean, inv_std, weight, bias, fuse_relu)
else:
out = syncbn.batchnorm_forward(input, mean, inv_std, weight, bias)
......@@ -73,11 +74,17 @@ class SyncBatchnormFunction(Function):
# mini batch mean & var are calculated by forward path.
# mu = 1./N*np.sum(h, axis = 0)
# var = 1./N*np.sum((h-mu)**2, axis = 0)
saved_input, weight, mean, inv_std = ctx.saved_tensors
saved_input, weight, mean, inv_std, z, bias = ctx.saved_tensors
process_group = ctx.process_group
channel_last = ctx.channel_last
world_size = ctx.world_size
grad_input = grad_weight = grad_bias = None
fuse_relu = ctx.fuse_relu
grad_input = grad_z = grad_weight = grad_bias = None
if fuse_relu:
grad_output = syncbn.relu_bw_c_last(grad_output, saved_input, z, mean, inv_std, weight, bias)
if isinstance(z, torch.Tensor) and ctx.needs_input_grad[1]:
grad_z = grad_output.clone()
# TODO(jie): why do I have to clone here? life time of grad_output?
if channel_last:
......@@ -100,11 +107,11 @@ class SyncBatchnormFunction(Function):
else:
grad_input = syncbn.batchnorm_backward(grad_output, saved_input, mean, inv_std, weight, mean_dy, mean_dy_xmu)
if weight is None or not ctx.needs_input_grad[1]:
if weight is None or not ctx.needs_input_grad[2]:
grad_weight = None
if weight is None or not ctx.needs_input_grad[2]:
if weight is None or not ctx.needs_input_grad[3]:
grad_bias = None
torch.cuda.nvtx.range_pop()
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None, None
return grad_input, grad_z, grad_weight, grad_bias, None, None, None, None, None, None, None, None
......@@ -55,10 +55,12 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input);
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
at::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input,
const at::optional<at::Tensor> z,
const at::Tensor mean,
const at::Tensor inv_std,
const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift);
const at::optional<at::Tensor> shift,
const bool fuse_relu);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type;
......@@ -82,6 +84,15 @@ at::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu);
at::Tensor relu_backward_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::optional<at::Tensor> z,
const at::Tensor mean,
const at::Tensor inv_std,
const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("welford_mean_var", &welford_mean_var_CUDA, "welford mean variance");
m.def("welford_parallel", &welford_parallel_CUDA, "welford parallel reduce mean variance");
......@@ -92,4 +103,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("batchnorm_forward_c_last", &batchnorm_forward_c_last_CUDA, "batchnorm forward nhwc");
m.def("reduce_bn_c_last", &reduce_bn_c_last_CUDA, "batchnorm backwards reduce grad sum and bias/weight grad nhwc");
m.def("batchnorm_backward_c_last", &batchnorm_backward_c_last_CUDA, "batchnorm backward dgrad nhwc");
m.def("relu_bw_c_last", &relu_backward_c_last_CUDA, "relu_bw_c_last");
}
......@@ -590,6 +590,58 @@ template <
int PARALLEL_LOADS>
__global__ void batchnorm_forward_c_last_kernel(
const scalar_t* __restrict__ input,
const scalar_t* __restrict__ z,
const accscalar_t* __restrict__ mean,
const accscalar_t* __restrict__ inv_std,
const layerscalar_t* __restrict__ weight,
const layerscalar_t* __restrict__ shift,
scalar_t* __restrict__ out,
const int reduction_size,
const int stride,
const bool fuse_relu) {
// tensor dimension (m,c)
// loop along m dimension
int inner_loop_stride = blockDim.y * gridDim.y;
// offset along m dimension
int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
auto m_c = mean[c_offset];
auto inv_std_c = static_cast<accscalar_t>(inv_std[c_offset]);
auto w_c = weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset]);
auto s_c = shift == NULL ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[c_offset]);
int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
int address_base = m_offset * stride + c_offset;
int address_increment = inner_loop_stride * stride;
for (int i = 0; i < loop_count; i++) {
#pragma unroll
for (int j = 0; j < PARALLEL_LOADS; j++) {
if (c_offset < stride && m_offset < reduction_size) {
auto tmp = w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c;
if (z != NULL) {
tmp += z[address_base];
}
out[address_base] = (fuse_relu && tmp <= accscalar_t(0.0) ? scalar_t(0.0) : static_cast<scalar_t>(tmp));
}
m_offset += inner_loop_stride;
address_base += address_increment;
}
}
}
// elementwise BN kernel
template <
typename scalar_t,
typename accscalar_t,
typename layerscalar_t,
int PARALLEL_LOADS>
__global__ void relu_backward_c_last_kernel(
const scalar_t* __restrict__ grad_output,
const scalar_t* __restrict__ input,
const scalar_t* __restrict__ z,
const accscalar_t* __restrict__ mean,
const accscalar_t* __restrict__ inv_std,
const layerscalar_t* __restrict__ weight,
......@@ -618,9 +670,11 @@ __global__ void batchnorm_forward_c_last_kernel(
#pragma unroll
for (int j = 0; j < PARALLEL_LOADS; j++) {
if (c_offset < stride && m_offset < reduction_size) {
out[address_base] = static_cast<scalar_t>(
w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c
);
auto tmp = w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c;
if (z != NULL) {
tmp += z[address_base];
}
out[address_base] = (tmp <= accscalar_t(0.0) ? scalar_t(0.0) : grad_output[address_base]);
}
m_offset += inner_loop_stride;
address_base += address_increment;
......@@ -1146,10 +1200,12 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input) {
at::Tensor batchnorm_forward_c_last_CUDA(
const at::Tensor input,
const at::optional<at::Tensor> z,
const at::Tensor mean,
const at::Tensor inv_std,
const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift) {
const at::optional<at::Tensor> shift,
const bool fuse_relu) {
const auto stride = input.size(input.ndimension()-1);
const auto reduction_size = input.numel() / stride;
......@@ -1169,13 +1225,15 @@ at::Tensor batchnorm_forward_c_last_CUDA(
batchnorm_forward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
input.data<scalar_t_0>(),
z.has_value() ? z.value().data<scalar_t_0>() : NULL,
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
shift.has_value() ? shift.value().data<accscalar_t>(): NULL,
out.data<scalar_t_0>(),
reduction_size,
stride);
stride,
fuse_relu);
);
} else {
if (weight.has_value()) {
......@@ -1188,13 +1246,15 @@ at::Tensor batchnorm_forward_c_last_CUDA(
batchnorm_forward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
input.data<scalar_t_0>(),
z.has_value() ? z.value().data<scalar_t_0>() : NULL,
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<scalar_t_0>() : NULL,
shift.has_value() ? shift.value().data<scalar_t_0>(): NULL,
out.data<scalar_t_0>(),
reduction_size,
stride);
stride,
fuse_relu);
);
}
return out;
......@@ -1350,3 +1410,66 @@ at::Tensor batchnorm_backward_c_last_CUDA(
return grad_input;
}
at::Tensor relu_backward_c_last_CUDA(
const at::Tensor grad_output,
const at::Tensor input,
const at::optional<at::Tensor> z,
const at::Tensor mean,
const at::Tensor inv_std,
const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift) {
const auto stride = input.size(input.ndimension()-1);
const auto reduction_size = input.numel() / stride;
at::Tensor out = at::empty_like(input);
dim3 block;
dim3 grid;
flexible_launch_configs(reduction_size, stride, block, grid);
auto stream = at::cuda::getCurrentCUDAStream();
if (input.scalar_type() == at::ScalarType::Half
&& weight.has_value() && weight.value().scalar_type() == at::ScalarType::Float) {
using namespace at;
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
using accscalar_t = at::acc_type<scalar_t_0, true>;
relu_backward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
grad_output.data<scalar_t_0>(),
input.data<scalar_t_0>(),
z.has_value() ? z.value().data<scalar_t_0>() : NULL,
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
shift.has_value() ? shift.value().data<accscalar_t>(): NULL,
out.data<scalar_t_0>(),
reduction_size,
stride);
);
} else {
if (weight.has_value()) {
AT_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.scalar_type() is not supported with weight.scalar_type()");
}
using namespace at;
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
using accscalar_t = at::acc_type<scalar_t_0, true>;
relu_backward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
grad_output.data<scalar_t_0>(),
input.data<scalar_t_0>(),
z.has_value() ? z.value().data<scalar_t_0>() : NULL,
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<scalar_t_0>() : NULL,
shift.has_value() ? shift.value().data<scalar_t_0>(): NULL,
out.data<scalar_t_0>(),
reduction_size,
stride);
);
}
return out;
}
......@@ -99,6 +99,35 @@ if "--cuda_ext" in sys.argv:
'-O3',
'--use_fast_math'] + version_ge_1_1}))
if "--bnp" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--bnp")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--cuda_ext was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
# Set up macros for forward/backward compatibility hack around
# https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e
version_ge_1_1 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):
version_ge_1_1 = ['-DVERSION_GE_1_1']
ext_modules.append(
CUDAExtension(name='bnp',
sources=['apex/contrib/csrc/groupbn/batch_norm.cu',
'apex/contrib/csrc/groupbn/ipc.cu',
'apex/contrib/csrc/groupbn/interface.cpp',
'apex/contrib/csrc/groupbn/batch_norm_add_relu.cu'],
extra_compile_args={'cxx': [] + version_ge_1_1,
'nvcc':['-DCUDA_HAS_FP16=1',
'-D__CUDA_NO_HALF_OPERATORS__',
'-D__CUDA_NO_HALF_CONVERSIONS__',
'-D__CUDA_NO_HALF2_OPERATORS__',
'-gencode',
'arch=compute_70,code=sm_70'] + version_ge_1_1}))
setup(
name='apex',
version='0.1',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment