Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
60ab8eea
"vscode:/vscode.git/clone" did not exist on "96ad7f0c80538e06323ab2a4d2d4fb513022f59a"
Commit
60ab8eea
authored
Apr 11, 2018
by
rusty1s
Browse files
backward gpu / weight
parent
9362b2d3
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
133 additions
and
2 deletions
+133
-2
aten/THC/THCAtomics.cuh
aten/THC/THCAtomics.cuh
+43
-0
aten/THC/THCWeighting.cu
aten/THC/THCWeighting.cu
+47
-1
aten/THC/generic/THCWeighting.cu
aten/THC/generic/THCWeighting.cu
+26
-1
test/test_weighting.py
test/test_weighting.py
+17
-0
No files found.
aten/THC/THCAtomics.cuh
0 → 100644
View file @
60ab8eea
#ifndef THC_ATOMICS_INC
#define THC_ATOMICS_INC
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
static
inline
__device__
void
atomicAdd
(
double
*
address
,
double
val
)
{
unsigned
long
long
int
*
address_as_ull
=
(
unsigned
long
long
int
*
)
address
;
unsigned
long
long
int
old
=
*
address_as_ull
;
unsigned
long
long
int
assumed
;
do
{
assumed
=
old
;
old
=
atomicCAS
(
address_as_ull
,
assumed
,
__double_as_longlong
(
val
+
__longlong_as_double
(
assumed
)));
}
while
(
assumed
!=
old
);
}
#elif !defined(__CUDA_ARCH__) && (CUDA_VERSION < 8000)
static
inline
__device__
void
atomicAdd
(
double
*
address
,
double
val
)
{}
#endif
#ifdef CUDA_HALF_TENSOR
static
inline
__device__
void
atomicAdd
(
half
*
address
,
half
val
)
{
unsigned
int
*
address_as_ui
=
(
unsigned
int
*
)
((
char
*
)
address
-
((
size_t
)
address
&
2
));
unsigned
int
old
=
*
address_as_ui
;
unsigned
int
assumed
;
do
{
assumed
=
old
;
#if CUDA_VERSION < 9000
half
hsum
;
hsum
.
x
=
(
size_t
)
address
&
2
?
(
old
>>
16
)
:
(
old
&
0xffff
);
hsum
=
THCNumerics
<
half
>::
add
(
hsum
,
val
);
#else // CUDA_VERSION < 9000
__half_raw
hsum
;
hsum
.
x
=
(
size_t
)
address
&
2
?
(
old
>>
16
)
:
(
old
&
0xffff
);
half
tmpres
=
THCNumerics
<
half
>::
add
(
hsum
,
val
);
hsum
=
__half_raw
(
tmpres
);
#endif // CUDA_VERSION
old
=
(
size_t
)
address
&
2
?
(
old
&
0xffff
)
|
(
hsum
.
x
<<
16
)
:
(
old
&
0xffff0000
)
|
hsum
.
x
;
old
=
atomicCAS
(
address_as_ui
,
assumed
,
old
);
}
while
(
assumed
!=
old
);
}
#endif // CUDA_HALF_TENSOR
#endif // THC_ATOMICS_INC
aten/THC/THCWeighting.cu
View file @
60ab8eea
...
...
@@ -2,6 +2,7 @@
#include "common.cuh"
#include "THCNumerics.cuh"
#include "THCAtomics.cuh"
template
<
typename
T
>
__global__
void
weightingForwardKernel
(
TensorInfo
<
T
>
self
,
TensorInfo
<
T
>
src
,
TensorInfo
<
T
>
weight
,
...
...
@@ -16,8 +17,8 @@ __global__ void weightingForwardKernel(TensorInfo<T> self, TensorInfo<T> src, Te
wi
=
weightIndex
.
data
[
e
*
weightIndex
.
stride
[
0
]
+
s
*
weightIndex
.
stride
[
1
]];
for
(
mIn
=
0
;
mIn
<
src
.
size
[
1
];
mIn
++
)
{
tmp
=
weight
.
data
[
wi
*
weight
.
stride
[
0
]
+
mIn
*
weight
.
stride
[
1
]
+
mOut
*
weight
.
stride
[
2
]];
tmp
=
THCNumerics
<
T
>::
mul
(
tmp
,
b
);
tmp
=
THCNumerics
<
T
>::
mul
(
tmp
,
src
.
data
[
e
*
src
.
stride
[
0
]
+
mIn
*
src
.
stride
[
1
]]);
tmp
=
THCNumerics
<
T
>::
mul
(
tmp
,
b
);
v
=
THCNumerics
<
T
>::
add
(
v
,
tmp
);
}
}
...
...
@@ -25,5 +26,50 @@ __global__ void weightingForwardKernel(TensorInfo<T> self, TensorInfo<T> src, Te
}
}
template
<
typename
T
>
__global__
void
weightingBackwardSrcKernel
(
TensorInfo
<
T
>
self
,
TensorInfo
<
T
>
gradOutput
,
TensorInfo
<
T
>
weight
,
TensorInfo
<
T
>
basis
,
TensorInfo
<
int64_t
>
weightIndex
,
int
n
)
{
KERNEL_LOOP
(
i
,
n
)
{
ptrdiff_t
e
=
i
/
self
.
size
[
1
],
mIn
=
i
%
self
.
size
[
1
],
s
,
mOut
;
T
v
=
ScalarConvert
<
int
,
T
>::
to
(
0
),
b
,
tmp
;
int64_t
wi
;
for
(
s
=
0
;
s
<
basis
.
size
[
1
];
s
++
)
{
b
=
basis
.
data
[
e
*
basis
.
stride
[
0
]
+
s
*
basis
.
stride
[
1
]];
wi
=
weightIndex
.
data
[
e
*
weightIndex
.
stride
[
0
]
+
s
*
weightIndex
.
stride
[
1
]];
for
(
mOut
=
0
;
mOut
<
gradOutput
.
size
[
1
];
mOut
++
)
{
tmp
=
weight
.
data
[
wi
*
weight
.
stride
[
0
]
+
mOut
*
weight
.
stride
[
1
]
+
mIn
*
weight
.
stride
[
2
]];
tmp
=
THCNumerics
<
T
>::
mul
(
tmp
,
gradOutput
.
data
[
e
*
gradOutput
.
stride
[
0
]
+
mOut
*
gradOutput
.
stride
[
1
]]);
tmp
=
THCNumerics
<
T
>::
mul
(
tmp
,
b
);
v
=
THCNumerics
<
T
>::
add
(
v
,
tmp
);
}
}
self
.
data
[
e
*
self
.
stride
[
0
]
+
mIn
*
self
.
stride
[
1
]]
=
v
;
}
}
template
<
typename
T
>
__global__
void
weightingBackwardBasisKernel
(
TensorInfo
<
T
>
self
,
TensorInfo
<
T
>
gradOutput
,
TensorInfo
<
T
>
src
,
TensorInfo
<
T
>
weight
,
TensorInfo
<
int64_t
>
weightIndex
,
int
n
)
{
KERNEL_LOOP
(
i
,
n
)
{
ptrdiff_t
e
=
i
/
gradOutput
.
size
[
1
],
mOut
=
i
%
gradOutput
.
size
[
1
],
s
,
mIn
;
T
v
,
g
,
tmp
;
int64_t
wi
;
g
=
gradOutput
.
data
[
e
*
gradOutput
.
stride
[
0
]
+
mOut
*
gradOutput
.
stride
[
1
]];
for
(
s
=
0
;
s
<
weightIndex
.
size
[
1
];
s
++
)
{
v
=
ScalarConvert
<
int
,
T
>::
to
(
0
);
wi
=
weightIndex
.
data
[
e
*
weightIndex
.
stride
[
0
]
+
s
*
weightIndex
.
stride
[
1
]];
for
(
mIn
=
0
;
mIn
<
src
.
size
[
1
];
mIn
++
)
{
tmp
=
weight
.
data
[
wi
*
weight
.
stride
[
0
]
+
mIn
*
weight
.
stride
[
1
]
+
mOut
*
weight
.
stride
[
2
]];
tmp
=
THCNumerics
<
T
>::
mul
(
tmp
,
src
.
data
[
e
*
src
.
stride
[
0
]
+
mIn
*
src
.
stride
[
1
]]);
tmp
=
THCNumerics
<
T
>::
mul
(
tmp
,
g
);
v
=
THCNumerics
<
T
>::
add
(
v
,
tmp
);
}
atomicAdd
(
&
self
.
data
[
e
*
self
.
stride
[
0
]
+
s
*
self
.
stride
[
1
]],
v
);
}
}
}
#include "generic/THCWeighting.cu"
#include "THC/THCGenerateFloatTypes.h"
aten/THC/generic/THCWeighting.cu
View file @
60ab8eea
...
...
@@ -20,6 +20,20 @@ void THCTensor_(weightingForward)(THCState *state, THCTensor *self, THCTensor *s
void
THCTensor_
(
weightingBackwardSrc
)(
THCState
*
state
,
THCTensor
*
self
,
THCTensor
*
gradOutput
,
THCTensor
*
weight
,
THCTensor
*
basis
,
THCudaLongTensor
*
weightIndex
)
{
THCAssertSameGPU
(
THCTensor_
(
checkGPU
)(
state
,
5
,
self
,
gradOutput
,
weight
,
basis
,
weightIndex
));
THCTensor
*
tweight
=
THCTensor_
(
new
)(
state
);
THCTensor_
(
transpose
)(
state
,
tweight
,
weight
,
1
,
2
);
weight
=
THCTensor_
(
newContiguous
)(
state
,
tweight
);
TensorInfo
<
real
>
selfInfo
=
THCTensor_
(
getTensorInfo
)(
state
,
self
);
TensorInfo
<
real
>
gradOutputInfo
=
THCTensor_
(
getTensorInfo
)(
state
,
gradOutput
);
TensorInfo
<
real
>
weightInfo
=
THCTensor_
(
getTensorInfo
)(
state
,
weight
);
TensorInfo
<
real
>
basisInfo
=
THCTensor_
(
getTensorInfo
)(
state
,
basis
);
TensorInfo
<
int64_t
>
weightIndexInfo
=
THCudaLongTensor_getTensorInfo
(
state
,
weightIndex
);
KERNEL_REAL_RUN
(
weightingBackwardSrcKernel
,
THCTensor_
(
nElement
)(
state
,
self
),
selfInfo
,
gradOutputInfo
,
weightInfo
,
basisInfo
,
weightIndexInfo
);
}
void
THCTensor_
(
weightingBackwardWeight
)(
THCState
*
state
,
THCTensor
*
self
,
THCTensor
*
gradOutput
,
...
...
@@ -30,7 +44,18 @@ void THCTensor_(weightingBackwardWeight)(THCState *state, THCTensor *self, THCTe
void
THCTensor_
(
weightingBackwardBasis
)(
THCState
*
state
,
THCTensor
*
self
,
THCTensor
*
gradOutput
,
THCTensor
*
src
,
THCTensor
*
weight
,
THCudaLongTensor
*
weightIndex
)
{
THCAssertSameGPU
(
THCTensor_
(
checkGPU
)(
state
,
5
,
self
,
gradOutput
,
src
,
weight
,
weightIndex
));
THCTensor_
(
fill
)(
state
,
self
,
ScalarConvert
<
int
,
real
>::
to
(
0
));
TensorInfo
<
real
>
selfInfo
=
THCTensor_
(
getTensorInfo
)(
state
,
self
);
TensorInfo
<
real
>
gradOutputInfo
=
THCTensor_
(
getTensorInfo
)(
state
,
gradOutput
);
TensorInfo
<
real
>
srcInfo
=
THCTensor_
(
getTensorInfo
)(
state
,
src
);
TensorInfo
<
real
>
weightInfo
=
THCTensor_
(
getTensorInfo
)(
state
,
weight
);
TensorInfo
<
int64_t
>
weightIndexInfo
=
THCudaLongTensor_getTensorInfo
(
state
,
weightIndex
);
KERNEL_REAL_RUN
(
weightingBackwardBasisKernel
,
THCTensor_
(
nElement
)(
state
,
gradOutput
),
selfInfo
,
gradOutputInfo
,
srcInfo
,
weightInfo
,
weightIndexInfo
);
}
#endif // THC_GENERIC_FILE
test/test_weighting.py
View file @
60ab8eea
...
...
@@ -61,3 +61,20 @@ def test_spline_basis_backward_cpu():
op
=
SplineWeighting
(
weight_index
)
assert
gradcheck
(
op
,
(
src
,
weight
,
basis
),
eps
=
1e-6
,
atol
=
1e-4
)
is
True
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'no CUDA'
)
def
test_spline_basis_backward_gpu
():
src
=
torch
.
cuda
.
DoubleTensor
(
4
,
2
).
uniform_
(
0
,
1
)
weight
=
torch
.
cuda
.
DoubleTensor
(
25
,
2
,
4
).
uniform_
(
0
,
1
)
kernel_size
=
torch
.
cuda
.
LongTensor
([
5
,
5
])
is_open_spline
=
torch
.
cuda
.
ByteTensor
([
1
,
1
])
pseudo
=
torch
.
cuda
.
DoubleTensor
(
4
,
2
).
uniform_
(
0
,
1
)
basis
,
weight_index
=
spline_basis
(
1
,
pseudo
,
kernel_size
,
is_open_spline
)
src
=
Variable
(
src
,
requires_grad
=
False
)
weight
=
Variable
(
weight
,
requires_grad
=
False
)
basis
=
Variable
(
basis
,
requires_grad
=
True
)
op
=
SplineWeighting
(
weight_index
)
assert
gradcheck
(
op
,
(
src
,
weight
,
basis
),
eps
=
1e-6
,
atol
=
1e-4
)
is
True
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment