Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
dd9778c8
Commit
dd9778c8
authored
Mar 14, 2018
by
rusty1s
Browse files
added weighting forward
parent
bbe0254f
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
75 additions
and
2 deletions
+75
-2
test/test_spline_conv.py
test/test_spline_conv.py
+26
-0
torch_spline_conv/kernel/THCBasisForward.cuh
torch_spline_conv/kernel/THCBasisForward.cuh
+1
-1
torch_spline_conv/kernel/common.cuh
torch_spline_conv/kernel/common.cuh
+8
-1
torch_spline_conv/kernel/generic/kernel.cu
torch_spline_conv/kernel/generic/kernel.cu
+13
-0
torch_spline_conv/kernel/kernel.cu
torch_spline_conv/kernel/kernel.cu
+17
-0
torch_spline_conv/kernel/kernel.h
torch_spline_conv/kernel/kernel.h
+3
-0
torch_spline_conv/src/cuda.h
torch_spline_conv/src/cuda.h
+3
-0
torch_spline_conv/src/generic/cuda.c
torch_spline_conv/src/generic/cuda.c
+4
-0
No files found.
test/test_spline_conv.py
View file @
dd9778c8
...
...
@@ -62,3 +62,29 @@ def test_spline_weighting_backward_cpu():
weight
=
Variable
(
weight
,
requires_grad
=
True
)
assert
gradcheck
(
op
,
(
x
,
pseudo
,
weight
),
eps
=
1e-6
,
atol
=
1e-4
)
is
True
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'no CUDA'
)
@
pytest
.
mark
.
parametrize
(
'tensor'
,
tensors
)
def
test_spline_conv_gpu
(
tensor
):
x
=
Tensor
(
tensor
,
[[
9
,
10
],
[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
]])
edge_index
=
torch
.
LongTensor
([[
0
,
0
,
0
,
0
],
[
1
,
2
,
3
,
4
]])
pseudo
=
[[
0.25
,
0.125
],
[
0.25
,
0.375
],
[
0.75
,
0.625
],
[
0.75
,
0.875
]]
pseudo
=
Tensor
(
tensor
,
pseudo
)
weight
=
torch
.
arange
(
0.5
,
0.5
*
25
,
step
=
0.5
,
out
=
x
.
new
()).
view
(
12
,
2
,
1
)
kernel_size
=
torch
.
LongTensor
([
3
,
4
])
is_open_spline
=
torch
.
ByteTensor
([
1
,
0
])
root_weight
=
torch
.
arange
(
12.5
,
13.5
,
step
=
0.5
,
out
=
x
.
new
()).
view
(
2
,
1
)
bias
=
Tensor
(
tensor
,
[
1
])
expected_output
=
spline_conv
(
x
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
is_open_spline
,
root_weight
,
1
,
bias
)
x
,
edge_index
,
pseudo
=
x
.
cuda
(),
edge_index
.
cuda
(),
pseudo
.
cuda
()
weight
,
kernel_size
=
weight
.
cuda
(),
kernel_size
.
cuda
()
is_open_spline
,
root_weight
=
is_open_spline
.
cuda
(),
root_weight
.
cuda
()
bias
=
bias
.
cuda
()
output
=
spline_conv
(
x
,
edge_index
,
pseudo
,
weight
,
kernel_size
,
is_open_spline
,
root_weight
,
1
,
bias
)
assert
output
.
cpu
().
tolist
()
==
expected_output
.
tolist
()
torch_spline_conv/kernel/THCBasisForward.cuh
View file @
dd9778c8
...
...
@@ -8,7 +8,7 @@
int64_t *kernelSizeData = THCudaLongTensor_data(state, kernel_size); \
uint8_t *isOpenSplineData = THCudaByteTensor_data(state, is_open_spline); \
\
KERNEL_RUN(NAME, pseudoInfo.size[1], n, basisInfo, weightIndexInfo, pseudoInfo, kernelSizeData, isOpenSplineData, K) \
KERNEL_
D_
RUN(NAME, pseudoInfo.size[1], n, basisInfo, weightIndexInfo, pseudoInfo, kernelSizeData, isOpenSplineData, K) \
}
template
<
typename
Real
,
int
M
,
int
D
>
...
...
torch_spline_conv/kernel/common.cuh
View file @
dd9778c8
...
...
@@ -24,7 +24,14 @@ struct TensorInfo {
#define KERNEL_LOOP(I, N) \
for (int I = blockIdx.x * blockDim.x + threadIdx.x; I < N; i += blockDim.x * gridDim.x)
#define KERNEL_RUN(NAME, D, N, ...) { \
#define KERNEL_RUN(NAME, N, ...) { \
int grid = GET_BLOCKS(N); \
cudaStream_t stream = THCState_getCurrentStream(state); \
NAME<real><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); \
THCudaCheck(cudaGetLastError()); \
}
#define KERNEL_D_RUN(NAME, D, N, ...) { \
int grid = GET_BLOCKS(N); \
cudaStream_t stream = THCState_getCurrentStream(state); \
switch (D) { \
...
...
torch_spline_conv/kernel/generic/kernel.cu
View file @
dd9778c8
...
...
@@ -14,4 +14,17 @@ void spline_(cubic_basis_forward)(THCState *state, THCTensor *basis, THCudaLongT
SPLINE_BASIS_FORWARD
(
cubicBasisForwardKernel
,
basis
,
weight_index
,
pseudo
,
kernel_size
,
is_open_spline
,
K
)
}
void
spline_
(
weighting_forward
)(
THCState
*
state
,
THCTensor
*
output
,
THCTensor
*
input
,
THCTensor
*
weight
,
THCTensor
*
basis
,
THCudaLongTensor
*
weight_index
)
{
THCAssertSameGPU
(
THCTensor_
(
checkGPU
)(
state
,
4
,
input
,
weight
,
basis
,
weight_index
));
const
int
n
=
THCTensor_
(
nElement
)(
state
,
output
);
TensorInfo
<
real
>
outputInfo
=
thc_
(
getTensorInfo
)(
state
,
output
);
TensorInfo
<
real
>
inputInfo
=
thc_
(
getTensorInfo
)(
state
,
input
);
TensorInfo
<
real
>
weightInfo
=
thc_
(
getTensorInfo
)(
state
,
weight
);
TensorInfo
<
real
>
basisInfo
=
thc_
(
getTensorInfo
)(
state
,
basis
);
TensorInfo
<
int64_t
>
weightIndexInfo
=
thc_getTensorInfo_Long
(
state
,
weight_index
);
KERNEL_RUN
(
weightingForwardKernel
,
n
,
outputInfo
,
inputInfo
,
weightInfo
,
basisInfo
,
weightIndexInfo
)
}
#endif
torch_spline_conv/kernel/kernel.cu
View file @
dd9778c8
...
...
@@ -11,6 +11,23 @@
#include "generic/common.cu"
#include "THCGenerateAllTypes.h"
template
<
typename
Real
>
__global__
void
weightingForwardKernel
(
TensorInfo
<
Real
>
output
,
TensorInfo
<
Real
>
input
,
TensorInfo
<
Real
>
weight
,
TensorInfo
<
Real
>
basis
,
TensorInfo
<
int64_t
>
weightIndex
,
int
n
)
{
KERNEL_LOOP
(
i
,
n
)
{
int64_t
edgeOffset
=
i
/
output
.
size
[
1
],
inputOffset
=
edgeOffset
*
input
.
stride
[
0
];
int64_t
s
,
S
=
basis
.
size
[
1
],
m_in
,
M_in
=
input
.
size
[
1
],
m_out
=
i
%
output
.
size
[
1
],
M_out
=
output
.
size
[
1
],
weightOffset
;
Real
b
,
value
=
0
;
for
(
s
=
0
;
s
<
S
;
s
++
)
{
b
=
basis
.
data
[
edgeOffset
+
s
];
weightOffset
=
weightIndex
.
data
[
edgeOffset
*
S
+
s
]
*
M_in
*
M_out
+
m_out
;
for
(
m_in
=
0
;
m_in
<
M_in
;
m_in
++
)
{
value
+=
b
*
weight
.
data
[
weightOffset
+
m_in
*
M_out
]
*
input
.
data
[
inputOffset
+
m_in
*
input
.
stride
[
1
]];
}
}
output
.
data
[
i
]
=
value
;
}
}
#include "generic/kernel.cu"
#include "THCGenerateFloatType.h"
#include "generic/kernel.cu"
...
...
torch_spline_conv/kernel/kernel.h
View file @
dd9778c8
...
...
@@ -9,6 +9,9 @@ void spline_quadratic_basis_forward_kernel_Double(THCState *state, THCudaDoubleT
void
spline_cubic_basis_forward_kernel_Float
(
THCState
*
state
,
THCudaTensor
*
basis
,
THCudaLongTensor
*
weight_index
,
THCudaTensor
*
pseudo
,
THCudaLongTensor
*
kernel_size
,
THCudaByteTensor
*
is_open_spline
,
int
K
);
void
spline_cubic_basis_forward_kernel_Double
(
THCState
*
state
,
THCudaDoubleTensor
*
basis
,
THCudaLongTensor
*
weight_index
,
THCudaDoubleTensor
*
pseudo
,
THCudaLongTensor
*
kernel_size
,
THCudaByteTensor
*
is_open_spline
,
int
K
);
void
spline_weighting_forward_kernel_Float
(
THCState
*
state
,
THCudaTensor
*
output
,
THCudaTensor
*
input
,
THCudaTensor
*
weight
,
THCudaTensor
*
basis
,
THCudaLongTensor
*
weight_index
);
void
spline_weighting_forward_kernel_Double
(
THCState
*
state
,
THCudaDoubleTensor
*
output
,
THCudaDoubleTensor
*
input
,
THCudaDoubleTensor
*
weight
,
THCudaDoubleTensor
*
basis
,
THCudaLongTensor
*
weight_index
);
#ifdef __cplusplus
}
#endif
torch_spline_conv/src/cuda.h
View file @
dd9778c8
...
...
@@ -4,3 +4,6 @@ void spline_quadratic_basis_forward_cuda_Float ( THCudaTensor *basis, THCud
void
spline_quadratic_basis_forward_cuda_Double
(
THCudaDoubleTensor
*
basis
,
THCudaLongTensor
*
weight_index
,
THCudaDoubleTensor
*
pseudo
,
THCudaLongTensor
*
kernel_size
,
THCudaByteTensor
*
is_open_spline
,
int
K
);
void
spline_cubic_basis_forward_cuda_Float
(
THCudaTensor
*
basis
,
THCudaLongTensor
*
weight_index
,
THCudaTensor
*
pseudo
,
THCudaLongTensor
*
kernel_size
,
THCudaByteTensor
*
is_open_spline
,
int
K
);
void
spline_cubic_basis_forward_cuda_Double
(
THCudaDoubleTensor
*
basis
,
THCudaLongTensor
*
weight_index
,
THCudaDoubleTensor
*
pseudo
,
THCudaLongTensor
*
kernel_size
,
THCudaByteTensor
*
is_open_spline
,
int
K
);
void
spline_weighting_forward_cuda_Float
(
THCudaTensor
*
output
,
THCudaTensor
*
input
,
THCudaTensor
*
weight
,
THCudaTensor
*
basis
,
THCudaLongTensor
*
weight_index
);
void
spline_weighting_forward_cuda_Double
(
THCudaDoubleTensor
*
output
,
THCudaDoubleTensor
*
input
,
THCudaDoubleTensor
*
weight
,
THCudaDoubleTensor
*
basis
,
THCudaLongTensor
*
weight_index
);
torch_spline_conv/src/generic/cuda.c
View file @
dd9778c8
...
...
@@ -14,4 +14,8 @@ void spline_(cubic_basis_forward)(THCTensor *basis, THCudaLongTensor *weight_ind
spline_kernel_
(
cubic_basis_forward
)(
state
,
basis
,
weight_index
,
pseudo
,
kernel_size
,
is_open_spline
,
K
);
}
void
spline_
(
weighting_forward
)(
THCTensor
*
output
,
THCTensor
*
input
,
THCTensor
*
weight
,
THCTensor
*
basis
,
THCudaLongTensor
*
weight_index
)
{
spline_kernel_
(
weighting_forward
)(
state
,
output
,
input
,
weight
,
basis
,
weight_index
);
}
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment