Commit fa0e478a authored by Hang Zhang's avatar Hang Zhang
Browse files

SyncBN backend and double type

parent 1633f310
...@@ -8,9 +8,12 @@ ...@@ -8,9 +8,12 @@
## LICENSE file in the root directory of this source tree ## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import threading
import torch import torch
import torch.cuda.nccl as nccl
import torch.nn as nn import torch.nn as nn
from torch.autograd import Function from torch.autograd import Function, Variable
from torch.nn.parameter import Parameter
from ._ext import encoding_lib from ._ext import encoding_lib
class aggregate(Function): class aggregate(Function):
...@@ -20,15 +23,26 @@ class aggregate(Function): ...@@ -20,15 +23,26 @@ class aggregate(Function):
B, N, K, D = R.size() B, N, K, D = R.size()
E = A.new(B,K,D) E = A.new(B,K,D)
# TODO support cpu backend # TODO support cpu backend
if isinstance(A, torch.cuda.FloatTensor):
encoding_lib.Encoding_Float_aggregate_forward(E, A, R) encoding_lib.Encoding_Float_aggregate_forward(E, A, R)
elif isinstance(A, torch.cuda.DoubleTensor):
encoding_lib.Encoding_Double_aggregate_forward(E, A, R)
else:
raise RuntimeError('unimplemented')
return E return E
def backward(self, gradE): def backward(self, gradE):
A, R = self.saved_tensors A, R = self.saved_tensors
gradA = A.new().resize_as_(A) gradA = A.new().resize_as_(A)
gradR = R.new().resize_as_(R) gradR = R.new().resize_as_(R)
if isinstance(A, torch.cuda.FloatTensor):
encoding_lib.Encoding_Float_aggregate_backward(gradA, gradR, gradE, encoding_lib.Encoding_Float_aggregate_backward(gradA, gradR, gradE,
A, R) A, R)
elif isinstance(A, torch.cuda.DoubleTensor):
encoding_lib.Encoding_Double_aggregate_backward(gradA, gradR, gradE,
A, R)
else:
raise RuntimeError('unimplemented')
return gradA, gradR return gradA, gradR
...@@ -82,3 +96,131 @@ class Encoding(nn.Module): ...@@ -82,3 +96,131 @@ class Encoding(nn.Module):
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(' \ return self.__class__.__name__ + '(' \
+ 'N x ' + str(self.D) + '=>' + str(self.K) + 'x' + str(self.D) + ')' + 'N x ' + str(self.D) + '=>' + str(self.K) + 'x' + str(self.D) + ')'
class sum_square(Function):
def forward(ctx, input):
ctx.save_for_backward(input)
B,C,H,W = input.size()
with torch.cuda.device_of(input):
xsum = input.new().resize_(C).zero_()
xsquare = input.new().resize_(C).zero_()
if isinstance(input, torch.cuda.FloatTensor):
with torch.cuda.device_of(input):
encoding_lib.Encoding_Float_sum_square_Forward(
input.view(B,C,-1), xsum, xsquare)
elif isinstance(input, torch.cuda.DoubleTensor):
with torch.cuda.device_of(input):
encoding_lib.Encoding_Double_sum_square_Forward(
input.view(B,C,-1), xsum, xsquare)
else:
raise RuntimeError('unimplemented')
return xsum, xsquare
def backward(ctx, gradSum, gradSquare):
input, = ctx.saved_tensors
B,C,H,W = input.size()
with torch.cuda.device_of(input):
gradInput = input.new().resize_(B,C,H*W).zero_()
# gradSum.view(1,C,1,1).expand_as(input) + \
# 2*gradSquare.view(1,C,1,1).expand_as(input)*input
if isinstance(input, torch.cuda.FloatTensor):
with torch.cuda.device_of(input):
encoding_lib.Encoding_Float_sum_square_Backward(
gradInput, input.view(B,C,-1), gradSum, gradSquare)
elif isinstance(input, torch.cuda.DoubleTensor):
with torch.cuda.device_of(input):
encoding_lib.Encoding_Double_sum_square_Backward(
gradInput, input.view(B,C,-1), gradSum, gradSquare)
else:
raise RuntimeError('unimplemented')
return gradInput.view(B,C,H,W)
class batchnormtrain(Function):
def forward(ctx, input, gamma, beta, mean, std):
ctx.save_for_backward(input, gamma, beta, mean, std)
assert(input.dim()==3)
with torch.cuda.device_of(input):
invstd = 1.0 / std
output = input.new().resize_as_(input)
if isinstance(input, torch.cuda.FloatTensor):
with torch.cuda.device_of(input):
encoding_lib.Encoding_Float_batchnorm_Forward(output,
input, mean, invstd, gamma, beta)
elif isinstance(input, torch.cuda.DoubleTensor):
with torch.cuda.device_of(input):
encoding_lib.Encoding_Double_batchnorm_Forward(output,
input, mean, invstd, gamma, beta)
else:
raise RuntimeError('unimplemented')
return output
def backward(ctx, gradOutput):
input, gamma, beta, mean, std = ctx.saved_tensors
invstd = 1.0 / std
with torch.cuda.device_of(input):
gradInput = gradOutput.new().resize_as_(input).zero_()
gradGamma = gradOutput.new().resize_as_(gamma).zero_()
gradBeta = gradOutput.new().resize_as_(beta).zero_()
gradMean = gradOutput.new().resize_as_(mean).zero_()
gradStd = gradOutput.new().resize_as_(std).zero_()
if isinstance(input, torch.cuda.FloatTensor):
with torch.cuda.device_of(input):
encoding_lib.Encoding_Float_batchnorm_Backward(
gradOutput, input, gradInput, gradGamma, gradBeta,
mean, invstd, gamma, beta, gradMean, gradStd,
True)
elif isinstance(input, torch.cuda.DoubleTensor):
with torch.cuda.device_of(input):
encoding_lib.Encoding_Double_batchnorm_Backward(
gradOutput, input, gradInput, gradGamma, gradBeta,
mean, invstd, gamma, beta, gradMean, gradStd,
True)
else:
raise RuntimeError('unimplemented')
return gradInput, gradGamma, gradBeta, gradMean, gradStd
class batchnormeval(Function):
def forward(ctx, input, gamma, beta, mean, std):
ctx.save_for_backward(input, gamma, beta, mean, std)
assert(input.dim()==3)
with torch.cuda.device_of(input):
invstd = 1.0 / std
output = input.new().resize_as_(input)
if isinstance(input, torch.cuda.FloatTensor):
with torch.cuda.device_of(input):
encoding_lib.Encoding_Float_batchnorm_Forward(output,
input, mean, invstd, gamma, beta)
elif isinstance(input, torch.cuda.DoubleTensor):
with torch.cuda.device_of(input):
encoding_lib.Encoding_Double_batchnorm_Forward(output,
input, mean, invstd, gamma, beta)
else:
raise RuntimeError('unimplemented')
return output
def backward(ctx, gradOutput):
input, gamma, beta, mean, std = ctx.saved_tensors
invstd = 1.0 / std
with torch.cuda.device_of(input):
gradInput = gradOutput.new().resize_as_(input).zero_()
gradGamma = gradOutput.new().resize_as_(gamma).zero_()
gradBeta = gradOutput.new().resize_as_(beta).zero_()
gradMean = gradOutput.new().resize_as_(mean).zero_()
gradStd = gradOutput.new().resize_as_(std).zero_()
if isinstance(input, torch.cuda.FloatTensor):
with torch.cuda.device_of(input):
encoding_lib.Encoding_Float_batchnorm_Backward(
gradOutput, input, gradInput, gradGamma, gradBeta,
mean, invstd, gamma, beta, gradMean, gradStd,
False)
elif isinstance(input, torch.cuda.DoubleTensor):
with torch.cuda.device_of(input):
encoding_lib.Encoding_Double_batchnorm_Backward(
gradOutput, input, gradInput, gradGamma, gradBeta,
mean, invstd, gamma, beta, gradMean, gradStd,
False)
else:
raise RuntimeError('unimplemented')
return gradInput, gradGamma, gradBeta, gradMean, gradStd
from torch.utils.ffi import _wrap_function
from ._encoding_lib import lib as _lib, ffi as _ffi
__all__ = []
def _import_symbols(locals):
for symbol in dir(_lib):
fn = getattr(_lib, symbol)
locals[symbol] = _wrap_function(fn, _ffi)
__all__.append(symbol)
_import_symbols(locals())
// The maximum number of threads in a block
const int WARP_SIZE = 32;
const int MAX_BLOCK_SIZE = 512;
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
static int getNumThreads(int nElem) {
int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
for (int i = 0; i != 5; ++i) {
if (nElem <= threadSizes[i]) {
return threadSizes[i];
}
}
return MAX_BLOCK_SIZE;
}
__device__ __forceinline__ int getMSB(int val) {
return 31 - __clz(val);
}
...@@ -12,13 +12,13 @@ ...@@ -12,13 +12,13 @@
#define THC_GENERIC_FILE "generic/device_tensor.h" #define THC_GENERIC_FILE "generic/device_tensor.h"
#else #else
template <int Dim> template <int Dim>
THCDeviceTensor<float, Dim> devicetensor(THCState *state, THCTensor *t) { THCDeviceTensor<real, Dim> devicetensor(THCState *state, THCTensor *t) {
if (!t) { if (!t) {
return THCDeviceTensor<float, Dim>(); return THCDeviceTensor<real, Dim>();
} }
int inDim = THCTensor_(nDimension)(state, t); int inDim = THCTensor_(nDimension)(state, t);
if (inDim == Dim) { if (inDim == Dim) {
return toDeviceTensor<float, Dim>(state, t); return toDeviceTensor<real, Dim>(state, t);
} }
// View in which the last dimensions are collapsed or expanded as needed // View in which the last dimensions are collapsed or expanded as needed
THAssert(THCTensor_(isContiguous)(state, t)); THAssert(THCTensor_(isContiguous)(state, t));
...@@ -32,6 +32,6 @@ THCDeviceTensor<float, Dim> devicetensor(THCState *state, THCTensor *t) { ...@@ -32,6 +32,6 @@ THCDeviceTensor<float, Dim> devicetensor(THCState *state, THCTensor *t) {
size[Dim - 1] *= t->size[i]; size[Dim - 1] *= t->size[i];
} }
} }
return THCDeviceTensor<float, Dim>(THCTensor_(data)(state, t), size); return THCDeviceTensor<real, Dim>(THCTensor_(data)(state, t), size);
} }
#endif #endif
...@@ -125,4 +125,399 @@ void Encoding_(Aggregate_Backward)(THCState *state, THCTensor *GA_, ...@@ -125,4 +125,399 @@ void Encoding_(Aggregate_Backward)(THCState *state, THCTensor *GA_,
GR, L, A, R); GR, L, A, R);
THCudaCheck(cudaGetLastError()); THCudaCheck(cudaGetLastError());
} }
/*+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++*/
// Returns the index of the most significant 1 bit in `val`.
__global__ void Encoding_(BatchNorm_Forward_kernel) (
THCDeviceTensor<real, 3> output,
THCDeviceTensor<real, 3> input,
THCDeviceTensor<real, 1> mean,
THCDeviceTensor<real, 1> invstd,
THCDeviceTensor<real, 1> gamma,
THCDeviceTensor<real, 1> beta)
{
int c = blockIdx.x;
//int N = input.getSize(0) * input.getSize(2);
/* main operation */
for (int b = 0; b < input.getSize(0); ++b) {
for (int x = threadIdx.x; x < input.getSize(2); x += blockDim.x) {
real inp = input[b][c][x].ldg();
output[b][c][x] = gamma[c].ldg() * (inp - mean[c].ldg()) *
invstd[c].ldg() + beta[c].ldg();
}
}
}
void Encoding_(BatchNorm_Forward)(THCState *state,
THCTensor *output_, THCTensor *input_,
THCTensor *mean_, THCTensor *invstd_,
THCTensor *gamma_, THCTensor *beta_)
/*
* batch norm forward function
* assuming the input is already flaghten
*/
{
/* Check the GPU index and tensor dims*/
THCTensor_(checkGPU)(state, 6, output_, input_, mean_, invstd_,
gamma_, beta_);
if (THCTensor_(nDimension)(state, output_) != 3 ||
THCTensor_(nDimension)(state, input_) != 3 ||
THCTensor_(nDimension)(state, mean_) != 1 ||
THCTensor_(nDimension)(state, invstd_) != 1 ||
THCTensor_(nDimension)(state, gamma_) != 1 ||
THCTensor_(nDimension)(state, beta_) != 1)
THError("BatchNorm2d forward: incorrect input dims. \n");
/* Device tensors */
THCDeviceTensor<real, 3> output = devicetensor<3>(state, output_);
THCDeviceTensor<real, 3> input = devicetensor<3>(state, input_);
THCDeviceTensor<real, 1> mean = devicetensor<1>(state, mean_);
THCDeviceTensor<real, 1> invstd = devicetensor<1>(state, invstd_);
THCDeviceTensor<real, 1> gamma = devicetensor<1>(state, gamma_);
THCDeviceTensor<real, 1> beta = devicetensor<1>(state, beta_);
/* kernel function */
cudaStream_t stream = THCState_getCurrentStream(state);
dim3 blocks(input.getSize(1));
dim3 threads(getNumThreads(input.getSize(2)));
Encoding_(BatchNorm_Forward_kernel)<<<blocks, threads, 0, stream>>>(
output, input, mean, invstd, gamma, beta);
THCudaCheck(cudaGetLastError());
}
struct Encoding_(Float2){
real v1, v2;
__device__ Encoding_(Float2)() {}
__device__ Encoding_(Float2)(real x1, real x2) : v1(x1), v2(x2) {}
__device__ Encoding_(Float2)(real v) : v1(v), v2(v) {}
__device__ Encoding_(Float2)(int v) : v1(v), v2(v) {}
__device__ Encoding_(Float2)& operator+=(const Encoding_(Float2)& a) {
v1 += a.v1;
v2 += a.v2;
return *this;
}
};
static __device__ __forceinline__ real Encoding_(rwarpSum)(real val) {
#if __CUDA_ARCH__ >= 300
for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
val += __shfl_xor(val, 1 << i, WARP_SIZE);
}
#else
__shared__ real values[MAX_BLOCK_SIZE];
values[threadIdx.x] = val;
__threadfence_block();
const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE;
for (int i = 1; i < WARP_SIZE; i++) {
val += values[base + ((i + threadIdx.x) % WARP_SIZE)];
}
#endif
return val;
}
static __device__ __forceinline__ Encoding_(Float2) Encoding_(warpSum)(Encoding_(Float2) value) {
value.v1 = Encoding_(rwarpSum)(value.v1);
value.v2 = Encoding_(rwarpSum)(value.v2);
return value;
}
struct Encoding_(GradOp) {
__device__ Encoding_(GradOp)(real m, THCDeviceTensor<real, 3> i, THCDeviceTensor<real, 3> g)
: mean(m), input(i), gradOutput(g) {}
__device__ __forceinline__ Encoding_(Float2) operator()(int batch, int plane, int n) {
real g = gradOutput[batch][plane][n].ldg();
real c = input[batch][plane][n].ldg() - mean;
return Encoding_(Float2)(g, g * c);
}
real mean;
THCDeviceTensor<real, 3> input;
THCDeviceTensor<real, 3> gradOutput;
};
// Sum across (batch, x/y/z) applying Op() pointwise
__device__ Encoding_(Float2) Encoding_(reduce)(Encoding_(GradOp) op, THCDeviceTensor<real, 3> tensor, int plane) {
Encoding_(Float2) sum = (Encoding_(Float2))0;
for (int batch = 0; batch < tensor.getSize(0); ++batch) {
for (int x = threadIdx.x; x < tensor.getSize(2); x += blockDim.x) {
sum += op(batch, plane, x);
}
}
// sum over NumThreads within a warp
sum = Encoding_(warpSum)(sum);
// 'transpose', and reduce within warp again
__shared__ Encoding_(Float2) shared[32];
__syncthreads();
if (threadIdx.x % WARP_SIZE == 0) {
if (threadIdx.x / WARP_SIZE < 32) {
shared[threadIdx.x / WARP_SIZE] = sum;
}
}
if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) {
// zero out the other entries in shared
shared[threadIdx.x] = (Encoding_(Float2))0;
}
__syncthreads();
if (threadIdx.x / WARP_SIZE == 0) {
sum = Encoding_(warpSum)(shared[threadIdx.x]);
if (threadIdx.x == 0) {
shared[0] = sum;
}
}
__syncthreads();
// Everyone picks it up, should be broadcast into the whole gradInput
return shared[0];
}
__global__ void Encoding_(BatchNorm_Backward_kernel) (
THCDeviceTensor<real, 3> gradoutput,
THCDeviceTensor<real, 3> input,
THCDeviceTensor<real, 3> gradinput,
THCDeviceTensor<real, 1> gradgamma,
THCDeviceTensor<real, 1> gradbeta,
THCDeviceTensor<real, 1> mean,
THCDeviceTensor<real, 1> invstd,
THCDeviceTensor<real, 1> gamma,
THCDeviceTensor<real, 1> beta,
THCDeviceTensor<real, 1> gradMean,
THCDeviceTensor<real, 1> gradStd,
int train)
{
/* declarations of the variables */
/* Get the index and channels */
int c = blockIdx.x;
/* main operation */
//int N = input.getSize(0) * input.getSize(2);
//real norm;
//norm = 1.0 / N;
Encoding_(GradOp) g(mean[c], input, gradoutput);
Encoding_(Float2) res = Encoding_(reduce)(g, gradoutput, c);
real gradOutputSum = res.v1;
real dotP = res.v2;
//real projScale = dotP * norm * invstd[c].ldg() * invstd[c].ldg();
real gradScale = invstd[c].ldg() * gamma[c].ldg();
if (train && threadIdx.x == 0) {
gradMean[c] = - gradOutputSum * gamma[c].ldg() * invstd[c].ldg();
gradStd[c] = - dotP * gamma[c].ldg() * invstd[c].ldg() * invstd[c].ldg();
}
if (gradinput.numElements() > 0) {
for (int batch = 0; batch < gradoutput.getSize(0); ++batch) {
for (int x = threadIdx.x; x < gradoutput.getSize(2); x += blockDim.x) {
gradinput[batch][c][x] = gradoutput[batch][c][x].ldg() * gradScale;
}
}
}
if (gradgamma.numElements() > 0) {
if (threadIdx.x == 0) {
gradgamma[c] += dotP * invstd[c].ldg();
}
}
if (gradbeta.numElements() > 0) {
if (threadIdx.x == 0) {
gradbeta[c] += gradOutputSum;
}
}
}
void Encoding_(BatchNorm_Backward)(THCState *state,
THCTensor *gradoutput_, THCTensor *input_, THCTensor *gradinput_,
THCTensor *gradgamma_, THCTensor *gradbeta_, THCTensor *mean_,
THCTensor *invstd_, THCTensor *gamma_, THCTensor *beta_,
THCTensor *gradMean_, THCTensor *gradStd_, int train)
/*
* batch norm backward function
* assuming the input is already flaghten
*/
{
/* Check the GPU index and tensor dims*/
THCTensor_(checkGPU)(state, 6, gradoutput_, input_, gradinput_,
gradgamma_, gradbeta_, mean_, invstd_, gamma_, beta_);
if (THCTensor_(nDimension)(state, gradoutput_) != 3 ||
THCTensor_(nDimension)(state, input_) != 3 ||
THCTensor_(nDimension)(state, gradinput_) != 3 ||
THCTensor_(nDimension)(state, gradgamma_) != 1 ||
THCTensor_(nDimension)(state, gradbeta_) != 1 ||
THCTensor_(nDimension)(state, mean_) != 1 ||
THCTensor_(nDimension)(state, invstd_) != 1 ||
THCTensor_(nDimension)(state, gamma_) != 1 ||
THCTensor_(nDimension)(state, beta_) != 1 ||
THCTensor_(nDimension)(state, gradMean_) != 1 ||
THCTensor_(nDimension)(state, gradStd_) != 1 )
THError("BatchNorm2d backward: incorrect input dims. \n");
/* Device tensors */
THCDeviceTensor<real, 3> gradoutput =
devicetensor<3>(state, gradoutput_);
THCDeviceTensor<real, 3> input =
devicetensor<3>(state, input_);
THCDeviceTensor<real, 3> gradinput =
devicetensor<3>(state, gradinput_);
THCDeviceTensor<real, 1> gradgamma =
devicetensor<1>(state, gradgamma_);
THCDeviceTensor<real, 1> gradbeta = devicetensor<1>(state, gradbeta_);
THCDeviceTensor<real, 1> mean = devicetensor<1>(state, mean_);
THCDeviceTensor<real, 1> invstd = devicetensor<1>(state, invstd_);
THCDeviceTensor<real, 1> gamma = devicetensor<1>(state, gamma_);
THCDeviceTensor<real, 1> beta = devicetensor<1>(state, beta_);
THCDeviceTensor<real, 1> gradMean = devicetensor<1>(state, gradMean_);
THCDeviceTensor<real, 1> gradStd = devicetensor<1>(state, gradStd_);
/* kernel function */
cudaStream_t stream = THCState_getCurrentStream(state);
dim3 blocks(input.getSize(1));
dim3 threads(getNumThreads(input.getSize(2)));
Encoding_(BatchNorm_Backward_kernel)<<<blocks, threads, 0, stream>>>(
gradoutput, input, gradinput, gradgamma, gradbeta, mean, invstd,
gamma, beta, gradMean, gradStd, train);
THCudaCheck(cudaGetLastError());
}
struct Encoding_(SumOp) {
__device__ Encoding_(SumOp)(THCDeviceTensor<real, 3> i)
: input(i){}
__device__ __forceinline__ Encoding_(Float2) operator()(int batch, int plane, int n) {
real g = input[batch][plane][n].ldg();
return Encoding_(Float2)(g, g * g);
}
real mean;
THCDeviceTensor<real, 3> input;
};
// Sum across (batch, x/y/z) applying Op() pointwise
__device__ Encoding_(Float2) Encoding_(reduce_sum)(Encoding_(SumOp) op, THCDeviceTensor<real, 3> tensor, int plane) {
Encoding_(Float2) sum = (Encoding_(Float2))0;
for (int batch = 0; batch < tensor.getSize(0); ++batch) {
for (int x = threadIdx.x; x < tensor.getSize(2); x += blockDim.x) {
sum += op(batch, plane, x);
}
}
// sum over NumThreads within a warp
sum = Encoding_(warpSum)(sum);
// 'transpose', and reduce within warp again
__shared__ Encoding_(Float2) shared[32];
__syncthreads();
if (threadIdx.x % WARP_SIZE == 0) {
if (threadIdx.x / WARP_SIZE < 32) {
shared[threadIdx.x / WARP_SIZE] = sum;
}
}
if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) {
// zero out the other entries in shared
shared[threadIdx.x] = (Encoding_(Float2))0;
}
__syncthreads();
if (threadIdx.x / WARP_SIZE == 0) {
sum = Encoding_(warpSum)(shared[threadIdx.x]);
if (threadIdx.x == 0) {
shared[0] = sum;
}
}
__syncthreads();
// Everyone picks it up, should be broadcast into the whole gradInput
return shared[0];
}
__global__ void Encoding_(Sum_Square_Forward_kernel) (
THCDeviceTensor<real, 3> input,
THCDeviceTensor<real, 1> sum,
THCDeviceTensor<real, 1> square)
{
int c = blockIdx.x;
/* main operation */
Encoding_(SumOp) g(input);
Encoding_(Float2) res = Encoding_(reduce_sum)(g, input, c);
real xsum = res.v1;
real xsquare = res.v2;
if (threadIdx.x == 0) {
sum[c] = xsum;
square[c] = xsquare;
}
}
void Encoding_(Sum_Square_Forward)(THCState *state,
THCTensor *input_, THCTensor *sum_, THCTensor *square_)
/*
*/
{
/* Check the GPU index and tensor dims*/
THCTensor_(checkGPU)(state, 3, input_, sum_, square_);
if (THCTensor_(nDimension)(state, input_) != 3 ||
THCTensor_(nDimension)(state, sum_) != 1 ||
THCTensor_(nDimension)(state, square_) != 1)
THError("Sum_Square forward: incorrect input dims. \n");
/* Device tensors */
THCDeviceTensor<real, 3> input = devicetensor<3>(state, input_);
THCDeviceTensor<real, 1> sum = devicetensor<1>(state, sum_);
THCDeviceTensor<real, 1> square = devicetensor<1>(state, square_);
/* kernel function */
cudaStream_t stream = THCState_getCurrentStream(state);
dim3 blocks(input.getSize(1));
dim3 threads(getNumThreads(input.getSize(2)));
Encoding_(Sum_Square_Forward_kernel)<<<blocks, threads, 0, stream>>>(
input, sum, square);
THCudaCheck(cudaGetLastError());
}
__global__ void Encoding_(Sum_Square_Backward_kernel) (
THCDeviceTensor<real, 3> gradInput,
THCDeviceTensor<real, 3> input,
THCDeviceTensor<real, 1> gradSum,
THCDeviceTensor<real, 1> gradSquare)
{
int c = blockIdx.x;
/* main operation */
for (int batch = 0; batch < gradInput.getSize(0); ++batch) {
for (int x = threadIdx.x; x < gradInput.getSize(2); x += blockDim.x)
{
gradInput[batch][c][x] = gradSum[c] + 2 * gradSquare[c] *
input[batch][c][x];
}
}
}
void Encoding_(Sum_Square_Backward)(THCState *state,
THCTensor *gradInput_, THCTensor *input_,
THCTensor *gradSum_, THCTensor *gradSquare_)
/*
*/
{
/* Check the GPU index and tensor dims*/
THCTensor_(checkGPU)(state, 4, gradInput_, input_, gradSum_,
gradSquare_);
if (THCTensor_(nDimension)(state, gradInput_) != 3 ||
THCTensor_(nDimension)(state, input_) != 3 ||
THCTensor_(nDimension)(state, gradSum_) != 1 ||
THCTensor_(nDimension)(state, gradSquare_) != 1)
THError("Sum_Square forward: incorrect input dims. \n");
/* Device tensors */
THCDeviceTensor<real, 3> gradInput = devicetensor<3>(state, gradInput_);
THCDeviceTensor<real, 3> input = devicetensor<3>(state, input_);
THCDeviceTensor<real, 1> gradSum = devicetensor<1>(state, gradSum_);
THCDeviceTensor<real, 1> gradSquare =devicetensor<1>(state, gradSquare_);
/* kernel function */
cudaStream_t stream = THCState_getCurrentStream(state);
dim3 blocks(input.getSize(1));
dim3 threads(getNumThreads(input.getSize(2)));
Encoding_(Sum_Square_Backward_kernel)<<<blocks, threads, 0, stream>>>(
gradInput, input, gradSum, gradSquare);
THCudaCheck(cudaGetLastError());
}
#endif #endif
...@@ -14,6 +14,26 @@ ...@@ -14,6 +14,26 @@
void Encoding_(Aggregate_Forward)(THCState *state, THCTensor *E_, void Encoding_(Aggregate_Forward)(THCState *state, THCTensor *E_,
THCTensor *A_, THCTensor *R_); THCTensor *A_, THCTensor *R_);
void Encoding_(Aggregate_Backward)(THCState *state, THCTensor *GA_, void Encoding_(Aggregate_Backward)(THCState *state, THCTensor *GA_,
THCTensor *GR_, THCTensor *L_, THCTensor *A_, THCTensor *R_); THCTensor *GR_, THCTensor *L_, THCTensor *A_, THCTensor *R_);
void Encoding_(BatchNorm_Forward)(THCState *state,
THCTensor *output_, THCTensor *input_,
THCTensor *mean_, THCTensor *invstd_,
THCTensor *gamma_, THCTensor *beta_);
void Encoding_(BatchNorm_Backward)(THCState *state,
THCTensor *gradoutput_, THCTensor *input_, THCTensor *gradinput_,
THCTensor *gradgamma_, THCTensor *gradbeta_, THCTensor *mean_,
THCTensor *invstd_, THCTensor *gamma_, THCTensor *beta_,
THCTensor *gradMean_, THCTensor *gradStd_, int train);
void Encoding_(Sum_Square_Forward)(THCState *state,
THCTensor *input_, THCTensor *sum_, THCTensor *square_);
void Encoding_(Sum_Square_Backward)(THCState *state,
THCTensor *gradInput, THCTensor *input_,
THCTensor *gradSum_, THCTensor *gradSquare_);
#endif #endif
...@@ -9,10 +9,14 @@ ...@@ -9,10 +9,14 @@
*+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ *+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
*/ */
#include "thc_encoding.h" #include "thc_encoding.h"
#include "common.h"
#include "generic/device_tensor.h" #include "generic/device_tensor.h"
#include "THC/THCGenerateFloatType.h" #include "THC/THCGenerateFloatType.h"
#include "generic/device_tensor.h"
#include "THC/THCGenerateDoubleType.h"
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
...@@ -20,6 +24,9 @@ extern "C" { ...@@ -20,6 +24,9 @@ extern "C" {
#include "generic/encoding_kernel.c" #include "generic/encoding_kernel.c"
#include "THC/THCGenerateFloatType.h" #include "THC/THCGenerateFloatType.h"
#include "generic/encoding_kernel.c"
#include "THC/THCGenerateDoubleType.h"
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif
...@@ -26,6 +26,9 @@ extern "C" { ...@@ -26,6 +26,9 @@ extern "C" {
#include "generic/encoding_kernel.h" #include "generic/encoding_kernel.h"
#include "THC/THCGenerateFloatType.h" #include "THC/THCGenerateFloatType.h"
#include "generic/encoding_kernel.h"
#include "THC/THCGenerateDoubleType.h"
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif
...@@ -20,6 +20,9 @@ extern "C" { ...@@ -20,6 +20,9 @@ extern "C" {
#include "generic/encoding_generic.c" #include "generic/encoding_generic.c"
#include "THC/THCGenerateFloatType.h" #include "THC/THCGenerateFloatType.h"
#include "generic/encoding_generic.c"
#include "THC/THCGenerateDoubleType.h"
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif
...@@ -24,3 +24,48 @@ int Encoding_Float_aggregate_forward(THCudaTensor *E, THCudaTensor *A, ...@@ -24,3 +24,48 @@ int Encoding_Float_aggregate_forward(THCudaTensor *E, THCudaTensor *A,
THCudaTensor *R); THCudaTensor *R);
int Encoding_Float_aggregate_backward(THCudaTensor *GA, THCudaTensor *GR, int Encoding_Float_aggregate_backward(THCudaTensor *GA, THCudaTensor *GR,
THCudaTensor *L, THCudaTensor *A, THCudaTensor *R); THCudaTensor *L, THCudaTensor *A, THCudaTensor *R);
int Encoding_Float_batchnorm_Forward(THCudaTensor *output_,
THCudaTensor *input_, THCudaTensor *mean_,
THCudaTensor *invstd_, THCudaTensor *gamma_, THCudaTensor *beta_);
int Encoding_Float_batchnorm_Backward(THCudaTensor *gradoutput_,
THCudaTensor *input_, THCudaTensor *gradinput_,
THCudaTensor *gradgamma_, THCudaTensor *gradbeta_,
THCudaTensor *mean_, THCudaTensor *invstd_,
THCudaTensor *gamma_,THCudaTensor *beta_,
THCudaTensor *gradMean_, THCudaTensor *gradStd_, int train);
int Encoding_Float_sum_square_Forward(THCudaTensor *input_,
THCudaTensor *sum_, THCudaTensor *square_);
void Encoding_Float_sum_square_Backward(
THCudaTensor *gradInput, THCudaTensor *input_,
THCudaTensor *gradSum_, THCudaTensor *gradSquare_);
/*++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++*/
int Encoding_Double_aggregate_forward(THCudaDoubleTensor *E,
THCudaDoubleTensor *A, THCudaDoubleTensor *R);
int Encoding_Double_aggregate_backward(THCudaDoubleTensor *GA,
THCudaDoubleTensor *GR, THCudaDoubleTensor *L,
THCudaDoubleTensor *A, THCudaDoubleTensor *R);
int Encoding_Double_batchnorm_Forward(THCudaDoubleTensor *output_,
THCudaDoubleTensor *input_, THCudaDoubleTensor *mean_,
THCudaDoubleTensor *invstd_, THCudaDoubleTensor *gamma_, THCudaDoubleTensor *beta_);
int Encoding_Double_batchnorm_Backward(THCudaDoubleTensor *gradoutput_,
THCudaDoubleTensor *input_, THCudaDoubleTensor *gradinput_,
THCudaDoubleTensor *gradgamma_, THCudaDoubleTensor *gradbeta_,
THCudaDoubleTensor *mean_, THCudaDoubleTensor *invstd_,
THCudaDoubleTensor *gamma_, THCudaDoubleTensor *beta_,
THCudaDoubleTensor *gradMean_, THCudaDoubleTensor *gradStd_, int train);
int Encoding_Double_sum_square_Forward(THCudaDoubleTensor *input_,
THCudaDoubleTensor *sum_, THCudaDoubleTensor *square_);
void Encoding_Double_sum_square_Backward(
THCudaDoubleTensor *gradInput, THCudaDoubleTensor *input_,
THCudaDoubleTensor *gradSum_, THCudaDoubleTensor *gradSquare_);
...@@ -34,4 +34,56 @@ int Encoding_(aggregate_backward)(THCTensor *GA, THCTensor *GR, ...@@ -34,4 +34,56 @@ int Encoding_(aggregate_backward)(THCTensor *GA, THCTensor *GR,
/* C function return number of the outputs */ /* C function return number of the outputs */
return 0; return 0;
} }
int Encoding_(batchnorm_Forward)(THCTensor *output_, THCTensor *input_,
THCTensor *mean_, THCTensor *invstd_,
THCTensor *gamma_, THCTensor *beta_)
/*
*
*/
{
Encoding_(BatchNorm_Forward)(state, output_, input_,
mean_, invstd_, gamma_, beta_);
/* C function return number of the outputs */
return 0;
}
int Encoding_(batchnorm_Backward)(THCTensor *gradoutput_,
THCTensor *input_, THCTensor *gradinput_,
THCTensor *gradgamma_, THCTensor *gradbeta_, THCTensor *mean_,
THCTensor *invstd_, THCTensor *gamma_, THCTensor *beta_,
THCTensor *gradMean_, THCTensor *gradStd_, int train)
/*
*/
{
Encoding_(BatchNorm_Backward)(state, gradoutput_, input_, gradinput_,
gradgamma_, gradbeta_, mean_, invstd_, gamma_, beta_, gradMean_, gradStd_,
train);
/* C function return number of the outputs */
return 0;
}
int Encoding_(sum_square_Forward)(THCTensor *input_,
THCTensor *sum_, THCTensor *square_)
/*
*/
{
Encoding_(Sum_Square_Forward)(state, input_, sum_, square_);
/* C function return number of the outputs */
return 0;
}
int Encoding_(sum_square_Backward)(
THCTensor *gradInput, THCTensor *input_,
THCTensor *gradSum_, THCTensor *gradSquare_)
/*
*/
{
Encoding_(Sum_Square_Backward)(state, gradInput, input_, gradSum_,
gradSquare_);
/* C function return number of the outputs */
return 0;
}
#endif #endif
/*+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
* Created by: Hang Zhang
* ECE Department, Rutgers University
* Email: zhang.hang@rutgers.edu
* Copyright (c) 2017
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree
*+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
*/
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/encoding_generic.h"
#else
int Encoding_(aggregate_forward)(THCudaTensor *E, THCudaTensor *A,
THCudaTensor *R);
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment