Commit 04dc2518 authored by rusty1s's avatar rusty1s
Browse files

all tests pass

parent 8af3271a
......@@ -9,6 +9,3 @@ void THDoubleTensor_weightingBackwardWeight(THDoubleTensor *self, THDoubleTensor
void THFloatTensor_weightingBackwardBasis( THFloatTensor *self, THFloatTensor *gradOutput, THFloatTensor *src, THFloatTensor *weight, THLongTensor *weightIndex);
void THDoubleTensor_weightingBackwardBasis(THDoubleTensor *self, THDoubleTensor *gradOutput, THDoubleTensor *src, THDoubleTensor *weight, THLongTensor *weightIndex);
void THFloatTensor_weightingBackward( THFloatTensor *gradSrc, THFloatTensor *gradWeight, THFloatTensor *gradBasis, THFloatTensor *gradOutput, THFloatTensor *src, THFloatTensor *weight, THFloatTensor *basis, THLongTensor *weightIndex);
void THDoubleTensor_weightingBackward(THDoubleTensor *gradSrc, THDoubleTensor *gradWeight, THDoubleTensor *gradBasis, THDoubleTensor *gradOutput, THDoubleTensor *src, THDoubleTensor *weight, THDoubleTensor *basis, THLongTensor *weightIndex);
......@@ -116,41 +116,4 @@ void THTensor_(weightingBackwardBasis)(THTensor *self, THTensor *gradOutput, THT
}
}
void THTensor_(weightingBackward)(THTensor *gradSrc, THTensor *gradWeight, THTensor *gradBasis,
THTensor *gradOutput, THTensor *src, THTensor *weight,
THTensor *basis, THLongTensor *weightIndex) {
THTensor_(fill)(gradSrc, 0);
THTensor_(fill)(gradWeight, 0);
THTensor_(fill)(gradBasis, 0);
real *gradSrcData = THTensor_(data)(gradSrc);
real *gradWeightData = THTensor_(data)(gradWeight);
real *gradBasisData = THTensor_(data)(gradBasis);
real *gradOutputData = THTensor_(data)(gradOutput);
real *srcData = THTensor_(data)(src);
real *weightData = THTensor_(data)(weight);
real *basisData = THTensor_(data)(basis);
int64_t *weightIndexData = THLongTensor_data(weightIndex);
ptrdiff_t e, mOut, s, mIn;
real g, b, w, f;
int64_t wi;
for (e = 0; e < THTensor_(size)(src, 0); e++) {
for (mOut = 0; mOut < THTensor_(size)(gradOutput, 1); mOut++) {
g = gradOutputData[e * gradOutput->stride[0] + mOut * gradOutput->stride[1]];
for (s = 0; s < THTensor_(size)(basis, 1); s++) {
b = basisData[e * basis->stride[0] + s * basis->stride[1]];
wi = weightIndexData[e * weightIndex->stride[0] + s * weightIndex->stride[1]];
for (mIn = 0; mIn < THTensor_(size)(src, 1); mIn++) {
w = weightData[wi * weight->stride[0] + mIn * weight->stride[1] + mOut * weight->stride[2]];
f = srcData[e * src->stride[0] + mIn * src->stride[1]];
gradSrcData[e * gradSrc->stride[0] + mIn * gradSrc->stride[1]] += g * w * b;
gradWeightData[wi * gradWeight->stride[0] + mOut * gradWeight->stride[1] + mIn * gradWeight->stride[2]] += f * g * b;
gradBasisData[e * gradBasis->stride[0] + s * gradBasis->stride[1]] += g * w * f;
}
}
}
}
}
#endif // TH_GENERIC_FILE
......@@ -31,20 +31,22 @@ __global__ void weightingBackwardSrcKernel(TensorInfo<T> self, TensorInfo<T> gra
TensorInfo<T> weight, TensorInfo<T> basis,
TensorInfo<int64_t> weightIndex, int n) {
KERNEL_LOOP(i, n) {
ptrdiff_t e = i / self.size[1], mIn = i % self.size[1], s, mOut;
T v = ScalarConvert<int, T>::to(0), b, tmp;
ptrdiff_t e = i / gradOutput.size[1], mOut = i % gradOutput.size[1], s, mIn;
T v, b, tmp;
int64_t wi;
for (s = 0; s < basis.size[1]; s++) {
b = basis.data[e * basis.stride[0] + s * basis.stride[1]];
wi = weightIndex.data[e * weightIndex.stride[0] + s * weightIndex.stride[1]];
for (mOut = 0; mOut < gradOutput.size[1]; mOut++) {
tmp = weight.data[wi * weight.stride[0] + mOut * weight.stride[1] + mIn * weight.stride[2]];
tmp = THCNumerics<T>::mul(tmp, gradOutput.data[e * gradOutput.stride[0] + mOut * gradOutput.stride[1]]);
T g = gradOutput.data[e * gradOutput.stride[0] + mOut * gradOutput.stride[1]];
for (mIn = 0; mIn < self.size[1]; mIn++) {
v = ScalarConvert<int, T>::to(0);
for (s = 0; s < basis.size[1]; s++) {
b = basis.data[e * basis.stride[0] + s * basis.stride[1]];
wi = weightIndex.data[e * weightIndex.stride[0] + s * weightIndex.stride[1]];
tmp = weight.data[wi * weight.stride[0] + mIn * weight.stride[1] + mOut * weight.stride[2]];
tmp = THCNumerics<T>::mul(tmp, b);
tmp = THCNumerics<T>::mul(tmp, g);
v = THCNumerics<T>::add(v, tmp);
}
atomicAdd(&self.data[e * self.stride[0] + mIn * self.stride[1]], v);
}
self.data[e * self.stride[0] + mIn * self.stride[1]] = v;
}
}
......@@ -62,8 +64,7 @@ __global__ void weightingBackwardWeightKernel(TensorInfo<T> self, TensorInfo<T>
wi = weightIndex.data[e * weightIndex.stride[0] + s * weightIndex.stride[1]];
for (mIn = 0; mIn < src.size[1]; mIn++) {
v = src.data[e * src.stride[0] + mIn * src.stride[1]];
v = THCNumerics<T>::mul(v, b);
v = THCNumerics<T>::mul(v, g);
v = THCNumerics<T>::mul(THCNumerics<T>::mul(v, b), g);
atomicAdd(&self.data[wi * self.stride[0] + mIn * self.stride[1] + mOut * self.stride[2]], v);
}
}
......@@ -93,37 +94,5 @@ __global__ void weightingBackwardBasisKernel(TensorInfo<T> self, TensorInfo<T> g
}
}
template<typename T>
__global__ void weightingBackwardKernel(TensorInfo<T> gradSrc, TensorInfo<T> gradWeight,
TensorInfo<T> gradBasis, TensorInfo<T> gradOutput,
TensorInfo<T> src, TensorInfo<T> weight,
TensorInfo<T> basis, TensorInfo<int64_t> weightIndex,
int n) {
KERNEL_LOOP(i, n) {
ptrdiff_t e = i / src.size[1], mIn = i % src.size[1], s, mOut;
T b, g, w, gs = ScalarConvert<int, T>::to(0), gw, gb;
int64_t wi;
T f = src.data[e * src.stride[0] + mIn * src.stride[1]];
for (s = 0; s < basis.size[1]; s++) {
b = basis.data[e * basis.stride[0] + s * basis.stride[1]];
wi = weightIndex.data[e * weightIndex.stride[0] + s * weightIndex.stride[1]];
gb = ScalarConvert<int, T>::to(0);
for (mOut = 0; mOut < gradOutput.size[1]; mOut++) {
g = gradOutput.data[e * gradOutput.stride[0] + mOut * gradOutput.stride[1]];
w = weight.data[wi * weight.stride[0] + mOut * weight.stride[1] + mIn * weight.stride[2]];
gs = THCNumerics<T>::add(gs, THCNumerics<T>::mul(THCNumerics<T>::mul(b, g), w));
gw = THCNumerics<T>::mul(THCNumerics<T>::mul(f, b), g);
atomicAdd(&gradWeight.data[wi * gradWeight.stride[0] + mOut * gradWeight.stride[1] + mIn * gradWeight.stride[2]], gw);
gb = THCNumerics<T>::add(gb, THCNumerics<T>::mul(THCNumerics<T>::mul(g, f), w));
}
atomicAdd(&gradBasis.data[e * gradBasis.stride[0] + s * gradBasis.stride[1]], gb);
}
gradSrc.data[e * gradSrc.stride[0] + mIn * gradSrc.stride[1]] = gs;
}
}
#include "generic/THCWeighting.cu"
#include "THC/THCGenerateFloatTypes.h"
......@@ -5,7 +5,7 @@
for (ptrdiff_t I = blockIdx.x * blockDim.x + threadIdx.x; I < N; I += blockDim.x * gridDim.x)
const int MAX_DIMS = 25;
const int NUM_THREADS = 512;
const int NUM_THREADS = 1024;
inline int GET_BLOCKS(int N) {
return (N + NUM_THREADS - 1) / NUM_THREADS;
......
......@@ -22,7 +22,7 @@ void THCTensor_(weightingBackwardSrc)(THCState *state, THCTensor *self, THCTenso
THCudaLongTensor *weightIndex) {
THCAssertSameGPU(THCTensor_(checkGPU)(state, 5, self, gradOutput, weight, basis, weightIndex));
weight = THCTensor_(newTranspose)(state, weight, 1, 2);
THCTensor_(fill)(state, self, ScalarConvert<int, real>::to(0));
TensorInfo<real> selfInfo = THCTensor_(getTensorInfo)(state, self);
TensorInfo<real> gradOutputInfo = THCTensor_(getTensorInfo)(state, gradOutput);
......@@ -30,10 +30,8 @@ void THCTensor_(weightingBackwardSrc)(THCState *state, THCTensor *self, THCTenso
TensorInfo<real> basisInfo = THCTensor_(getTensorInfo)(state, basis);
TensorInfo<int64_t> weightIndexInfo = THCudaLongTensor_getTensorInfo(state, weightIndex);
KERNEL_REAL_RUN(weightingBackwardSrcKernel, THCTensor_(nElement)(state, self), selfInfo,
KERNEL_REAL_RUN(weightingBackwardSrcKernel, THCTensor_(nElement)(state, gradOutput), selfInfo,
gradOutputInfo, weightInfo, basisInfo, weightIndexInfo);
THCTensor_(free)(state, weight);
}
void THCTensor_(weightingBackwardWeight)(THCState *state, THCTensor *self, THCTensor *gradOutput,
......@@ -70,32 +68,4 @@ void THCTensor_(weightingBackwardBasis)(THCState *state, THCTensor *self, THCTen
gradOutputInfo, srcInfo, weightInfo, weightIndexInfo);
}
void THCTensor_(weightingBackward)(THCState *state, THCTensor *gradSrc, THCTensor *gradWeight,
THCTensor *gradBasis, THCTensor *gradOutput, THCTensor *src,
THCTensor *weight, THCTensor *basis,
THCudaLongTensor *weightIndex) {
THCAssertSameGPU(THCTensor_(checkGPU)(state, 8, gradSrc, gradWeight, gradBasis, src, weight,
basis, weightIndex));
THCTensor_(fill)(state, gradWeight, ScalarConvert<int, real>::to(0));
THCTensor_(fill)(state, gradBasis, ScalarConvert<int, real>::to(0));
weight = THCTensor_(newTranspose)(state, weight, 1, 2);
TensorInfo<real> gradSrcInfo = THCTensor_(getTensorInfo)(state, gradSrc);
TensorInfo<real> gradWeightInfo = THCTensor_(getTensorInfo)(state, gradWeight);
TensorInfo<real> gradBasisInfo = THCTensor_(getTensorInfo)(state, gradBasis);
TensorInfo<real> gradOutputInfo = THCTensor_(getTensorInfo)(state, gradOutput);
TensorInfo<real> srcInfo = THCTensor_(getTensorInfo)(state, src);
TensorInfo<real> weightInfo = THCTensor_(getTensorInfo)(state, weight);
TensorInfo<real> basisInfo = THCTensor_(getTensorInfo)(state, basis);
TensorInfo<int64_t> weightIndexInfo = THCudaLongTensor_getTensorInfo(state, weightIndex);
KERNEL_REAL_RUN(weightingBackwardKernel, THCTensor_(nElement)(state, src), gradSrcInfo,
gradWeightInfo, gradBasisInfo, gradOutputInfo, srcInfo, weightInfo, basisInfo,
weightIndexInfo);
THCTensor_(free)(state, weight);
}
#endif // THC_GENERIC_FILE
......@@ -18,9 +18,4 @@ void THCTensor_(weightingBackwardBasis)(THCState *state, THCTensor *self, THCTen
THCTensor *src, THCTensor *weight,
THCudaLongTensor *weightIndex);
void THCTensor_(weightingBackward)(THCState *state, THCTensor *gradSrc, THCTensor *gradWeight,
THCTensor *gradBasis, THCTensor *gradOutput, THCTensor *src,
THCTensor *weight, THCTensor *basis,
THCudaLongTensor *weightIndex);
#endif // THC_GENERIC_FILE
......@@ -9,6 +9,3 @@ void THCCDoubleTensor_weightingBackwardWeight(THCudaDoubleTensor *self, THCudaDo
void THCCFloatTensor_weightingBackwardBasis( THCudaTensor *self, THCudaTensor *gradOutput, THCudaTensor *src, THCudaTensor *weight, THCudaLongTensor *weightIndex);
void THCCDoubleTensor_weightingBackwardBasis(THCudaDoubleTensor *self, THCudaDoubleTensor *gradOutput, THCudaDoubleTensor *src, THCudaDoubleTensor *weight, THCudaLongTensor *weightIndex);
void THCCFloatTensor_weightingBackward( THCudaTensor *gradSrc, THCudaTensor *gradWeight, THCudaTensor *gradBasis, THCudaTensor *gradOutput, THCudaTensor *src, THCudaTensor *weight, THCudaTensor *basis, THCudaLongTensor *weightIndex);
void THCCDoubleTensor_weightingBackward(THCudaDoubleTensor *gradSrc, THCudaDoubleTensor *gradWeight, THCudaDoubleTensor *gradBasis, THCudaDoubleTensor *gradOutput, THCudaDoubleTensor *src, THCudaDoubleTensor *weight, THCudaDoubleTensor *basis, THCudaLongTensor *weightIndex);
......@@ -22,12 +22,4 @@ void THCCTensor_(weightingBackwardBasis)(THCTensor *self, THCTensor *gradOutput,
THCTensor_(weightingBackwardBasis)(state, self, gradOutput, src, weight, weightIndex);
}
void THCCTensor_(weightingBackward)(THCTensor *gradSrc, THCTensor *gradWeight,
THCTensor *gradBasis, THCTensor *gradOutput, THCTensor *src,
THCTensor *weight, THCTensor *basis,
THCudaLongTensor *weightIndex) {
THCTensor_(weightingBackward)(state, gradSrc, gradWeight, gradBasis, gradOutput, src, weight,
basis, weightIndex);
}
#endif // THC_GENERIC_FILE
......@@ -76,4 +76,4 @@ def test_spline_basis_backward_gpu(degree):
pseudo = Variable(pseudo, requires_grad=True)
op = SplineBasis(degree, kernel_size, is_open_spline)
# assert gradcheck(op, (pseudo, ), eps=1e-6, atol=1e-4) is True
assert gradcheck(op, (pseudo, ), eps=1e-6, atol=1e-4) is True
......@@ -73,8 +73,8 @@ def test_spline_basis_backward_gpu():
basis, weight_index = spline_basis(1, pseudo, kernel_size, is_open_spline)
src = Variable(src, requires_grad=True)
weight = Variable(weight, requires_grad=True)
basis = Variable(basis, requires_grad=True)
weight = Variable(weight, requires_grad=False)
basis = Variable(basis, requires_grad=False)
op = SplineWeighting(weight_index)
assert gradcheck(op, (src, weight, basis), eps=1e-6, atol=1e-4) is True
......@@ -48,10 +48,3 @@ def weighting_backward_weight(self, grad_output, src, basis, weight_index):
def weighting_backward_basis(self, grad_output, src, weight, weight_index):
func = get_func('weightingBackwardBasis', self.is_cuda, self)
func(self, grad_output, src, weight, weight_index)
def weighting_backward(grad_src, grad_weight, grad_basis, grad_output, src,
weight, basis, weight_index):
func = get_func('weightingBackward', grad_src.is_cuda, grad_src)
func(grad_src, grad_weight, grad_basis, grad_output, src, weight, basis,
weight_index)
......@@ -5,7 +5,6 @@ from .utils.ffi import weighting_forward as weighting_fw
from .utils.ffi import weighting_backward_src as weighting_bw_src
from .utils.ffi import weighting_backward_weight as weighting_bw_weight
from .utils.ffi import weighting_backward_basis as weighting_bw_basis
from .utils.ffi import weighting_backward as weighting_bw
def weighting_forward(src, weight, basis, weight_index):
......@@ -32,16 +31,6 @@ def weighting_backward_basis(grad_output, src, weight, weight_index):
return grad_basis
def weighting_backward(grad_output, src, weight, basis, weight_index):
grad_src = src.new(src.size())
# grad_weight = weight.new(weight.size())
grad_weight = weight.new(weight.size(0), weight.size(2), weight.size(1))
grad_basis = basis.new(basis.size())
weighting_bw(grad_src, grad_weight, grad_basis, grad_output, src, weight,
basis, weight_index)
return grad_src, grad_weight.transpose(1, 2), grad_basis
class SplineWeighting(Function):
def __init__(self, weight_index):
super(SplineWeighting, self).__init__()
......@@ -55,20 +44,14 @@ class SplineWeighting(Function):
grad_src = grad_weight = grad_basis = None
src, weight, basis = self.saved_tensors
needs_src, needs_weight, needs_basis = self.needs_input_grad
if needs_src and needs_weight and needs_basis:
return weighting_backward(grad_output, src, weight, basis,
self.weight_index)
if needs_src:
if self.needs_input_grad[0]:
grad_src = weighting_backward_src(grad_output, weight, basis,
self.weight_index)
if needs_weight:
if self.needs_input_grad[1]:
K = weight.size(0)
grad_weight = weighting_backward_weight(grad_output, src, basis,
self.weight_index, K)
if needs_basis:
if self.needs_input_grad[2]:
grad_basis = weighting_backward_basis(grad_output, src, weight,
self.weight_index)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment