Unverified Commit 31aceeaa authored by Deyu Fu's avatar Deyu Fu Committed by GitHub
Browse files

Improvements to apex.mlp (#804)

* update fused bias relu backward kernel

* adding support for not require first layer dgrad

* fix bug: wrong layer in requires grad

* add infrastructure for optional bias and activation, currently only support no bias and no relu

* make bias and relu optional separately

* add sigmoid activation option
parent aad9300b
...@@ -7,17 +7,19 @@ from .. import amp ...@@ -7,17 +7,19 @@ from .. import amp
class MlpFunction(torch.autograd.Function): class MlpFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, *args): def forward(ctx, bias, activation, *args):
output = mlp_cuda.forward(args) output = mlp_cuda.forward(bias, activation, args)
ctx.save_for_backward(*args) ctx.save_for_backward(*args)
ctx.outputs = output ctx.outputs = output
ctx.bias = bias
ctx.activation = activation
return output[0] return output[0]
@staticmethod @staticmethod
def backward(ctx, grad_o): def backward(ctx, grad_o):
grads = mlp_cuda.backward(grad_o, ctx.outputs, ctx.saved_tensors) grads = mlp_cuda.backward(ctx.bias, ctx.activation, grad_o, ctx.outputs, ctx.saved_tensors)
del ctx.outputs del ctx.outputs
return tuple(grads) return (None, None, *grads)
mlp_function = amp.half_function(MlpFunction.apply) mlp_function = amp.half_function(MlpFunction.apply)
...@@ -29,16 +31,21 @@ class MLP(torch.nn.Module): ...@@ -29,16 +31,21 @@ class MLP(torch.nn.Module):
bias (bool): Default True: bias (bool): Default True:
relu (bool): Default True relu (bool): Default True
""" """
def __init__(self, mlp_sizes, bias=True, relu=True): def __init__(self, mlp_sizes, bias=True, activation='relu'):
if not (bias and relu):
raise TypeError("bias and relu must be both true.")
super(MLP, self).__init__() super(MLP, self).__init__()
self.num_layers = len(mlp_sizes) - 1 self.num_layers = len(mlp_sizes) - 1
self.mlp_sizes = copy(mlp_sizes) self.mlp_sizes = copy(mlp_sizes)
self.bias = bias self.bias = 1 if bias else 0
self.relu= relu
if activation is 'none':
self.activation = 0
elif activation is 'relu':
self.activation = 1
elif activation is 'sigmoid':
self.activation = 2
else:
raise TypeError("activation must be relu or none.")
# ignoring bias = False now
self.weights = [] self.weights = []
self.biases = [] self.biases = []
for i in range(self.num_layers): for i in range(self.num_layers):
...@@ -46,10 +53,11 @@ class MLP(torch.nn.Module): ...@@ -46,10 +53,11 @@ class MLP(torch.nn.Module):
self.weights.append(w) self.weights.append(w)
name = 'weight_{}'.format(i) name = 'weight_{}'.format(i)
setattr(self, name, w) setattr(self, name, w)
b = torch.nn.Parameter(torch.empty(mlp_sizes[i+1])) if self.bias:
self.biases.append(b) b = torch.nn.Parameter(torch.empty(mlp_sizes[i+1]))
name = 'bias_{}'.format(i) self.biases.append(b)
setattr(self, name, b) name = 'bias_{}'.format(i)
setattr(self, name, b)
self.reset_parameters() self.reset_parameters()
...@@ -58,13 +66,14 @@ class MLP(torch.nn.Module): ...@@ -58,13 +66,14 @@ class MLP(torch.nn.Module):
dimsum = weight.size(0) + weight.size(1) dimsum = weight.size(0) + weight.size(1)
std = math.sqrt(2. / float(dimsum)) std = math.sqrt(2. / float(dimsum))
nn.init.normal_(weight, 0., std) nn.init.normal_(weight, 0., std)
for bias in self.biases: if self.bias:
std = math.sqrt(1. / float(bias.size(0))) for bias in self.biases:
nn.init.normal_(bias, 0., std) std = math.sqrt(1. / float(bias.size(0)))
nn.init.normal_(bias, 0., std)
def forward(self, input): def forward(self, input):
return mlp_function(input, *self.weights, *self.biases) return mlp_function(self.bias, self.activation, input, *self.weights, *self.biases)
def extra_repr(self): def extra_repr(self):
s = F"MLP sizes: {self.mlp_sizes}, Bias={self.bias}, ReLU={self.relu}" s = F"MLP sizes: {self.mlp_sizes}, Bias={self.bias}, activation={self.activation}"
return s return s
...@@ -19,7 +19,9 @@ int mlp_fp( ...@@ -19,7 +19,9 @@ int mlp_fp(
int* output_features, int* output_features,
T** BPtr, T** BPtr,
T* Y, T* Y,
T* reserved_space); T* reserved_space,
int use_bias,
int activation);
template <typename T> template <typename T>
int mlp_bp( int mlp_bp(
...@@ -35,11 +37,18 @@ int mlp_bp( ...@@ -35,11 +37,18 @@ int mlp_bp(
T* work_space, T* work_space,
T* dX, T* dX,
T** dwPtr, T** dwPtr,
T** dbPtr); T** dbPtr,
bool requires_grad,
int use_bias,
int activation);
std::vector<at::Tensor> mlp_forward(int use_bias, int activation, std::vector<at::Tensor> inputs) {
std::vector<at::Tensor> mlp_forward(std::vector<at::Tensor> inputs) { auto num_layers = inputs.size() - 1;
// inputs contains (input, weights, biases) if (use_bias) {
auto num_layers = (inputs.size() - 1) / 2; // inputs contains (input, weights, biases)
num_layers /= 2;
}
auto batch_size = inputs[0].size(0); auto batch_size = inputs[0].size(0);
auto input_features = inputs[0].size(1); auto input_features = inputs[0].size(1);
...@@ -60,7 +69,9 @@ std::vector<at::Tensor> mlp_forward(std::vector<at::Tensor> inputs) { ...@@ -60,7 +69,9 @@ std::vector<at::Tensor> mlp_forward(std::vector<at::Tensor> inputs) {
std::vector<scalar_t*> b_ptr; std::vector<scalar_t*> b_ptr;
for (int i = 0; i < num_layers; i++) { for (int i = 0; i < num_layers; i++) {
w_ptr.push_back(inputs[i + 1].data_ptr<scalar_t>()); w_ptr.push_back(inputs[i + 1].data_ptr<scalar_t>());
b_ptr.push_back(inputs[i + 1 + num_layers].data_ptr<scalar_t>()); if (use_bias) {
b_ptr.push_back(inputs[i + 1 + num_layers].data_ptr<scalar_t>());
}
} }
auto result = mlp_fp<scalar_t>( auto result = mlp_fp<scalar_t>(
inputs[0].data_ptr<scalar_t>(), inputs[0].data_ptr<scalar_t>(),
...@@ -71,37 +82,48 @@ std::vector<at::Tensor> mlp_forward(std::vector<at::Tensor> inputs) { ...@@ -71,37 +82,48 @@ std::vector<at::Tensor> mlp_forward(std::vector<at::Tensor> inputs) {
output_features.data(), output_features.data(),
b_ptr.data(), b_ptr.data(),
out.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(),
reserved_space.data_ptr<scalar_t>()); reserved_space.data_ptr<scalar_t>(),
use_bias,
activation);
}); });
return {out, reserved_space}; return {out, reserved_space};
} }
std::vector<at::Tensor> mlp_backward( std::vector<at::Tensor> mlp_backward(
at::Tensor grad_o, int use_bias,
std::vector<at::Tensor> fprop_outputs, int activation,
std::vector<at::Tensor> inputs) { at::Tensor grad_o,
// same code to get sizes and W pointers std::vector<at::Tensor> fprop_outputs,
auto num_layers = (inputs.size() - 1) / 2; std::vector<at::Tensor> inputs) {
auto num_layers = inputs.size() - 1;
if (use_bias) {
// inputs contains (input, weights, biases)
num_layers /= 2;
}
auto batch_size = inputs[0].size(0); auto batch_size = inputs[0].size(0);
auto input_features = inputs[0].size(1); auto input_features = inputs[0].size(1);
// TODO: not creating empty tensor for it?
bool requires_grad = inputs[0].requires_grad();
std::vector<int> output_features; std::vector<int> output_features;
for (int i = 0; i < num_layers; i++) { for (int i = 0; i < num_layers; i++) {
output_features.push_back(inputs[i + 1].size(0)); output_features.push_back(inputs[i + 1].size(0));
} }
// create outputs, length of inputs // create outputs, length of inputs
// TODO: not create bias if not needed
std::vector<at::Tensor> outputs; std::vector<at::Tensor> outputs;
for (int i = 0; i < inputs.size(); i++) { for (int i = 0; i < inputs.size(); i++) {
outputs.push_back(at::empty(inputs[i].sizes(), inputs[i].type())); // clone for testing now outputs.push_back(at::empty(inputs[i].sizes(), inputs[i].type())); // clone for testing now
} }
AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].type(), "mlp_forward", [&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].type(), "mlp_backward", [&] {
std::vector<scalar_t*> w_ptr; std::vector<scalar_t*> w_ptr;
std::vector<scalar_t*> b_ptr;
for (int i = 0; i < num_layers; i++) { for (int i = 0; i < num_layers; i++) {
w_ptr.push_back(inputs[i + 1].data_ptr<scalar_t>()); w_ptr.push_back(inputs[i + 1].data_ptr<scalar_t>());
b_ptr.push_back(inputs[i + 1 + num_layers].data_ptr<scalar_t>());
} }
std::vector<scalar_t*> outputs_ptr; std::vector<scalar_t*> outputs_ptr;
for (int i = 0; i < inputs.size(); i++) { for (int i = 0; i < inputs.size(); i++) {
...@@ -127,7 +149,10 @@ std::vector<at::Tensor> mlp_backward( ...@@ -127,7 +149,10 @@ std::vector<at::Tensor> mlp_backward(
work_space.data_ptr<scalar_t>(), work_space.data_ptr<scalar_t>(),
outputs_ptr[0], outputs_ptr[0],
outputs_ptr.data() + 1, outputs_ptr.data() + 1,
outputs_ptr.data() + 1 + num_layers); outputs_ptr.data() + 1 + num_layers,
requires_grad,
use_bias,
activation);
}); });
return outputs; return outputs;
......
...@@ -10,8 +10,11 @@ ...@@ -10,8 +10,11 @@
#include <cublas_v2.h> #include <cublas_v2.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#define BIASADDRELU_FPROP_NUM_THREADS 128 // constants for fused bias+relu kernel
#define BIASADDRELU_BPROP_NUM_THREADS 128 #define BIAS_RELU_FW_NTHREADS 128 // forward number of thread per block
#define BIAS_RELU_BW_NTHREADS_X 32 // backward number of thread in feature dim
#define BIAS_RELU_BW_NTHREADS_Y 16 // backward number of thread in batch dim
#define BIAS_RELU_RED_PER_THREAD 16 // backward minimal reduction length per thread
// move to a header later on // move to a header later on
#define ILP 4 #define ILP 4
...@@ -42,6 +45,12 @@ __device__ __inline__ float relu(float a) { ...@@ -42,6 +45,12 @@ __device__ __inline__ float relu(float a) {
return (retf); return (retf);
} }
// Keep Sigmoid in float only. When using half, cast to float before calling.
__device__ __inline__ float sigmoid(float a) {
float retf = 1.f / (1.f + expf(-a));
return (retf);
}
// FP64 Wrapper around cublas GEMMEx // FP64 Wrapper around cublas GEMMEx
cublasStatus_t mlp_gemm( cublasStatus_t mlp_gemm(
cublasHandle_t handle, cublasHandle_t handle,
...@@ -156,9 +165,55 @@ cublasStatus_t mlp_gemm( ...@@ -156,9 +165,55 @@ cublasStatus_t mlp_gemm(
CUBLAS_GEMM_DEFAULT_TENSOR_OP); CUBLAS_GEMM_DEFAULT_TENSOR_OP);
} }
// Bias ADD + ReLU. Assume input X is [features x batch size], assume column major. // Bias ADD. Assume input X is [features x batch size], column major.
// Bias is one 'features' long vector, with implicit broadcast. // Bias is one 'features' long vector, with implicit broadcast.
// Currently, activation support fuesed ReLU. Safe to call in-place. template <typename T>
__global__ void biasAdd_fprop(T *X, T *b, uint batch_size, uint features) {
T r_x[ILP];
T r_b[ILP];
if(is_aligned(X) && is_aligned(b) && features % ILP ==0) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
int row = tid % (features / ILP);
load_store(r_x, X, 0 , tid);
load_store(r_b, b, 0 , row);
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
float bias_sum = static_cast<float>(r_x[ii]) + static_cast<float>(r_b[ii]);
r_x[ii] = bias_sum;
}
load_store(X, r_x, tid , 0);
}
} else {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int idx = tid + ii * blockDim.x * gridDim.x;
if(idx < features * batch_size) {
int row = tid % features;
r_x[ii] = X[idx];
r_b[ii] = b[row];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
float bias_sum = static_cast<float>(r_x[ii]) + static_cast<float>(r_b[ii]);
r_x[ii] = bias_sum;
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int idx = tid + ii * blockDim.x * gridDim.x;
if(idx < features * batch_size) {
X[idx] = r_x[ii];
}
}
}
}
}
// Bias ADD + ReLU. Assume input X is [features x batch size], column major.
// Activation support fuesed ReLU. Safe to call in-place.
template <typename T> template <typename T>
__global__ void biasAddRelu_fprop(T *X, T *b, uint batch_size, uint features) { __global__ void biasAddRelu_fprop(T *X, T *b, uint batch_size, uint features) {
T r_x[ILP]; T r_x[ILP];
...@@ -204,32 +259,308 @@ __global__ void biasAddRelu_fprop(T *X, T *b, uint batch_size, uint features) { ...@@ -204,32 +259,308 @@ __global__ void biasAddRelu_fprop(T *X, T *b, uint batch_size, uint features) {
} }
} }
// ReLU. Assume input X is [features x batch size], column major.
// Safe to call in-place.
template <typename T>
__global__ void Relu_fprop(T *X, uint batch_size, uint features) {
T r_x[ILP];
if(is_aligned(X) && features % ILP ==0) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
load_store(r_x, X, 0 , tid);
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
r_x[ii] = relu(static_cast<float>(r_x[ii]));
}
load_store(X, r_x, tid , 0);
}
} else {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int idx = tid + ii * blockDim.x * gridDim.x;
if(idx < features * batch_size) {
r_x[ii] = X[idx];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
r_x[ii] = relu(static_cast<float>(r_x[ii]));
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int idx = tid + ii * blockDim.x * gridDim.x;
if(idx < features * batch_size) {
X[idx] = r_x[ii];
}
}
}
}
}
// Sigmoid. Assume input X is [features x batch size], column major.
// Safe to call in-place.
template <typename T>
__global__ void Sigmoid_fprop(T *X, uint batch_size, uint features) {
T r_x[ILP];
if(is_aligned(X) && features % ILP ==0) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
load_store(r_x, X, 0 , tid);
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
r_x[ii] = sigmoid(static_cast<float>(r_x[ii]));
}
load_store(X, r_x, tid , 0);
}
} else {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int idx = tid + ii * blockDim.x * gridDim.x;
if(idx < features * batch_size) {
r_x[ii] = X[idx];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
r_x[ii] = sigmoid(static_cast<float>(r_x[ii]));
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int idx = tid + ii * blockDim.x * gridDim.x;
if(idx < features * batch_size) {
X[idx] = r_x[ii];
}
}
}
}
}
// ReLU. Assume input X is [features x batch size], column major.
// Safe to call in-place.
template <typename T>
__global__ void Relu_bprop(T *dY, T *Y, uint batch_size, uint features, T *dX) {
T r_dy[ILP];
T r_y[ILP];
if(is_aligned(dY) &&
is_aligned(Y) &&
is_aligned(dX) &&
features % ILP ==0) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
load_store(r_dy, dY, 0 , tid);
load_store(r_y, Y, 0 , tid);
#pragma unroll
for(int ii=0;ii<ILP;ii++){
if ((float)r_y[ii] <= 0.f)
r_dy[ii] = 0;
}
load_store(dX, r_dy, tid, 0);
}
} else {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int idx = tid + ii * blockDim.x * gridDim.x;
if(idx < features * batch_size) {
r_dy[ii] = dY[idx];
r_y[ii] = Y[idx];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
if ((float)r_y[ii] <= 0.f)
r_dy[ii] = 0;
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int idx = tid + ii * blockDim.x * gridDim.x;
if(idx < features * batch_size) {
dX[idx] = r_dy[ii];
}
}
}
}
}
// Sigmoid. Assume input X is [features x batch size], column major.
// Safe to call in-place.
template <typename T>
__global__ void Sigmoid_bprop(T *dY, T *Y, uint batch_size, uint features, T *dX) {
T r_dy[ILP];
T r_y[ILP];
if(is_aligned(dY) &&
is_aligned(Y) &&
is_aligned(dX) &&
features % ILP ==0) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) {
load_store(r_dy, dY, 0 , tid);
load_store(r_y, Y, 0 , tid);
#pragma unroll
for(int ii=0;ii<ILP;ii++){
float grad_out = r_dy[ii];
float out = r_y[ii];
float grad_i = out * ( 1.f - out) * grad_out;
r_dy[ii] = grad_i;
}
load_store(dX, r_dy, tid, 0);
}
} else {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int idx = tid + ii * blockDim.x * gridDim.x;
if(idx < features * batch_size) {
r_dy[ii] = dY[idx];
r_y[ii] = Y[idx];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
float grad_out = r_dy[ii];
float out = r_y[ii];
float grad_i = out * ( 1.f - out) * grad_out;
r_dy[ii] = grad_i;
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int idx = tid + ii * blockDim.x * gridDim.x;
if(idx < features * batch_size) {
dX[idx] = r_dy[ii];
}
}
}
}
}
// Compute grid size for pointwise backward kernel. // Compute grid size for pointwise backward kernel.
// Some intelligence needed to determine number of splits for reduction. // block_x/y is total elment being handled per block, not number of threads
void get_biasAddRelu_bprop_grid_size( void get_biasAddRelu_bprop_grid_size(
int yfeat, int yfeat,
int threadsPerBlock,
int batch_size, int batch_size,
int block_x,
int block_y,
int* grid_x, int* grid_x,
int* grid_y) { int* grid_y) {
*grid_x = (yfeat + block_x - 1) / block_x;
// Get number of SMs for efficient reduction. // Get number of SMs for efficient reduction.
int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
// First preference, whole reduction in 1 CTA // can switch to occupancy calculation. use 4 below now for sm_70
int nBlocks = (yfeat + threadsPerBlock - 1) / threadsPerBlock; int max_blocks_y = num_SMs * 4 / (*grid_x);
// block_y should be from minimal work per thread
// Figure out how many splits to divide reduction into. At least 32 elements per CTA. int nRedSplits = (batch_size + block_y - 1) / block_y;
// we want grid_y as close to sqrt(batchsize)? // increase number of elem per thread redcution to not launch more than enough
int nRedSplits = std::sqrt(batch_size); // kernel adjust work, so here we just launch max block
// for batchsize <=64, just use 1 block *grid_y = std::min(nRedSplits, max_blocks_y);
if(batch_size < 64) nRedSplits = 1;
// no need to go over occupancy
nRedSplits = min((8*num_SMs)/nBlocks, nRedSplits);
*grid_x = nBlocks;
*grid_y = nRedSplits;
return; return;
} }
// Addition done deterministically via a 2-pass approach. Each CTA writes out partial
// sum, and the last CTA in grid Y dimension accumulates partials serially and writes to result.
template <typename T, int UNROLL_FACTOR>
__global__ void biasAdd_bprop(
T* dY,
int features,
int batch_size,
volatile float* intermediate,
int* semaphores,
T* db) {
// The feature that this thread is responsible for
int f = blockIdx.x * blockDim.x + threadIdx.x;
// Compute the span this thread is responsible for
// For this block
int b_chunkSize = (batch_size + gridDim.y - 1) / gridDim.y;
int b_nStart = blockIdx.y * b_chunkSize;
int b_nSpan = min(batch_size, b_nStart + b_chunkSize) - b_nStart;
// For this thread
int chunkSize = (b_chunkSize + blockDim.y - 1) / blockDim.y;
int nStart = threadIdx.y * chunkSize + b_nStart;
int nSpan = min(b_nStart + b_nSpan, nStart + chunkSize) - nStart;
volatile float* out = intermediate + blockIdx.y * features;
// Flag to trigger last reduction.
__shared__ bool isLastBlock;
// we know block size for now
__shared__ float smem[BIAS_RELU_BW_NTHREADS_X*BIAS_RELU_BW_NTHREADS_Y];
// Accumulate db in FP32 always
float db_local = 0;
if (f < features) {
int nidx = 0;
// Handle non-multiple of UNROLL_FACTOR residue
for (; nidx < nSpan % UNROLL_FACTOR; nidx++) {
int row, col, flat_idx;
row = f;
col = nStart + nidx;
flat_idx = col * features + row;
db_local += (float)dY[flat_idx];
}
// Handle meat of work
for (; (nidx + UNROLL_FACTOR - 1) < nSpan; nidx += UNROLL_FACTOR) {
int row, col, flat_idx;
row = f;
col = nStart + nidx;
flat_idx = col * features + row;
#pragma unroll 4
for (int u = 0; u < UNROLL_FACTOR; u++) {
db_local += (float)dY[flat_idx];
flat_idx += features;
}
}
// naive block reduction on y-dim
int linear_idx = threadIdx.y * blockDim.x + threadIdx.x;
smem[linear_idx] = db_local;
}
__syncthreads();
if (f < features) {
if(threadIdx.y == 0) {
for(int yidx = 1; yidx < blockDim.y; yidx++){
db_local += smem[yidx * blockDim.x + threadIdx.x];
}
// block result is in db_local now for all threadIdx.y == 0
// Write out partial result
out[f] = db_local;
}
}
__threadfence();
__syncthreads();
// Increment semaphore and check if this is the last CTA in the grid_y dimension.
// Only thread (0,0) calls this
if (threadIdx.x == 0 && threadIdx.y == 0 && f < features) {
unsigned int sum_idx;
sum_idx = atomicAdd(&(semaphores[blockIdx.x]), 1);
isLastBlock = (sum_idx == (gridDim.y - 1));
}
__syncthreads();
db_local = 0;
// No block reduction for now, only thread (*,0) do grid reduction
if (isLastBlock && f < features) {
if(threadIdx.y == 0) {
for (int n = 0; n < gridDim.y; n++) {
int row, col;
row = f;
col = n;
db_local += (float)(intermediate[col * features + row]);
}
db[f] = (T)db_local;
}
}
}
// Addition done deterministically via a 2-pass approach. Each CTA writes out partial // Addition done deterministically via a 2-pass approach. Each CTA writes out partial
// sum, and the last CTA in grid Y dimension accumulates partials serially and writes to result. // sum, and the last CTA in grid Y dimension accumulates partials serially and writes to result.
template <typename T, int UNROLL_FACTOR> template <typename T, int UNROLL_FACTOR>
...@@ -245,14 +576,22 @@ __global__ void biasAddRelu_bprop( ...@@ -245,14 +576,22 @@ __global__ void biasAddRelu_bprop(
// The feature that this thread is responsible for // The feature that this thread is responsible for
int f = blockIdx.x * blockDim.x + threadIdx.x; int f = blockIdx.x * blockDim.x + threadIdx.x;
// Compute the batch span this thread is responsible for // Compute the span this thread is responsible for
int chunkSize = (batch_size + gridDim.y - 1) / gridDim.y; // For this block
int nStart = blockIdx.y * chunkSize; int b_chunkSize = (batch_size + gridDim.y - 1) / gridDim.y;
int nSpan = min(batch_size, nStart + chunkSize) - nStart; int b_nStart = blockIdx.y * b_chunkSize;
int b_nSpan = min(batch_size, b_nStart + b_chunkSize) - b_nStart;
// For this thread
int chunkSize = (b_chunkSize + blockDim.y - 1) / blockDim.y;
int nStart = threadIdx.y * chunkSize + b_nStart;
int nSpan = min(b_nStart + b_nSpan, nStart + chunkSize) - nStart;
volatile float* out = intermediate + blockIdx.y * features; volatile float* out = intermediate + blockIdx.y * features;
// Flag to trigger last reduction. // Flag to trigger last reduction.
__shared__ bool isLastBlock; __shared__ bool isLastBlock;
// we know block size for now
__shared__ float smem[BIAS_RELU_BW_NTHREADS_X*BIAS_RELU_BW_NTHREADS_Y];
// Accumulate db in FP32 always // Accumulate db in FP32 always
float db_local = 0; float db_local = 0;
...@@ -296,15 +635,28 @@ __global__ void biasAddRelu_bprop( ...@@ -296,15 +635,28 @@ __global__ void biasAddRelu_bprop(
} }
} }
// Write out partial result // naive block reduction on y-dim
out[f] = db_local; int linear_idx = threadIdx.y * blockDim.x + threadIdx.x;
smem[linear_idx] = db_local;
}
__syncthreads();
if (f < features) {
if(threadIdx.y == 0) {
for(int yidx = 1; yidx < blockDim.y; yidx++){
db_local += smem[yidx * blockDim.x + threadIdx.x];
}
// block result is in db_local now for all threadIdx.y == 0
// Write out partial result
out[f] = db_local;
}
} }
__threadfence(); __threadfence();
__syncthreads(); __syncthreads();
// Increment semaphore and check if this is the last CTA in // Increment semaphore and check if this is the last CTA in the grid_y dimension.
// the grid_y dimension. // Only thread (0,0) calls this
if (threadIdx.x == 0 && f < features) { if (threadIdx.x == 0 && threadIdx.y == 0 && f < features) {
unsigned int sum_idx; unsigned int sum_idx;
sum_idx = atomicAdd(&(semaphores[blockIdx.x]), 1); sum_idx = atomicAdd(&(semaphores[blockIdx.x]), 1);
isLastBlock = (sum_idx == (gridDim.y - 1)); isLastBlock = (sum_idx == (gridDim.y - 1));
...@@ -312,14 +664,17 @@ __global__ void biasAddRelu_bprop( ...@@ -312,14 +664,17 @@ __global__ void biasAddRelu_bprop(
__syncthreads(); __syncthreads();
db_local = 0; db_local = 0;
// No block reduction for now, only thread (*,0) do grid reduction
if (isLastBlock && f < features) { if (isLastBlock && f < features) {
for (int n = 0; n < gridDim.y; n++) { if(threadIdx.y == 0) {
int row, col; for (int n = 0; n < gridDim.y; n++) {
row = f; int row, col;
col = n; row = f;
db_local += (float)(intermediate[col * features + row]); col = n;
db_local += (float)(intermediate[col * features + row]);
}
db[f] = (T)db_local;
} }
db[f] = (T)db_local;
} }
} }
...@@ -338,10 +693,16 @@ __global__ void biasAddRelu_bprop_aligned( ...@@ -338,10 +693,16 @@ __global__ void biasAddRelu_bprop_aligned(
// The feature that this thread is responsible for // The feature that this thread is responsible for
int f = blockIdx.x * blockDim.x + threadIdx.x; int f = blockIdx.x * blockDim.x + threadIdx.x;
// Compute the batch span this thread is responsible for // Compute the span this thread is responsible for
int chunkSize = (batch_size + gridDim.y - 1) / gridDim.y; // For this block
int nStart = blockIdx.y * chunkSize; int b_chunkSize = (batch_size + gridDim.y - 1) / gridDim.y;
int nSpan = min(batch_size, nStart + chunkSize) - nStart; int b_nStart = blockIdx.y * b_chunkSize;
int b_nSpan = min(batch_size, b_nStart + b_chunkSize) - b_nStart;
// For this thread
int chunkSize = (b_chunkSize + blockDim.y - 1) / blockDim.y;
int nStart = threadIdx.y * chunkSize + b_nStart;
int nSpan = min(b_nStart + b_nSpan, nStart + chunkSize) - nStart;
volatile float* out = intermediate + blockIdx.y * features; volatile float* out = intermediate + blockIdx.y * features;
// Flag to trigger last reduction. // Flag to trigger last reduction.
...@@ -399,24 +760,45 @@ __global__ void biasAddRelu_bprop_aligned( ...@@ -399,24 +760,45 @@ __global__ void biasAddRelu_bprop_aligned(
} }
} }
if(gridDim.y == 1) { // we know block size for now
__shared__ float smem[BIAS_RELU_BW_NTHREADS_X*BIAS_RELU_BW_NTHREADS_Y*ILP];
// naive block reduction on y-dim
int linear_idx = threadIdx.y * blockDim.x + threadIdx.x;
float* smem_out = smem + ILP * linear_idx;
#pragma unroll #pragma unroll
for(int ii=0;ii<ILP;ii++){ for(int ii=0;ii<ILP;ii++){
r_dy[ii] = db_local[ii]; // reuse local dy buffer smem_out[ii] = db_local[ii]; // reuse local dy buffer
}
load_store(db, r_dy, f, 0);
return;
} }
__syncthreads();
if(threadIdx.y == 0) {
for(int yidx = 1; yidx < blockDim.y; yidx++){
float* smem_in = smem + ILP * (yidx * blockDim.x + threadIdx.x);
#pragma unroll
for(int ii=0;ii<ILP;ii++){
db_local[ii] += smem_in[ii]; // reuse local dy buffer
}
}
// Write out partial result // block result is in db_local now for all threadIdx.y == 0
load_store(out, db_local, f, 0); // TODO: maybe not useful early exit here
if(gridDim.y == 1) {
#pragma unroll
for(int ii=0;ii<ILP;ii++){
r_dy[ii] = db_local[ii]; // reuse local dy buffer
}
load_store(db, r_dy, f, 0);
return;
}
// Write out partial result
load_store(out, db_local, f, 0);
}
__threadfence(); __threadfence();
__syncthreads(); __syncthreads();
// Increment semaphore and check if this is the last CTA in // Increment semaphore and check if this is the last CTA in the grid_y dimension.
// the grid_y dimension. // Only thread (0,0) calls this
if (threadIdx.x == 0) { if (threadIdx.x == 0 && threadIdx.y == 0) {
unsigned int sum_idx; unsigned int sum_idx;
sum_idx = atomicAdd(&(semaphores[blockIdx.x]), 1); sum_idx = atomicAdd(&(semaphores[blockIdx.x]), 1);
isLastBlock = (sum_idx == (gridDim.y - 1)); isLastBlock = (sum_idx == (gridDim.y - 1));
...@@ -428,22 +810,26 @@ __global__ void biasAddRelu_bprop_aligned( ...@@ -428,22 +810,26 @@ __global__ void biasAddRelu_bprop_aligned(
db_local[ii] = 0.f; db_local[ii] = 0.f;
} }
float r_db[ILP]; float r_db[ILP];
// No block reduction for now, only thread (*,0) do grid reduction
if (isLastBlock) { if (isLastBlock) {
for (int n = 0; n < gridDim.y; n++) { if(threadIdx.y == 0){
int row, col; for (int n = 0; n < gridDim.y; n++) {
row = f; int row, col;
col = n; row = f;
load_store(r_db, intermediate, 0, col * features / ILP + row); col = n;
load_store(r_db, intermediate, 0, col * features / ILP + row);
#pragma unroll #pragma unroll
for(int ii=0;ii<ILP;ii++){ for(int ii=0;ii<ILP;ii++){
db_local[ii] += r_db[ii]; db_local[ii] += r_db[ii];
}
} }
}
#pragma unroll #pragma unroll
for(int ii=0;ii<ILP;ii++){ for(int ii=0;ii<ILP;ii++){
r_dy[ii] = db_local[ii]; // reuse local dy buffer r_dy[ii] = db_local[ii]; // reuse local dy buffer
}
load_store(db, r_dy, f, 0);
} }
load_store(db, r_dy, f, 0);
} }
} }
...@@ -502,10 +888,20 @@ size_t get_reduction_scratch_space(int batch_size, int num_layers, const int* ou ...@@ -502,10 +888,20 @@ size_t get_reduction_scratch_space(int batch_size, int num_layers, const int* ou
size_t max_scratch_space = 0; size_t max_scratch_space = 0;
// Loop over all layers to see which one needs the max scratch space // Loop over all layers to see which one needs the max scratch space
for (int l = 0; l < num_layers; l++) { for (int l = 0; l < num_layers; l++) {
int tmp, num_splits; // need to find max(aligned, not_aligned)
int tmp, res0, res1;
int block_x = BIAS_RELU_BW_NTHREADS_X;
int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;
get_biasAddRelu_bprop_grid_size(
output_features[l], batch_size, block_x, block_y, &tmp, &res0);
block_x = ILP * BIAS_RELU_BW_NTHREADS_X;
get_biasAddRelu_bprop_grid_size( get_biasAddRelu_bprop_grid_size(
output_features[l], BIASADDRELU_BPROP_NUM_THREADS, batch_size, &tmp, &num_splits); output_features[l], batch_size, block_x, block_y, &tmp, &res1);
max_scratch_space = std::max(max_scratch_space, (size_t)(output_features[l] * num_splits));
max_scratch_space = std::max(max_scratch_space, (size_t)(output_features[l] * res0));
max_scratch_space = std::max(max_scratch_space, (size_t)(output_features[l] * res1));
} }
return max_scratch_space; return max_scratch_space;
...@@ -581,7 +977,9 @@ int mlp_fp( ...@@ -581,7 +977,9 @@ int mlp_fp(
int* output_features, int* output_features,
T** BPtr, T** BPtr,
T* Y, T* Y,
T* reserved_space) { T* reserved_space,
int use_bias,
int activation) {
T *weight, *input, *output, *bias; T *weight, *input, *output, *bias;
T *reserved_space_x, *reserved_space_y; T *reserved_space_x, *reserved_space_y;
reserved_space_x = NULL; reserved_space_x = NULL;
...@@ -597,7 +995,9 @@ int mlp_fp( ...@@ -597,7 +995,9 @@ int mlp_fp(
weight = WPtr[layer]; weight = WPtr[layer];
input = (layer == 0) ? X : reserved_space_x; input = (layer == 0) ? X : reserved_space_x;
output = (layer == num_layers - 1) ? Y : reserved_space_y; output = (layer == num_layers - 1) ? Y : reserved_space_y;
bias = BPtr[layer]; if (use_bias) {
bias = BPtr[layer];
}
int ifeat = (layer == 0) ? input_features : output_features[layer - 1]; int ifeat = (layer == 0) ? input_features : output_features[layer - 1];
int ofeat = output_features[layer]; int ofeat = output_features[layer];
...@@ -627,12 +1027,33 @@ int mlp_fp( ...@@ -627,12 +1027,33 @@ int mlp_fp(
return 1; return 1;
} }
// Call biasReLU
const uint &input_size = ofeat; const uint &input_size = ofeat;
int num_blocks = 0; int num_blocks = 0;
int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAddRelu_fprop<T>, BIASADDRELU_FPROP_NUM_THREADS, 0); // Call biasReLU
biasAddRelu_fprop<<<num_SMs*num_blocks, BIASADDRELU_FPROP_NUM_THREADS, 0, stream>>>(output, bias, batch_size, input_size); if(use_bias == 1) {
if (activation == 0) { // no activation
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAdd_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
biasAdd_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size, input_size);
} else if (activation == 1) { // relu
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAddRelu_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
biasAddRelu_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size, input_size);
} else if (activation == 2) { // sigmoid
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAdd_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
biasAdd_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size, input_size);
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
Sigmoid_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size);
}
} else {
// don't need to do anything in case of no activation and no bias
if (activation == 1) { // relu
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Relu_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
Relu_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size);
} else if (activation == 2) { // sigmoid
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
Sigmoid_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size);
}
}
// Set current output as next layer input // Set current output as next layer input
reserved_space_x = reserved_space_y; reserved_space_x = reserved_space_y;
...@@ -660,7 +1081,10 @@ int mlp_bp( ...@@ -660,7 +1081,10 @@ int mlp_bp(
T* work_space, T* work_space,
T* dX, T* dX,
T** dwPtr, T** dwPtr,
T** dbPtr) { T** dbPtr,
bool requires_grad,
int use_bias,
int activation) {
T* weight; T* weight;
T *dweight, *dx, *dy, *dbias; T *dweight, *dx, *dy, *dbias;
T *x, *y; T *x, *y;
...@@ -719,32 +1143,85 @@ int mlp_bp( ...@@ -719,32 +1143,85 @@ int mlp_bp(
float one = 1.f; float one = 1.f;
float zero = 0.f; float zero = 0.f;
// Call bias ReLU backprop - first implementation, 1 thread per bias element if (use_bias == 1) {
int threadsPerBlock = BIASADDRELU_BPROP_NUM_THREADS; if (activation == 0) { // no acitvation
int grid_x, grid_y; // bgrad
get_biasAddRelu_bprop_grid_size(yfeat, threadsPerBlock, batch_size, &grid_x, &grid_y); dim3 block(BIAS_RELU_BW_NTHREADS_X, BIAS_RELU_BW_NTHREADS_Y);
int grid_x, grid_y;
dim3 block(threadsPerBlock); cudaMemsetAsync(semaphores, 0, semaphore_size, stream);
cudaMemsetAsync(semaphores, 0, semaphore_size, stream); int block_x = BIAS_RELU_BW_NTHREADS_X;
int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;
if(yfeat % (ILP * threadsPerBlock) == 0 && get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y);
is_aligned(y) && dim3 grid(grid_x, grid_y);
is_aligned(dy) && biasAdd_bprop<T, 4><<<grid, block, 0, stream>>>(
is_aligned(dy_gemm) && dy, yfeat, batch_size, db_scratch, semaphores, dbias);
is_aligned(dbias)){ // bypass dgrad through reset pointer
dim3 grid(grid_x/ILP, grid_y); dy_gemm = dy;
biasAddRelu_bprop_aligned<T, 4><<<grid, block, 0, stream>>>( } else if (activation == 1) { // relu
y, dy, yfeat, batch_size, dy_gemm, db_scratch, semaphores, dbias); dim3 block(BIAS_RELU_BW_NTHREADS_X, BIAS_RELU_BW_NTHREADS_Y);
} else { int grid_x, grid_y;
dim3 grid(grid_x, grid_y); cudaMemsetAsync(semaphores, 0, semaphore_size, stream);
biasAddRelu_bprop<T, 4><<<grid, block, 0, stream>>>(
y, dy, yfeat, batch_size, dy_gemm, db_scratch, semaphores, dbias); if(yfeat % (ILP * BIAS_RELU_BW_NTHREADS_X) == 0 &&
is_aligned(y) &&
is_aligned(dy) &&
is_aligned(dy_gemm) &&
is_aligned(dbias)){
int block_x = ILP * BIAS_RELU_BW_NTHREADS_X;
int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;
get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y);
dim3 grid(grid_x, grid_y);
biasAddRelu_bprop_aligned<T, 4><<<grid, block, 0, stream>>>(
y, dy, yfeat, batch_size, dy_gemm, db_scratch, semaphores, dbias);
} else {
int block_x = BIAS_RELU_BW_NTHREADS_X;
int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;
get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y);
dim3 grid(grid_x, grid_y);
biasAddRelu_bprop<T, 4><<<grid, block, 0, stream>>>(
y, dy, yfeat, batch_size, dy_gemm, db_scratch, semaphores, dbias);
}
} else if (activation == 2) { // sigmoid
// activation backward
int num_blocks = 0;
int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_bprop<T>, BIAS_RELU_FW_NTHREADS, 0);
Sigmoid_bprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(dy, y, batch_size, yfeat, dy_gemm);
// bgrad, from dy_gemm
dim3 block(BIAS_RELU_BW_NTHREADS_X, BIAS_RELU_BW_NTHREADS_Y);
int grid_x, grid_y;
cudaMemsetAsync(semaphores, 0, semaphore_size, stream);
int block_x = BIAS_RELU_BW_NTHREADS_X;
int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y;
get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y);
dim3 grid(grid_x, grid_y);
biasAdd_bprop<T, 4><<<grid, block, 0, stream>>>(
dy_gemm, yfeat, batch_size, db_scratch, semaphores, dbias);
}
} else { // no bias below
if (activation == 0) {
// bypass dgrad through reset pointer
dy_gemm = dy;
} else if (activation == 1) { // relu
int num_blocks = 0;
int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Relu_bprop<T>, BIAS_RELU_FW_NTHREADS, 0);
Relu_bprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(dy, y, batch_size, yfeat, dy_gemm);
} else if (activation == 2) { // sigmoid
int num_blocks = 0;
int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_bprop<T>, BIAS_RELU_FW_NTHREADS, 0);
Sigmoid_bprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(dy, y, batch_size, yfeat, dy_gemm);
}
} }
cublasStatus_t cublas_status; cublasStatus_t cublas_status;
// Call GEMM dgrad // Call GEMM dgrad
cublas_status = mlp_gemm( if (layer > 0 || requires_grad == 1) {
cublas_status = mlp_gemm(
handle, handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
...@@ -760,9 +1237,10 @@ int mlp_bp( ...@@ -760,9 +1237,10 @@ int mlp_bp(
dx, dx,
xfeat); xfeat);
if (cublas_status != CUBLAS_STATUS_SUCCESS) { if (cublas_status != CUBLAS_STATUS_SUCCESS) {
printf("GEMM dgrad failed with %d\n", cublas_status); printf("GEMM dgrad failed with %d\n", cublas_status);
return 1; return 1;
}
} }
// Call GEMM wgrad // Call GEMM wgrad
...@@ -801,7 +1279,9 @@ template int mlp_fp<float>( ...@@ -801,7 +1279,9 @@ template int mlp_fp<float>(
int* output_features, int* output_features,
float** BPtr, float** BPtr,
float* Y, float* Y,
float* reserved_space); float* reserved_space,
int use_bias,
int activation);
template int mlp_bp<float>( template int mlp_bp<float>(
float* X, float* X,
...@@ -816,7 +1296,10 @@ template int mlp_bp<float>( ...@@ -816,7 +1296,10 @@ template int mlp_bp<float>(
float* work_space, float* work_space,
float* dX, float* dX,
float** dwPtr, float** dwPtr,
float** dbPtr); float** dbPtr,
bool requires_grad,
int use_bias,
int activation);
template int mlp_fp<at::Half>( template int mlp_fp<at::Half>(
at::Half* X, at::Half* X,
...@@ -827,7 +1310,9 @@ template int mlp_fp<at::Half>( ...@@ -827,7 +1310,9 @@ template int mlp_fp<at::Half>(
int* output_features, int* output_features,
at::Half** BPtr, at::Half** BPtr,
at::Half* Y, at::Half* Y,
at::Half* reserved_space); at::Half* reserved_space,
int use_bias,
int activation);
template int mlp_bp<at::Half>( template int mlp_bp<at::Half>(
at::Half* X, at::Half* X,
...@@ -842,7 +1327,10 @@ template int mlp_bp<at::Half>( ...@@ -842,7 +1327,10 @@ template int mlp_bp<at::Half>(
at::Half* work_space, at::Half* work_space,
at::Half* dX, at::Half* dX,
at::Half** dwPtr, at::Half** dwPtr,
at::Half** dbPtr); at::Half** dbPtr,
bool requires_grad,
int use_bias,
int activation);
template int mlp_fp<double>( template int mlp_fp<double>(
double* X, double* X,
...@@ -853,7 +1341,9 @@ template int mlp_fp<double>( ...@@ -853,7 +1341,9 @@ template int mlp_fp<double>(
int* output_features, int* output_features,
double** BPtr, double** BPtr,
double* Y, double* Y,
double* reserved_space); double* reserved_space,
int use_bias,
int activation);
template int mlp_bp<double>( template int mlp_bp<double>(
double* X, double* X,
...@@ -868,7 +1358,10 @@ template int mlp_bp<double>( ...@@ -868,7 +1358,10 @@ template int mlp_bp<double>(
double* work_space, double* work_space,
double* dX, double* dX,
double** dwPtr, double** dwPtr,
double** dbPtr); double** dbPtr,
bool requires_grad,
int use_bias,
int activation);
template size_t get_mlp_bp_workspace_in_bytes<float>( template size_t get_mlp_bp_workspace_in_bytes<float>(
int batch_size, int batch_size,
......
...@@ -51,6 +51,116 @@ class TestMLP(unittest.TestCase): ...@@ -51,6 +51,116 @@ class TestMLP(unittest.TestCase):
ref_mlp[0].bias.grad.detach().cpu().numpy(), ref_mlp[0].bias.grad.detach().cpu().numpy(),
atol=1e-7, rtol=1e-5) atol=1e-7, rtol=1e-5)
def test_no_bias(self):
for use_activation in ['none', 'relu', 'sigmoid']:
mlp = MLP(mlp_sizes, bias=False, activation=use_activation).cuda()
mlp_layers = []
for i in range(mlp.num_layers):
linear = nn.Linear(mlp_sizes[i], mlp_sizes[i + 1], bias=False)
mlp.weights[i].data.copy_(linear.weight)
mlp_layers.append(linear)
if use_activation == 'relu':
mlp_layers.append(nn.ReLU(inplace=True))
if use_activation == 'sigmoid':
mlp_layers.append(nn.Sigmoid())
ref_mlp = nn.Sequential(*mlp_layers).cuda()
test_input = torch.empty(batch_size, mlp_sizes[0], device="cuda").uniform_(-1., 1.).requires_grad_()
ref_input = test_input.clone().detach().requires_grad_()
mlp_out = mlp(test_input)
ref_out = ref_mlp(ref_input)
np.testing.assert_allclose(
mlp_out.detach().cpu().numpy(),
ref_out.detach().cpu().numpy(),
atol=1e-7, rtol=1e-5)
# Use mean value as scalar loss. Multiply 10 to make it big enough not zero out
mlp_out.mean().mul(10.).backward()
ref_out.mean().mul(10.).backward()
np.testing.assert_allclose(
test_input.grad.detach().cpu().numpy(),
ref_input.grad.detach().cpu().numpy(),
atol=0, rtol=100)
np.testing.assert_allclose(
mlp.weights[0].grad.detach().cpu().numpy(),
ref_mlp[0].weight.grad.detach().cpu().numpy(),
atol=1e-7, rtol=100)
def test_with_bias(self):
for use_activation in ['none', 'relu', 'sigmoid']:
mlp = MLP(mlp_sizes, bias=True, activation=use_activation).cuda()
mlp_layers = []
for i in range(mlp.num_layers):
linear = nn.Linear(mlp_sizes[i], mlp_sizes[i + 1], bias=True)
mlp.weights[i].data.copy_(linear.weight)
mlp.biases[i].data.copy_(linear.bias)
mlp_layers.append(linear)
if use_activation == 'relu':
mlp_layers.append(nn.ReLU(inplace=True))
if use_activation == 'sigmoid':
mlp_layers.append(nn.Sigmoid())
ref_mlp = nn.Sequential(*mlp_layers).cuda()
test_input = torch.empty(batch_size, mlp_sizes[0], device="cuda").uniform_(-1., 1.).requires_grad_()
ref_input = test_input.clone().detach().requires_grad_()
mlp_out = mlp(test_input)
ref_out = ref_mlp(ref_input)
np.testing.assert_allclose(
mlp_out.detach().cpu().numpy(),
ref_out.detach().cpu().numpy(),
atol=1e-7, rtol=1e-5)
# Use mean value as scalar loss. Multiply 10 to make it big enough not zero out
mlp_out.mean().mul(10.).backward()
ref_out.mean().mul(10.).backward()
np.testing.assert_allclose(
test_input.grad.detach().cpu().numpy(),
ref_input.grad.detach().cpu().numpy(),
atol=0, rtol=1)
np.testing.assert_allclose(
mlp.weights[0].grad.detach().cpu().numpy(),
ref_mlp[0].weight.grad.detach().cpu().numpy(),
atol=1e-7, rtol=1)
np.testing.assert_allclose(
mlp.biases[0].grad.detach().cpu().numpy(),
ref_mlp[0].bias.grad.detach().cpu().numpy(),
atol=1e-7, rtol=1e-5)
def test_no_grad(self):
mlp = MLP(mlp_sizes).cuda()
mlp_layers = []
for i in range(mlp.num_layers):
linear = nn.Linear(mlp_sizes[i], mlp_sizes[i + 1])
mlp.weights[i].data.copy_(linear.weight)
mlp.biases[i].data.copy_(linear.bias)
mlp_layers.append(linear)
mlp_layers.append(nn.ReLU(inplace=True))
ref_mlp = nn.Sequential(*mlp_layers).cuda()
test_input = torch.empty(batch_size, mlp_sizes[0], device="cuda").uniform_(-1., 1.)
ref_input = test_input.clone().detach()
mlp_out = mlp(test_input)
ref_out = ref_mlp(ref_input)
np.testing.assert_allclose(
mlp_out.detach().cpu().numpy(),
ref_out.detach().cpu().numpy(),
atol=1e-7, rtol=1e-5)
# Use mean value as scalar loss. Multiply 10 to make it big enough not zero out
mlp_out.mean().mul(10.).backward()
ref_out.mean().mul(10.).backward()
np.testing.assert_allclose(
mlp.weights[0].grad.detach().cpu().numpy(),
ref_mlp[0].weight.grad.detach().cpu().numpy(),
atol=1e-7, rtol=1e-5)
def test_performance_half(self): def test_performance_half(self):
mlp = MLP(mlp_sizes).cuda().half() mlp = MLP(mlp_sizes).cuda().half()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment