Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Pytorch-Encoding
Commits
1235f3b0
Unverified
Commit
1235f3b0
authored
Aug 08, 2020
by
Hang Zhang
Committed by
GitHub
Aug 08, 2020
Browse files
Support PyTorch 1.6.0 (#309)
* doc * pre-compile * fix dispach
parent
f46bcf7f
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
145 additions
and
138 deletions
+145
-138
docs/source/notes/compile.rst
docs/source/notes/compile.rst
+5
-0
encoding/functions/customize.py
encoding/functions/customize.py
+6
-3
encoding/functions/dist_syncbn.py
encoding/functions/dist_syncbn.py
+11
-8
encoding/functions/encoding.py
encoding/functions/encoding.py
+12
-9
encoding/functions/rectify.py
encoding/functions/rectify.py
+7
-5
encoding/functions/syncbn.py
encoding/functions/syncbn.py
+17
-14
encoding/lib/__init__.py
encoding/lib/__init__.py
+0
-29
encoding/lib/cpu/rectify_cpu.cpp
encoding/lib/cpu/rectify_cpu.cpp
+1
-1
encoding/lib/cpu/roi_align_cpu.cpp
encoding/lib/cpu/roi_align_cpu.cpp
+1
-1
encoding/lib/cpu/setup.py
encoding/lib/cpu/setup.py
+0
-17
encoding/lib/gpu/activation_kernel.cu
encoding/lib/gpu/activation_kernel.cu
+15
-15
encoding/lib/gpu/encoding_kernel.cu
encoding/lib/gpu/encoding_kernel.cu
+4
-4
encoding/lib/gpu/rectify_cuda.cu
encoding/lib/gpu/rectify_cuda.cu
+1
-1
encoding/lib/gpu/roi_align_kernel.cu
encoding/lib/gpu/roi_align_kernel.cu
+2
-2
encoding/lib/gpu/setup.py
encoding/lib/gpu/setup.py
+0
-19
encoding/lib/gpu/syncbn_kernel.cu
encoding/lib/gpu/syncbn_kernel.cu
+7
-7
experiments/recognition/README.md
experiments/recognition/README.md
+1
-3
setup.py
setup.py
+55
-0
No files found.
docs/source/notes/compile.rst
View file @
1235f3b0
...
...
@@ -10,11 +10,16 @@ Installation
* PIP Install::
pip install torch-encoding --pre
# macOS
CC=clang CXX=clang++ pip install torch-encoding --pre
* Install from source::
git clone https://github.com/zhanghang1989/PyTorch-Encoding && cd PyTorch-Encoding
# ubuntu
python setup.py install
# macOS
CC=clang CXX=clang++ python setup.py install
Using Docker
...
...
encoding/functions/customize.py
View file @
1235f3b0
...
...
@@ -11,7 +11,10 @@
import
torch
from
torch.autograd
import
Variable
,
Function
from
..
import
lib
from
encoding
import
cpu
if
torch
.
cuda
.
device_count
()
>
0
:
from
encoding
import
gpu
__all__
=
[
'NonMaxSuppression'
]
...
...
@@ -49,6 +52,6 @@ def NonMaxSuppression(boxes, scores, threshold):
>>> surviving_box_indices = indices[mask]
"""
if
boxes
.
is_cuda
:
return
lib
.
gpu
.
non_max_suppression
(
boxes
,
scores
,
threshold
)
return
gpu
.
non_max_suppression
(
boxes
,
scores
,
threshold
)
else
:
return
lib
.
cpu
.
non_max_suppression
(
boxes
,
scores
,
threshold
)
return
cpu
.
non_max_suppression
(
boxes
,
scores
,
threshold
)
encoding/functions/dist_syncbn.py
View file @
1235f3b0
...
...
@@ -8,7 +8,10 @@
import
torch
from
torch.autograd.function
import
Function
from
..
import
lib
from
encoding
import
cpu
if
torch
.
cuda
.
device_count
()
>
0
:
from
encoding
import
gpu
__all__
=
[
'dist_syncbatchnorm'
]
...
...
@@ -25,9 +28,9 @@ class dist_syncbatchnorm_(Function):
_ex
,
_var
=
running_mean
.
contiguous
(),
running_var
.
contiguous
()
_exs
=
_var
+
_ex
**
2
if
x
.
is_cuda
:
y
=
lib
.
gpu
.
batchnorm_forward
(
x
,
_ex
,
_exs
,
gamma
,
beta
,
ctx
.
eps
)
y
=
gpu
.
batchnorm_forward
(
x
,
_ex
,
_exs
,
gamma
,
beta
,
ctx
.
eps
)
else
:
y
=
lib
.
cpu
.
batchnorm_forward
(
x
,
_ex
,
_exs
,
gamma
,
beta
,
ctx
.
eps
)
y
=
cpu
.
batchnorm_forward
(
x
,
_ex
,
_exs
,
gamma
,
beta
,
ctx
.
eps
)
ctx
.
save_for_backward
(
x
,
_ex
,
_exs
,
gamma
,
beta
)
return
y
...
...
@@ -36,7 +39,7 @@ class dist_syncbatchnorm_(Function):
raise
ValueError
(
'Expected more than 1 value per channel when training, got input size {}'
.
format
(
size
))
if
x
.
is_cuda
:
_ex
,
_exs
=
lib
.
gpu
.
expectation_forward
(
x
)
_ex
,
_exs
=
gpu
.
expectation_forward
(
x
)
else
:
raise
NotImplemented
...
...
@@ -62,9 +65,9 @@ class dist_syncbatchnorm_(Function):
# BN forward + activation
if
x
.
is_cuda
:
y
=
lib
.
gpu
.
batchnorm_forward
(
x
,
_ex
,
_exs
,
gamma
,
beta
,
ctx
.
eps
)
y
=
gpu
.
batchnorm_forward
(
x
,
_ex
,
_exs
,
gamma
,
beta
,
ctx
.
eps
)
else
:
y
=
lib
.
cpu
.
batchnorm_forward
(
x
,
_ex
,
_exs
,
gamma
,
beta
,
ctx
.
eps
)
y
=
cpu
.
batchnorm_forward
(
x
,
_ex
,
_exs
,
gamma
,
beta
,
ctx
.
eps
)
ctx
.
save_for_backward
(
x
,
_ex
,
_exs
,
gamma
,
beta
)
return
y
...
...
@@ -77,7 +80,7 @@ class dist_syncbatchnorm_(Function):
# BN backward
if
dz
.
is_cuda
:
dx
,
_dex
,
_dexs
,
dgamma
,
dbeta
=
\
lib
.
gpu
.
batchnorm_backward
(
dz
,
x
,
_ex
,
_exs
,
gamma
,
beta
,
ctx
.
eps
)
gpu
.
batchnorm_backward
(
dz
,
x
,
_ex
,
_exs
,
gamma
,
beta
,
ctx
.
eps
)
else
:
raise
NotImplemented
...
...
@@ -96,7 +99,7 @@ class dist_syncbatchnorm_(Function):
_dexs
=
_dexs
/
count
if
x
.
is_cuda
:
dx_
=
lib
.
gpu
.
expectation_backward
(
x
,
_dex
,
_dexs
)
dx_
=
gpu
.
expectation_backward
(
x
,
_dex
,
_dexs
)
else
:
raise
NotImplemented
dx
=
dx
+
dx_
...
...
encoding/functions/encoding.py
View file @
1235f3b0
...
...
@@ -11,7 +11,10 @@
import
torch
from
torch.autograd
import
Function
,
Variable
import
torch.nn.functional
as
F
from
..
import
lib
from
encoding
import
cpu
if
torch
.
cuda
.
device_count
()
>
0
:
from
encoding
import
gpu
__all__
=
[
'aggregate'
,
'scaled_l2'
,
'pairwise_cosine'
]
...
...
@@ -21,18 +24,18 @@ class _aggregate(Function):
# A \in(BxNxK) R \in(BxNxKxD) => E \in(BxNxD)
ctx
.
save_for_backward
(
A
,
X
,
C
)
if
A
.
is_cuda
:
E
=
lib
.
gpu
.
aggregate_forward
(
A
,
X
,
C
)
E
=
gpu
.
aggregate_forward
(
A
,
X
,
C
)
else
:
E
=
lib
.
cpu
.
aggregate_forward
(
A
,
X
,
C
)
E
=
cpu
.
aggregate_forward
(
A
,
X
,
C
)
return
E
@
staticmethod
def
backward
(
ctx
,
gradE
):
A
,
X
,
C
=
ctx
.
saved_variables
if
A
.
is_cuda
:
gradA
,
gradX
,
gradC
=
lib
.
gpu
.
aggregate_backward
(
gradE
,
A
,
X
,
C
)
gradA
,
gradX
,
gradC
=
gpu
.
aggregate_backward
(
gradE
,
A
,
X
,
C
)
else
:
gradA
,
gradX
,
gradC
=
lib
.
cpu
.
aggregate_backward
(
gradE
,
A
,
X
,
C
)
gradA
,
gradX
,
gradC
=
cpu
.
aggregate_backward
(
gradE
,
A
,
X
,
C
)
return
gradA
,
gradX
,
gradC
def
aggregate
(
A
,
X
,
C
):
...
...
@@ -64,9 +67,9 @@ class _scaled_l2(Function):
@
staticmethod
def
forward
(
ctx
,
X
,
C
,
S
):
if
X
.
is_cuda
:
SL
=
lib
.
gpu
.
scaled_l2_forward
(
X
,
C
,
S
)
SL
=
gpu
.
scaled_l2_forward
(
X
,
C
,
S
)
else
:
SL
=
lib
.
cpu
.
scaled_l2_forward
(
X
,
C
,
S
)
SL
=
cpu
.
scaled_l2_forward
(
X
,
C
,
S
)
ctx
.
save_for_backward
(
X
,
C
,
S
,
SL
)
return
SL
...
...
@@ -74,9 +77,9 @@ class _scaled_l2(Function):
def
backward
(
ctx
,
gradSL
):
X
,
C
,
S
,
SL
=
ctx
.
saved_variables
if
X
.
is_cuda
:
gradX
,
gradC
,
gradS
=
lib
.
gpu
.
scaled_l2_backward
(
gradSL
,
X
,
C
,
S
,
SL
)
gradX
,
gradC
,
gradS
=
gpu
.
scaled_l2_backward
(
gradSL
,
X
,
C
,
S
,
SL
)
else
:
gradX
,
gradC
,
gradS
=
lib
.
cpu
.
scaled_l2_backward
(
gradSL
,
X
,
C
,
S
,
SL
)
gradX
,
gradC
,
gradS
=
cpu
.
scaled_l2_backward
(
gradSL
,
X
,
C
,
S
,
SL
)
return
gradX
,
gradC
,
gradS
def
scaled_l2
(
X
,
C
,
S
):
...
...
encoding/functions/rectify.py
View file @
1235f3b0
...
...
@@ -10,7 +10,9 @@
import
torch
from
torch.autograd
import
Function
from
..
import
lib
from
encoding
import
cpu
if
torch
.
cuda
.
device_count
()
>
0
:
from
encoding
import
gpu
__all__
=
[
'rectify'
]
...
...
@@ -26,9 +28,9 @@ class _rectify(Function):
ctx
.
dilation
=
dilation
ctx
.
average
=
average
if
x
.
is_cuda
:
lib
.
gpu
.
conv_rectify
(
y
,
x
,
kernel_size
,
stride
,
padding
,
dilation
,
average
)
gpu
.
conv_rectify
(
y
,
x
,
kernel_size
,
stride
,
padding
,
dilation
,
average
)
else
:
lib
.
cpu
.
conv_rectify
(
y
,
x
,
kernel_size
,
stride
,
padding
,
dilation
,
average
)
cpu
.
conv_rectify
(
y
,
x
,
kernel_size
,
stride
,
padding
,
dilation
,
average
)
ctx
.
mark_dirty
(
y
)
return
y
...
...
@@ -36,10 +38,10 @@ class _rectify(Function):
def
backward
(
ctx
,
grad_y
):
x
,
=
ctx
.
saved_variables
if
x
.
is_cuda
:
lib
.
gpu
.
conv_rectify
(
grad_y
,
x
,
ctx
.
kernel_size
,
ctx
.
stride
,
gpu
.
conv_rectify
(
grad_y
,
x
,
ctx
.
kernel_size
,
ctx
.
stride
,
ctx
.
padding
,
ctx
.
dilation
,
ctx
.
average
)
else
:
lib
.
cpu
.
conv_rectify
(
grad_y
,
x
,
ctx
.
kernel_size
,
ctx
.
stride
,
cpu
.
conv_rectify
(
grad_y
,
x
,
ctx
.
kernel_size
,
ctx
.
stride
,
ctx
.
padding
,
ctx
.
dilation
,
ctx
.
average
)
ctx
.
mark_dirty
(
grad_y
)
return
grad_y
,
None
,
None
,
None
,
None
,
None
,
None
...
...
encoding/functions/syncbn.py
View file @
1235f3b0
...
...
@@ -12,7 +12,10 @@ import torch
import
torch.cuda.comm
as
comm
from
torch.autograd
import
Function
from
torch.autograd.function
import
once_differentiable
from
..
import
lib
from
encoding
import
cpu
if
torch
.
cuda
.
device_count
()
>
0
:
from
encoding
import
gpu
__all__
=
[
'moments'
,
'syncbatchnorm'
,
'inp_syncbatchnorm'
]
...
...
@@ -20,7 +23,7 @@ class moments_(Function):
@
staticmethod
def
forward
(
ctx
,
x
):
if
x
.
is_cuda
:
ex
,
ex2
=
lib
.
gpu
.
expectation_forward
(
x
)
ex
,
ex2
=
gpu
.
expectation_forward
(
x
)
else
:
raise
NotImplemented
ctx
.
save_for_backward
(
x
)
...
...
@@ -30,7 +33,7 @@ class moments_(Function):
def
backward
(
ctx
,
dex
,
dex2
):
x
,
=
ctx
.
saved_tensors
if
dex
.
is_cuda
:
dx
=
lib
.
gpu
.
expectation_backward
(
x
,
dex
,
dex2
)
dx
=
gpu
.
expectation_backward
(
x
,
dex
,
dex2
)
else
:
raise
NotImplemented
return
dx
...
...
@@ -57,7 +60,7 @@ class syncbatchnorm_(Function):
if
ctx
.
training
:
if
x
.
is_cuda
:
_ex
,
_exs
=
lib
.
gpu
.
expectation_forward
(
x
)
_ex
,
_exs
=
gpu
.
expectation_forward
(
x
)
else
:
raise
NotImplemented
...
...
@@ -94,9 +97,9 @@ class syncbatchnorm_(Function):
# BN forward + activation
if
x
.
is_cuda
:
y
=
lib
.
gpu
.
batchnorm_forward
(
x
,
_ex
,
_exs
,
gamma
,
beta
,
ctx
.
eps
)
y
=
gpu
.
batchnorm_forward
(
x
,
_ex
,
_exs
,
gamma
,
beta
,
ctx
.
eps
)
else
:
y
=
lib
.
cpu
.
batchnorm_forward
(
x
,
_ex
,
_exs
,
gamma
,
beta
,
ctx
.
eps
)
y
=
cpu
.
batchnorm_forward
(
x
,
_ex
,
_exs
,
gamma
,
beta
,
ctx
.
eps
)
# Output
ctx
.
save_for_backward
(
x
,
_ex
,
_exs
,
gamma
,
beta
)
...
...
@@ -111,7 +114,7 @@ class syncbatchnorm_(Function):
# BN backward
if
dz
.
is_cuda
:
dx
,
_dex
,
_dexs
,
dgamma
,
dbeta
=
\
lib
.
gpu
.
batchnorm_backward
(
dz
,
x
,
_ex
,
_exs
,
gamma
,
beta
,
ctx
.
eps
)
gpu
.
batchnorm_backward
(
dz
,
x
,
_ex
,
_exs
,
gamma
,
beta
,
ctx
.
eps
)
else
:
raise
NotImplemented
...
...
@@ -137,7 +140,7 @@ class syncbatchnorm_(Function):
ctx
.
worker_queue
.
task_done
()
if
x
.
is_cuda
:
dx_
=
lib
.
gpu
.
expectation_backward
(
x
,
_dex
,
_dexs
)
dx_
=
gpu
.
expectation_backward
(
x
,
_dex
,
_dexs
)
else
:
raise
NotImplemented
dx
=
dx
+
dx_
...
...
@@ -158,7 +161,7 @@ class syncbatchnorm_(Function):
def
_act_forward
(
ctx
,
x
):
if
ctx
.
activation
.
lower
()
==
"leaky_relu"
:
if
x
.
is_cuda
:
lib
.
gpu
.
leaky_relu_forward
(
x
,
ctx
.
slope
)
gpu
.
leaky_relu_forward
(
x
,
ctx
.
slope
)
else
:
raise
NotImplemented
else
:
...
...
@@ -167,7 +170,7 @@ def _act_forward(ctx, x):
def
_act_backward
(
ctx
,
x
,
dx
):
if
ctx
.
activation
.
lower
()
==
"leaky_relu"
:
if
x
.
is_cuda
:
lib
.
gpu
.
leaky_relu_backward
(
x
,
dx
,
ctx
.
slope
)
gpu
.
leaky_relu_backward
(
x
,
dx
,
ctx
.
slope
)
else
:
raise
NotImplemented
else
:
...
...
@@ -194,7 +197,7 @@ class inp_syncbatchnorm_(Function):
if
ctx
.
training
:
if
x
.
is_cuda
:
_ex
,
_exs
=
lib
.
gpu
.
expectation_forward
(
x
)
_ex
,
_exs
=
gpu
.
expectation_forward
(
x
)
else
:
raise
NotImplemented
...
...
@@ -232,7 +235,7 @@ class inp_syncbatchnorm_(Function):
# BN forward + activation
if
x
.
is_cuda
:
lib
.
gpu
.
batchnorm_inp_forward
(
x
,
_ex
,
_exs
,
gamma
,
beta
,
ctx
.
eps
)
gpu
.
batchnorm_inp_forward
(
x
,
_ex
,
_exs
,
gamma
,
beta
,
ctx
.
eps
)
else
:
raise
NotImplemented
...
...
@@ -254,7 +257,7 @@ class inp_syncbatchnorm_(Function):
# BN backward
if
dz
.
is_cuda
:
dx
,
_dex
,
_dexs
,
dgamma
,
dbeta
=
\
lib
.
gpu
.
batchnorm_inp_backward
(
dz
,
z
,
_ex
,
_exs
,
gamma
,
beta
,
ctx
.
eps
)
gpu
.
batchnorm_inp_backward
(
dz
,
z
,
_ex
,
_exs
,
gamma
,
beta
,
ctx
.
eps
)
else
:
raise
NotImplemented
...
...
@@ -280,7 +283,7 @@ class inp_syncbatchnorm_(Function):
ctx
.
worker_queue
.
task_done
()
if
z
.
is_cuda
:
lib
.
gpu
.
expectation_inp_backward
(
dx
,
z
,
_dex
,
_dexs
,
_ex
,
_exs
,
gamma
,
beta
,
ctx
.
eps
)
gpu
.
expectation_inp_backward
(
dx
,
z
,
_dex
,
_dexs
,
_ex
,
_exs
,
gamma
,
beta
,
ctx
.
eps
)
else
:
raise
NotImplemented
...
...
encoding/lib/__init__.py
deleted
100644 → 0
View file @
f46bcf7f
import
os
import
torch
from
torch.utils.cpp_extension
import
load
cwd
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
cpu_path
=
os
.
path
.
join
(
cwd
,
'cpu'
)
gpu_path
=
os
.
path
.
join
(
cwd
,
'gpu'
)
cpu
=
load
(
'enclib_cpu'
,
[
os
.
path
.
join
(
cpu_path
,
'operator.cpp'
),
os
.
path
.
join
(
cpu_path
,
'encoding_cpu.cpp'
),
os
.
path
.
join
(
cpu_path
,
'syncbn_cpu.cpp'
),
os
.
path
.
join
(
cpu_path
,
'roi_align_cpu.cpp'
),
os
.
path
.
join
(
cpu_path
,
'nms_cpu.cpp'
),
os
.
path
.
join
(
cpu_path
,
'rectify_cpu.cpp'
),
],
build_directory
=
cpu_path
,
verbose
=
False
)
if
torch
.
cuda
.
is_available
():
gpu
=
load
(
'enclib_gpu'
,
[
os
.
path
.
join
(
gpu_path
,
'operator.cpp'
),
os
.
path
.
join
(
gpu_path
,
'activation_kernel.cu'
),
os
.
path
.
join
(
gpu_path
,
'encoding_kernel.cu'
),
os
.
path
.
join
(
gpu_path
,
'syncbn_kernel.cu'
),
os
.
path
.
join
(
gpu_path
,
'roi_align_kernel.cu'
),
os
.
path
.
join
(
gpu_path
,
'nms_kernel.cu'
),
os
.
path
.
join
(
gpu_path
,
'rectify_cuda.cu'
),
os
.
path
.
join
(
gpu_path
,
'lib_ssd.cu'
),
],
extra_cuda_cflags
=
[
"--expt-extended-lambda"
],
build_directory
=
gpu_path
,
verbose
=
False
)
encoding/lib/cpu/rectify_cpu.cpp
View file @
1235f3b0
...
...
@@ -194,7 +194,7 @@ void conv_rectify_cpu_tempalte(
at
::
Tensor
input
=
input_
.
contiguous
();
AT_DISPATCH_FLOATING_TYPES
(
input
.
type
(),
"conv_rectify_cuda_frame"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_
type
(),
"conv_rectify_cuda_frame"
,
([
&
]
{
scalar_t
*
output_data
=
output
.
data_ptr
<
scalar_t
>
();
conv_rectify_cpu_frame
<
scalar_t
>
(
output_data
,
...
...
encoding/lib/cpu/roi_align_cpu.cpp
View file @
1235f3b0
...
...
@@ -410,7 +410,7 @@ at::Tensor ROIAlign_Forward_CPU(
AT_ASSERT
(
input
.
is_contiguous
());
AT_ASSERT
(
bottom_rois
.
is_contiguous
());
AT_DISPATCH_FLOATING_TYPES
(
input
.
type
(),
"ROIAlign_Forward_CPU"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_
type
(),
"ROIAlign_Forward_CPU"
,
([
&
]
{
ROIAlignForwardCompute
<
scalar_t
>
(
output
.
numel
(),
input
.
data
<
scalar_t
>
(),
...
...
encoding/lib/cpu/setup.py
deleted
100644 → 0
View file @
f46bcf7f
from
setuptools
import
setup
from
torch.utils.cpp_extension
import
BuildExtension
,
CppExtension
setup
(
name
=
'enclib_cpu'
,
ext_modules
=
[
CppExtension
(
'enclib_cpu'
,
[
'operator.cpp'
,
'roi_align_cpu.cpp'
,
'encoding_cpu.cpp'
,
'syncbn_cpu.cpp'
,
'nms_cpu.cpp'
,
]),
],
cmdclass
=
{
'build_ext'
:
BuildExtension
})
encoding/lib/gpu/activation_kernel.cu
View file @
1235f3b0
...
...
@@ -10,19 +10,19 @@
namespace
{
template
<
typename
T
>
inline
void
leaky_relu_backward_impl
(
T
*
z
,
T
*
dz
,
float
slope
,
int64_t
count
)
{
// Create thrust pointers
thrust
::
device_ptr
<
T
>
th_z
=
thrust
::
device_pointer_cast
(
z
);
thrust
::
device_ptr
<
T
>
th_dz
=
thrust
::
device_pointer_cast
(
dz
);
thrust
::
transform_if
(
th_dz
,
th_dz
+
count
,
th_z
,
th_dz
,
[
slope
]
__device__
(
const
T
&
dz
)
{
return
dz
*
slope
;
},
[]
__device__
(
const
T
&
z
)
{
return
z
<
0
;
});
thrust
::
transform_if
(
th_z
,
th_z
+
count
,
th_z
,
[
slope
]
__device__
(
const
T
&
z
)
{
return
z
/
slope
;
},
[]
__device__
(
const
T
&
z
)
{
return
z
<
0
;
});
}
//
template<typename T>
//
inline void leaky_relu_backward_impl(T *z, T *dz, float slope, int64_t count) {
//
// Create thrust pointers
//
thrust::device_ptr<T> th_z = thrust::device_pointer_cast(z);
//
thrust::device_ptr<T> th_dz = thrust::device_pointer_cast(dz);
//
//
thrust::transform_if(th_dz, th_dz + count, th_z, th_dz,
//
[slope] __device__ (const T& dz) { return dz * slope; },
//
[] __device__ (const T& z) { return z < 0; });
//
thrust::transform_if(th_z, th_z + count, th_z,
//
[slope] __device__ (const T& z) { return z / slope; },
//
[] __device__ (const T& z) { return z < 0; });
//
}
}
...
...
@@ -33,12 +33,12 @@ void LeakyRelu_Forward_CUDA(at::Tensor z, float slope) {
void
LeakyRelu_Backward_CUDA
(
at
::
Tensor
z
,
at
::
Tensor
dz
,
float
slope
)
{
int64_t
count
=
z
.
numel
();
/*
AT_DISPATCH_FLOATING_TYPES(z.type(), "LeakyRelu_Backward_CUDA", ([&] {
leaky_relu_backward_impl<scalar_t>(z.data<scalar_t>(), dz.data<scalar_t>(), slope, count);
}));
/
*
*
/
// unstable after scaling
at
::
leaky_relu_
(
z
,
1.0
/
slope
);
at
::
leaky_relu_backward
(
dz
,
z
,
slope
);
*/
}
encoding/lib/gpu/encoding_kernel.cu
View file @
1235f3b0
...
...
@@ -172,7 +172,7 @@ at::Tensor Aggregate_Forward_CUDA(
dim3
blocks
(
C_
.
size
(
1
),
C_
.
size
(
0
),
X_
.
size
(
0
));
dim3
threads
(
getNumThreads
(
X_
.
size
(
1
)));
AT_DISPATCH_FLOATING_TYPES
(
A_
.
type
(),
"Aggregate_Forward_CUDA"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
A_
.
scalar_
type
(),
"Aggregate_Forward_CUDA"
,
([
&
]
{
DeviceTensor
<
scalar_t
,
3
>
E
=
devicetensor
<
scalar_t
,
3
>
(
E_
);
DeviceTensor
<
scalar_t
,
3
>
A
=
devicetensor
<
scalar_t
,
3
>
(
A_
);
DeviceTensor
<
scalar_t
,
3
>
X
=
devicetensor
<
scalar_t
,
3
>
(
X_
);
...
...
@@ -197,7 +197,7 @@ std::vector<at::Tensor> Aggregate_Backward_CUDA(
// B, K, D
dim3
blocks
(
C_
.
size
(
0
),
X_
.
size
(
1
),
X_
.
size
(
0
));
dim3
threads
(
getNumThreads
(
C_
.
size
(
1
)));
AT_DISPATCH_FLOATING_TYPES
(
A_
.
type
(),
"Aggregate_Backward_CUDA"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
A_
.
scalar_
type
(),
"Aggregate_Backward_CUDA"
,
([
&
]
{
/* Device tensors */
DeviceTensor
<
scalar_t
,
3
>
GA
=
devicetensor
<
scalar_t
,
3
>
(
gradA_
);
DeviceTensor
<
scalar_t
,
3
>
GE
=
devicetensor
<
scalar_t
,
3
>
(
GE_
);
...
...
@@ -220,7 +220,7 @@ at::Tensor ScaledL2_Forward_CUDA(
dim3
blocks
(
C_
.
size
(
0
),
X_
.
size
(
1
),
X_
.
size
(
0
));
dim3
threads
(
getNumThreads
(
C_
.
size
(
1
)));
AT_DISPATCH_FLOATING_TYPES
(
X_
.
type
(),
"ScaledL2_Forward_CUDA"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
X_
.
scalar_
type
(),
"ScaledL2_Forward_CUDA"
,
([
&
]
{
/* Device tensors */
DeviceTensor
<
scalar_t
,
3
>
SL
=
devicetensor
<
scalar_t
,
3
>
(
SL_
);
DeviceTensor
<
scalar_t
,
3
>
X
=
devicetensor
<
scalar_t
,
3
>
(
X_
);
...
...
@@ -249,7 +249,7 @@ std::vector<at::Tensor> ScaledL2_Backward_CUDA(
dim3
blocks2
(
C_
.
size
(
1
),
C_
.
size
(
0
));
dim3
threads2
(
getNumThreads
(
X_
.
size
(
1
)));
auto
GS_
=
(
GSL_
*
(
SL_
/
S_
.
view
({
1
,
1
,
C_
.
size
(
0
)}))).
sum
(
0
).
sum
(
0
);
AT_DISPATCH_FLOATING_TYPES
(
X_
.
type
(),
"ScaledL2_Backward_CUDA"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
X_
.
scalar_
type
(),
"ScaledL2_Backward_CUDA"
,
([
&
]
{
/* Device tensors */
DeviceTensor
<
scalar_t
,
3
>
GSL
=
devicetensor
<
scalar_t
,
3
>
(
GSL_
);
DeviceTensor
<
scalar_t
,
3
>
GX
=
devicetensor
<
scalar_t
,
3
>
(
GX_
);
...
...
encoding/lib/gpu/rectify_cuda.cu
View file @
1235f3b0
...
...
@@ -179,7 +179,7 @@ void conv_rectify_cuda_tempalte(
const
uint32_t
num_threads
=
std
::
min
(
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxThreadsPerBlock
,
1024
);
const
uint32_t
num_blocks
=
at
::
cuda
::
ATenCeilDiv
<
uint32_t
>
(
count
,
num_threads
);
AT_DISPATCH_FLOATING_TYPES
(
input
.
type
(),
"conv_rectify_cuda_frame"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_
type
(),
"conv_rectify_cuda_frame"
,
([
&
]
{
//using accscalar_t = acc_type<scalar_t, true>;
scalar_t
*
output_data
=
output
.
data_ptr
<
scalar_t
>
();
conv_rectify_cuda_frame
<
scalar_t
,
scalar_t
>
...
...
encoding/lib/gpu/roi_align_kernel.cu
View file @
1235f3b0
...
...
@@ -372,7 +372,7 @@ at::Tensor ROIAlign_Forward_CUDA(
auto
count
=
output
.
numel
();
AT_DISPATCH_FLOATING_TYPES
(
input
.
type
(),
"ROIAlign_Forward_CUDA"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_
type
(),
"ROIAlign_Forward_CUDA"
,
([
&
]
{
RoIAlignForwardKernel
<
scalar_t
>
<<<
ROI_GET_BLOCKS
(
count
),
ROI_CUDA_NUM_THREADS
,
...
...
@@ -419,7 +419,7 @@ at::Tensor ROIAlign_Backward_CUDA(
auto
num_rois
=
rois
.
size
(
0
);
auto
count
=
grad_output
.
numel
();
AT_DISPATCH_FLOATING_TYPES
(
rois
.
type
(),
"ROIAlign_Backward_CUDA"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
rois
.
scalar_
type
(),
"ROIAlign_Backward_CUDA"
,
([
&
]
{
RoIAlignBackwardKernel
<
scalar_t
>
<<<
ROI_GET_BLOCKS
(
count
),
ROI_CUDA_NUM_THREADS
,
...
...
encoding/lib/gpu/setup.py
deleted
100644 → 0
View file @
f46bcf7f
from
setuptools
import
setup
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
setup
(
name
=
'enclib_gpu'
,
ext_modules
=
[
CUDAExtension
(
'enclib_gpu'
,
[
'operator.cpp'
,
'activation_kernel.cu'
,
'encoding_kernel.cu'
,
'syncbn_kernel.cu'
,
'roi_align_kernel.cu'
,
'nms_kernel.cu'
,
'rectify.cu'
,
]),
],
cmdclass
=
{
'build_ext'
:
BuildExtension
})
encoding/lib/gpu/syncbn_kernel.cu
View file @
1235f3b0
...
...
@@ -274,7 +274,7 @@ at::Tensor BatchNorm_Forward_CUDA(
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
dim3
blocks
(
input_
.
size
(
1
));
dim3
threads
(
getNumThreads
(
input_
.
size
(
2
)));
AT_DISPATCH_FLOATING_TYPES
(
input_
.
type
(),
"BatchNorm_Forward_CUDA"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
input_
.
scalar_
type
(),
"BatchNorm_Forward_CUDA"
,
([
&
]
{
/* Device tensors */
DeviceTensor
<
scalar_t
,
3
>
output
=
devicetensor
<
scalar_t
,
3
>
(
output_
);
DeviceTensor
<
scalar_t
,
3
>
input
=
devicetensor
<
scalar_t
,
3
>
(
input_
);
...
...
@@ -301,7 +301,7 @@ at::Tensor BatchNorm_Forward_Inp_CUDA(
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
dim3
blocks
(
input_
.
size
(
1
));
dim3
threads
(
getNumThreads
(
input_
.
size
(
2
)));
AT_DISPATCH_FLOATING_TYPES
(
input_
.
type
(),
"BatchNorm_Forward_CUDA"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
input_
.
scalar_
type
(),
"BatchNorm_Forward_CUDA"
,
([
&
]
{
/* Device tensors */
DeviceTensor
<
scalar_t
,
3
>
input
=
devicetensor
<
scalar_t
,
3
>
(
input_
);
DeviceTensor
<
scalar_t
,
1
>
ex
=
devicetensor
<
scalar_t
,
1
>
(
ex_
);
...
...
@@ -336,7 +336,7 @@ std::vector<at::Tensor> BatchNorm_Inp_Backward_CUDA(
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
dim3
blocks
(
output_
.
size
(
1
));
dim3
threads
(
getNumThreads
(
output_
.
size
(
2
)));
AT_DISPATCH_FLOATING_TYPES
(
output_
.
type
(),
"BatchNorm_Inp_Backward_CUDA"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
output_
.
scalar_
type
(),
"BatchNorm_Inp_Backward_CUDA"
,
([
&
]
{
/* Device tensors */
DeviceTensor
<
scalar_t
,
3
>
gradoutput
=
devicetensor
<
scalar_t
,
3
>
(
gradoutput_
);
DeviceTensor
<
scalar_t
,
3
>
output
=
devicetensor
<
scalar_t
,
3
>
(
output_
);
...
...
@@ -379,7 +379,7 @@ std::vector<at::Tensor> BatchNorm_Backward_CUDA(
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
dim3
blocks
(
input_
.
size
(
1
));
dim3
threads
(
getNumThreads
(
input_
.
size
(
2
)));
AT_DISPATCH_FLOATING_TYPES
(
input_
.
type
(),
"BatchNorm_Inp_Backward_CUDA"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
input_
.
scalar_
type
(),
"BatchNorm_Inp_Backward_CUDA"
,
([
&
]
{
/* Device tensors */
DeviceTensor
<
scalar_t
,
3
>
gradoutput
=
devicetensor
<
scalar_t
,
3
>
(
gradoutput_
);
DeviceTensor
<
scalar_t
,
3
>
input
=
devicetensor
<
scalar_t
,
3
>
(
input_
);
...
...
@@ -411,7 +411,7 @@ std::vector<at::Tensor> Expectation_Forward_CUDA(
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
dim3
blocks
(
input_
.
size
(
1
));
dim3
threads
(
getNumThreads
(
input_
.
size
(
2
)));
AT_DISPATCH_FLOATING_TYPES
(
input_
.
type
(),
"SumSquare_forward_CUDA"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
input_
.
scalar_
type
(),
"SumSquare_forward_CUDA"
,
([
&
]
{
scalar_t
norm
=
scalar_t
(
1
)
/
(
input_
.
size
(
0
)
*
input_
.
size
(
2
));
/* Device tensors */
DeviceTensor
<
scalar_t
,
3
>
input
=
devicetensor
<
scalar_t
,
3
>
(
input_
);
...
...
@@ -435,7 +435,7 @@ at::Tensor Expectation_Backward_CUDA(
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
dim3
blocks
(
input_
.
size
(
1
));
dim3
threads
(
getNumThreads
(
input_
.
size
(
2
)));
AT_DISPATCH_FLOATING_TYPES
(
input_
.
type
(),
"SumSquare_Backward_CUDA"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
input_
.
scalar_
type
(),
"SumSquare_Backward_CUDA"
,
([
&
]
{
scalar_t
norm
=
scalar_t
(
1
)
/
(
input_
.
size
(
0
)
*
input_
.
size
(
2
));
/* Device tensors */
DeviceTensor
<
scalar_t
,
3
>
gradInput
=
devicetensor
<
scalar_t
,
3
>
(
gradInput_
);
...
...
@@ -467,7 +467,7 @@ at::Tensor Expectation_Inp_Backward_CUDA(
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
dim3
blocks
(
output_
.
size
(
1
));
dim3
threads
(
getNumThreads
(
output_
.
size
(
2
)));
AT_DISPATCH_FLOATING_TYPES
(
output_
.
type
(),
"SumSquare_Backward_CUDA"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
output_
.
scalar_
type
(),
"SumSquare_Backward_CUDA"
,
([
&
]
{
scalar_t
norm
=
scalar_t
(
1
)
/
(
output_
.
size
(
0
)
*
output_
.
size
(
2
));
/* Device tensors */
DeviceTensor
<
scalar_t
,
3
>
gradInput
=
devicetensor
<
scalar_t
,
3
>
(
gradInput_
);
...
...
experiments/recognition/README.md
View file @
1235f3b0
-
[
Link to the EncNet CIFAR experiments and pre-trained models
](
http://hangzh.com/PyTorch-Encoding/experiments/cifar.html
)
-
[
Link to the Deep TEN experiments and pre-trained models
](
http://hangzh.com/PyTorch-Encoding/experiments/texture.html
)
-
[
Link to Docs
](
https://hangzhang.org/PyTorch-Encoding/model_zoo/imagenet.html
)
setup.py
View file @
1235f3b0
...
...
@@ -10,10 +10,14 @@
import
io
import
os
import
glob
import
subprocess
from
setuptools
import
setup
,
find_packages
import
torch
from
torch.utils.cpp_extension
import
CUDA_HOME
,
CppExtension
,
CUDAExtension
cwd
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
version
=
'1.2.2'
...
...
@@ -46,6 +50,55 @@ requirements = [
'requests'
,
]
def
get_extensions
():
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
cpu_extensions_dir
=
os
.
path
.
join
(
this_dir
,
"encoding"
,
"lib"
,
"cpu"
)
gpu_extensions_dir
=
os
.
path
.
join
(
this_dir
,
"encoding"
,
"lib"
,
"gpu"
)
source_cpu
=
glob
.
glob
(
os
.
path
.
join
(
cpu_extensions_dir
,
"*.cpp"
))
source_cuda
=
glob
.
glob
(
os
.
path
.
join
(
gpu_extensions_dir
,
"*.cpp"
))
+
\
glob
.
glob
(
os
.
path
.
join
(
gpu_extensions_dir
,
"*.cu"
))
print
(
'c++: '
,
source_cpu
)
print
(
'cuda: '
,
source_cuda
)
sources
=
source_cpu
extra_compile_args
=
{
"cxx"
:
[]}
include_dirs
=
[
cpu_extensions_dir
]
ext_modules
=
[
CppExtension
(
"encoding.cpu"
,
source_cpu
,
include_dirs
=
include_dirs
,
extra_compile_args
=
extra_compile_args
,
)
]
if
CUDA_HOME
is
not
None
:
define_macros
=
[(
"WITH_CUDA"
,
None
)]
include_dirs
+=
[
gpu_extensions_dir
]
extra_compile_args
[
"nvcc"
]
=
[
"-DCUDA_HAS_FP16=1"
,
"-D__CUDA_NO_HALF_OPERATORS__"
,
"-D__CUDA_NO_HALF_CONVERSIONS__"
,
"-D__CUDA_NO_HALF2_OPERATORS__"
,
]
ext_modules
.
extend
([
CUDAExtension
(
"encoding.gpu"
,
source_cuda
,
include_dirs
=
include_dirs
,
define_macros
=
define_macros
,
extra_compile_args
=
extra_compile_args
,
)
])
return
ext_modules
if
__name__
==
'__main__'
:
create_version_file
()
setup
(
...
...
@@ -68,4 +121,6 @@ if __name__ == '__main__':
'lib/gpu/*.cpp'
,
'lib/gpu/*.cu'
,
]},
ext_modules
=
get_extensions
(),
cmdclass
=
{
"build_ext"
:
torch
.
utils
.
cpp_extension
.
BuildExtension
},
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment