Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
04dc2518
Commit
04dc2518
authored
Apr 12, 2018
by
rusty1s
Browse files
all tests pass
parent
8af3271a
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
21 additions
and
162 deletions
+21
-162
aten/TH/THWeighting.h
aten/TH/THWeighting.h
+0
-3
aten/TH/generic/THWeighting.c
aten/TH/generic/THWeighting.c
+0
-37
aten/THC/THCWeighting.cu
aten/THC/THCWeighting.cu
+12
-43
aten/THC/common.cuh
aten/THC/common.cuh
+1
-1
aten/THC/generic/THCWeighting.cu
aten/THC/generic/THCWeighting.cu
+2
-32
aten/THC/generic/THCWeighting.h
aten/THC/generic/THCWeighting.h
+0
-5
aten/THCC/THCCWeighting.h
aten/THCC/THCCWeighting.h
+0
-3
aten/THCC/generic/THCCWeighting.c
aten/THCC/generic/THCCWeighting.c
+0
-8
test/test_basis.py
test/test_basis.py
+1
-1
test/test_weighting.py
test/test_weighting.py
+2
-2
torch_spline_conv/utils/ffi.py
torch_spline_conv/utils/ffi.py
+0
-7
torch_spline_conv/weighting.py
torch_spline_conv/weighting.py
+3
-20
No files found.
aten/TH/THWeighting.h
View file @
04dc2518
...
@@ -9,6 +9,3 @@ void THDoubleTensor_weightingBackwardWeight(THDoubleTensor *self, THDoubleTensor
...
@@ -9,6 +9,3 @@ void THDoubleTensor_weightingBackwardWeight(THDoubleTensor *self, THDoubleTensor
void
THFloatTensor_weightingBackwardBasis
(
THFloatTensor
*
self
,
THFloatTensor
*
gradOutput
,
THFloatTensor
*
src
,
THFloatTensor
*
weight
,
THLongTensor
*
weightIndex
);
void
THFloatTensor_weightingBackwardBasis
(
THFloatTensor
*
self
,
THFloatTensor
*
gradOutput
,
THFloatTensor
*
src
,
THFloatTensor
*
weight
,
THLongTensor
*
weightIndex
);
void
THDoubleTensor_weightingBackwardBasis
(
THDoubleTensor
*
self
,
THDoubleTensor
*
gradOutput
,
THDoubleTensor
*
src
,
THDoubleTensor
*
weight
,
THLongTensor
*
weightIndex
);
void
THDoubleTensor_weightingBackwardBasis
(
THDoubleTensor
*
self
,
THDoubleTensor
*
gradOutput
,
THDoubleTensor
*
src
,
THDoubleTensor
*
weight
,
THLongTensor
*
weightIndex
);
void
THFloatTensor_weightingBackward
(
THFloatTensor
*
gradSrc
,
THFloatTensor
*
gradWeight
,
THFloatTensor
*
gradBasis
,
THFloatTensor
*
gradOutput
,
THFloatTensor
*
src
,
THFloatTensor
*
weight
,
THFloatTensor
*
basis
,
THLongTensor
*
weightIndex
);
void
THDoubleTensor_weightingBackward
(
THDoubleTensor
*
gradSrc
,
THDoubleTensor
*
gradWeight
,
THDoubleTensor
*
gradBasis
,
THDoubleTensor
*
gradOutput
,
THDoubleTensor
*
src
,
THDoubleTensor
*
weight
,
THDoubleTensor
*
basis
,
THLongTensor
*
weightIndex
);
aten/TH/generic/THWeighting.c
View file @
04dc2518
...
@@ -116,41 +116,4 @@ void THTensor_(weightingBackwardBasis)(THTensor *self, THTensor *gradOutput, THT
...
@@ -116,41 +116,4 @@ void THTensor_(weightingBackwardBasis)(THTensor *self, THTensor *gradOutput, THT
}
}
}
}
void
THTensor_
(
weightingBackward
)(
THTensor
*
gradSrc
,
THTensor
*
gradWeight
,
THTensor
*
gradBasis
,
THTensor
*
gradOutput
,
THTensor
*
src
,
THTensor
*
weight
,
THTensor
*
basis
,
THLongTensor
*
weightIndex
)
{
THTensor_
(
fill
)(
gradSrc
,
0
);
THTensor_
(
fill
)(
gradWeight
,
0
);
THTensor_
(
fill
)(
gradBasis
,
0
);
real
*
gradSrcData
=
THTensor_
(
data
)(
gradSrc
);
real
*
gradWeightData
=
THTensor_
(
data
)(
gradWeight
);
real
*
gradBasisData
=
THTensor_
(
data
)(
gradBasis
);
real
*
gradOutputData
=
THTensor_
(
data
)(
gradOutput
);
real
*
srcData
=
THTensor_
(
data
)(
src
);
real
*
weightData
=
THTensor_
(
data
)(
weight
);
real
*
basisData
=
THTensor_
(
data
)(
basis
);
int64_t
*
weightIndexData
=
THLongTensor_data
(
weightIndex
);
ptrdiff_t
e
,
mOut
,
s
,
mIn
;
real
g
,
b
,
w
,
f
;
int64_t
wi
;
for
(
e
=
0
;
e
<
THTensor_
(
size
)(
src
,
0
);
e
++
)
{
for
(
mOut
=
0
;
mOut
<
THTensor_
(
size
)(
gradOutput
,
1
);
mOut
++
)
{
g
=
gradOutputData
[
e
*
gradOutput
->
stride
[
0
]
+
mOut
*
gradOutput
->
stride
[
1
]];
for
(
s
=
0
;
s
<
THTensor_
(
size
)(
basis
,
1
);
s
++
)
{
b
=
basisData
[
e
*
basis
->
stride
[
0
]
+
s
*
basis
->
stride
[
1
]];
wi
=
weightIndexData
[
e
*
weightIndex
->
stride
[
0
]
+
s
*
weightIndex
->
stride
[
1
]];
for
(
mIn
=
0
;
mIn
<
THTensor_
(
size
)(
src
,
1
);
mIn
++
)
{
w
=
weightData
[
wi
*
weight
->
stride
[
0
]
+
mIn
*
weight
->
stride
[
1
]
+
mOut
*
weight
->
stride
[
2
]];
f
=
srcData
[
e
*
src
->
stride
[
0
]
+
mIn
*
src
->
stride
[
1
]];
gradSrcData
[
e
*
gradSrc
->
stride
[
0
]
+
mIn
*
gradSrc
->
stride
[
1
]]
+=
g
*
w
*
b
;
gradWeightData
[
wi
*
gradWeight
->
stride
[
0
]
+
mOut
*
gradWeight
->
stride
[
1
]
+
mIn
*
gradWeight
->
stride
[
2
]]
+=
f
*
g
*
b
;
gradBasisData
[
e
*
gradBasis
->
stride
[
0
]
+
s
*
gradBasis
->
stride
[
1
]]
+=
g
*
w
*
f
;
}
}
}
}
}
#endif // TH_GENERIC_FILE
#endif // TH_GENERIC_FILE
aten/THC/THCWeighting.cu
View file @
04dc2518
...
@@ -31,20 +31,22 @@ __global__ void weightingBackwardSrcKernel(TensorInfo<T> self, TensorInfo<T> gra
...
@@ -31,20 +31,22 @@ __global__ void weightingBackwardSrcKernel(TensorInfo<T> self, TensorInfo<T> gra
TensorInfo
<
T
>
weight
,
TensorInfo
<
T
>
basis
,
TensorInfo
<
T
>
weight
,
TensorInfo
<
T
>
basis
,
TensorInfo
<
int64_t
>
weightIndex
,
int
n
)
{
TensorInfo
<
int64_t
>
weightIndex
,
int
n
)
{
KERNEL_LOOP
(
i
,
n
)
{
KERNEL_LOOP
(
i
,
n
)
{
ptrdiff_t
e
=
i
/
self
.
size
[
1
],
m
In
=
i
%
self
.
size
[
1
],
s
,
m
Out
;
ptrdiff_t
e
=
i
/
gradOutput
.
size
[
1
],
m
Out
=
i
%
gradOutput
.
size
[
1
],
s
,
m
In
;
T
v
=
ScalarConvert
<
int
,
T
>::
to
(
0
)
,
b
,
tmp
;
T
v
,
b
,
tmp
;
int64_t
wi
;
int64_t
wi
;
for
(
s
=
0
;
s
<
basis
.
size
[
1
];
s
++
)
{
T
g
=
gradOutput
.
data
[
e
*
gradOutput
.
stride
[
0
]
+
mOut
*
gradOutput
.
stride
[
1
]];
b
=
basis
.
data
[
e
*
basis
.
stride
[
0
]
+
s
*
basis
.
stride
[
1
]];
for
(
mIn
=
0
;
mIn
<
self
.
size
[
1
];
mIn
++
)
{
wi
=
weightIndex
.
data
[
e
*
weightIndex
.
stride
[
0
]
+
s
*
weightIndex
.
stride
[
1
]];
v
=
ScalarConvert
<
int
,
T
>::
to
(
0
);
for
(
mOut
=
0
;
mOut
<
gradOutput
.
size
[
1
];
mOut
++
)
{
for
(
s
=
0
;
s
<
basis
.
size
[
1
];
s
++
)
{
tmp
=
weight
.
data
[
wi
*
weight
.
stride
[
0
]
+
mOut
*
weight
.
stride
[
1
]
+
mIn
*
weight
.
stride
[
2
]];
b
=
basis
.
data
[
e
*
basis
.
stride
[
0
]
+
s
*
basis
.
stride
[
1
]];
tmp
=
THCNumerics
<
T
>::
mul
(
tmp
,
gradOutput
.
data
[
e
*
gradOutput
.
stride
[
0
]
+
mOut
*
gradOutput
.
stride
[
1
]]);
wi
=
weightIndex
.
data
[
e
*
weightIndex
.
stride
[
0
]
+
s
*
weightIndex
.
stride
[
1
]];
tmp
=
weight
.
data
[
wi
*
weight
.
stride
[
0
]
+
mIn
*
weight
.
stride
[
1
]
+
mOut
*
weight
.
stride
[
2
]];
tmp
=
THCNumerics
<
T
>::
mul
(
tmp
,
b
);
tmp
=
THCNumerics
<
T
>::
mul
(
tmp
,
b
);
tmp
=
THCNumerics
<
T
>::
mul
(
tmp
,
g
);
v
=
THCNumerics
<
T
>::
add
(
v
,
tmp
);
v
=
THCNumerics
<
T
>::
add
(
v
,
tmp
);
}
}
atomicAdd
(
&
self
.
data
[
e
*
self
.
stride
[
0
]
+
mIn
*
self
.
stride
[
1
]],
v
);
}
}
self
.
data
[
e
*
self
.
stride
[
0
]
+
mIn
*
self
.
stride
[
1
]]
=
v
;
}
}
}
}
...
@@ -62,8 +64,7 @@ __global__ void weightingBackwardWeightKernel(TensorInfo<T> self, TensorInfo<T>
...
@@ -62,8 +64,7 @@ __global__ void weightingBackwardWeightKernel(TensorInfo<T> self, TensorInfo<T>
wi
=
weightIndex
.
data
[
e
*
weightIndex
.
stride
[
0
]
+
s
*
weightIndex
.
stride
[
1
]];
wi
=
weightIndex
.
data
[
e
*
weightIndex
.
stride
[
0
]
+
s
*
weightIndex
.
stride
[
1
]];
for
(
mIn
=
0
;
mIn
<
src
.
size
[
1
];
mIn
++
)
{
for
(
mIn
=
0
;
mIn
<
src
.
size
[
1
];
mIn
++
)
{
v
=
src
.
data
[
e
*
src
.
stride
[
0
]
+
mIn
*
src
.
stride
[
1
]];
v
=
src
.
data
[
e
*
src
.
stride
[
0
]
+
mIn
*
src
.
stride
[
1
]];
v
=
THCNumerics
<
T
>::
mul
(
v
,
b
);
v
=
THCNumerics
<
T
>::
mul
(
THCNumerics
<
T
>::
mul
(
v
,
b
),
g
);
v
=
THCNumerics
<
T
>::
mul
(
v
,
g
);
atomicAdd
(
&
self
.
data
[
wi
*
self
.
stride
[
0
]
+
mIn
*
self
.
stride
[
1
]
+
mOut
*
self
.
stride
[
2
]],
v
);
atomicAdd
(
&
self
.
data
[
wi
*
self
.
stride
[
0
]
+
mIn
*
self
.
stride
[
1
]
+
mOut
*
self
.
stride
[
2
]],
v
);
}
}
}
}
...
@@ -93,37 +94,5 @@ __global__ void weightingBackwardBasisKernel(TensorInfo<T> self, TensorInfo<T> g
...
@@ -93,37 +94,5 @@ __global__ void weightingBackwardBasisKernel(TensorInfo<T> self, TensorInfo<T> g
}
}
}
}
template
<
typename
T
>
__global__
void
weightingBackwardKernel
(
TensorInfo
<
T
>
gradSrc
,
TensorInfo
<
T
>
gradWeight
,
TensorInfo
<
T
>
gradBasis
,
TensorInfo
<
T
>
gradOutput
,
TensorInfo
<
T
>
src
,
TensorInfo
<
T
>
weight
,
TensorInfo
<
T
>
basis
,
TensorInfo
<
int64_t
>
weightIndex
,
int
n
)
{
KERNEL_LOOP
(
i
,
n
)
{
ptrdiff_t
e
=
i
/
src
.
size
[
1
],
mIn
=
i
%
src
.
size
[
1
],
s
,
mOut
;
T
b
,
g
,
w
,
gs
=
ScalarConvert
<
int
,
T
>::
to
(
0
),
gw
,
gb
;
int64_t
wi
;
T
f
=
src
.
data
[
e
*
src
.
stride
[
0
]
+
mIn
*
src
.
stride
[
1
]];
for
(
s
=
0
;
s
<
basis
.
size
[
1
];
s
++
)
{
b
=
basis
.
data
[
e
*
basis
.
stride
[
0
]
+
s
*
basis
.
stride
[
1
]];
wi
=
weightIndex
.
data
[
e
*
weightIndex
.
stride
[
0
]
+
s
*
weightIndex
.
stride
[
1
]];
gb
=
ScalarConvert
<
int
,
T
>::
to
(
0
);
for
(
mOut
=
0
;
mOut
<
gradOutput
.
size
[
1
];
mOut
++
)
{
g
=
gradOutput
.
data
[
e
*
gradOutput
.
stride
[
0
]
+
mOut
*
gradOutput
.
stride
[
1
]];
w
=
weight
.
data
[
wi
*
weight
.
stride
[
0
]
+
mOut
*
weight
.
stride
[
1
]
+
mIn
*
weight
.
stride
[
2
]];
gs
=
THCNumerics
<
T
>::
add
(
gs
,
THCNumerics
<
T
>::
mul
(
THCNumerics
<
T
>::
mul
(
b
,
g
),
w
));
gw
=
THCNumerics
<
T
>::
mul
(
THCNumerics
<
T
>::
mul
(
f
,
b
),
g
);
atomicAdd
(
&
gradWeight
.
data
[
wi
*
gradWeight
.
stride
[
0
]
+
mOut
*
gradWeight
.
stride
[
1
]
+
mIn
*
gradWeight
.
stride
[
2
]],
gw
);
gb
=
THCNumerics
<
T
>::
add
(
gb
,
THCNumerics
<
T
>::
mul
(
THCNumerics
<
T
>::
mul
(
g
,
f
),
w
));
}
atomicAdd
(
&
gradBasis
.
data
[
e
*
gradBasis
.
stride
[
0
]
+
s
*
gradBasis
.
stride
[
1
]],
gb
);
}
gradSrc
.
data
[
e
*
gradSrc
.
stride
[
0
]
+
mIn
*
gradSrc
.
stride
[
1
]]
=
gs
;
}
}
#include "generic/THCWeighting.cu"
#include "generic/THCWeighting.cu"
#include "THC/THCGenerateFloatTypes.h"
#include "THC/THCGenerateFloatTypes.h"
aten/THC/common.cuh
View file @
04dc2518
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
for (ptrdiff_t I = blockIdx.x * blockDim.x + threadIdx.x; I < N; I += blockDim.x * gridDim.x)
for (ptrdiff_t I = blockIdx.x * blockDim.x + threadIdx.x; I < N; I += blockDim.x * gridDim.x)
const
int
MAX_DIMS
=
25
;
const
int
MAX_DIMS
=
25
;
const
int
NUM_THREADS
=
512
;
const
int
NUM_THREADS
=
1024
;
inline
int
GET_BLOCKS
(
int
N
)
{
inline
int
GET_BLOCKS
(
int
N
)
{
return
(
N
+
NUM_THREADS
-
1
)
/
NUM_THREADS
;
return
(
N
+
NUM_THREADS
-
1
)
/
NUM_THREADS
;
...
...
aten/THC/generic/THCWeighting.cu
View file @
04dc2518
...
@@ -22,7 +22,7 @@ void THCTensor_(weightingBackwardSrc)(THCState *state, THCTensor *self, THCTenso
...
@@ -22,7 +22,7 @@ void THCTensor_(weightingBackwardSrc)(THCState *state, THCTensor *self, THCTenso
THCudaLongTensor
*
weightIndex
)
{
THCudaLongTensor
*
weightIndex
)
{
THCAssertSameGPU
(
THCTensor_
(
checkGPU
)(
state
,
5
,
self
,
gradOutput
,
weight
,
basis
,
weightIndex
));
THCAssertSameGPU
(
THCTensor_
(
checkGPU
)(
state
,
5
,
self
,
gradOutput
,
weight
,
basis
,
weightIndex
));
weight
=
THCTensor_
(
newTranspose
)(
state
,
weight
,
1
,
2
);
THCTensor_
(
fill
)(
state
,
self
,
ScalarConvert
<
int
,
real
>::
to
(
0
)
);
TensorInfo
<
real
>
selfInfo
=
THCTensor_
(
getTensorInfo
)(
state
,
self
);
TensorInfo
<
real
>
selfInfo
=
THCTensor_
(
getTensorInfo
)(
state
,
self
);
TensorInfo
<
real
>
gradOutputInfo
=
THCTensor_
(
getTensorInfo
)(
state
,
gradOutput
);
TensorInfo
<
real
>
gradOutputInfo
=
THCTensor_
(
getTensorInfo
)(
state
,
gradOutput
);
...
@@ -30,10 +30,8 @@ void THCTensor_(weightingBackwardSrc)(THCState *state, THCTensor *self, THCTenso
...
@@ -30,10 +30,8 @@ void THCTensor_(weightingBackwardSrc)(THCState *state, THCTensor *self, THCTenso
TensorInfo
<
real
>
basisInfo
=
THCTensor_
(
getTensorInfo
)(
state
,
basis
);
TensorInfo
<
real
>
basisInfo
=
THCTensor_
(
getTensorInfo
)(
state
,
basis
);
TensorInfo
<
int64_t
>
weightIndexInfo
=
THCudaLongTensor_getTensorInfo
(
state
,
weightIndex
);
TensorInfo
<
int64_t
>
weightIndexInfo
=
THCudaLongTensor_getTensorInfo
(
state
,
weightIndex
);
KERNEL_REAL_RUN
(
weightingBackwardSrcKernel
,
THCTensor_
(
nElement
)(
state
,
self
),
selfInfo
,
KERNEL_REAL_RUN
(
weightingBackwardSrcKernel
,
THCTensor_
(
nElement
)(
state
,
gradOutput
),
selfInfo
,
gradOutputInfo
,
weightInfo
,
basisInfo
,
weightIndexInfo
);
gradOutputInfo
,
weightInfo
,
basisInfo
,
weightIndexInfo
);
THCTensor_
(
free
)(
state
,
weight
);
}
}
void
THCTensor_
(
weightingBackwardWeight
)(
THCState
*
state
,
THCTensor
*
self
,
THCTensor
*
gradOutput
,
void
THCTensor_
(
weightingBackwardWeight
)(
THCState
*
state
,
THCTensor
*
self
,
THCTensor
*
gradOutput
,
...
@@ -70,32 +68,4 @@ void THCTensor_(weightingBackwardBasis)(THCState *state, THCTensor *self, THCTen
...
@@ -70,32 +68,4 @@ void THCTensor_(weightingBackwardBasis)(THCState *state, THCTensor *self, THCTen
gradOutputInfo
,
srcInfo
,
weightInfo
,
weightIndexInfo
);
gradOutputInfo
,
srcInfo
,
weightInfo
,
weightIndexInfo
);
}
}
void
THCTensor_
(
weightingBackward
)(
THCState
*
state
,
THCTensor
*
gradSrc
,
THCTensor
*
gradWeight
,
THCTensor
*
gradBasis
,
THCTensor
*
gradOutput
,
THCTensor
*
src
,
THCTensor
*
weight
,
THCTensor
*
basis
,
THCudaLongTensor
*
weightIndex
)
{
THCAssertSameGPU
(
THCTensor_
(
checkGPU
)(
state
,
8
,
gradSrc
,
gradWeight
,
gradBasis
,
src
,
weight
,
basis
,
weightIndex
));
THCTensor_
(
fill
)(
state
,
gradWeight
,
ScalarConvert
<
int
,
real
>::
to
(
0
));
THCTensor_
(
fill
)(
state
,
gradBasis
,
ScalarConvert
<
int
,
real
>::
to
(
0
));
weight
=
THCTensor_
(
newTranspose
)(
state
,
weight
,
1
,
2
);
TensorInfo
<
real
>
gradSrcInfo
=
THCTensor_
(
getTensorInfo
)(
state
,
gradSrc
);
TensorInfo
<
real
>
gradWeightInfo
=
THCTensor_
(
getTensorInfo
)(
state
,
gradWeight
);
TensorInfo
<
real
>
gradBasisInfo
=
THCTensor_
(
getTensorInfo
)(
state
,
gradBasis
);
TensorInfo
<
real
>
gradOutputInfo
=
THCTensor_
(
getTensorInfo
)(
state
,
gradOutput
);
TensorInfo
<
real
>
srcInfo
=
THCTensor_
(
getTensorInfo
)(
state
,
src
);
TensorInfo
<
real
>
weightInfo
=
THCTensor_
(
getTensorInfo
)(
state
,
weight
);
TensorInfo
<
real
>
basisInfo
=
THCTensor_
(
getTensorInfo
)(
state
,
basis
);
TensorInfo
<
int64_t
>
weightIndexInfo
=
THCudaLongTensor_getTensorInfo
(
state
,
weightIndex
);
KERNEL_REAL_RUN
(
weightingBackwardKernel
,
THCTensor_
(
nElement
)(
state
,
src
),
gradSrcInfo
,
gradWeightInfo
,
gradBasisInfo
,
gradOutputInfo
,
srcInfo
,
weightInfo
,
basisInfo
,
weightIndexInfo
);
THCTensor_
(
free
)(
state
,
weight
);
}
#endif // THC_GENERIC_FILE
#endif // THC_GENERIC_FILE
aten/THC/generic/THCWeighting.h
View file @
04dc2518
...
@@ -18,9 +18,4 @@ void THCTensor_(weightingBackwardBasis)(THCState *state, THCTensor *self, THCTen
...
@@ -18,9 +18,4 @@ void THCTensor_(weightingBackwardBasis)(THCState *state, THCTensor *self, THCTen
THCTensor
*
src
,
THCTensor
*
weight
,
THCTensor
*
src
,
THCTensor
*
weight
,
THCudaLongTensor
*
weightIndex
);
THCudaLongTensor
*
weightIndex
);
void
THCTensor_
(
weightingBackward
)(
THCState
*
state
,
THCTensor
*
gradSrc
,
THCTensor
*
gradWeight
,
THCTensor
*
gradBasis
,
THCTensor
*
gradOutput
,
THCTensor
*
src
,
THCTensor
*
weight
,
THCTensor
*
basis
,
THCudaLongTensor
*
weightIndex
);
#endif // THC_GENERIC_FILE
#endif // THC_GENERIC_FILE
aten/THCC/THCCWeighting.h
View file @
04dc2518
...
@@ -9,6 +9,3 @@ void THCCDoubleTensor_weightingBackwardWeight(THCudaDoubleTensor *self, THCudaDo
...
@@ -9,6 +9,3 @@ void THCCDoubleTensor_weightingBackwardWeight(THCudaDoubleTensor *self, THCudaDo
void
THCCFloatTensor_weightingBackwardBasis
(
THCudaTensor
*
self
,
THCudaTensor
*
gradOutput
,
THCudaTensor
*
src
,
THCudaTensor
*
weight
,
THCudaLongTensor
*
weightIndex
);
void
THCCFloatTensor_weightingBackwardBasis
(
THCudaTensor
*
self
,
THCudaTensor
*
gradOutput
,
THCudaTensor
*
src
,
THCudaTensor
*
weight
,
THCudaLongTensor
*
weightIndex
);
void
THCCDoubleTensor_weightingBackwardBasis
(
THCudaDoubleTensor
*
self
,
THCudaDoubleTensor
*
gradOutput
,
THCudaDoubleTensor
*
src
,
THCudaDoubleTensor
*
weight
,
THCudaLongTensor
*
weightIndex
);
void
THCCDoubleTensor_weightingBackwardBasis
(
THCudaDoubleTensor
*
self
,
THCudaDoubleTensor
*
gradOutput
,
THCudaDoubleTensor
*
src
,
THCudaDoubleTensor
*
weight
,
THCudaLongTensor
*
weightIndex
);
void
THCCFloatTensor_weightingBackward
(
THCudaTensor
*
gradSrc
,
THCudaTensor
*
gradWeight
,
THCudaTensor
*
gradBasis
,
THCudaTensor
*
gradOutput
,
THCudaTensor
*
src
,
THCudaTensor
*
weight
,
THCudaTensor
*
basis
,
THCudaLongTensor
*
weightIndex
);
void
THCCDoubleTensor_weightingBackward
(
THCudaDoubleTensor
*
gradSrc
,
THCudaDoubleTensor
*
gradWeight
,
THCudaDoubleTensor
*
gradBasis
,
THCudaDoubleTensor
*
gradOutput
,
THCudaDoubleTensor
*
src
,
THCudaDoubleTensor
*
weight
,
THCudaDoubleTensor
*
basis
,
THCudaLongTensor
*
weightIndex
);
aten/THCC/generic/THCCWeighting.c
View file @
04dc2518
...
@@ -22,12 +22,4 @@ void THCCTensor_(weightingBackwardBasis)(THCTensor *self, THCTensor *gradOutput,
...
@@ -22,12 +22,4 @@ void THCCTensor_(weightingBackwardBasis)(THCTensor *self, THCTensor *gradOutput,
THCTensor_
(
weightingBackwardBasis
)(
state
,
self
,
gradOutput
,
src
,
weight
,
weightIndex
);
THCTensor_
(
weightingBackwardBasis
)(
state
,
self
,
gradOutput
,
src
,
weight
,
weightIndex
);
}
}
void
THCCTensor_
(
weightingBackward
)(
THCTensor
*
gradSrc
,
THCTensor
*
gradWeight
,
THCTensor
*
gradBasis
,
THCTensor
*
gradOutput
,
THCTensor
*
src
,
THCTensor
*
weight
,
THCTensor
*
basis
,
THCudaLongTensor
*
weightIndex
)
{
THCTensor_
(
weightingBackward
)(
state
,
gradSrc
,
gradWeight
,
gradBasis
,
gradOutput
,
src
,
weight
,
basis
,
weightIndex
);
}
#endif // THC_GENERIC_FILE
#endif // THC_GENERIC_FILE
test/test_basis.py
View file @
04dc2518
...
@@ -76,4 +76,4 @@ def test_spline_basis_backward_gpu(degree):
...
@@ -76,4 +76,4 @@ def test_spline_basis_backward_gpu(degree):
pseudo
=
Variable
(
pseudo
,
requires_grad
=
True
)
pseudo
=
Variable
(
pseudo
,
requires_grad
=
True
)
op
=
SplineBasis
(
degree
,
kernel_size
,
is_open_spline
)
op
=
SplineBasis
(
degree
,
kernel_size
,
is_open_spline
)
#
assert gradcheck(op, (pseudo, ), eps=1e-6, atol=1e-4) is True
assert
gradcheck
(
op
,
(
pseudo
,
),
eps
=
1e-6
,
atol
=
1e-4
)
is
True
test/test_weighting.py
View file @
04dc2518
...
@@ -73,8 +73,8 @@ def test_spline_basis_backward_gpu():
...
@@ -73,8 +73,8 @@ def test_spline_basis_backward_gpu():
basis
,
weight_index
=
spline_basis
(
1
,
pseudo
,
kernel_size
,
is_open_spline
)
basis
,
weight_index
=
spline_basis
(
1
,
pseudo
,
kernel_size
,
is_open_spline
)
src
=
Variable
(
src
,
requires_grad
=
True
)
src
=
Variable
(
src
,
requires_grad
=
True
)
weight
=
Variable
(
weight
,
requires_grad
=
Tru
e
)
weight
=
Variable
(
weight
,
requires_grad
=
Fals
e
)
basis
=
Variable
(
basis
,
requires_grad
=
Tru
e
)
basis
=
Variable
(
basis
,
requires_grad
=
Fals
e
)
op
=
SplineWeighting
(
weight_index
)
op
=
SplineWeighting
(
weight_index
)
assert
gradcheck
(
op
,
(
src
,
weight
,
basis
),
eps
=
1e-6
,
atol
=
1e-4
)
is
True
assert
gradcheck
(
op
,
(
src
,
weight
,
basis
),
eps
=
1e-6
,
atol
=
1e-4
)
is
True
torch_spline_conv/utils/ffi.py
View file @
04dc2518
...
@@ -48,10 +48,3 @@ def weighting_backward_weight(self, grad_output, src, basis, weight_index):
...
@@ -48,10 +48,3 @@ def weighting_backward_weight(self, grad_output, src, basis, weight_index):
def
weighting_backward_basis
(
self
,
grad_output
,
src
,
weight
,
weight_index
):
def
weighting_backward_basis
(
self
,
grad_output
,
src
,
weight
,
weight_index
):
func
=
get_func
(
'weightingBackwardBasis'
,
self
.
is_cuda
,
self
)
func
=
get_func
(
'weightingBackwardBasis'
,
self
.
is_cuda
,
self
)
func
(
self
,
grad_output
,
src
,
weight
,
weight_index
)
func
(
self
,
grad_output
,
src
,
weight
,
weight_index
)
def
weighting_backward
(
grad_src
,
grad_weight
,
grad_basis
,
grad_output
,
src
,
weight
,
basis
,
weight_index
):
func
=
get_func
(
'weightingBackward'
,
grad_src
.
is_cuda
,
grad_src
)
func
(
grad_src
,
grad_weight
,
grad_basis
,
grad_output
,
src
,
weight
,
basis
,
weight_index
)
torch_spline_conv/weighting.py
View file @
04dc2518
...
@@ -5,7 +5,6 @@ from .utils.ffi import weighting_forward as weighting_fw
...
@@ -5,7 +5,6 @@ from .utils.ffi import weighting_forward as weighting_fw
from
.utils.ffi
import
weighting_backward_src
as
weighting_bw_src
from
.utils.ffi
import
weighting_backward_src
as
weighting_bw_src
from
.utils.ffi
import
weighting_backward_weight
as
weighting_bw_weight
from
.utils.ffi
import
weighting_backward_weight
as
weighting_bw_weight
from
.utils.ffi
import
weighting_backward_basis
as
weighting_bw_basis
from
.utils.ffi
import
weighting_backward_basis
as
weighting_bw_basis
from
.utils.ffi
import
weighting_backward
as
weighting_bw
def
weighting_forward
(
src
,
weight
,
basis
,
weight_index
):
def
weighting_forward
(
src
,
weight
,
basis
,
weight_index
):
...
@@ -32,16 +31,6 @@ def weighting_backward_basis(grad_output, src, weight, weight_index):
...
@@ -32,16 +31,6 @@ def weighting_backward_basis(grad_output, src, weight, weight_index):
return
grad_basis
return
grad_basis
def
weighting_backward
(
grad_output
,
src
,
weight
,
basis
,
weight_index
):
grad_src
=
src
.
new
(
src
.
size
())
# grad_weight = weight.new(weight.size())
grad_weight
=
weight
.
new
(
weight
.
size
(
0
),
weight
.
size
(
2
),
weight
.
size
(
1
))
grad_basis
=
basis
.
new
(
basis
.
size
())
weighting_bw
(
grad_src
,
grad_weight
,
grad_basis
,
grad_output
,
src
,
weight
,
basis
,
weight_index
)
return
grad_src
,
grad_weight
.
transpose
(
1
,
2
),
grad_basis
class
SplineWeighting
(
Function
):
class
SplineWeighting
(
Function
):
def
__init__
(
self
,
weight_index
):
def
__init__
(
self
,
weight_index
):
super
(
SplineWeighting
,
self
).
__init__
()
super
(
SplineWeighting
,
self
).
__init__
()
...
@@ -55,20 +44,14 @@ class SplineWeighting(Function):
...
@@ -55,20 +44,14 @@ class SplineWeighting(Function):
grad_src
=
grad_weight
=
grad_basis
=
None
grad_src
=
grad_weight
=
grad_basis
=
None
src
,
weight
,
basis
=
self
.
saved_tensors
src
,
weight
,
basis
=
self
.
saved_tensors
needs_src
,
needs_weight
,
needs_basis
=
self
.
needs_input_grad
if
self
.
needs_input_grad
[
0
]:
if
needs_src
and
needs_weight
and
needs_basis
:
return
weighting_backward
(
grad_output
,
src
,
weight
,
basis
,
self
.
weight_index
)
if
needs_src
:
grad_src
=
weighting_backward_src
(
grad_output
,
weight
,
basis
,
grad_src
=
weighting_backward_src
(
grad_output
,
weight
,
basis
,
self
.
weight_index
)
self
.
weight_index
)
if
needs_
weight
:
if
self
.
needs_
input_grad
[
1
]
:
K
=
weight
.
size
(
0
)
K
=
weight
.
size
(
0
)
grad_weight
=
weighting_backward_weight
(
grad_output
,
src
,
basis
,
grad_weight
=
weighting_backward_weight
(
grad_output
,
src
,
basis
,
self
.
weight_index
,
K
)
self
.
weight_index
,
K
)
if
needs_
basis
:
if
self
.
needs_
input_grad
[
2
]
:
grad_basis
=
weighting_backward_basis
(
grad_output
,
src
,
weight
,
grad_basis
=
weighting_backward_basis
(
grad_output
,
src
,
weight
,
self
.
weight_index
)
self
.
weight_index
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment