Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
865ef24d
Commit
865ef24d
authored
Apr 11, 2018
by
rusty1s
Browse files
backward gpu complete
parent
60ab8eea
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
39 additions
and
5 deletions
+39
-5
aten/THC/THCWeighting.cu
aten/THC/THCWeighting.cu
+24
-2
aten/THC/generic/THCWeighting.cu
aten/THC/generic/THCWeighting.cu
+12
-0
test/test_weighting.py
test/test_weighting.py
+3
-3
No files found.
aten/THC/THCWeighting.cu
View file @
865ef24d
...
...
@@ -48,15 +48,37 @@ __global__ void weightingBackwardSrcKernel(TensorInfo<T> self, TensorInfo<T> gra
}
}
template
<
typename
T
>
__global__
void
weightingBackwardWeightKernel
(
TensorInfo
<
T
>
self
,
TensorInfo
<
T
>
gradOutput
,
TensorInfo
<
T
>
src
,
TensorInfo
<
T
>
basis
,
TensorInfo
<
int64_t
>
weightIndex
,
int
n
)
{
KERNEL_LOOP
(
i
,
n
)
{
ptrdiff_t
e
=
i
/
gradOutput
.
size
[
1
],
mOut
=
i
%
gradOutput
.
size
[
1
],
s
,
mIn
;
T
b
,
v
;
int64_t
wi
;
T
g
=
gradOutput
.
data
[
e
*
gradOutput
.
stride
[
0
]
+
mOut
*
gradOutput
.
stride
[
1
]];
for
(
s
=
0
;
s
<
weightIndex
.
size
[
1
];
s
++
)
{
b
=
basis
.
data
[
e
*
basis
.
stride
[
0
]
+
s
*
basis
.
stride
[
1
]];
wi
=
weightIndex
.
data
[
e
*
weightIndex
.
stride
[
0
]
+
s
*
weightIndex
.
stride
[
1
]];
for
(
mIn
=
0
;
mIn
<
src
.
size
[
1
];
mIn
++
)
{
v
=
src
.
data
[
e
*
src
.
stride
[
0
]
+
mIn
*
src
.
stride
[
1
]];
v
=
THCNumerics
<
T
>::
mul
(
v
,
b
);
v
=
THCNumerics
<
T
>::
mul
(
v
,
g
);
atomicAdd
(
&
self
.
data
[
wi
*
self
.
stride
[
0
]
+
mIn
*
self
.
stride
[
1
]
+
mOut
*
self
.
stride
[
2
]],
v
);
}
}
}
}
template
<
typename
T
>
__global__
void
weightingBackwardBasisKernel
(
TensorInfo
<
T
>
self
,
TensorInfo
<
T
>
gradOutput
,
TensorInfo
<
T
>
src
,
TensorInfo
<
T
>
weight
,
TensorInfo
<
int64_t
>
weightIndex
,
int
n
)
{
KERNEL_LOOP
(
i
,
n
)
{
ptrdiff_t
e
=
i
/
gradOutput
.
size
[
1
],
mOut
=
i
%
gradOutput
.
size
[
1
],
s
,
mIn
;
T
v
,
g
,
tmp
;
T
v
,
tmp
;
int64_t
wi
;
g
=
gradOutput
.
data
[
e
*
gradOutput
.
stride
[
0
]
+
mOut
*
gradOutput
.
stride
[
1
]];
T
g
=
gradOutput
.
data
[
e
*
gradOutput
.
stride
[
0
]
+
mOut
*
gradOutput
.
stride
[
1
]];
for
(
s
=
0
;
s
<
weightIndex
.
size
[
1
];
s
++
)
{
v
=
ScalarConvert
<
int
,
T
>::
to
(
0
);
wi
=
weightIndex
.
data
[
e
*
weightIndex
.
stride
[
0
]
+
s
*
weightIndex
.
stride
[
1
]];
...
...
aten/THC/generic/THCWeighting.cu
View file @
865ef24d
...
...
@@ -39,6 +39,18 @@ void THCTensor_(weightingBackwardSrc)(THCState *state, THCTensor *self, THCTenso
void
THCTensor_
(
weightingBackwardWeight
)(
THCState
*
state
,
THCTensor
*
self
,
THCTensor
*
gradOutput
,
THCTensor
*
src
,
THCTensor
*
basis
,
THCudaLongTensor
*
weightIndex
)
{
THCAssertSameGPU
(
THCTensor_
(
checkGPU
)(
state
,
5
,
self
,
gradOutput
,
src
,
basis
,
weightIndex
));
THCTensor_
(
fill
)(
state
,
self
,
ScalarConvert
<
int
,
real
>::
to
(
0
));
TensorInfo
<
real
>
selfInfo
=
THCTensor_
(
getTensorInfo
)(
state
,
self
);
TensorInfo
<
real
>
gradOutputInfo
=
THCTensor_
(
getTensorInfo
)(
state
,
gradOutput
);
TensorInfo
<
real
>
srcInfo
=
THCTensor_
(
getTensorInfo
)(
state
,
src
);
TensorInfo
<
real
>
basisInfo
=
THCTensor_
(
getTensorInfo
)(
state
,
basis
);
TensorInfo
<
int64_t
>
weightIndexInfo
=
THCudaLongTensor_getTensorInfo
(
state
,
weightIndex
);
KERNEL_REAL_RUN
(
weightingBackwardWeightKernel
,
THCTensor_
(
nElement
)(
state
,
gradOutput
),
selfInfo
,
gradOutputInfo
,
srcInfo
,
basisInfo
,
weightIndexInfo
);
}
void
THCTensor_
(
weightingBackwardBasis
)(
THCState
*
state
,
THCTensor
*
self
,
THCTensor
*
gradOutput
,
...
...
test/test_weighting.py
View file @
865ef24d
...
...
@@ -72,9 +72,9 @@ def test_spline_basis_backward_gpu():
pseudo
=
torch
.
cuda
.
DoubleTensor
(
4
,
2
).
uniform_
(
0
,
1
)
basis
,
weight_index
=
spline_basis
(
1
,
pseudo
,
kernel_size
,
is_open_spline
)
src
=
Variable
(
src
,
requires_grad
=
Fals
e
)
weight
=
Variable
(
weight
,
requires_grad
=
Fals
e
)
basis
=
Variable
(
basis
,
requires_grad
=
Tru
e
)
src
=
Variable
(
src
,
requires_grad
=
Tru
e
)
weight
=
Variable
(
weight
,
requires_grad
=
Tru
e
)
basis
=
Variable
(
basis
,
requires_grad
=
Fals
e
)
op
=
SplineWeighting
(
weight_index
)
assert
gradcheck
(
op
,
(
src
,
weight
,
basis
),
eps
=
1e-6
,
atol
=
1e-4
)
is
True
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment