Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
9a9a511c
Commit
9a9a511c
authored
Mar 15, 2018
by
rusty1s
Browse files
added backward impl
parent
90576890
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
94 additions
and
3 deletions
+94
-3
torch_spline_conv/kernel/THCAtomics.cuh
torch_spline_conv/kernel/THCAtomics.cuh
+14
-0
torch_spline_conv/kernel/generic/kernel.cu
torch_spline_conv/kernel/generic/kernel.cu
+21
-2
torch_spline_conv/kernel/kernel.cu
torch_spline_conv/kernel/kernel.cu
+39
-1
torch_spline_conv/kernel/kernel.h
torch_spline_conv/kernel/kernel.h
+6
-0
torch_spline_conv/src/cuda.h
torch_spline_conv/src/cuda.h
+6
-0
torch_spline_conv/src/generic/cuda.c
torch_spline_conv/src/generic/cuda.c
+8
-0
No files found.
torch_spline_conv/kernel/THCAtomics.cuh
0 → 100644
View file @
9a9a511c
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
static
inline
__device__
void
atomicAdd
(
double
*
address
,
double
val
)
{
unsigned
long
long
int
*
address_as_ull
=
(
unsigned
long
long
int
*
)
address
;
unsigned
long
long
int
old
=
*
address_as_ull
;
unsigned
long
long
int
assumed
;
do
{
assumed
=
old
;
old
=
atomicCAS
(
address_as_ull
,
assumed
,
__double_as_longlong
(
val
+
__longlong_as_double
(
assumed
)));
}
while
(
assumed
!=
old
);
}
#elif !defined(__CUDA_ARCH__) && (CUDA_VERSION < 8000)
static
inline
__device__
void
atomicAdd
(
double
*
address
,
double
val
)
{}
#endif
torch_spline_conv/kernel/generic/kernel.cu
View file @
9a9a511c
...
...
@@ -17,14 +17,33 @@ void spline_(cubic_basis_forward)(THCState *state, THCTensor *basis, THCudaLongT
void
spline_
(
weighting_forward
)(
THCState
*
state
,
THCTensor
*
output
,
THCTensor
*
input
,
THCTensor
*
weight
,
THCTensor
*
basis
,
THCudaLongTensor
*
weight_index
)
{
THCAssertSameGPU
(
THCTensor_
(
checkGPU
)(
state
,
4
,
input
,
weight
,
basis
,
weight_index
));
const
int
n
=
THCTensor_
(
nElement
)(
state
,
output
);
TensorInfo
<
real
>
outputInfo
=
thc_
(
getTensorInfo
)(
state
,
output
);
TensorInfo
<
real
>
inputInfo
=
thc_
(
getTensorInfo
)(
state
,
input
);
TensorInfo
<
real
>
weightInfo
=
thc_
(
getTensorInfo
)(
state
,
weight
);
TensorInfo
<
real
>
basisInfo
=
thc_
(
getTensorInfo
)(
state
,
basis
);
TensorInfo
<
int64_t
>
weightIndexInfo
=
thc_getTensorInfo_Long
(
state
,
weight_index
);
KERNEL_RUN
(
weightingForwardKernel
,
n
,
outputInfo
,
inputInfo
,
weightInfo
,
basisInfo
,
weightIndexInfo
)
KERNEL_RUN
(
weightingForwardKernel
,
THCTensor_
(
nElement
)(
state
,
output
),
outputInfo
,
inputInfo
,
weightInfo
,
basisInfo
,
weightIndexInfo
)
}
void
spline_
(
weighting_backward_input
)(
THCState
*
state
,
THCTensor
*
grad_input
,
THCTensor
*
grad_output
,
THCTensor
*
weight
,
THCTensor
*
basis
,
THCudaLongTensor
*
weight_index
)
{
TensorInfo
<
real
>
gradInputInfo
=
thc_
(
getTensorInfo
)(
state
,
grad_input
);
TensorInfo
<
real
>
gradOutputInfo
=
thc_
(
getTensorInfo
)(
state
,
grad_output
);
TensorInfo
<
real
>
weightInfo
=
thc_
(
getTensorInfo
)(
state
,
weight
);
TensorInfo
<
real
>
basisInfo
=
thc_
(
getTensorInfo
)(
state
,
basis
);
TensorInfo
<
int64_t
>
weightIndexInfo
=
thc_getTensorInfo_Long
(
state
,
weight_index
);
KERNEL_RUN
(
weightingBackwardInputKernel
,
THCTensor_
(
nElement
)(
state
,
grad_input
),
gradInputInfo
,
gradOutputInfo
,
weightInfo
,
basisInfo
,
weightIndexInfo
)
}
void
spline_
(
weighting_backward_weight
)(
THCState
*
state
,
THCTensor
*
grad_weight
,
THCTensor
*
grad_output
,
THCTensor
*
input
,
THCTensor
*
basis
,
THCudaLongTensor
*
weight_index
)
{
TensorInfo
<
real
>
gradWeightInfo
=
thc_
(
getTensorInfo
)(
state
,
grad_weight
);
TensorInfo
<
real
>
gradOutputInfo
=
thc_
(
getTensorInfo
)(
state
,
grad_output
);
TensorInfo
<
real
>
inputInfo
=
thc_
(
getTensorInfo
)(
state
,
input
);
TensorInfo
<
real
>
basisInfo
=
thc_
(
getTensorInfo
)(
state
,
basis
);
TensorInfo
<
int64_t
>
weightIndexInfo
=
thc_getTensorInfo_Long
(
state
,
weight_index
);
KERNEL_RUN
(
weightingBackwardWeightKernel
,
THCTensor_
(
nElement
)(
state
,
grad_output
),
gradWeightInfo
,
gradOutputInfo
,
inputInfo
,
basisInfo
,
weightIndexInfo
)
}
#endif
torch_spline_conv/kernel/kernel.cu
View file @
9a9a511c
...
...
@@ -4,6 +4,7 @@
#include "common.cuh"
#include "THCBasisForward.cuh"
#include "THCAtomics.cuh"
#define spline_(NAME) TH_CONCAT_4(spline_, NAME, _kernel_, Real)
#define thc_(NAME) TH_CONCAT_4(thc_, NAME, _, Real)
...
...
@@ -18,7 +19,7 @@ __global__ void weightingForwardKernel(TensorInfo<Real> output, TensorInfo<Real>
int64_t
s
,
S
=
basis
.
size
[
1
],
m_in
,
M_in
=
input
.
size
[
1
],
m_out
=
i
%
output
.
size
[
1
],
M_out
=
output
.
size
[
1
],
weightOffset
;
Real
b
,
value
=
0
;
for
(
s
=
0
;
s
<
S
;
s
++
)
{
b
=
basis
.
data
[
edgeOffset
+
s
];
b
=
basis
.
data
[
edgeOffset
*
S
+
s
];
weightOffset
=
weightIndex
.
data
[
edgeOffset
*
S
+
s
]
*
M_in
*
M_out
+
m_out
;
for
(
m_in
=
0
;
m_in
<
M_in
;
m_in
++
)
{
value
+=
b
*
weight
.
data
[
weightOffset
+
m_in
*
M_out
]
*
input
.
data
[
inputOffset
+
m_in
*
input
.
stride
[
1
]];
...
...
@@ -28,6 +29,43 @@ __global__ void weightingForwardKernel(TensorInfo<Real> output, TensorInfo<Real>
}
}
template
<
typename
Real
>
__global__
void
weightingBackwardInputKernel
(
TensorInfo
<
Real
>
gradInput
,
TensorInfo
<
Real
>
gradOutput
,
TensorInfo
<
Real
>
weight
,
TensorInfo
<
Real
>
basis
,
TensorInfo
<
int64_t
>
weightIndex
,
int
n
)
{
KERNEL_LOOP
(
i
,
n
)
{
int64_t
edgeOffset
=
i
/
gradInput
.
size
[
1
],
gradOutputOffset
=
edgeOffset
*
M_out
;
int64_t
s
,
S
=
basis
.
size
[
1
],
m_in
=
i
%
gradInput
.
size
[
1
],
M_in
=
gradInput
.
size
[
1
],
m_out
,
M_out
=
gradOutput
.
size
[
1
],
weightOffset
;
Real
b
,
value
=
0
;
for
(
s
=
0
;
s
<
S
;
s
++
)
{
b
=
basis
.
data
[
edgeOffset
*
S
+
s
];
weightOffset
=
weightIndex
.
data
[
edgeOffset
*
S
+
s
]
*
M_in
*
M_out
;
for
(
m_out
=
0
;
m_out
<
M_out
;
m_out
++
)
{
value
+=
b
*
weight
.
data
[
weightOffset
+
m_in
*
M_out
+
m_out
]
*
gradOutput
.
data
[
gradOutputOffset
+
m_out
];
}
}
gradInput
.
data
[
i
]
=
value
;
}
}
template
<
typename
Real
>
__global__
void
weightingBackwardWeightKernel
(
TensorInfo
<
Real
>
gradWeight
,
TensorInfo
<
Real
>
gradOutput
,
TensorInfo
<
Real
>
input
,
TensorInfo
<
Real
>
basis
,
TensorInfo
<
int64_t
>
weightIndex
,
int
n
)
{
KERNEL_LOOP
(
i
,
n
)
{
int64_t
edgeOffset
=
i
/
gradOutput
.
size
[
1
],
inputOffset
=
edgeOffset
*
input
.
stride
[
0
];
int64_t
s
,
S
=
basis
.
size
[
1
];
int64_t
m_in
,
M_in
=
input
.
size
[
1
];
int64_t
m_out
=
i
%
gradOutput
.
size
[
1
],
M_out
=
gradOutput
.
size
[
1
];
int64_t
weightOffset
;
Real
b
;
Real
value
=
gradOutput
.
data
[
edgeOffset
*
M_out
+
m_out
];
for
(
s
=
0
;
s
<
S
;
s
++
)
{
b
=
basis
.
data
[
edgeOffset
*
S
+
s
];
weightOffset
=
weightIndex
.
data
[
edgeOffset
*
S
+
s
]
*
M_in
*
M_out
+
m_out
;
for
(
m_in
=
0
;
m_in
<
M_in
;
m_in
++
)
{
atomicAdd
(
&
gradWeight
.
data
[
weightOffset
+
m_in
*
M_out
],
b
*
value
*
input
.
data
[
inputOffset
+
m_in
*
input
.
stride
[
1
]]);
}
}
}
}
#include "generic/kernel.cu"
#include "THCGenerateFloatType.h"
#include "generic/kernel.cu"
...
...
torch_spline_conv/kernel/kernel.h
View file @
9a9a511c
...
...
@@ -12,6 +12,12 @@ void spline_cubic_basis_forward_kernel_Double(THCState *state, THCudaDoubleT
void
spline_weighting_forward_kernel_Float
(
THCState
*
state
,
THCudaTensor
*
output
,
THCudaTensor
*
input
,
THCudaTensor
*
weight
,
THCudaTensor
*
basis
,
THCudaLongTensor
*
weight_index
);
void
spline_weighting_forward_kernel_Double
(
THCState
*
state
,
THCudaDoubleTensor
*
output
,
THCudaDoubleTensor
*
input
,
THCudaDoubleTensor
*
weight
,
THCudaDoubleTensor
*
basis
,
THCudaLongTensor
*
weight_index
);
void
spline_weighting_backward_input_kernel_Float
(
THCState
*
state
,
THCudaTensor
*
grad_input
,
THCudaTensor
*
grad_output
,
THCudaTensor
*
weight
,
THCudaTensor
*
basis
,
THCudaLongTensor
*
weight_index
);
void
spline_weighting_backward_input_kernel_Double
(
THCState
*
state
,
THCudaDoubleTensor
*
grad_input
,
THCudaDoubleTensor
*
grad_output
,
THCudaDoubleTensor
*
weight
,
THCudaDoubleTensor
*
basis
,
THCudaLongTensor
*
weight_index
);
void
spline_weighting_backward_weight_kernel_Float
(
THCState
*
state
,
THCudaTensor
*
grad_weight
,
THCudaTensor
*
grad_output
,
THCudaTensor
*
input
,
THCudaTensor
*
basis
,
THCudaLongTensor
*
weight_index
);
void
spline_weighting_backward_weight_kernel_Double
(
THCState
*
state
,
THCudaDoubleTensor
*
grad_weight
,
THCudaDoubleTensor
*
grad_output
,
THCudaDoubleTensor
*
input
,
THCudaDoubleTensor
*
basis
,
THCudaLongTensor
*
weight_index
);
#ifdef __cplusplus
}
#endif
torch_spline_conv/src/cuda.h
View file @
9a9a511c
...
...
@@ -7,3 +7,9 @@ void spline_cubic_basis_forward_cuda_Double(THCudaDoubleTensor *basis, THCud
void
spline_weighting_forward_cuda_Float
(
THCudaTensor
*
output
,
THCudaTensor
*
input
,
THCudaTensor
*
weight
,
THCudaTensor
*
basis
,
THCudaLongTensor
*
weight_index
);
void
spline_weighting_forward_cuda_Double
(
THCudaDoubleTensor
*
output
,
THCudaDoubleTensor
*
input
,
THCudaDoubleTensor
*
weight
,
THCudaDoubleTensor
*
basis
,
THCudaLongTensor
*
weight_index
);
void
spline_weighting_backward_input_cuda_Float
(
THCudaTensor
*
grad_input
,
THCudaTensor
*
grad_output
,
THCudaTensor
*
weight
,
THCudaTensor
*
basis
,
THCudaLongTensor
*
weight_index
);
void
spline_weighting_backward_input_cuda_Double
(
THCudaDoubleTensor
*
grad_input
,
THCudaDoubleTensor
*
grad_output
,
THCudaDoubleTensor
*
weight
,
THCudaDoubleTensor
*
basis
,
THCudaLongTensor
*
weight_index
);
void
spline_weighting_backward_weight_cuda_Float
(
THCudaTensor
*
grad_weight
,
THCudaTensor
*
grad_output
,
THCudaTensor
*
input
,
THCudaTensor
*
basis
,
THCudaLongTensor
*
weight_index
);
void
spline_weighting_backward_weight_cuda_Double
(
THCudaDoubleTensor
*
grad_weight
,
THCudaDoubleTensor
*
grad_output
,
THCudaDoubleTensor
*
input
,
THCudaDoubleTensor
*
basis
,
THCudaLongTensor
*
weight_index
);
torch_spline_conv/src/generic/cuda.c
View file @
9a9a511c
...
...
@@ -18,4 +18,12 @@ void spline_(weighting_forward)(THCTensor *output, THCTensor *input, THCTensor *
spline_kernel_
(
weighting_forward
)(
state
,
output
,
input
,
weight
,
basis
,
weight_index
);
}
void
spline_
(
weighting_backward_input
)(
THCTensor
*
grad_input
,
THCTensor
*
grad_output
,
THCTensor
*
weight
,
THCTensor
*
basis
,
THCudaLongTensor
*
weight_index
)
{
spline_kernel_
(
weighting_backward_input
)(
state
,
grad_input
,
grad_output
,
weight
,
basis
,
weight_index
);
}
void
spline_
(
weighting_backward_weight
)(
THCTensor
*
grad_weight
,
THCTensor
*
grad_output
,
THCTensor
*
input
,
THCTensor
*
basis
,
THCudaLongTensor
*
weight_index
)
{
spline_kernel_
(
weighting_backward_weight
)(
state
,
grad_weight
,
grad_output
,
input
,
basis
,
weight_index
);
}
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment