Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
MetaPortrait_pytorch
Commits
5efcc6ff
Commit
5efcc6ff
authored
Oct 11, 2023
by
mashun1
Browse files
metaportrait
parents
Pipeline
#584
canceled with stages
Changes
258
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4147 additions
and
0 deletions
+4147
-0
sr_model/Basicsr/basicsr/ops/dcn/__init__.py
sr_model/Basicsr/basicsr/ops/dcn/__init__.py
+7
-0
sr_model/Basicsr/basicsr/ops/dcn/deform_conv.py
sr_model/Basicsr/basicsr/ops/dcn/deform_conv.py
+379
-0
sr_model/Basicsr/basicsr/ops/dcn/src/deform_conv_cuda.cpp
sr_model/Basicsr/basicsr/ops/dcn/src/deform_conv_cuda.cpp
+685
-0
sr_model/Basicsr/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu
...el/Basicsr/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu
+867
-0
sr_model/Basicsr/basicsr/ops/dcn/src/deform_conv_ext.cpp
sr_model/Basicsr/basicsr/ops/dcn/src/deform_conv_ext.cpp
+164
-0
sr_model/Basicsr/basicsr/ops/fused_act/__init__.py
sr_model/Basicsr/basicsr/ops/fused_act/__init__.py
+3
-0
sr_model/Basicsr/basicsr/ops/fused_act/fused_act.py
sr_model/Basicsr/basicsr/ops/fused_act/fused_act.py
+95
-0
sr_model/Basicsr/basicsr/ops/fused_act/src/fused_bias_act.cpp
...odel/Basicsr/basicsr/ops/fused_act/src/fused_bias_act.cpp
+26
-0
sr_model/Basicsr/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu
...asicsr/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu
+100
-0
sr_model/Basicsr/basicsr/ops/upfirdn2d/__init__.py
sr_model/Basicsr/basicsr/ops/upfirdn2d/__init__.py
+3
-0
sr_model/Basicsr/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp
sr_model/Basicsr/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp
+24
-0
sr_model/Basicsr/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu
...del/Basicsr/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu
+370
-0
sr_model/Basicsr/basicsr/ops/upfirdn2d/upfirdn2d.py
sr_model/Basicsr/basicsr/ops/upfirdn2d/upfirdn2d.py
+192
-0
sr_model/Basicsr/basicsr/test.py
sr_model/Basicsr/basicsr/test.py
+45
-0
sr_model/Basicsr/basicsr/train.py
sr_model/Basicsr/basicsr/train.py
+237
-0
sr_model/Basicsr/basicsr/utils/__init__.py
sr_model/Basicsr/basicsr/utils/__init__.py
+46
-0
sr_model/Basicsr/basicsr/utils/color_util.py
sr_model/Basicsr/basicsr/utils/color_util.py
+208
-0
sr_model/Basicsr/basicsr/utils/diffjpeg.py
sr_model/Basicsr/basicsr/utils/diffjpeg.py
+515
-0
sr_model/Basicsr/basicsr/utils/dist_util.py
sr_model/Basicsr/basicsr/utils/dist_util.py
+82
-0
sr_model/Basicsr/basicsr/utils/download_util.py
sr_model/Basicsr/basicsr/utils/download_util.py
+99
-0
No files found.
Too many changes to show.
To preserve performance only
258 of 258+
files are displayed.
Plain diff
Email patch
sr_model/Basicsr/basicsr/ops/dcn/__init__.py
0 → 100644
View file @
5efcc6ff
from
.deform_conv
import
(
DeformConv
,
DeformConvPack
,
ModulatedDeformConv
,
ModulatedDeformConvPack
,
deform_conv
,
modulated_deform_conv
)
__all__
=
[
'DeformConv'
,
'DeformConvPack'
,
'ModulatedDeformConv'
,
'ModulatedDeformConvPack'
,
'deform_conv'
,
'modulated_deform_conv'
]
sr_model/Basicsr/basicsr/ops/dcn/deform_conv.py
0 → 100644
View file @
5efcc6ff
import
math
import
os
import
torch
from
torch
import
nn
as
nn
from
torch.autograd
import
Function
from
torch.autograd.function
import
once_differentiable
from
torch.nn
import
functional
as
F
from
torch.nn.modules.utils
import
_pair
,
_single
BASICSR_JIT
=
os
.
getenv
(
'BASICSR_JIT'
)
if
BASICSR_JIT
==
'True'
:
from
torch.utils.cpp_extension
import
load
module_path
=
os
.
path
.
dirname
(
__file__
)
deform_conv_ext
=
load
(
'deform_conv'
,
sources
=
[
os
.
path
.
join
(
module_path
,
'src'
,
'deform_conv_ext.cpp'
),
os
.
path
.
join
(
module_path
,
'src'
,
'deform_conv_cuda.cpp'
),
os
.
path
.
join
(
module_path
,
'src'
,
'deform_conv_cuda_kernel.cu'
),
],
)
else
:
try
:
from
.
import
deform_conv_ext
except
ImportError
:
pass
# avoid annoying print output
# print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n '
# '1. compile with BASICSR_EXT=True. or\n '
# '2. set BASICSR_JIT=True during running')
class
DeformConvFunction
(
Function
):
@
staticmethod
def
forward
(
ctx
,
input
,
offset
,
weight
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
,
deformable_groups
=
1
,
im2col_step
=
64
):
if
input
is
not
None
and
input
.
dim
()
!=
4
:
raise
ValueError
(
f
'Expected 4D tensor as input, got
{
input
.
dim
()
}
D tensor instead.'
)
ctx
.
stride
=
_pair
(
stride
)
ctx
.
padding
=
_pair
(
padding
)
ctx
.
dilation
=
_pair
(
dilation
)
ctx
.
groups
=
groups
ctx
.
deformable_groups
=
deformable_groups
ctx
.
im2col_step
=
im2col_step
ctx
.
save_for_backward
(
input
,
offset
,
weight
)
output
=
input
.
new_empty
(
DeformConvFunction
.
_output_size
(
input
,
weight
,
ctx
.
padding
,
ctx
.
dilation
,
ctx
.
stride
))
ctx
.
bufs_
=
[
input
.
new_empty
(
0
),
input
.
new_empty
(
0
)]
# columns, ones
if
not
input
.
is_cuda
:
raise
NotImplementedError
else
:
cur_im2col_step
=
min
(
ctx
.
im2col_step
,
input
.
shape
[
0
])
assert
(
input
.
shape
[
0
]
%
cur_im2col_step
)
==
0
,
'im2col step must divide batchsize'
deform_conv_ext
.
deform_conv_forward
(
input
,
weight
,
offset
,
output
,
ctx
.
bufs_
[
0
],
ctx
.
bufs_
[
1
],
weight
.
size
(
3
),
weight
.
size
(
2
),
ctx
.
stride
[
1
],
ctx
.
stride
[
0
],
ctx
.
padding
[
1
],
ctx
.
padding
[
0
],
ctx
.
dilation
[
1
],
ctx
.
dilation
[
0
],
ctx
.
groups
,
ctx
.
deformable_groups
,
cur_im2col_step
)
return
output
@
staticmethod
@
once_differentiable
def
backward
(
ctx
,
grad_output
):
input
,
offset
,
weight
=
ctx
.
saved_tensors
grad_input
=
grad_offset
=
grad_weight
=
None
if
not
grad_output
.
is_cuda
:
raise
NotImplementedError
else
:
cur_im2col_step
=
min
(
ctx
.
im2col_step
,
input
.
shape
[
0
])
assert
(
input
.
shape
[
0
]
%
cur_im2col_step
)
==
0
,
'im2col step must divide batchsize'
if
ctx
.
needs_input_grad
[
0
]
or
ctx
.
needs_input_grad
[
1
]:
grad_input
=
torch
.
zeros_like
(
input
)
grad_offset
=
torch
.
zeros_like
(
offset
)
deform_conv_ext
.
deform_conv_backward_input
(
input
,
offset
,
grad_output
,
grad_input
,
grad_offset
,
weight
,
ctx
.
bufs_
[
0
],
weight
.
size
(
3
),
weight
.
size
(
2
),
ctx
.
stride
[
1
],
ctx
.
stride
[
0
],
ctx
.
padding
[
1
],
ctx
.
padding
[
0
],
ctx
.
dilation
[
1
],
ctx
.
dilation
[
0
],
ctx
.
groups
,
ctx
.
deformable_groups
,
cur_im2col_step
)
if
ctx
.
needs_input_grad
[
2
]:
grad_weight
=
torch
.
zeros_like
(
weight
)
deform_conv_ext
.
deform_conv_backward_parameters
(
input
,
offset
,
grad_output
,
grad_weight
,
ctx
.
bufs_
[
0
],
ctx
.
bufs_
[
1
],
weight
.
size
(
3
),
weight
.
size
(
2
),
ctx
.
stride
[
1
],
ctx
.
stride
[
0
],
ctx
.
padding
[
1
],
ctx
.
padding
[
0
],
ctx
.
dilation
[
1
],
ctx
.
dilation
[
0
],
ctx
.
groups
,
ctx
.
deformable_groups
,
1
,
cur_im2col_step
)
return
(
grad_input
,
grad_offset
,
grad_weight
,
None
,
None
,
None
,
None
,
None
)
@
staticmethod
def
_output_size
(
input
,
weight
,
padding
,
dilation
,
stride
):
channels
=
weight
.
size
(
0
)
output_size
=
(
input
.
size
(
0
),
channels
)
for
d
in
range
(
input
.
dim
()
-
2
):
in_size
=
input
.
size
(
d
+
2
)
pad
=
padding
[
d
]
kernel
=
dilation
[
d
]
*
(
weight
.
size
(
d
+
2
)
-
1
)
+
1
stride_
=
stride
[
d
]
output_size
+=
((
in_size
+
(
2
*
pad
)
-
kernel
)
//
stride_
+
1
,
)
if
not
all
(
map
(
lambda
s
:
s
>
0
,
output_size
)):
raise
ValueError
(
f
'convolution input is too small (output would be
{
"x"
.
join
(
map
(
str
,
output_size
))
}
)'
)
return
output_size
class
ModulatedDeformConvFunction
(
Function
):
@
staticmethod
def
forward
(
ctx
,
input
,
offset
,
mask
,
weight
,
bias
=
None
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
,
deformable_groups
=
1
):
ctx
.
stride
=
stride
ctx
.
padding
=
padding
ctx
.
dilation
=
dilation
ctx
.
groups
=
groups
ctx
.
deformable_groups
=
deformable_groups
ctx
.
with_bias
=
bias
is
not
None
if
not
ctx
.
with_bias
:
bias
=
input
.
new_empty
(
1
)
# fake tensor
if
not
input
.
is_cuda
:
raise
NotImplementedError
if
weight
.
requires_grad
or
mask
.
requires_grad
or
offset
.
requires_grad
or
input
.
requires_grad
:
ctx
.
save_for_backward
(
input
,
offset
,
mask
,
weight
,
bias
)
output
=
input
.
new_empty
(
ModulatedDeformConvFunction
.
_infer_shape
(
ctx
,
input
,
weight
))
ctx
.
_bufs
=
[
input
.
new_empty
(
0
),
input
.
new_empty
(
0
)]
deform_conv_ext
.
modulated_deform_conv_forward
(
input
,
weight
,
bias
,
ctx
.
_bufs
[
0
],
offset
,
mask
,
output
,
ctx
.
_bufs
[
1
],
weight
.
shape
[
2
],
weight
.
shape
[
3
],
ctx
.
stride
,
ctx
.
stride
,
ctx
.
padding
,
ctx
.
padding
,
ctx
.
dilation
,
ctx
.
dilation
,
ctx
.
groups
,
ctx
.
deformable_groups
,
ctx
.
with_bias
)
return
output
@
staticmethod
@
once_differentiable
def
backward
(
ctx
,
grad_output
):
if
not
grad_output
.
is_cuda
:
raise
NotImplementedError
input
,
offset
,
mask
,
weight
,
bias
=
ctx
.
saved_tensors
grad_input
=
torch
.
zeros_like
(
input
)
grad_offset
=
torch
.
zeros_like
(
offset
)
grad_mask
=
torch
.
zeros_like
(
mask
)
grad_weight
=
torch
.
zeros_like
(
weight
)
grad_bias
=
torch
.
zeros_like
(
bias
)
deform_conv_ext
.
modulated_deform_conv_backward
(
input
,
weight
,
bias
,
ctx
.
_bufs
[
0
],
offset
,
mask
,
ctx
.
_bufs
[
1
],
grad_input
,
grad_weight
,
grad_bias
,
grad_offset
,
grad_mask
,
grad_output
,
weight
.
shape
[
2
],
weight
.
shape
[
3
],
ctx
.
stride
,
ctx
.
stride
,
ctx
.
padding
,
ctx
.
padding
,
ctx
.
dilation
,
ctx
.
dilation
,
ctx
.
groups
,
ctx
.
deformable_groups
,
ctx
.
with_bias
)
if
not
ctx
.
with_bias
:
grad_bias
=
None
return
(
grad_input
,
grad_offset
,
grad_mask
,
grad_weight
,
grad_bias
,
None
,
None
,
None
,
None
,
None
)
@
staticmethod
def
_infer_shape
(
ctx
,
input
,
weight
):
n
=
input
.
size
(
0
)
channels_out
=
weight
.
size
(
0
)
height
,
width
=
input
.
shape
[
2
:
4
]
kernel_h
,
kernel_w
=
weight
.
shape
[
2
:
4
]
height_out
=
(
height
+
2
*
ctx
.
padding
-
(
ctx
.
dilation
*
(
kernel_h
-
1
)
+
1
))
//
ctx
.
stride
+
1
width_out
=
(
width
+
2
*
ctx
.
padding
-
(
ctx
.
dilation
*
(
kernel_w
-
1
)
+
1
))
//
ctx
.
stride
+
1
return
n
,
channels_out
,
height_out
,
width_out
deform_conv
=
DeformConvFunction
.
apply
modulated_deform_conv
=
ModulatedDeformConvFunction
.
apply
class
DeformConv
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
,
deformable_groups
=
1
,
bias
=
False
):
super
(
DeformConv
,
self
).
__init__
()
assert
not
bias
assert
in_channels
%
groups
==
0
,
f
'in_channels
{
in_channels
}
is not divisible by groups
{
groups
}
'
assert
out_channels
%
groups
==
0
,
f
'out_channels
{
out_channels
}
is not divisible by groups
{
groups
}
'
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
kernel_size
=
_pair
(
kernel_size
)
self
.
stride
=
_pair
(
stride
)
self
.
padding
=
_pair
(
padding
)
self
.
dilation
=
_pair
(
dilation
)
self
.
groups
=
groups
self
.
deformable_groups
=
deformable_groups
# enable compatibility with nn.Conv2d
self
.
transposed
=
False
self
.
output_padding
=
_single
(
0
)
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
out_channels
,
in_channels
//
self
.
groups
,
*
self
.
kernel_size
))
self
.
reset_parameters
()
def
reset_parameters
(
self
):
n
=
self
.
in_channels
for
k
in
self
.
kernel_size
:
n
*=
k
stdv
=
1.
/
math
.
sqrt
(
n
)
self
.
weight
.
data
.
uniform_
(
-
stdv
,
stdv
)
def
forward
(
self
,
x
,
offset
):
# To fix an assert error in deform_conv_cuda.cpp:128
# input image is smaller than kernel
input_pad
=
(
x
.
size
(
2
)
<
self
.
kernel_size
[
0
]
or
x
.
size
(
3
)
<
self
.
kernel_size
[
1
])
if
input_pad
:
pad_h
=
max
(
self
.
kernel_size
[
0
]
-
x
.
size
(
2
),
0
)
pad_w
=
max
(
self
.
kernel_size
[
1
]
-
x
.
size
(
3
),
0
)
x
=
F
.
pad
(
x
,
(
0
,
pad_w
,
0
,
pad_h
),
'constant'
,
0
).
contiguous
()
offset
=
F
.
pad
(
offset
,
(
0
,
pad_w
,
0
,
pad_h
),
'constant'
,
0
).
contiguous
()
out
=
deform_conv
(
x
,
offset
,
self
.
weight
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
groups
,
self
.
deformable_groups
)
if
input_pad
:
out
=
out
[:,
:,
:
out
.
size
(
2
)
-
pad_h
,
:
out
.
size
(
3
)
-
pad_w
].
contiguous
()
return
out
class
DeformConvPack
(
DeformConv
):
"""A Deformable Conv Encapsulation that acts as normal Conv layers.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int or tuple[int]): Same as nn.Conv2d.
padding (int or tuple[int]): Same as nn.Conv2d.
dilation (int or tuple[int]): Same as nn.Conv2d.
groups (int): Same as nn.Conv2d.
bias (bool or str): If specified as `auto`, it will be decided by the
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
False.
"""
_version
=
2
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
DeformConvPack
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
conv_offset
=
nn
.
Conv2d
(
self
.
in_channels
,
self
.
deformable_groups
*
2
*
self
.
kernel_size
[
0
]
*
self
.
kernel_size
[
1
],
kernel_size
=
self
.
kernel_size
,
stride
=
_pair
(
self
.
stride
),
padding
=
_pair
(
self
.
padding
),
dilation
=
_pair
(
self
.
dilation
),
bias
=
True
)
self
.
init_offset
()
def
init_offset
(
self
):
self
.
conv_offset
.
weight
.
data
.
zero_
()
self
.
conv_offset
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
):
offset
=
self
.
conv_offset
(
x
)
return
deform_conv
(
x
,
offset
,
self
.
weight
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
groups
,
self
.
deformable_groups
)
class
ModulatedDeformConv
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
,
deformable_groups
=
1
,
bias
=
True
):
super
(
ModulatedDeformConv
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
kernel_size
=
_pair
(
kernel_size
)
self
.
stride
=
stride
self
.
padding
=
padding
self
.
dilation
=
dilation
self
.
groups
=
groups
self
.
deformable_groups
=
deformable_groups
self
.
with_bias
=
bias
# enable compatibility with nn.Conv2d
self
.
transposed
=
False
self
.
output_padding
=
_single
(
0
)
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
out_channels
,
in_channels
//
groups
,
*
self
.
kernel_size
))
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
Tensor
(
out_channels
))
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
init_weights
()
def
init_weights
(
self
):
n
=
self
.
in_channels
for
k
in
self
.
kernel_size
:
n
*=
k
stdv
=
1.
/
math
.
sqrt
(
n
)
self
.
weight
.
data
.
uniform_
(
-
stdv
,
stdv
)
if
self
.
bias
is
not
None
:
self
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
,
offset
,
mask
):
return
modulated_deform_conv
(
x
,
offset
,
mask
,
self
.
weight
,
self
.
bias
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
groups
,
self
.
deformable_groups
)
class
ModulatedDeformConvPack
(
ModulatedDeformConv
):
"""A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int or tuple[int]): Same as nn.Conv2d.
padding (int or tuple[int]): Same as nn.Conv2d.
dilation (int or tuple[int]): Same as nn.Conv2d.
groups (int): Same as nn.Conv2d.
bias (bool or str): If specified as `auto`, it will be decided by the
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
False.
"""
_version
=
2
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
ModulatedDeformConvPack
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
conv_offset
=
nn
.
Conv2d
(
self
.
in_channels
,
self
.
deformable_groups
*
3
*
self
.
kernel_size
[
0
]
*
self
.
kernel_size
[
1
],
kernel_size
=
self
.
kernel_size
,
stride
=
_pair
(
self
.
stride
),
padding
=
_pair
(
self
.
padding
),
dilation
=
_pair
(
self
.
dilation
),
bias
=
True
)
self
.
init_weights
()
def
init_weights
(
self
):
super
(
ModulatedDeformConvPack
,
self
).
init_weights
()
if
hasattr
(
self
,
'conv_offset'
):
self
.
conv_offset
.
weight
.
data
.
zero_
()
self
.
conv_offset
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
):
out
=
self
.
conv_offset
(
x
)
o1
,
o2
,
mask
=
torch
.
chunk
(
out
,
3
,
dim
=
1
)
offset
=
torch
.
cat
((
o1
,
o2
),
dim
=
1
)
mask
=
torch
.
sigmoid
(
mask
)
return
modulated_deform_conv
(
x
,
offset
,
mask
,
self
.
weight
,
self
.
bias
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
groups
,
self
.
deformable_groups
)
sr_model/Basicsr/basicsr/ops/dcn/src/deform_conv_cuda.cpp
0 → 100644
View file @
5efcc6ff
// modify from
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
#include <torch/extension.h>
#include <ATen/DeviceGuard.h>
#include <cmath>
#include <vector>
void
deformable_im2col
(
const
at
::
Tensor
data_im
,
const
at
::
Tensor
data_offset
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
ksize_h
,
const
int
ksize_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
parallel_imgs
,
const
int
deformable_group
,
at
::
Tensor
data_col
);
void
deformable_col2im
(
const
at
::
Tensor
data_col
,
const
at
::
Tensor
data_offset
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
ksize_h
,
const
int
ksize_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
parallel_imgs
,
const
int
deformable_group
,
at
::
Tensor
grad_im
);
void
deformable_col2im_coord
(
const
at
::
Tensor
data_col
,
const
at
::
Tensor
data_im
,
const
at
::
Tensor
data_offset
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
ksize_h
,
const
int
ksize_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
parallel_imgs
,
const
int
deformable_group
,
at
::
Tensor
grad_offset
);
void
modulated_deformable_im2col_cuda
(
const
at
::
Tensor
data_im
,
const
at
::
Tensor
data_offset
,
const
at
::
Tensor
data_mask
,
const
int
batch_size
,
const
int
channels
,
const
int
height_im
,
const
int
width_im
,
const
int
height_col
,
const
int
width_col
,
const
int
kernel_h
,
const
int
kenerl_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
deformable_group
,
at
::
Tensor
data_col
);
void
modulated_deformable_col2im_cuda
(
const
at
::
Tensor
data_col
,
const
at
::
Tensor
data_offset
,
const
at
::
Tensor
data_mask
,
const
int
batch_size
,
const
int
channels
,
const
int
height_im
,
const
int
width_im
,
const
int
height_col
,
const
int
width_col
,
const
int
kernel_h
,
const
int
kenerl_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
deformable_group
,
at
::
Tensor
grad_im
);
void
modulated_deformable_col2im_coord_cuda
(
const
at
::
Tensor
data_col
,
const
at
::
Tensor
data_im
,
const
at
::
Tensor
data_offset
,
const
at
::
Tensor
data_mask
,
const
int
batch_size
,
const
int
channels
,
const
int
height_im
,
const
int
width_im
,
const
int
height_col
,
const
int
width_col
,
const
int
kernel_h
,
const
int
kenerl_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
deformable_group
,
at
::
Tensor
grad_offset
,
at
::
Tensor
grad_mask
);
void
shape_check
(
at
::
Tensor
input
,
at
::
Tensor
offset
,
at
::
Tensor
*
gradOutput
,
at
::
Tensor
weight
,
int
kH
,
int
kW
,
int
dH
,
int
dW
,
int
padH
,
int
padW
,
int
dilationH
,
int
dilationW
,
int
group
,
int
deformable_group
)
{
TORCH_CHECK
(
weight
.
ndimension
()
==
4
,
"4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
"but got: %s"
,
weight
.
ndimension
());
TORCH_CHECK
(
weight
.
is_contiguous
(),
"weight tensor has to be contiguous"
);
TORCH_CHECK
(
kW
>
0
&&
kH
>
0
,
"kernel size should be greater than zero, but got kH: %d kW: %d"
,
kH
,
kW
);
TORCH_CHECK
((
weight
.
size
(
2
)
==
kH
&&
weight
.
size
(
3
)
==
kW
),
"kernel size should be consistent with weight, "
,
"but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d"
,
kH
,
kW
,
weight
.
size
(
2
),
weight
.
size
(
3
));
TORCH_CHECK
(
dW
>
0
&&
dH
>
0
,
"stride should be greater than zero, but got dH: %d dW: %d"
,
dH
,
dW
);
TORCH_CHECK
(
dilationW
>
0
&&
dilationH
>
0
,
"dilation should be greater than 0, but got dilationH: %d dilationW: %d"
,
dilationH
,
dilationW
);
int
ndim
=
input
.
ndimension
();
int
dimf
=
0
;
int
dimh
=
1
;
int
dimw
=
2
;
if
(
ndim
==
4
)
{
dimf
++
;
dimh
++
;
dimw
++
;
}
TORCH_CHECK
(
ndim
==
3
||
ndim
==
4
,
"3D or 4D input tensor expected but got: %s"
,
ndim
);
long
nInputPlane
=
weight
.
size
(
1
)
*
group
;
long
inputHeight
=
input
.
size
(
dimh
);
long
inputWidth
=
input
.
size
(
dimw
);
long
nOutputPlane
=
weight
.
size
(
0
);
long
outputHeight
=
(
inputHeight
+
2
*
padH
-
(
dilationH
*
(
kH
-
1
)
+
1
))
/
dH
+
1
;
long
outputWidth
=
(
inputWidth
+
2
*
padW
-
(
dilationW
*
(
kW
-
1
)
+
1
))
/
dW
+
1
;
TORCH_CHECK
(
nInputPlane
%
deformable_group
==
0
,
"input channels must divide deformable group size"
);
if
(
outputWidth
<
1
||
outputHeight
<
1
)
AT_ERROR
(
"Given input size: (%ld x %ld x %ld). "
"Calculated output size: (%ld x %ld x %ld). Output size is too small"
,
nInputPlane
,
inputHeight
,
inputWidth
,
nOutputPlane
,
outputHeight
,
outputWidth
);
TORCH_CHECK
(
input
.
size
(
1
)
==
nInputPlane
,
"invalid number of input planes, expected: %d, but got: %d"
,
nInputPlane
,
input
.
size
(
1
));
TORCH_CHECK
((
inputHeight
>=
kH
&&
inputWidth
>=
kW
),
"input image is smaller than kernel"
);
TORCH_CHECK
((
offset
.
size
(
2
)
==
outputHeight
&&
offset
.
size
(
3
)
==
outputWidth
),
"invalid spatial size of offset, expected height: %d width: %d, but "
"got height: %d width: %d"
,
outputHeight
,
outputWidth
,
offset
.
size
(
2
),
offset
.
size
(
3
));
TORCH_CHECK
((
offset
.
size
(
1
)
==
deformable_group
*
2
*
kH
*
kW
),
"invalid number of channels of offset"
);
if
(
gradOutput
!=
NULL
)
{
TORCH_CHECK
(
gradOutput
->
size
(
dimf
)
==
nOutputPlane
,
"invalid number of gradOutput planes, expected: %d, but got: %d"
,
nOutputPlane
,
gradOutput
->
size
(
dimf
));
TORCH_CHECK
((
gradOutput
->
size
(
dimh
)
==
outputHeight
&&
gradOutput
->
size
(
dimw
)
==
outputWidth
),
"invalid size of gradOutput, expected height: %d width: %d , but "
"got height: %d width: %d"
,
outputHeight
,
outputWidth
,
gradOutput
->
size
(
dimh
),
gradOutput
->
size
(
dimw
));
}
}
int
deform_conv_forward_cuda
(
at
::
Tensor
input
,
at
::
Tensor
weight
,
at
::
Tensor
offset
,
at
::
Tensor
output
,
at
::
Tensor
columns
,
at
::
Tensor
ones
,
int
kW
,
int
kH
,
int
dW
,
int
dH
,
int
padW
,
int
padH
,
int
dilationW
,
int
dilationH
,
int
group
,
int
deformable_group
,
int
im2col_step
)
{
// todo: resize columns to include im2col: done
// todo: add im2col_step as input
// todo: add new output buffer and transpose it to output (or directly
// transpose output) todo: possibly change data indexing because of
// parallel_imgs
shape_check
(
input
,
offset
,
NULL
,
weight
,
kH
,
kW
,
dH
,
dW
,
padH
,
padW
,
dilationH
,
dilationW
,
group
,
deformable_group
);
at
::
DeviceGuard
guard
(
input
.
device
());
input
=
input
.
contiguous
();
offset
=
offset
.
contiguous
();
weight
=
weight
.
contiguous
();
int
batch
=
1
;
if
(
input
.
ndimension
()
==
3
)
{
// Force batch
batch
=
0
;
input
.
unsqueeze_
(
0
);
offset
.
unsqueeze_
(
0
);
}
// todo: assert batchsize dividable by im2col_step
long
batchSize
=
input
.
size
(
0
);
long
nInputPlane
=
input
.
size
(
1
);
long
inputHeight
=
input
.
size
(
2
);
long
inputWidth
=
input
.
size
(
3
);
long
nOutputPlane
=
weight
.
size
(
0
);
long
outputWidth
=
(
inputWidth
+
2
*
padW
-
(
dilationW
*
(
kW
-
1
)
+
1
))
/
dW
+
1
;
long
outputHeight
=
(
inputHeight
+
2
*
padH
-
(
dilationH
*
(
kH
-
1
)
+
1
))
/
dH
+
1
;
TORCH_CHECK
((
offset
.
size
(
0
)
==
batchSize
),
"invalid batch size of offset"
);
output
=
output
.
view
({
batchSize
/
im2col_step
,
im2col_step
,
nOutputPlane
,
outputHeight
,
outputWidth
});
columns
=
at
::
zeros
(
{
nInputPlane
*
kW
*
kH
,
im2col_step
*
outputHeight
*
outputWidth
},
input
.
options
());
if
(
ones
.
ndimension
()
!=
2
||
ones
.
size
(
0
)
*
ones
.
size
(
1
)
<
outputHeight
*
outputWidth
)
{
ones
=
at
::
ones
({
outputHeight
,
outputWidth
},
input
.
options
());
}
input
=
input
.
view
({
batchSize
/
im2col_step
,
im2col_step
,
nInputPlane
,
inputHeight
,
inputWidth
});
offset
=
offset
.
view
({
batchSize
/
im2col_step
,
im2col_step
,
deformable_group
*
2
*
kH
*
kW
,
outputHeight
,
outputWidth
});
at
::
Tensor
output_buffer
=
at
::
zeros
({
batchSize
/
im2col_step
,
nOutputPlane
,
im2col_step
*
outputHeight
,
outputWidth
},
output
.
options
());
output_buffer
=
output_buffer
.
view
(
{
output_buffer
.
size
(
0
),
group
,
output_buffer
.
size
(
1
)
/
group
,
output_buffer
.
size
(
2
),
output_buffer
.
size
(
3
)});
for
(
int
elt
=
0
;
elt
<
batchSize
/
im2col_step
;
elt
++
)
{
deformable_im2col
(
input
[
elt
],
offset
[
elt
],
nInputPlane
,
inputHeight
,
inputWidth
,
kH
,
kW
,
padH
,
padW
,
dH
,
dW
,
dilationH
,
dilationW
,
im2col_step
,
deformable_group
,
columns
);
columns
=
columns
.
view
({
group
,
columns
.
size
(
0
)
/
group
,
columns
.
size
(
1
)});
weight
=
weight
.
view
({
group
,
weight
.
size
(
0
)
/
group
,
weight
.
size
(
1
),
weight
.
size
(
2
),
weight
.
size
(
3
)});
for
(
int
g
=
0
;
g
<
group
;
g
++
)
{
output_buffer
[
elt
][
g
]
=
output_buffer
[
elt
][
g
]
.
flatten
(
1
)
.
addmm_
(
weight
[
g
].
flatten
(
1
),
columns
[
g
])
.
view_as
(
output_buffer
[
elt
][
g
]);
}
}
output_buffer
=
output_buffer
.
view
(
{
output_buffer
.
size
(
0
),
output_buffer
.
size
(
1
)
*
output_buffer
.
size
(
2
),
output_buffer
.
size
(
3
),
output_buffer
.
size
(
4
)});
output_buffer
=
output_buffer
.
view
({
batchSize
/
im2col_step
,
nOutputPlane
,
im2col_step
,
outputHeight
,
outputWidth
});
output_buffer
.
transpose_
(
1
,
2
);
output
.
copy_
(
output_buffer
);
output
=
output
.
view
({
batchSize
,
nOutputPlane
,
outputHeight
,
outputWidth
});
input
=
input
.
view
({
batchSize
,
nInputPlane
,
inputHeight
,
inputWidth
});
offset
=
offset
.
view
(
{
batchSize
,
deformable_group
*
2
*
kH
*
kW
,
outputHeight
,
outputWidth
});
if
(
batch
==
0
)
{
output
=
output
.
view
({
nOutputPlane
,
outputHeight
,
outputWidth
});
input
=
input
.
view
({
nInputPlane
,
inputHeight
,
inputWidth
});
offset
=
offset
.
view
({
offset
.
size
(
1
),
offset
.
size
(
2
),
offset
.
size
(
3
)});
}
return
1
;
}
int
deform_conv_backward_input_cuda
(
at
::
Tensor
input
,
at
::
Tensor
offset
,
at
::
Tensor
gradOutput
,
at
::
Tensor
gradInput
,
at
::
Tensor
gradOffset
,
at
::
Tensor
weight
,
at
::
Tensor
columns
,
int
kW
,
int
kH
,
int
dW
,
int
dH
,
int
padW
,
int
padH
,
int
dilationW
,
int
dilationH
,
int
group
,
int
deformable_group
,
int
im2col_step
)
{
shape_check
(
input
,
offset
,
&
gradOutput
,
weight
,
kH
,
kW
,
dH
,
dW
,
padH
,
padW
,
dilationH
,
dilationW
,
group
,
deformable_group
);
at
::
DeviceGuard
guard
(
input
.
device
());
input
=
input
.
contiguous
();
offset
=
offset
.
contiguous
();
gradOutput
=
gradOutput
.
contiguous
();
weight
=
weight
.
contiguous
();
int
batch
=
1
;
if
(
input
.
ndimension
()
==
3
)
{
// Force batch
batch
=
0
;
input
=
input
.
view
({
1
,
input
.
size
(
0
),
input
.
size
(
1
),
input
.
size
(
2
)});
offset
=
offset
.
view
({
1
,
offset
.
size
(
0
),
offset
.
size
(
1
),
offset
.
size
(
2
)});
gradOutput
=
gradOutput
.
view
(
{
1
,
gradOutput
.
size
(
0
),
gradOutput
.
size
(
1
),
gradOutput
.
size
(
2
)});
}
long
batchSize
=
input
.
size
(
0
);
long
nInputPlane
=
input
.
size
(
1
);
long
inputHeight
=
input
.
size
(
2
);
long
inputWidth
=
input
.
size
(
3
);
long
nOutputPlane
=
weight
.
size
(
0
);
long
outputWidth
=
(
inputWidth
+
2
*
padW
-
(
dilationW
*
(
kW
-
1
)
+
1
))
/
dW
+
1
;
long
outputHeight
=
(
inputHeight
+
2
*
padH
-
(
dilationH
*
(
kH
-
1
)
+
1
))
/
dH
+
1
;
TORCH_CHECK
((
offset
.
size
(
0
)
==
batchSize
),
3
,
"invalid batch size of offset"
);
gradInput
=
gradInput
.
view
({
batchSize
,
nInputPlane
,
inputHeight
,
inputWidth
});
columns
=
at
::
zeros
(
{
nInputPlane
*
kW
*
kH
,
im2col_step
*
outputHeight
*
outputWidth
},
input
.
options
());
// change order of grad output
gradOutput
=
gradOutput
.
view
({
batchSize
/
im2col_step
,
im2col_step
,
nOutputPlane
,
outputHeight
,
outputWidth
});
gradOutput
.
transpose_
(
1
,
2
);
gradInput
=
gradInput
.
view
({
batchSize
/
im2col_step
,
im2col_step
,
nInputPlane
,
inputHeight
,
inputWidth
});
input
=
input
.
view
({
batchSize
/
im2col_step
,
im2col_step
,
nInputPlane
,
inputHeight
,
inputWidth
});
gradOffset
=
gradOffset
.
view
({
batchSize
/
im2col_step
,
im2col_step
,
deformable_group
*
2
*
kH
*
kW
,
outputHeight
,
outputWidth
});
offset
=
offset
.
view
({
batchSize
/
im2col_step
,
im2col_step
,
deformable_group
*
2
*
kH
*
kW
,
outputHeight
,
outputWidth
});
for
(
int
elt
=
0
;
elt
<
batchSize
/
im2col_step
;
elt
++
)
{
// divide into groups
columns
=
columns
.
view
({
group
,
columns
.
size
(
0
)
/
group
,
columns
.
size
(
1
)});
weight
=
weight
.
view
({
group
,
weight
.
size
(
0
)
/
group
,
weight
.
size
(
1
),
weight
.
size
(
2
),
weight
.
size
(
3
)});
gradOutput
=
gradOutput
.
view
(
{
gradOutput
.
size
(
0
),
group
,
gradOutput
.
size
(
1
)
/
group
,
gradOutput
.
size
(
2
),
gradOutput
.
size
(
3
),
gradOutput
.
size
(
4
)});
for
(
int
g
=
0
;
g
<
group
;
g
++
)
{
columns
[
g
]
=
columns
[
g
].
addmm_
(
weight
[
g
].
flatten
(
1
).
transpose
(
0
,
1
),
gradOutput
[
elt
][
g
].
flatten
(
1
),
0.0
f
,
1.0
f
);
}
columns
=
columns
.
view
({
columns
.
size
(
0
)
*
columns
.
size
(
1
),
columns
.
size
(
2
)});
gradOutput
=
gradOutput
.
view
(
{
gradOutput
.
size
(
0
),
gradOutput
.
size
(
1
)
*
gradOutput
.
size
(
2
),
gradOutput
.
size
(
3
),
gradOutput
.
size
(
4
),
gradOutput
.
size
(
5
)});
deformable_col2im_coord
(
columns
,
input
[
elt
],
offset
[
elt
],
nInputPlane
,
inputHeight
,
inputWidth
,
kH
,
kW
,
padH
,
padW
,
dH
,
dW
,
dilationH
,
dilationW
,
im2col_step
,
deformable_group
,
gradOffset
[
elt
]);
deformable_col2im
(
columns
,
offset
[
elt
],
nInputPlane
,
inputHeight
,
inputWidth
,
kH
,
kW
,
padH
,
padW
,
dH
,
dW
,
dilationH
,
dilationW
,
im2col_step
,
deformable_group
,
gradInput
[
elt
]);
}
gradOutput
.
transpose_
(
1
,
2
);
gradOutput
=
gradOutput
.
view
({
batchSize
,
nOutputPlane
,
outputHeight
,
outputWidth
});
gradInput
=
gradInput
.
view
({
batchSize
,
nInputPlane
,
inputHeight
,
inputWidth
});
input
=
input
.
view
({
batchSize
,
nInputPlane
,
inputHeight
,
inputWidth
});
gradOffset
=
gradOffset
.
view
(
{
batchSize
,
deformable_group
*
2
*
kH
*
kW
,
outputHeight
,
outputWidth
});
offset
=
offset
.
view
(
{
batchSize
,
deformable_group
*
2
*
kH
*
kW
,
outputHeight
,
outputWidth
});
if
(
batch
==
0
)
{
gradOutput
=
gradOutput
.
view
({
nOutputPlane
,
outputHeight
,
outputWidth
});
input
=
input
.
view
({
nInputPlane
,
inputHeight
,
inputWidth
});
gradInput
=
gradInput
.
view
({
nInputPlane
,
inputHeight
,
inputWidth
});
offset
=
offset
.
view
({
offset
.
size
(
1
),
offset
.
size
(
2
),
offset
.
size
(
3
)});
gradOffset
=
gradOffset
.
view
({
offset
.
size
(
1
),
offset
.
size
(
2
),
offset
.
size
(
3
)});
}
return
1
;
}
int
deform_conv_backward_parameters_cuda
(
at
::
Tensor
input
,
at
::
Tensor
offset
,
at
::
Tensor
gradOutput
,
at
::
Tensor
gradWeight
,
// at::Tensor gradBias,
at
::
Tensor
columns
,
at
::
Tensor
ones
,
int
kW
,
int
kH
,
int
dW
,
int
dH
,
int
padW
,
int
padH
,
int
dilationW
,
int
dilationH
,
int
group
,
int
deformable_group
,
float
scale
,
int
im2col_step
)
{
// todo: transpose and reshape outGrad
// todo: reshape columns
// todo: add im2col_step as input
shape_check
(
input
,
offset
,
&
gradOutput
,
gradWeight
,
kH
,
kW
,
dH
,
dW
,
padH
,
padW
,
dilationH
,
dilationW
,
group
,
deformable_group
);
at
::
DeviceGuard
guard
(
input
.
device
());
input
=
input
.
contiguous
();
offset
=
offset
.
contiguous
();
gradOutput
=
gradOutput
.
contiguous
();
int
batch
=
1
;
if
(
input
.
ndimension
()
==
3
)
{
// Force batch
batch
=
0
;
input
=
input
.
view
(
at
::
IntList
({
1
,
input
.
size
(
0
),
input
.
size
(
1
),
input
.
size
(
2
)}));
gradOutput
=
gradOutput
.
view
(
{
1
,
gradOutput
.
size
(
0
),
gradOutput
.
size
(
1
),
gradOutput
.
size
(
2
)});
}
long
batchSize
=
input
.
size
(
0
);
long
nInputPlane
=
input
.
size
(
1
);
long
inputHeight
=
input
.
size
(
2
);
long
inputWidth
=
input
.
size
(
3
);
long
nOutputPlane
=
gradWeight
.
size
(
0
);
long
outputWidth
=
(
inputWidth
+
2
*
padW
-
(
dilationW
*
(
kW
-
1
)
+
1
))
/
dW
+
1
;
long
outputHeight
=
(
inputHeight
+
2
*
padH
-
(
dilationH
*
(
kH
-
1
)
+
1
))
/
dH
+
1
;
TORCH_CHECK
((
offset
.
size
(
0
)
==
batchSize
),
"invalid batch size of offset"
);
columns
=
at
::
zeros
(
{
nInputPlane
*
kW
*
kH
,
im2col_step
*
outputHeight
*
outputWidth
},
input
.
options
());
gradOutput
=
gradOutput
.
view
({
batchSize
/
im2col_step
,
im2col_step
,
nOutputPlane
,
outputHeight
,
outputWidth
});
gradOutput
.
transpose_
(
1
,
2
);
at
::
Tensor
gradOutputBuffer
=
at
::
zeros_like
(
gradOutput
);
gradOutputBuffer
=
gradOutputBuffer
.
view
({
batchSize
/
im2col_step
,
nOutputPlane
,
im2col_step
,
outputHeight
,
outputWidth
});
gradOutputBuffer
.
copy_
(
gradOutput
);
gradOutputBuffer
=
gradOutputBuffer
.
view
({
batchSize
/
im2col_step
,
nOutputPlane
,
im2col_step
*
outputHeight
,
outputWidth
});
gradOutput
.
transpose_
(
1
,
2
);
gradOutput
=
gradOutput
.
view
({
batchSize
,
nOutputPlane
,
outputHeight
,
outputWidth
});
input
=
input
.
view
({
batchSize
/
im2col_step
,
im2col_step
,
nInputPlane
,
inputHeight
,
inputWidth
});
offset
=
offset
.
view
({
batchSize
/
im2col_step
,
im2col_step
,
deformable_group
*
2
*
kH
*
kW
,
outputHeight
,
outputWidth
});
for
(
int
elt
=
0
;
elt
<
batchSize
/
im2col_step
;
elt
++
)
{
deformable_im2col
(
input
[
elt
],
offset
[
elt
],
nInputPlane
,
inputHeight
,
inputWidth
,
kH
,
kW
,
padH
,
padW
,
dH
,
dW
,
dilationH
,
dilationW
,
im2col_step
,
deformable_group
,
columns
);
// divide into group
gradOutputBuffer
=
gradOutputBuffer
.
view
(
{
gradOutputBuffer
.
size
(
0
),
group
,
gradOutputBuffer
.
size
(
1
)
/
group
,
gradOutputBuffer
.
size
(
2
),
gradOutputBuffer
.
size
(
3
)});
columns
=
columns
.
view
({
group
,
columns
.
size
(
0
)
/
group
,
columns
.
size
(
1
)});
gradWeight
=
gradWeight
.
view
({
group
,
gradWeight
.
size
(
0
)
/
group
,
gradWeight
.
size
(
1
),
gradWeight
.
size
(
2
),
gradWeight
.
size
(
3
)});
for
(
int
g
=
0
;
g
<
group
;
g
++
)
{
gradWeight
[
g
]
=
gradWeight
[
g
]
.
flatten
(
1
)
.
addmm_
(
gradOutputBuffer
[
elt
][
g
].
flatten
(
1
),
columns
[
g
].
transpose
(
1
,
0
),
1.0
,
scale
)
.
view_as
(
gradWeight
[
g
]);
}
gradOutputBuffer
=
gradOutputBuffer
.
view
(
{
gradOutputBuffer
.
size
(
0
),
gradOutputBuffer
.
size
(
1
)
*
gradOutputBuffer
.
size
(
2
),
gradOutputBuffer
.
size
(
3
),
gradOutputBuffer
.
size
(
4
)});
columns
=
columns
.
view
({
columns
.
size
(
0
)
*
columns
.
size
(
1
),
columns
.
size
(
2
)});
gradWeight
=
gradWeight
.
view
({
gradWeight
.
size
(
0
)
*
gradWeight
.
size
(
1
),
gradWeight
.
size
(
2
),
gradWeight
.
size
(
3
),
gradWeight
.
size
(
4
)});
}
input
=
input
.
view
({
batchSize
,
nInputPlane
,
inputHeight
,
inputWidth
});
offset
=
offset
.
view
(
{
batchSize
,
deformable_group
*
2
*
kH
*
kW
,
outputHeight
,
outputWidth
});
if
(
batch
==
0
)
{
gradOutput
=
gradOutput
.
view
({
nOutputPlane
,
outputHeight
,
outputWidth
});
input
=
input
.
view
({
nInputPlane
,
inputHeight
,
inputWidth
});
}
return
1
;
}
void
modulated_deform_conv_cuda_forward
(
at
::
Tensor
input
,
at
::
Tensor
weight
,
at
::
Tensor
bias
,
at
::
Tensor
ones
,
at
::
Tensor
offset
,
at
::
Tensor
mask
,
at
::
Tensor
output
,
at
::
Tensor
columns
,
int
kernel_h
,
int
kernel_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
group
,
const
int
deformable_group
,
const
bool
with_bias
)
{
TORCH_CHECK
(
input
.
is_contiguous
(),
"input tensor has to be contiguous"
);
TORCH_CHECK
(
weight
.
is_contiguous
(),
"weight tensor has to be contiguous"
);
at
::
DeviceGuard
guard
(
input
.
device
());
const
int
batch
=
input
.
size
(
0
);
const
int
channels
=
input
.
size
(
1
);
const
int
height
=
input
.
size
(
2
);
const
int
width
=
input
.
size
(
3
);
const
int
channels_out
=
weight
.
size
(
0
);
const
int
channels_kernel
=
weight
.
size
(
1
);
const
int
kernel_h_
=
weight
.
size
(
2
);
const
int
kernel_w_
=
weight
.
size
(
3
);
if
(
kernel_h_
!=
kernel_h
||
kernel_w_
!=
kernel_w
)
AT_ERROR
(
"Input shape and kernel shape won't match: (%d x %d vs %d x %d)."
,
kernel_h_
,
kernel_w
,
kernel_h_
,
kernel_w_
);
if
(
channels
!=
channels_kernel
*
group
)
AT_ERROR
(
"Input shape and kernel channels won't match: (%d vs %d)."
,
channels
,
channels_kernel
*
group
);
const
int
height_out
=
(
height
+
2
*
pad_h
-
(
dilation_h
*
(
kernel_h
-
1
)
+
1
))
/
stride_h
+
1
;
const
int
width_out
=
(
width
+
2
*
pad_w
-
(
dilation_w
*
(
kernel_w
-
1
)
+
1
))
/
stride_w
+
1
;
if
(
ones
.
ndimension
()
!=
2
||
ones
.
size
(
0
)
*
ones
.
size
(
1
)
<
height_out
*
width_out
)
{
// Resize plane and fill with ones...
ones
=
at
::
ones
({
height_out
,
width_out
},
input
.
options
());
}
// resize output
output
=
output
.
view
({
batch
,
channels_out
,
height_out
,
width_out
}).
zero_
();
// resize temporary columns
columns
=
at
::
zeros
({
channels
*
kernel_h
*
kernel_w
,
1
*
height_out
*
width_out
},
input
.
options
());
output
=
output
.
view
({
output
.
size
(
0
),
group
,
output
.
size
(
1
)
/
group
,
output
.
size
(
2
),
output
.
size
(
3
)});
for
(
int
b
=
0
;
b
<
batch
;
b
++
)
{
modulated_deformable_im2col_cuda
(
input
[
b
],
offset
[
b
],
mask
[
b
],
1
,
channels
,
height
,
width
,
height_out
,
width_out
,
kernel_h
,
kernel_w
,
pad_h
,
pad_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
deformable_group
,
columns
);
// divide into group
weight
=
weight
.
view
({
group
,
weight
.
size
(
0
)
/
group
,
weight
.
size
(
1
),
weight
.
size
(
2
),
weight
.
size
(
3
)});
columns
=
columns
.
view
({
group
,
columns
.
size
(
0
)
/
group
,
columns
.
size
(
1
)});
for
(
int
g
=
0
;
g
<
group
;
g
++
)
{
output
[
b
][
g
]
=
output
[
b
][
g
]
.
flatten
(
1
)
.
addmm_
(
weight
[
g
].
flatten
(
1
),
columns
[
g
])
.
view_as
(
output
[
b
][
g
]);
}
weight
=
weight
.
view
({
weight
.
size
(
0
)
*
weight
.
size
(
1
),
weight
.
size
(
2
),
weight
.
size
(
3
),
weight
.
size
(
4
)});
columns
=
columns
.
view
({
columns
.
size
(
0
)
*
columns
.
size
(
1
),
columns
.
size
(
2
)});
}
output
=
output
.
view
({
output
.
size
(
0
),
output
.
size
(
1
)
*
output
.
size
(
2
),
output
.
size
(
3
),
output
.
size
(
4
)});
if
(
with_bias
)
{
output
+=
bias
.
view
({
1
,
bias
.
size
(
0
),
1
,
1
});
}
}
void
modulated_deform_conv_cuda_backward
(
at
::
Tensor
input
,
at
::
Tensor
weight
,
at
::
Tensor
bias
,
at
::
Tensor
ones
,
at
::
Tensor
offset
,
at
::
Tensor
mask
,
at
::
Tensor
columns
,
at
::
Tensor
grad_input
,
at
::
Tensor
grad_weight
,
at
::
Tensor
grad_bias
,
at
::
Tensor
grad_offset
,
at
::
Tensor
grad_mask
,
at
::
Tensor
grad_output
,
int
kernel_h
,
int
kernel_w
,
int
stride_h
,
int
stride_w
,
int
pad_h
,
int
pad_w
,
int
dilation_h
,
int
dilation_w
,
int
group
,
int
deformable_group
,
const
bool
with_bias
)
{
TORCH_CHECK
(
input
.
is_contiguous
(),
"input tensor has to be contiguous"
);
TORCH_CHECK
(
weight
.
is_contiguous
(),
"weight tensor has to be contiguous"
);
at
::
DeviceGuard
guard
(
input
.
device
());
const
int
batch
=
input
.
size
(
0
);
const
int
channels
=
input
.
size
(
1
);
const
int
height
=
input
.
size
(
2
);
const
int
width
=
input
.
size
(
3
);
const
int
channels_kernel
=
weight
.
size
(
1
);
const
int
kernel_h_
=
weight
.
size
(
2
);
const
int
kernel_w_
=
weight
.
size
(
3
);
if
(
kernel_h_
!=
kernel_h
||
kernel_w_
!=
kernel_w
)
AT_ERROR
(
"Input shape and kernel shape won't match: (%d x %d vs %d x %d)."
,
kernel_h_
,
kernel_w
,
kernel_h_
,
kernel_w_
);
if
(
channels
!=
channels_kernel
*
group
)
AT_ERROR
(
"Input shape and kernel channels won't match: (%d vs %d)."
,
channels
,
channels_kernel
*
group
);
const
int
height_out
=
(
height
+
2
*
pad_h
-
(
dilation_h
*
(
kernel_h
-
1
)
+
1
))
/
stride_h
+
1
;
const
int
width_out
=
(
width
+
2
*
pad_w
-
(
dilation_w
*
(
kernel_w
-
1
)
+
1
))
/
stride_w
+
1
;
if
(
ones
.
ndimension
()
!=
2
||
ones
.
size
(
0
)
*
ones
.
size
(
1
)
<
height_out
*
width_out
)
{
// Resize plane and fill with ones...
ones
=
at
::
ones
({
height_out
,
width_out
},
input
.
options
());
}
grad_input
=
grad_input
.
view
({
batch
,
channels
,
height
,
width
});
columns
=
at
::
zeros
({
channels
*
kernel_h
*
kernel_w
,
height_out
*
width_out
},
input
.
options
());
grad_output
=
grad_output
.
view
({
grad_output
.
size
(
0
),
group
,
grad_output
.
size
(
1
)
/
group
,
grad_output
.
size
(
2
),
grad_output
.
size
(
3
)});
for
(
int
b
=
0
;
b
<
batch
;
b
++
)
{
// divide int group
columns
=
columns
.
view
({
group
,
columns
.
size
(
0
)
/
group
,
columns
.
size
(
1
)});
weight
=
weight
.
view
({
group
,
weight
.
size
(
0
)
/
group
,
weight
.
size
(
1
),
weight
.
size
(
2
),
weight
.
size
(
3
)});
for
(
int
g
=
0
;
g
<
group
;
g
++
)
{
columns
[
g
].
addmm_
(
weight
[
g
].
flatten
(
1
).
transpose
(
0
,
1
),
grad_output
[
b
][
g
].
flatten
(
1
),
0.0
f
,
1.0
f
);
}
columns
=
columns
.
view
({
columns
.
size
(
0
)
*
columns
.
size
(
1
),
columns
.
size
(
2
)});
weight
=
weight
.
view
({
weight
.
size
(
0
)
*
weight
.
size
(
1
),
weight
.
size
(
2
),
weight
.
size
(
3
),
weight
.
size
(
4
)});
// gradient w.r.t. input coordinate data
modulated_deformable_col2im_coord_cuda
(
columns
,
input
[
b
],
offset
[
b
],
mask
[
b
],
1
,
channels
,
height
,
width
,
height_out
,
width_out
,
kernel_h
,
kernel_w
,
pad_h
,
pad_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
deformable_group
,
grad_offset
[
b
],
grad_mask
[
b
]);
// gradient w.r.t. input data
modulated_deformable_col2im_cuda
(
columns
,
offset
[
b
],
mask
[
b
],
1
,
channels
,
height
,
width
,
height_out
,
width_out
,
kernel_h
,
kernel_w
,
pad_h
,
pad_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
deformable_group
,
grad_input
[
b
]);
// gradient w.r.t. weight, dWeight should accumulate across the batch and
// group
modulated_deformable_im2col_cuda
(
input
[
b
],
offset
[
b
],
mask
[
b
],
1
,
channels
,
height
,
width
,
height_out
,
width_out
,
kernel_h
,
kernel_w
,
pad_h
,
pad_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
deformable_group
,
columns
);
columns
=
columns
.
view
({
group
,
columns
.
size
(
0
)
/
group
,
columns
.
size
(
1
)});
grad_weight
=
grad_weight
.
view
({
group
,
grad_weight
.
size
(
0
)
/
group
,
grad_weight
.
size
(
1
),
grad_weight
.
size
(
2
),
grad_weight
.
size
(
3
)});
if
(
with_bias
)
grad_bias
=
grad_bias
.
view
({
group
,
grad_bias
.
size
(
0
)
/
group
});
for
(
int
g
=
0
;
g
<
group
;
g
++
)
{
grad_weight
[
g
]
=
grad_weight
[
g
]
.
flatten
(
1
)
.
addmm_
(
grad_output
[
b
][
g
].
flatten
(
1
),
columns
[
g
].
transpose
(
0
,
1
))
.
view_as
(
grad_weight
[
g
]);
if
(
with_bias
)
{
grad_bias
[
g
]
=
grad_bias
[
g
]
.
view
({
-
1
,
1
})
.
addmm_
(
grad_output
[
b
][
g
].
flatten
(
1
),
ones
.
view
({
-
1
,
1
}))
.
view
(
-
1
);
}
}
columns
=
columns
.
view
({
columns
.
size
(
0
)
*
columns
.
size
(
1
),
columns
.
size
(
2
)});
grad_weight
=
grad_weight
.
view
({
grad_weight
.
size
(
0
)
*
grad_weight
.
size
(
1
),
grad_weight
.
size
(
2
),
grad_weight
.
size
(
3
),
grad_weight
.
size
(
4
)});
if
(
with_bias
)
grad_bias
=
grad_bias
.
view
({
grad_bias
.
size
(
0
)
*
grad_bias
.
size
(
1
)});
}
grad_output
=
grad_output
.
view
({
grad_output
.
size
(
0
)
*
grad_output
.
size
(
1
),
grad_output
.
size
(
2
),
grad_output
.
size
(
3
),
grad_output
.
size
(
4
)});
}
sr_model/Basicsr/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu
0 → 100644
View file @
5efcc6ff
/*!
******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
*
* COPYRIGHT
*
* All contributions by the University of California:
* Copyright (c) 2014-2017 The Regents of the University of California (Regents)
* All rights reserved.
*
* All other contributions:
* Copyright (c) 2014-2017, the respective contributors
* All rights reserved.
*
* Caffe uses a shared copyright model: each contributor holds copyright over
* their contributions to Caffe. The project versioning records all such
* contribution and copyright details. If a contributor wants to further mark
* their specific copyright on a particular contribution, they should indicate
* their copyright solely in the commit message of the change when it is
* committed.
*
* LICENSE
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* CONTRIBUTION AGREEMENT
*
* By contributing to the BVLC/caffe repository through pull-request, comment,
* or otherwise, the contributor releases their content to the
* license and copyright terms herein.
*
***************** END Caffe Copyright Notice and Disclaimer ********************
*
* Copyright (c) 2018 Microsoft
* Licensed under The MIT License [see LICENSE for details]
* \file modulated_deformable_im2col.cuh
* \brief Function definitions of converting an image to
* column matrix based on kernel, padding, dilation, and offset.
* These functions are mainly used in deformable convolution operators.
* \ref: https://arxiv.org/abs/1703.06211
* \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
*/
// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCAtomics.cuh>
#include <stdio.h>
#include <math.h>
#include <float.h>
using
namespace
at
;
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
const
int
CUDA_NUM_THREADS
=
1024
;
const
int
kMaxGridNum
=
65535
;
inline
int
GET_BLOCKS
(
const
int
N
)
{
return
std
::
min
(
kMaxGridNum
,
(
N
+
CUDA_NUM_THREADS
-
1
)
/
CUDA_NUM_THREADS
);
}
template
<
typename
scalar_t
>
__device__
scalar_t
deformable_im2col_bilinear
(
const
scalar_t
*
bottom_data
,
const
int
data_width
,
const
int
height
,
const
int
width
,
scalar_t
h
,
scalar_t
w
)
{
int
h_low
=
floor
(
h
);
int
w_low
=
floor
(
w
);
int
h_high
=
h_low
+
1
;
int
w_high
=
w_low
+
1
;
scalar_t
lh
=
h
-
h_low
;
scalar_t
lw
=
w
-
w_low
;
scalar_t
hh
=
1
-
lh
,
hw
=
1
-
lw
;
scalar_t
v1
=
0
;
if
(
h_low
>=
0
&&
w_low
>=
0
)
v1
=
bottom_data
[
h_low
*
data_width
+
w_low
];
scalar_t
v2
=
0
;
if
(
h_low
>=
0
&&
w_high
<=
width
-
1
)
v2
=
bottom_data
[
h_low
*
data_width
+
w_high
];
scalar_t
v3
=
0
;
if
(
h_high
<=
height
-
1
&&
w_low
>=
0
)
v3
=
bottom_data
[
h_high
*
data_width
+
w_low
];
scalar_t
v4
=
0
;
if
(
h_high
<=
height
-
1
&&
w_high
<=
width
-
1
)
v4
=
bottom_data
[
h_high
*
data_width
+
w_high
];
scalar_t
w1
=
hh
*
hw
,
w2
=
hh
*
lw
,
w3
=
lh
*
hw
,
w4
=
lh
*
lw
;
scalar_t
val
=
(
w1
*
v1
+
w2
*
v2
+
w3
*
v3
+
w4
*
v4
);
return
val
;
}
template
<
typename
scalar_t
>
__device__
scalar_t
get_gradient_weight
(
scalar_t
argmax_h
,
scalar_t
argmax_w
,
const
int
h
,
const
int
w
,
const
int
height
,
const
int
width
)
{
if
(
argmax_h
<=
-
1
||
argmax_h
>=
height
||
argmax_w
<=
-
1
||
argmax_w
>=
width
)
{
//empty
return
0
;
}
int
argmax_h_low
=
floor
(
argmax_h
);
int
argmax_w_low
=
floor
(
argmax_w
);
int
argmax_h_high
=
argmax_h_low
+
1
;
int
argmax_w_high
=
argmax_w_low
+
1
;
scalar_t
weight
=
0
;
if
(
h
==
argmax_h_low
&&
w
==
argmax_w_low
)
weight
=
(
h
+
1
-
argmax_h
)
*
(
w
+
1
-
argmax_w
);
if
(
h
==
argmax_h_low
&&
w
==
argmax_w_high
)
weight
=
(
h
+
1
-
argmax_h
)
*
(
argmax_w
+
1
-
w
);
if
(
h
==
argmax_h_high
&&
w
==
argmax_w_low
)
weight
=
(
argmax_h
+
1
-
h
)
*
(
w
+
1
-
argmax_w
);
if
(
h
==
argmax_h_high
&&
w
==
argmax_w_high
)
weight
=
(
argmax_h
+
1
-
h
)
*
(
argmax_w
+
1
-
w
);
return
weight
;
}
template
<
typename
scalar_t
>
__device__
scalar_t
get_coordinate_weight
(
scalar_t
argmax_h
,
scalar_t
argmax_w
,
const
int
height
,
const
int
width
,
const
scalar_t
*
im_data
,
const
int
data_width
,
const
int
bp_dir
)
{
if
(
argmax_h
<=
-
1
||
argmax_h
>=
height
||
argmax_w
<=
-
1
||
argmax_w
>=
width
)
{
//empty
return
0
;
}
int
argmax_h_low
=
floor
(
argmax_h
);
int
argmax_w_low
=
floor
(
argmax_w
);
int
argmax_h_high
=
argmax_h_low
+
1
;
int
argmax_w_high
=
argmax_w_low
+
1
;
scalar_t
weight
=
0
;
if
(
bp_dir
==
0
)
{
if
(
argmax_h_low
>=
0
&&
argmax_w_low
>=
0
)
weight
+=
-
1
*
(
argmax_w_low
+
1
-
argmax_w
)
*
im_data
[
argmax_h_low
*
data_width
+
argmax_w_low
];
if
(
argmax_h_low
>=
0
&&
argmax_w_high
<=
width
-
1
)
weight
+=
-
1
*
(
argmax_w
-
argmax_w_low
)
*
im_data
[
argmax_h_low
*
data_width
+
argmax_w_high
];
if
(
argmax_h_high
<=
height
-
1
&&
argmax_w_low
>=
0
)
weight
+=
(
argmax_w_low
+
1
-
argmax_w
)
*
im_data
[
argmax_h_high
*
data_width
+
argmax_w_low
];
if
(
argmax_h_high
<=
height
-
1
&&
argmax_w_high
<=
width
-
1
)
weight
+=
(
argmax_w
-
argmax_w_low
)
*
im_data
[
argmax_h_high
*
data_width
+
argmax_w_high
];
}
else
if
(
bp_dir
==
1
)
{
if
(
argmax_h_low
>=
0
&&
argmax_w_low
>=
0
)
weight
+=
-
1
*
(
argmax_h_low
+
1
-
argmax_h
)
*
im_data
[
argmax_h_low
*
data_width
+
argmax_w_low
];
if
(
argmax_h_low
>=
0
&&
argmax_w_high
<=
width
-
1
)
weight
+=
(
argmax_h_low
+
1
-
argmax_h
)
*
im_data
[
argmax_h_low
*
data_width
+
argmax_w_high
];
if
(
argmax_h_high
<=
height
-
1
&&
argmax_w_low
>=
0
)
weight
+=
-
1
*
(
argmax_h
-
argmax_h_low
)
*
im_data
[
argmax_h_high
*
data_width
+
argmax_w_low
];
if
(
argmax_h_high
<=
height
-
1
&&
argmax_w_high
<=
width
-
1
)
weight
+=
(
argmax_h
-
argmax_h_low
)
*
im_data
[
argmax_h_high
*
data_width
+
argmax_w_high
];
}
return
weight
;
}
template
<
typename
scalar_t
>
__global__
void
deformable_im2col_gpu_kernel
(
const
int
n
,
const
scalar_t
*
data_im
,
const
scalar_t
*
data_offset
,
const
int
height
,
const
int
width
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
channel_per_deformable_group
,
const
int
batch_size
,
const
int
num_channels
,
const
int
deformable_group
,
const
int
height_col
,
const
int
width_col
,
scalar_t
*
data_col
)
{
CUDA_KERNEL_LOOP
(
index
,
n
)
{
// index index of output matrix
const
int
w_col
=
index
%
width_col
;
const
int
h_col
=
(
index
/
width_col
)
%
height_col
;
const
int
b_col
=
(
index
/
width_col
/
height_col
)
%
batch_size
;
const
int
c_im
=
(
index
/
width_col
/
height_col
)
/
batch_size
;
const
int
c_col
=
c_im
*
kernel_h
*
kernel_w
;
// compute deformable group index
const
int
deformable_group_index
=
c_im
/
channel_per_deformable_group
;
const
int
h_in
=
h_col
*
stride_h
-
pad_h
;
const
int
w_in
=
w_col
*
stride_w
-
pad_w
;
scalar_t
*
data_col_ptr
=
data_col
+
((
c_col
*
batch_size
+
b_col
)
*
height_col
+
h_col
)
*
width_col
+
w_col
;
//const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
const
scalar_t
*
data_im_ptr
=
data_im
+
(
b_col
*
num_channels
+
c_im
)
*
height
*
width
;
const
scalar_t
*
data_offset_ptr
=
data_offset
+
(
b_col
*
deformable_group
+
deformable_group_index
)
*
2
*
kernel_h
*
kernel_w
*
height_col
*
width_col
;
for
(
int
i
=
0
;
i
<
kernel_h
;
++
i
)
{
for
(
int
j
=
0
;
j
<
kernel_w
;
++
j
)
{
const
int
data_offset_h_ptr
=
((
2
*
(
i
*
kernel_w
+
j
))
*
height_col
+
h_col
)
*
width_col
+
w_col
;
const
int
data_offset_w_ptr
=
((
2
*
(
i
*
kernel_w
+
j
)
+
1
)
*
height_col
+
h_col
)
*
width_col
+
w_col
;
const
scalar_t
offset_h
=
data_offset_ptr
[
data_offset_h_ptr
];
const
scalar_t
offset_w
=
data_offset_ptr
[
data_offset_w_ptr
];
scalar_t
val
=
static_cast
<
scalar_t
>
(
0
);
const
scalar_t
h_im
=
h_in
+
i
*
dilation_h
+
offset_h
;
const
scalar_t
w_im
=
w_in
+
j
*
dilation_w
+
offset_w
;
if
(
h_im
>
-
1
&&
w_im
>
-
1
&&
h_im
<
height
&&
w_im
<
width
)
{
//const scalar_t map_h = i * dilation_h + offset_h;
//const scalar_t map_w = j * dilation_w + offset_w;
//const int cur_height = height - h_in;
//const int cur_width = width - w_in;
//val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
val
=
deformable_im2col_bilinear
(
data_im_ptr
,
width
,
height
,
width
,
h_im
,
w_im
);
}
*
data_col_ptr
=
val
;
data_col_ptr
+=
batch_size
*
height_col
*
width_col
;
}
}
}
}
void
deformable_im2col
(
const
at
::
Tensor
data_im
,
const
at
::
Tensor
data_offset
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
ksize_h
,
const
int
ksize_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
parallel_imgs
,
const
int
deformable_group
,
at
::
Tensor
data_col
)
{
// num_axes should be smaller than block size
// todo: check parallel_imgs is correctly passed in
int
height_col
=
(
height
+
2
*
pad_h
-
(
dilation_h
*
(
ksize_h
-
1
)
+
1
))
/
stride_h
+
1
;
int
width_col
=
(
width
+
2
*
pad_w
-
(
dilation_w
*
(
ksize_w
-
1
)
+
1
))
/
stride_w
+
1
;
int
num_kernels
=
channels
*
height_col
*
width_col
*
parallel_imgs
;
int
channel_per_deformable_group
=
channels
/
deformable_group
;
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
data_im
.
scalar_type
(),
"deformable_im2col_gpu"
,
([
&
]
{
const
scalar_t
*
data_im_
=
data_im
.
data_ptr
<
scalar_t
>
();
const
scalar_t
*
data_offset_
=
data_offset
.
data_ptr
<
scalar_t
>
();
scalar_t
*
data_col_
=
data_col
.
data_ptr
<
scalar_t
>
();
deformable_im2col_gpu_kernel
<<<
GET_BLOCKS
(
num_kernels
),
CUDA_NUM_THREADS
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
num_kernels
,
data_im_
,
data_offset_
,
height
,
width
,
ksize_h
,
ksize_w
,
pad_h
,
pad_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channel_per_deformable_group
,
parallel_imgs
,
channels
,
deformable_group
,
height_col
,
width_col
,
data_col_
);
}));
cudaError_t
err
=
cudaGetLastError
();
if
(
err
!=
cudaSuccess
)
{
printf
(
"error in deformable_im2col: %s
\n
"
,
cudaGetErrorString
(
err
));
}
}
template
<
typename
scalar_t
>
__global__
void
deformable_col2im_gpu_kernel
(
const
int
n
,
const
scalar_t
*
data_col
,
const
scalar_t
*
data_offset
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
channel_per_deformable_group
,
const
int
batch_size
,
const
int
deformable_group
,
const
int
height_col
,
const
int
width_col
,
scalar_t
*
grad_im
)
{
CUDA_KERNEL_LOOP
(
index
,
n
)
{
const
int
j
=
(
index
/
width_col
/
height_col
/
batch_size
)
%
kernel_w
;
const
int
i
=
(
index
/
width_col
/
height_col
/
batch_size
/
kernel_w
)
%
kernel_h
;
const
int
c
=
index
/
width_col
/
height_col
/
batch_size
/
kernel_w
/
kernel_h
;
// compute the start and end of the output
const
int
deformable_group_index
=
c
/
channel_per_deformable_group
;
int
w_out
=
index
%
width_col
;
int
h_out
=
(
index
/
width_col
)
%
height_col
;
int
b
=
(
index
/
width_col
/
height_col
)
%
batch_size
;
int
w_in
=
w_out
*
stride_w
-
pad_w
;
int
h_in
=
h_out
*
stride_h
-
pad_h
;
const
scalar_t
*
data_offset_ptr
=
data_offset
+
(
b
*
deformable_group
+
deformable_group_index
)
*
2
*
kernel_h
*
kernel_w
*
height_col
*
width_col
;
const
int
data_offset_h_ptr
=
((
2
*
(
i
*
kernel_w
+
j
))
*
height_col
+
h_out
)
*
width_col
+
w_out
;
const
int
data_offset_w_ptr
=
((
2
*
(
i
*
kernel_w
+
j
)
+
1
)
*
height_col
+
h_out
)
*
width_col
+
w_out
;
const
scalar_t
offset_h
=
data_offset_ptr
[
data_offset_h_ptr
];
const
scalar_t
offset_w
=
data_offset_ptr
[
data_offset_w_ptr
];
const
scalar_t
cur_inv_h_data
=
h_in
+
i
*
dilation_h
+
offset_h
;
const
scalar_t
cur_inv_w_data
=
w_in
+
j
*
dilation_w
+
offset_w
;
const
scalar_t
cur_top_grad
=
data_col
[
index
];
const
int
cur_h
=
(
int
)
cur_inv_h_data
;
const
int
cur_w
=
(
int
)
cur_inv_w_data
;
for
(
int
dy
=
-
2
;
dy
<=
2
;
dy
++
)
{
for
(
int
dx
=
-
2
;
dx
<=
2
;
dx
++
)
{
if
(
cur_h
+
dy
>=
0
&&
cur_h
+
dy
<
height
&&
cur_w
+
dx
>=
0
&&
cur_w
+
dx
<
width
&&
abs
(
cur_inv_h_data
-
(
cur_h
+
dy
))
<
1
&&
abs
(
cur_inv_w_data
-
(
cur_w
+
dx
))
<
1
)
{
int
cur_bottom_grad_pos
=
((
b
*
channels
+
c
)
*
height
+
cur_h
+
dy
)
*
width
+
cur_w
+
dx
;
scalar_t
weight
=
get_gradient_weight
(
cur_inv_h_data
,
cur_inv_w_data
,
cur_h
+
dy
,
cur_w
+
dx
,
height
,
width
);
atomicAdd
(
grad_im
+
cur_bottom_grad_pos
,
weight
*
cur_top_grad
);
}
}
}
}
}
void
deformable_col2im
(
const
at
::
Tensor
data_col
,
const
at
::
Tensor
data_offset
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
ksize_h
,
const
int
ksize_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
parallel_imgs
,
const
int
deformable_group
,
at
::
Tensor
grad_im
)
{
// todo: make sure parallel_imgs is passed in correctly
int
height_col
=
(
height
+
2
*
pad_h
-
(
dilation_h
*
(
ksize_h
-
1
)
+
1
))
/
stride_h
+
1
;
int
width_col
=
(
width
+
2
*
pad_w
-
(
dilation_w
*
(
ksize_w
-
1
)
+
1
))
/
stride_w
+
1
;
int
num_kernels
=
channels
*
ksize_h
*
ksize_w
*
height_col
*
width_col
*
parallel_imgs
;
int
channel_per_deformable_group
=
channels
/
deformable_group
;
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
data_col
.
scalar_type
(),
"deformable_col2im_gpu"
,
([
&
]
{
const
scalar_t
*
data_col_
=
data_col
.
data_ptr
<
scalar_t
>
();
const
scalar_t
*
data_offset_
=
data_offset
.
data_ptr
<
scalar_t
>
();
scalar_t
*
grad_im_
=
grad_im
.
data_ptr
<
scalar_t
>
();
deformable_col2im_gpu_kernel
<<<
GET_BLOCKS
(
num_kernels
),
CUDA_NUM_THREADS
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
num_kernels
,
data_col_
,
data_offset_
,
channels
,
height
,
width
,
ksize_h
,
ksize_w
,
pad_h
,
pad_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channel_per_deformable_group
,
parallel_imgs
,
deformable_group
,
height_col
,
width_col
,
grad_im_
);
}));
cudaError_t
err
=
cudaGetLastError
();
if
(
err
!=
cudaSuccess
)
{
printf
(
"error in deformable_col2im: %s
\n
"
,
cudaGetErrorString
(
err
));
}
}
template
<
typename
scalar_t
>
__global__
void
deformable_col2im_coord_gpu_kernel
(
const
int
n
,
const
scalar_t
*
data_col
,
const
scalar_t
*
data_im
,
const
scalar_t
*
data_offset
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
channel_per_deformable_group
,
const
int
batch_size
,
const
int
offset_channels
,
const
int
deformable_group
,
const
int
height_col
,
const
int
width_col
,
scalar_t
*
grad_offset
)
{
CUDA_KERNEL_LOOP
(
index
,
n
)
{
scalar_t
val
=
0
;
int
w
=
index
%
width_col
;
int
h
=
(
index
/
width_col
)
%
height_col
;
int
c
=
(
index
/
width_col
/
height_col
)
%
offset_channels
;
int
b
=
(
index
/
width_col
/
height_col
)
/
offset_channels
;
// compute the start and end of the output
const
int
deformable_group_index
=
c
/
(
2
*
kernel_h
*
kernel_w
);
const
int
col_step
=
kernel_h
*
kernel_w
;
int
cnt
=
0
;
const
scalar_t
*
data_col_ptr
=
data_col
+
deformable_group_index
*
channel_per_deformable_group
*
batch_size
*
width_col
*
height_col
;
const
scalar_t
*
data_im_ptr
=
data_im
+
(
b
*
deformable_group
+
deformable_group_index
)
*
channel_per_deformable_group
/
kernel_h
/
kernel_w
*
height
*
width
;
const
scalar_t
*
data_offset_ptr
=
data_offset
+
(
b
*
deformable_group
+
deformable_group_index
)
*
2
*
kernel_h
*
kernel_w
*
height_col
*
width_col
;
const
int
offset_c
=
c
-
deformable_group_index
*
2
*
kernel_h
*
kernel_w
;
for
(
int
col_c
=
(
offset_c
/
2
);
col_c
<
channel_per_deformable_group
;
col_c
+=
col_step
)
{
const
int
col_pos
=
(((
col_c
*
batch_size
+
b
)
*
height_col
)
+
h
)
*
width_col
+
w
;
const
int
bp_dir
=
offset_c
%
2
;
int
j
=
(
col_pos
/
width_col
/
height_col
/
batch_size
)
%
kernel_w
;
int
i
=
(
col_pos
/
width_col
/
height_col
/
batch_size
/
kernel_w
)
%
kernel_h
;
int
w_out
=
col_pos
%
width_col
;
int
h_out
=
(
col_pos
/
width_col
)
%
height_col
;
int
w_in
=
w_out
*
stride_w
-
pad_w
;
int
h_in
=
h_out
*
stride_h
-
pad_h
;
const
int
data_offset_h_ptr
=
(((
2
*
(
i
*
kernel_w
+
j
))
*
height_col
+
h_out
)
*
width_col
+
w_out
);
const
int
data_offset_w_ptr
=
(((
2
*
(
i
*
kernel_w
+
j
)
+
1
)
*
height_col
+
h_out
)
*
width_col
+
w_out
);
const
scalar_t
offset_h
=
data_offset_ptr
[
data_offset_h_ptr
];
const
scalar_t
offset_w
=
data_offset_ptr
[
data_offset_w_ptr
];
scalar_t
inv_h
=
h_in
+
i
*
dilation_h
+
offset_h
;
scalar_t
inv_w
=
w_in
+
j
*
dilation_w
+
offset_w
;
if
(
inv_h
<=
-
1
||
inv_w
<=
-
1
||
inv_h
>=
height
||
inv_w
>=
width
)
{
inv_h
=
inv_w
=
-
2
;
}
const
scalar_t
weight
=
get_coordinate_weight
(
inv_h
,
inv_w
,
height
,
width
,
data_im_ptr
+
cnt
*
height
*
width
,
width
,
bp_dir
);
val
+=
weight
*
data_col_ptr
[
col_pos
];
cnt
+=
1
;
}
grad_offset
[
index
]
=
val
;
}
}
void
deformable_col2im_coord
(
const
at
::
Tensor
data_col
,
const
at
::
Tensor
data_im
,
const
at
::
Tensor
data_offset
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
ksize_h
,
const
int
ksize_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
parallel_imgs
,
const
int
deformable_group
,
at
::
Tensor
grad_offset
)
{
int
height_col
=
(
height
+
2
*
pad_h
-
(
dilation_h
*
(
ksize_h
-
1
)
+
1
))
/
stride_h
+
1
;
int
width_col
=
(
width
+
2
*
pad_w
-
(
dilation_w
*
(
ksize_w
-
1
)
+
1
))
/
stride_w
+
1
;
int
num_kernels
=
height_col
*
width_col
*
2
*
ksize_h
*
ksize_w
*
deformable_group
*
parallel_imgs
;
int
channel_per_deformable_group
=
channels
*
ksize_h
*
ksize_w
/
deformable_group
;
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
data_col
.
scalar_type
(),
"deformable_col2im_coord_gpu"
,
([
&
]
{
const
scalar_t
*
data_col_
=
data_col
.
data_ptr
<
scalar_t
>
();
const
scalar_t
*
data_im_
=
data_im
.
data_ptr
<
scalar_t
>
();
const
scalar_t
*
data_offset_
=
data_offset
.
data_ptr
<
scalar_t
>
();
scalar_t
*
grad_offset_
=
grad_offset
.
data_ptr
<
scalar_t
>
();
deformable_col2im_coord_gpu_kernel
<<<
GET_BLOCKS
(
num_kernels
),
CUDA_NUM_THREADS
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
num_kernels
,
data_col_
,
data_im_
,
data_offset_
,
channels
,
height
,
width
,
ksize_h
,
ksize_w
,
pad_h
,
pad_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channel_per_deformable_group
,
parallel_imgs
,
2
*
ksize_h
*
ksize_w
*
deformable_group
,
deformable_group
,
height_col
,
width_col
,
grad_offset_
);
}));
}
template
<
typename
scalar_t
>
__device__
scalar_t
dmcn_im2col_bilinear
(
const
scalar_t
*
bottom_data
,
const
int
data_width
,
const
int
height
,
const
int
width
,
scalar_t
h
,
scalar_t
w
)
{
int
h_low
=
floor
(
h
);
int
w_low
=
floor
(
w
);
int
h_high
=
h_low
+
1
;
int
w_high
=
w_low
+
1
;
scalar_t
lh
=
h
-
h_low
;
scalar_t
lw
=
w
-
w_low
;
scalar_t
hh
=
1
-
lh
,
hw
=
1
-
lw
;
scalar_t
v1
=
0
;
if
(
h_low
>=
0
&&
w_low
>=
0
)
v1
=
bottom_data
[
h_low
*
data_width
+
w_low
];
scalar_t
v2
=
0
;
if
(
h_low
>=
0
&&
w_high
<=
width
-
1
)
v2
=
bottom_data
[
h_low
*
data_width
+
w_high
];
scalar_t
v3
=
0
;
if
(
h_high
<=
height
-
1
&&
w_low
>=
0
)
v3
=
bottom_data
[
h_high
*
data_width
+
w_low
];
scalar_t
v4
=
0
;
if
(
h_high
<=
height
-
1
&&
w_high
<=
width
-
1
)
v4
=
bottom_data
[
h_high
*
data_width
+
w_high
];
scalar_t
w1
=
hh
*
hw
,
w2
=
hh
*
lw
,
w3
=
lh
*
hw
,
w4
=
lh
*
lw
;
scalar_t
val
=
(
w1
*
v1
+
w2
*
v2
+
w3
*
v3
+
w4
*
v4
);
return
val
;
}
template
<
typename
scalar_t
>
__device__
scalar_t
dmcn_get_gradient_weight
(
scalar_t
argmax_h
,
scalar_t
argmax_w
,
const
int
h
,
const
int
w
,
const
int
height
,
const
int
width
)
{
if
(
argmax_h
<=
-
1
||
argmax_h
>=
height
||
argmax_w
<=
-
1
||
argmax_w
>=
width
)
{
//empty
return
0
;
}
int
argmax_h_low
=
floor
(
argmax_h
);
int
argmax_w_low
=
floor
(
argmax_w
);
int
argmax_h_high
=
argmax_h_low
+
1
;
int
argmax_w_high
=
argmax_w_low
+
1
;
scalar_t
weight
=
0
;
if
(
h
==
argmax_h_low
&&
w
==
argmax_w_low
)
weight
=
(
h
+
1
-
argmax_h
)
*
(
w
+
1
-
argmax_w
);
if
(
h
==
argmax_h_low
&&
w
==
argmax_w_high
)
weight
=
(
h
+
1
-
argmax_h
)
*
(
argmax_w
+
1
-
w
);
if
(
h
==
argmax_h_high
&&
w
==
argmax_w_low
)
weight
=
(
argmax_h
+
1
-
h
)
*
(
w
+
1
-
argmax_w
);
if
(
h
==
argmax_h_high
&&
w
==
argmax_w_high
)
weight
=
(
argmax_h
+
1
-
h
)
*
(
argmax_w
+
1
-
w
);
return
weight
;
}
template
<
typename
scalar_t
>
__device__
scalar_t
dmcn_get_coordinate_weight
(
scalar_t
argmax_h
,
scalar_t
argmax_w
,
const
int
height
,
const
int
width
,
const
scalar_t
*
im_data
,
const
int
data_width
,
const
int
bp_dir
)
{
if
(
argmax_h
<=
-
1
||
argmax_h
>=
height
||
argmax_w
<=
-
1
||
argmax_w
>=
width
)
{
//empty
return
0
;
}
int
argmax_h_low
=
floor
(
argmax_h
);
int
argmax_w_low
=
floor
(
argmax_w
);
int
argmax_h_high
=
argmax_h_low
+
1
;
int
argmax_w_high
=
argmax_w_low
+
1
;
scalar_t
weight
=
0
;
if
(
bp_dir
==
0
)
{
if
(
argmax_h_low
>=
0
&&
argmax_w_low
>=
0
)
weight
+=
-
1
*
(
argmax_w_low
+
1
-
argmax_w
)
*
im_data
[
argmax_h_low
*
data_width
+
argmax_w_low
];
if
(
argmax_h_low
>=
0
&&
argmax_w_high
<=
width
-
1
)
weight
+=
-
1
*
(
argmax_w
-
argmax_w_low
)
*
im_data
[
argmax_h_low
*
data_width
+
argmax_w_high
];
if
(
argmax_h_high
<=
height
-
1
&&
argmax_w_low
>=
0
)
weight
+=
(
argmax_w_low
+
1
-
argmax_w
)
*
im_data
[
argmax_h_high
*
data_width
+
argmax_w_low
];
if
(
argmax_h_high
<=
height
-
1
&&
argmax_w_high
<=
width
-
1
)
weight
+=
(
argmax_w
-
argmax_w_low
)
*
im_data
[
argmax_h_high
*
data_width
+
argmax_w_high
];
}
else
if
(
bp_dir
==
1
)
{
if
(
argmax_h_low
>=
0
&&
argmax_w_low
>=
0
)
weight
+=
-
1
*
(
argmax_h_low
+
1
-
argmax_h
)
*
im_data
[
argmax_h_low
*
data_width
+
argmax_w_low
];
if
(
argmax_h_low
>=
0
&&
argmax_w_high
<=
width
-
1
)
weight
+=
(
argmax_h_low
+
1
-
argmax_h
)
*
im_data
[
argmax_h_low
*
data_width
+
argmax_w_high
];
if
(
argmax_h_high
<=
height
-
1
&&
argmax_w_low
>=
0
)
weight
+=
-
1
*
(
argmax_h
-
argmax_h_low
)
*
im_data
[
argmax_h_high
*
data_width
+
argmax_w_low
];
if
(
argmax_h_high
<=
height
-
1
&&
argmax_w_high
<=
width
-
1
)
weight
+=
(
argmax_h
-
argmax_h_low
)
*
im_data
[
argmax_h_high
*
data_width
+
argmax_w_high
];
}
return
weight
;
}
template
<
typename
scalar_t
>
__global__
void
modulated_deformable_im2col_gpu_kernel
(
const
int
n
,
const
scalar_t
*
data_im
,
const
scalar_t
*
data_offset
,
const
scalar_t
*
data_mask
,
const
int
height
,
const
int
width
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
channel_per_deformable_group
,
const
int
batch_size
,
const
int
num_channels
,
const
int
deformable_group
,
const
int
height_col
,
const
int
width_col
,
scalar_t
*
data_col
)
{
CUDA_KERNEL_LOOP
(
index
,
n
)
{
// index index of output matrix
const
int
w_col
=
index
%
width_col
;
const
int
h_col
=
(
index
/
width_col
)
%
height_col
;
const
int
b_col
=
(
index
/
width_col
/
height_col
)
%
batch_size
;
const
int
c_im
=
(
index
/
width_col
/
height_col
)
/
batch_size
;
const
int
c_col
=
c_im
*
kernel_h
*
kernel_w
;
// compute deformable group index
const
int
deformable_group_index
=
c_im
/
channel_per_deformable_group
;
const
int
h_in
=
h_col
*
stride_h
-
pad_h
;
const
int
w_in
=
w_col
*
stride_w
-
pad_w
;
scalar_t
*
data_col_ptr
=
data_col
+
((
c_col
*
batch_size
+
b_col
)
*
height_col
+
h_col
)
*
width_col
+
w_col
;
//const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
const
scalar_t
*
data_im_ptr
=
data_im
+
(
b_col
*
num_channels
+
c_im
)
*
height
*
width
;
const
scalar_t
*
data_offset_ptr
=
data_offset
+
(
b_col
*
deformable_group
+
deformable_group_index
)
*
2
*
kernel_h
*
kernel_w
*
height_col
*
width_col
;
const
scalar_t
*
data_mask_ptr
=
data_mask
+
(
b_col
*
deformable_group
+
deformable_group_index
)
*
kernel_h
*
kernel_w
*
height_col
*
width_col
;
for
(
int
i
=
0
;
i
<
kernel_h
;
++
i
)
{
for
(
int
j
=
0
;
j
<
kernel_w
;
++
j
)
{
const
int
data_offset_h_ptr
=
((
2
*
(
i
*
kernel_w
+
j
))
*
height_col
+
h_col
)
*
width_col
+
w_col
;
const
int
data_offset_w_ptr
=
((
2
*
(
i
*
kernel_w
+
j
)
+
1
)
*
height_col
+
h_col
)
*
width_col
+
w_col
;
const
int
data_mask_hw_ptr
=
((
i
*
kernel_w
+
j
)
*
height_col
+
h_col
)
*
width_col
+
w_col
;
const
scalar_t
offset_h
=
data_offset_ptr
[
data_offset_h_ptr
];
const
scalar_t
offset_w
=
data_offset_ptr
[
data_offset_w_ptr
];
const
scalar_t
mask
=
data_mask_ptr
[
data_mask_hw_ptr
];
scalar_t
val
=
static_cast
<
scalar_t
>
(
0
);
const
scalar_t
h_im
=
h_in
+
i
*
dilation_h
+
offset_h
;
const
scalar_t
w_im
=
w_in
+
j
*
dilation_w
+
offset_w
;
//if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
if
(
h_im
>
-
1
&&
w_im
>
-
1
&&
h_im
<
height
&&
w_im
<
width
)
{
//const float map_h = i * dilation_h + offset_h;
//const float map_w = j * dilation_w + offset_w;
//const int cur_height = height - h_in;
//const int cur_width = width - w_in;
//val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
val
=
dmcn_im2col_bilinear
(
data_im_ptr
,
width
,
height
,
width
,
h_im
,
w_im
);
}
*
data_col_ptr
=
val
*
mask
;
data_col_ptr
+=
batch_size
*
height_col
*
width_col
;
//data_col_ptr += height_col * width_col;
}
}
}
}
template
<
typename
scalar_t
>
__global__
void
modulated_deformable_col2im_gpu_kernel
(
const
int
n
,
const
scalar_t
*
data_col
,
const
scalar_t
*
data_offset
,
const
scalar_t
*
data_mask
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
channel_per_deformable_group
,
const
int
batch_size
,
const
int
deformable_group
,
const
int
height_col
,
const
int
width_col
,
scalar_t
*
grad_im
)
{
CUDA_KERNEL_LOOP
(
index
,
n
)
{
const
int
j
=
(
index
/
width_col
/
height_col
/
batch_size
)
%
kernel_w
;
const
int
i
=
(
index
/
width_col
/
height_col
/
batch_size
/
kernel_w
)
%
kernel_h
;
const
int
c
=
index
/
width_col
/
height_col
/
batch_size
/
kernel_w
/
kernel_h
;
// compute the start and end of the output
const
int
deformable_group_index
=
c
/
channel_per_deformable_group
;
int
w_out
=
index
%
width_col
;
int
h_out
=
(
index
/
width_col
)
%
height_col
;
int
b
=
(
index
/
width_col
/
height_col
)
%
batch_size
;
int
w_in
=
w_out
*
stride_w
-
pad_w
;
int
h_in
=
h_out
*
stride_h
-
pad_h
;
const
scalar_t
*
data_offset_ptr
=
data_offset
+
(
b
*
deformable_group
+
deformable_group_index
)
*
2
*
kernel_h
*
kernel_w
*
height_col
*
width_col
;
const
scalar_t
*
data_mask_ptr
=
data_mask
+
(
b
*
deformable_group
+
deformable_group_index
)
*
kernel_h
*
kernel_w
*
height_col
*
width_col
;
const
int
data_offset_h_ptr
=
((
2
*
(
i
*
kernel_w
+
j
))
*
height_col
+
h_out
)
*
width_col
+
w_out
;
const
int
data_offset_w_ptr
=
((
2
*
(
i
*
kernel_w
+
j
)
+
1
)
*
height_col
+
h_out
)
*
width_col
+
w_out
;
const
int
data_mask_hw_ptr
=
((
i
*
kernel_w
+
j
)
*
height_col
+
h_out
)
*
width_col
+
w_out
;
const
scalar_t
offset_h
=
data_offset_ptr
[
data_offset_h_ptr
];
const
scalar_t
offset_w
=
data_offset_ptr
[
data_offset_w_ptr
];
const
scalar_t
mask
=
data_mask_ptr
[
data_mask_hw_ptr
];
const
scalar_t
cur_inv_h_data
=
h_in
+
i
*
dilation_h
+
offset_h
;
const
scalar_t
cur_inv_w_data
=
w_in
+
j
*
dilation_w
+
offset_w
;
const
scalar_t
cur_top_grad
=
data_col
[
index
]
*
mask
;
const
int
cur_h
=
(
int
)
cur_inv_h_data
;
const
int
cur_w
=
(
int
)
cur_inv_w_data
;
for
(
int
dy
=
-
2
;
dy
<=
2
;
dy
++
)
{
for
(
int
dx
=
-
2
;
dx
<=
2
;
dx
++
)
{
if
(
cur_h
+
dy
>=
0
&&
cur_h
+
dy
<
height
&&
cur_w
+
dx
>=
0
&&
cur_w
+
dx
<
width
&&
abs
(
cur_inv_h_data
-
(
cur_h
+
dy
))
<
1
&&
abs
(
cur_inv_w_data
-
(
cur_w
+
dx
))
<
1
)
{
int
cur_bottom_grad_pos
=
((
b
*
channels
+
c
)
*
height
+
cur_h
+
dy
)
*
width
+
cur_w
+
dx
;
scalar_t
weight
=
dmcn_get_gradient_weight
(
cur_inv_h_data
,
cur_inv_w_data
,
cur_h
+
dy
,
cur_w
+
dx
,
height
,
width
);
atomicAdd
(
grad_im
+
cur_bottom_grad_pos
,
weight
*
cur_top_grad
);
}
}
}
}
}
template
<
typename
scalar_t
>
__global__
void
modulated_deformable_col2im_coord_gpu_kernel
(
const
int
n
,
const
scalar_t
*
data_col
,
const
scalar_t
*
data_im
,
const
scalar_t
*
data_offset
,
const
scalar_t
*
data_mask
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
channel_per_deformable_group
,
const
int
batch_size
,
const
int
offset_channels
,
const
int
deformable_group
,
const
int
height_col
,
const
int
width_col
,
scalar_t
*
grad_offset
,
scalar_t
*
grad_mask
)
{
CUDA_KERNEL_LOOP
(
index
,
n
)
{
scalar_t
val
=
0
,
mval
=
0
;
int
w
=
index
%
width_col
;
int
h
=
(
index
/
width_col
)
%
height_col
;
int
c
=
(
index
/
width_col
/
height_col
)
%
offset_channels
;
int
b
=
(
index
/
width_col
/
height_col
)
/
offset_channels
;
// compute the start and end of the output
const
int
deformable_group_index
=
c
/
(
2
*
kernel_h
*
kernel_w
);
const
int
col_step
=
kernel_h
*
kernel_w
;
int
cnt
=
0
;
const
scalar_t
*
data_col_ptr
=
data_col
+
deformable_group_index
*
channel_per_deformable_group
*
batch_size
*
width_col
*
height_col
;
const
scalar_t
*
data_im_ptr
=
data_im
+
(
b
*
deformable_group
+
deformable_group_index
)
*
channel_per_deformable_group
/
kernel_h
/
kernel_w
*
height
*
width
;
const
scalar_t
*
data_offset_ptr
=
data_offset
+
(
b
*
deformable_group
+
deformable_group_index
)
*
2
*
kernel_h
*
kernel_w
*
height_col
*
width_col
;
const
scalar_t
*
data_mask_ptr
=
data_mask
+
(
b
*
deformable_group
+
deformable_group_index
)
*
kernel_h
*
kernel_w
*
height_col
*
width_col
;
const
int
offset_c
=
c
-
deformable_group_index
*
2
*
kernel_h
*
kernel_w
;
for
(
int
col_c
=
(
offset_c
/
2
);
col_c
<
channel_per_deformable_group
;
col_c
+=
col_step
)
{
const
int
col_pos
=
(((
col_c
*
batch_size
+
b
)
*
height_col
)
+
h
)
*
width_col
+
w
;
const
int
bp_dir
=
offset_c
%
2
;
int
j
=
(
col_pos
/
width_col
/
height_col
/
batch_size
)
%
kernel_w
;
int
i
=
(
col_pos
/
width_col
/
height_col
/
batch_size
/
kernel_w
)
%
kernel_h
;
int
w_out
=
col_pos
%
width_col
;
int
h_out
=
(
col_pos
/
width_col
)
%
height_col
;
int
w_in
=
w_out
*
stride_w
-
pad_w
;
int
h_in
=
h_out
*
stride_h
-
pad_h
;
const
int
data_offset_h_ptr
=
(((
2
*
(
i
*
kernel_w
+
j
))
*
height_col
+
h_out
)
*
width_col
+
w_out
);
const
int
data_offset_w_ptr
=
(((
2
*
(
i
*
kernel_w
+
j
)
+
1
)
*
height_col
+
h_out
)
*
width_col
+
w_out
);
const
int
data_mask_hw_ptr
=
(((
i
*
kernel_w
+
j
)
*
height_col
+
h_out
)
*
width_col
+
w_out
);
const
scalar_t
offset_h
=
data_offset_ptr
[
data_offset_h_ptr
];
const
scalar_t
offset_w
=
data_offset_ptr
[
data_offset_w_ptr
];
const
scalar_t
mask
=
data_mask_ptr
[
data_mask_hw_ptr
];
scalar_t
inv_h
=
h_in
+
i
*
dilation_h
+
offset_h
;
scalar_t
inv_w
=
w_in
+
j
*
dilation_w
+
offset_w
;
if
(
inv_h
<=
-
1
||
inv_w
<=
-
1
||
inv_h
>=
height
||
inv_w
>=
width
)
{
inv_h
=
inv_w
=
-
2
;
}
else
{
mval
+=
data_col_ptr
[
col_pos
]
*
dmcn_im2col_bilinear
(
data_im_ptr
+
cnt
*
height
*
width
,
width
,
height
,
width
,
inv_h
,
inv_w
);
}
const
scalar_t
weight
=
dmcn_get_coordinate_weight
(
inv_h
,
inv_w
,
height
,
width
,
data_im_ptr
+
cnt
*
height
*
width
,
width
,
bp_dir
);
val
+=
weight
*
data_col_ptr
[
col_pos
]
*
mask
;
cnt
+=
1
;
}
// KERNEL_ASSIGN(grad_offset[index], offset_req, val);
grad_offset
[
index
]
=
val
;
if
(
offset_c
%
2
==
0
)
// KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);
grad_mask
[(((
b
*
deformable_group
+
deformable_group_index
)
*
kernel_h
*
kernel_w
+
offset_c
/
2
)
*
height_col
+
h
)
*
width_col
+
w
]
=
mval
;
}
}
void
modulated_deformable_im2col_cuda
(
const
at
::
Tensor
data_im
,
const
at
::
Tensor
data_offset
,
const
at
::
Tensor
data_mask
,
const
int
batch_size
,
const
int
channels
,
const
int
height_im
,
const
int
width_im
,
const
int
height_col
,
const
int
width_col
,
const
int
kernel_h
,
const
int
kenerl_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
deformable_group
,
at
::
Tensor
data_col
)
{
// num_axes should be smaller than block size
const
int
channel_per_deformable_group
=
channels
/
deformable_group
;
const
int
num_kernels
=
channels
*
batch_size
*
height_col
*
width_col
;
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
data_im
.
scalar_type
(),
"modulated_deformable_im2col_gpu"
,
([
&
]
{
const
scalar_t
*
data_im_
=
data_im
.
data_ptr
<
scalar_t
>
();
const
scalar_t
*
data_offset_
=
data_offset
.
data_ptr
<
scalar_t
>
();
const
scalar_t
*
data_mask_
=
data_mask
.
data_ptr
<
scalar_t
>
();
scalar_t
*
data_col_
=
data_col
.
data_ptr
<
scalar_t
>
();
modulated_deformable_im2col_gpu_kernel
<<<
GET_BLOCKS
(
num_kernels
),
CUDA_NUM_THREADS
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
num_kernels
,
data_im_
,
data_offset_
,
data_mask_
,
height_im
,
width_im
,
kernel_h
,
kenerl_w
,
pad_h
,
pad_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channel_per_deformable_group
,
batch_size
,
channels
,
deformable_group
,
height_col
,
width_col
,
data_col_
);
}));
cudaError_t
err
=
cudaGetLastError
();
if
(
err
!=
cudaSuccess
)
{
printf
(
"error in modulated_deformable_im2col_cuda: %s
\n
"
,
cudaGetErrorString
(
err
));
}
}
void
modulated_deformable_col2im_cuda
(
const
at
::
Tensor
data_col
,
const
at
::
Tensor
data_offset
,
const
at
::
Tensor
data_mask
,
const
int
batch_size
,
const
int
channels
,
const
int
height_im
,
const
int
width_im
,
const
int
height_col
,
const
int
width_col
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
deformable_group
,
at
::
Tensor
grad_im
)
{
const
int
channel_per_deformable_group
=
channels
/
deformable_group
;
const
int
num_kernels
=
channels
*
kernel_h
*
kernel_w
*
batch_size
*
height_col
*
width_col
;
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
data_col
.
scalar_type
(),
"modulated_deformable_col2im_gpu"
,
([
&
]
{
const
scalar_t
*
data_col_
=
data_col
.
data_ptr
<
scalar_t
>
();
const
scalar_t
*
data_offset_
=
data_offset
.
data_ptr
<
scalar_t
>
();
const
scalar_t
*
data_mask_
=
data_mask
.
data_ptr
<
scalar_t
>
();
scalar_t
*
grad_im_
=
grad_im
.
data_ptr
<
scalar_t
>
();
modulated_deformable_col2im_gpu_kernel
<<<
GET_BLOCKS
(
num_kernels
),
CUDA_NUM_THREADS
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
num_kernels
,
data_col_
,
data_offset_
,
data_mask_
,
channels
,
height_im
,
width_im
,
kernel_h
,
kernel_w
,
pad_h
,
pad_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channel_per_deformable_group
,
batch_size
,
deformable_group
,
height_col
,
width_col
,
grad_im_
);
}));
cudaError_t
err
=
cudaGetLastError
();
if
(
err
!=
cudaSuccess
)
{
printf
(
"error in modulated_deformable_col2im_cuda: %s
\n
"
,
cudaGetErrorString
(
err
));
}
}
void
modulated_deformable_col2im_coord_cuda
(
const
at
::
Tensor
data_col
,
const
at
::
Tensor
data_im
,
const
at
::
Tensor
data_offset
,
const
at
::
Tensor
data_mask
,
const
int
batch_size
,
const
int
channels
,
const
int
height_im
,
const
int
width_im
,
const
int
height_col
,
const
int
width_col
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
deformable_group
,
at
::
Tensor
grad_offset
,
at
::
Tensor
grad_mask
)
{
const
int
num_kernels
=
batch_size
*
height_col
*
width_col
*
2
*
kernel_h
*
kernel_w
*
deformable_group
;
const
int
channel_per_deformable_group
=
channels
*
kernel_h
*
kernel_w
/
deformable_group
;
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
data_col
.
scalar_type
(),
"modulated_deformable_col2im_coord_gpu"
,
([
&
]
{
const
scalar_t
*
data_col_
=
data_col
.
data_ptr
<
scalar_t
>
();
const
scalar_t
*
data_im_
=
data_im
.
data_ptr
<
scalar_t
>
();
const
scalar_t
*
data_offset_
=
data_offset
.
data_ptr
<
scalar_t
>
();
const
scalar_t
*
data_mask_
=
data_mask
.
data_ptr
<
scalar_t
>
();
scalar_t
*
grad_offset_
=
grad_offset
.
data_ptr
<
scalar_t
>
();
scalar_t
*
grad_mask_
=
grad_mask
.
data_ptr
<
scalar_t
>
();
modulated_deformable_col2im_coord_gpu_kernel
<<<
GET_BLOCKS
(
num_kernels
),
CUDA_NUM_THREADS
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
num_kernels
,
data_col_
,
data_im_
,
data_offset_
,
data_mask_
,
channels
,
height_im
,
width_im
,
kernel_h
,
kernel_w
,
pad_h
,
pad_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channel_per_deformable_group
,
batch_size
,
2
*
kernel_h
*
kernel_w
*
deformable_group
,
deformable_group
,
height_col
,
width_col
,
grad_offset_
,
grad_mask_
);
}));
cudaError_t
err
=
cudaGetLastError
();
if
(
err
!=
cudaSuccess
)
{
printf
(
"error in modulated_deformable_col2im_coord_cuda: %s
\n
"
,
cudaGetErrorString
(
err
));
}
}
sr_model/Basicsr/basicsr/ops/dcn/src/deform_conv_ext.cpp
0 → 100644
View file @
5efcc6ff
// modify from
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
#include <torch/extension.h>
#include <ATen/DeviceGuard.h>
#include <cmath>
#include <vector>
#define WITH_CUDA // always use cuda
#ifdef WITH_CUDA
int
deform_conv_forward_cuda
(
at
::
Tensor
input
,
at
::
Tensor
weight
,
at
::
Tensor
offset
,
at
::
Tensor
output
,
at
::
Tensor
columns
,
at
::
Tensor
ones
,
int
kW
,
int
kH
,
int
dW
,
int
dH
,
int
padW
,
int
padH
,
int
dilationW
,
int
dilationH
,
int
group
,
int
deformable_group
,
int
im2col_step
);
int
deform_conv_backward_input_cuda
(
at
::
Tensor
input
,
at
::
Tensor
offset
,
at
::
Tensor
gradOutput
,
at
::
Tensor
gradInput
,
at
::
Tensor
gradOffset
,
at
::
Tensor
weight
,
at
::
Tensor
columns
,
int
kW
,
int
kH
,
int
dW
,
int
dH
,
int
padW
,
int
padH
,
int
dilationW
,
int
dilationH
,
int
group
,
int
deformable_group
,
int
im2col_step
);
int
deform_conv_backward_parameters_cuda
(
at
::
Tensor
input
,
at
::
Tensor
offset
,
at
::
Tensor
gradOutput
,
at
::
Tensor
gradWeight
,
// at::Tensor gradBias,
at
::
Tensor
columns
,
at
::
Tensor
ones
,
int
kW
,
int
kH
,
int
dW
,
int
dH
,
int
padW
,
int
padH
,
int
dilationW
,
int
dilationH
,
int
group
,
int
deformable_group
,
float
scale
,
int
im2col_step
);
void
modulated_deform_conv_cuda_forward
(
at
::
Tensor
input
,
at
::
Tensor
weight
,
at
::
Tensor
bias
,
at
::
Tensor
ones
,
at
::
Tensor
offset
,
at
::
Tensor
mask
,
at
::
Tensor
output
,
at
::
Tensor
columns
,
int
kernel_h
,
int
kernel_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
group
,
const
int
deformable_group
,
const
bool
with_bias
);
void
modulated_deform_conv_cuda_backward
(
at
::
Tensor
input
,
at
::
Tensor
weight
,
at
::
Tensor
bias
,
at
::
Tensor
ones
,
at
::
Tensor
offset
,
at
::
Tensor
mask
,
at
::
Tensor
columns
,
at
::
Tensor
grad_input
,
at
::
Tensor
grad_weight
,
at
::
Tensor
grad_bias
,
at
::
Tensor
grad_offset
,
at
::
Tensor
grad_mask
,
at
::
Tensor
grad_output
,
int
kernel_h
,
int
kernel_w
,
int
stride_h
,
int
stride_w
,
int
pad_h
,
int
pad_w
,
int
dilation_h
,
int
dilation_w
,
int
group
,
int
deformable_group
,
const
bool
with_bias
);
#endif
int
deform_conv_forward
(
at
::
Tensor
input
,
at
::
Tensor
weight
,
at
::
Tensor
offset
,
at
::
Tensor
output
,
at
::
Tensor
columns
,
at
::
Tensor
ones
,
int
kW
,
int
kH
,
int
dW
,
int
dH
,
int
padW
,
int
padH
,
int
dilationW
,
int
dilationH
,
int
group
,
int
deformable_group
,
int
im2col_step
)
{
if
(
input
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
return
deform_conv_forward_cuda
(
input
,
weight
,
offset
,
output
,
columns
,
ones
,
kW
,
kH
,
dW
,
dH
,
padW
,
padH
,
dilationW
,
dilationH
,
group
,
deformable_group
,
im2col_step
);
#else
AT_ERROR
(
"deform conv is not compiled with GPU support"
);
#endif
}
AT_ERROR
(
"deform conv is not implemented on CPU"
);
}
int
deform_conv_backward_input
(
at
::
Tensor
input
,
at
::
Tensor
offset
,
at
::
Tensor
gradOutput
,
at
::
Tensor
gradInput
,
at
::
Tensor
gradOffset
,
at
::
Tensor
weight
,
at
::
Tensor
columns
,
int
kW
,
int
kH
,
int
dW
,
int
dH
,
int
padW
,
int
padH
,
int
dilationW
,
int
dilationH
,
int
group
,
int
deformable_group
,
int
im2col_step
)
{
if
(
input
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
return
deform_conv_backward_input_cuda
(
input
,
offset
,
gradOutput
,
gradInput
,
gradOffset
,
weight
,
columns
,
kW
,
kH
,
dW
,
dH
,
padW
,
padH
,
dilationW
,
dilationH
,
group
,
deformable_group
,
im2col_step
);
#else
AT_ERROR
(
"deform conv is not compiled with GPU support"
);
#endif
}
AT_ERROR
(
"deform conv is not implemented on CPU"
);
}
int
deform_conv_backward_parameters
(
at
::
Tensor
input
,
at
::
Tensor
offset
,
at
::
Tensor
gradOutput
,
at
::
Tensor
gradWeight
,
// at::Tensor gradBias,
at
::
Tensor
columns
,
at
::
Tensor
ones
,
int
kW
,
int
kH
,
int
dW
,
int
dH
,
int
padW
,
int
padH
,
int
dilationW
,
int
dilationH
,
int
group
,
int
deformable_group
,
float
scale
,
int
im2col_step
)
{
if
(
input
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
return
deform_conv_backward_parameters_cuda
(
input
,
offset
,
gradOutput
,
gradWeight
,
columns
,
ones
,
kW
,
kH
,
dW
,
dH
,
padW
,
padH
,
dilationW
,
dilationH
,
group
,
deformable_group
,
scale
,
im2col_step
);
#else
AT_ERROR
(
"deform conv is not compiled with GPU support"
);
#endif
}
AT_ERROR
(
"deform conv is not implemented on CPU"
);
}
void
modulated_deform_conv_forward
(
at
::
Tensor
input
,
at
::
Tensor
weight
,
at
::
Tensor
bias
,
at
::
Tensor
ones
,
at
::
Tensor
offset
,
at
::
Tensor
mask
,
at
::
Tensor
output
,
at
::
Tensor
columns
,
int
kernel_h
,
int
kernel_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
group
,
const
int
deformable_group
,
const
bool
with_bias
)
{
if
(
input
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
return
modulated_deform_conv_cuda_forward
(
input
,
weight
,
bias
,
ones
,
offset
,
mask
,
output
,
columns
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
group
,
deformable_group
,
with_bias
);
#else
AT_ERROR
(
"modulated deform conv is not compiled with GPU support"
);
#endif
}
AT_ERROR
(
"modulated deform conv is not implemented on CPU"
);
}
void
modulated_deform_conv_backward
(
at
::
Tensor
input
,
at
::
Tensor
weight
,
at
::
Tensor
bias
,
at
::
Tensor
ones
,
at
::
Tensor
offset
,
at
::
Tensor
mask
,
at
::
Tensor
columns
,
at
::
Tensor
grad_input
,
at
::
Tensor
grad_weight
,
at
::
Tensor
grad_bias
,
at
::
Tensor
grad_offset
,
at
::
Tensor
grad_mask
,
at
::
Tensor
grad_output
,
int
kernel_h
,
int
kernel_w
,
int
stride_h
,
int
stride_w
,
int
pad_h
,
int
pad_w
,
int
dilation_h
,
int
dilation_w
,
int
group
,
int
deformable_group
,
const
bool
with_bias
)
{
if
(
input
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
return
modulated_deform_conv_cuda_backward
(
input
,
weight
,
bias
,
ones
,
offset
,
mask
,
columns
,
grad_input
,
grad_weight
,
grad_bias
,
grad_offset
,
grad_mask
,
grad_output
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
pad_h
,
pad_w
,
dilation_h
,
dilation_w
,
group
,
deformable_group
,
with_bias
);
#else
AT_ERROR
(
"modulated deform conv is not compiled with GPU support"
);
#endif
}
AT_ERROR
(
"modulated deform conv is not implemented on CPU"
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"deform_conv_forward"
,
&
deform_conv_forward
,
"deform forward"
);
m
.
def
(
"deform_conv_backward_input"
,
&
deform_conv_backward_input
,
"deform_conv_backward_input"
);
m
.
def
(
"deform_conv_backward_parameters"
,
&
deform_conv_backward_parameters
,
"deform_conv_backward_parameters"
);
m
.
def
(
"modulated_deform_conv_forward"
,
&
modulated_deform_conv_forward
,
"modulated deform conv forward"
);
m
.
def
(
"modulated_deform_conv_backward"
,
&
modulated_deform_conv_backward
,
"modulated deform conv backward"
);
}
sr_model/Basicsr/basicsr/ops/fused_act/__init__.py
0 → 100644
View file @
5efcc6ff
from
.fused_act
import
FusedLeakyReLU
,
fused_leaky_relu
__all__
=
[
'FusedLeakyReLU'
,
'fused_leaky_relu'
]
sr_model/Basicsr/basicsr/ops/fused_act/fused_act.py
0 → 100644
View file @
5efcc6ff
# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501
import
os
import
torch
from
torch
import
nn
from
torch.autograd
import
Function
BASICSR_JIT
=
os
.
getenv
(
'BASICSR_JIT'
)
if
BASICSR_JIT
==
'True'
:
from
torch.utils.cpp_extension
import
load
module_path
=
os
.
path
.
dirname
(
__file__
)
fused_act_ext
=
load
(
'fused'
,
sources
=
[
os
.
path
.
join
(
module_path
,
'src'
,
'fused_bias_act.cpp'
),
os
.
path
.
join
(
module_path
,
'src'
,
'fused_bias_act_kernel.cu'
),
],
)
else
:
try
:
from
.
import
fused_act_ext
except
ImportError
:
pass
# avoid annoying print output
# print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n '
# '1. compile with BASICSR_EXT=True. or\n '
# '2. set BASICSR_JIT=True during running')
class
FusedLeakyReLUFunctionBackward
(
Function
):
@
staticmethod
def
forward
(
ctx
,
grad_output
,
out
,
negative_slope
,
scale
):
ctx
.
save_for_backward
(
out
)
ctx
.
negative_slope
=
negative_slope
ctx
.
scale
=
scale
empty
=
grad_output
.
new_empty
(
0
)
grad_input
=
fused_act_ext
.
fused_bias_act
(
grad_output
,
empty
,
out
,
3
,
1
,
negative_slope
,
scale
)
dim
=
[
0
]
if
grad_input
.
ndim
>
2
:
dim
+=
list
(
range
(
2
,
grad_input
.
ndim
))
grad_bias
=
grad_input
.
sum
(
dim
).
detach
()
return
grad_input
,
grad_bias
@
staticmethod
def
backward
(
ctx
,
gradgrad_input
,
gradgrad_bias
):
out
,
=
ctx
.
saved_tensors
gradgrad_out
=
fused_act_ext
.
fused_bias_act
(
gradgrad_input
,
gradgrad_bias
,
out
,
3
,
1
,
ctx
.
negative_slope
,
ctx
.
scale
)
return
gradgrad_out
,
None
,
None
,
None
class
FusedLeakyReLUFunction
(
Function
):
@
staticmethod
def
forward
(
ctx
,
input
,
bias
,
negative_slope
,
scale
):
empty
=
input
.
new_empty
(
0
)
out
=
fused_act_ext
.
fused_bias_act
(
input
,
bias
,
empty
,
3
,
0
,
negative_slope
,
scale
)
ctx
.
save_for_backward
(
out
)
ctx
.
negative_slope
=
negative_slope
ctx
.
scale
=
scale
return
out
@
staticmethod
def
backward
(
ctx
,
grad_output
):
out
,
=
ctx
.
saved_tensors
grad_input
,
grad_bias
=
FusedLeakyReLUFunctionBackward
.
apply
(
grad_output
,
out
,
ctx
.
negative_slope
,
ctx
.
scale
)
return
grad_input
,
grad_bias
,
None
,
None
class
FusedLeakyReLU
(
nn
.
Module
):
def
__init__
(
self
,
channel
,
negative_slope
=
0.2
,
scale
=
2
**
0.5
):
super
().
__init__
()
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
channel
))
self
.
negative_slope
=
negative_slope
self
.
scale
=
scale
def
forward
(
self
,
input
):
return
fused_leaky_relu
(
input
,
self
.
bias
,
self
.
negative_slope
,
self
.
scale
)
def
fused_leaky_relu
(
input
,
bias
,
negative_slope
=
0.2
,
scale
=
2
**
0.5
):
return
FusedLeakyReLUFunction
.
apply
(
input
,
bias
.
type_as
(
input
),
negative_slope
,
scale
)
sr_model/Basicsr/basicsr/ops/fused_act/src/fused_bias_act.cpp
0 → 100644
View file @
5efcc6ff
// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp
#include <torch/extension.h>
torch
::
Tensor
fused_bias_act_op
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
bias
,
const
torch
::
Tensor
&
refer
,
int
act
,
int
grad
,
float
alpha
,
float
scale
);
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch
::
Tensor
fused_bias_act
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
bias
,
const
torch
::
Tensor
&
refer
,
int
act
,
int
grad
,
float
alpha
,
float
scale
)
{
CHECK_CUDA
(
input
);
CHECK_CUDA
(
bias
);
return
fused_bias_act_op
(
input
,
bias
,
refer
,
act
,
grad
,
alpha
,
scale
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"fused_bias_act"
,
&
fused_bias_act
,
"fused bias act (CUDA)"
);
}
sr_model/Basicsr/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu
0 → 100644
View file @
5efcc6ff
// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
#include <torch/types.h>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <cuda.h>
#include <cuda_runtime.h>
template
<
typename
scalar_t
>
static
__global__
void
fused_bias_act_kernel
(
scalar_t
*
out
,
const
scalar_t
*
p_x
,
const
scalar_t
*
p_b
,
const
scalar_t
*
p_ref
,
int
act
,
int
grad
,
scalar_t
alpha
,
scalar_t
scale
,
int
loop_x
,
int
size_x
,
int
step_b
,
int
size_b
,
int
use_bias
,
int
use_ref
)
{
int
xi
=
blockIdx
.
x
*
loop_x
*
blockDim
.
x
+
threadIdx
.
x
;
scalar_t
zero
=
0.0
;
for
(
int
loop_idx
=
0
;
loop_idx
<
loop_x
&&
xi
<
size_x
;
loop_idx
++
,
xi
+=
blockDim
.
x
)
{
scalar_t
x
=
p_x
[
xi
];
if
(
use_bias
)
{
x
+=
p_b
[(
xi
/
step_b
)
%
size_b
];
}
scalar_t
ref
=
use_ref
?
p_ref
[
xi
]
:
zero
;
scalar_t
y
;
switch
(
act
*
10
+
grad
)
{
default:
case
10
:
y
=
x
;
break
;
case
11
:
y
=
x
;
break
;
case
12
:
y
=
0.0
;
break
;
case
30
:
y
=
(
x
>
0.0
)
?
x
:
x
*
alpha
;
break
;
case
31
:
y
=
(
ref
>
0.0
)
?
x
:
x
*
alpha
;
break
;
case
32
:
y
=
0.0
;
break
;
}
out
[
xi
]
=
y
*
scale
;
}
}
torch
::
Tensor
fused_bias_act_op
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
bias
,
const
torch
::
Tensor
&
refer
,
int
act
,
int
grad
,
float
alpha
,
float
scale
)
{
int
curDevice
=
-
1
;
cudaGetDevice
(
&
curDevice
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
curDevice
);
auto
x
=
input
.
contiguous
();
auto
b
=
bias
.
contiguous
();
auto
ref
=
refer
.
contiguous
();
int
use_bias
=
b
.
numel
()
?
1
:
0
;
int
use_ref
=
ref
.
numel
()
?
1
:
0
;
int
size_x
=
x
.
numel
();
int
size_b
=
b
.
numel
();
int
step_b
=
1
;
for
(
int
i
=
1
+
1
;
i
<
x
.
dim
();
i
++
)
{
step_b
*=
x
.
size
(
i
);
}
int
loop_x
=
4
;
int
block_size
=
4
*
32
;
int
grid_size
=
(
size_x
-
1
)
/
(
loop_x
*
block_size
)
+
1
;
auto
y
=
torch
::
empty_like
(
x
);
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
x
.
scalar_type
(),
"fused_bias_act_kernel"
,
[
&
]
{
fused_bias_act_kernel
<
scalar_t
><<<
grid_size
,
block_size
,
0
,
stream
>>>
(
y
.
data_ptr
<
scalar_t
>
(),
x
.
data_ptr
<
scalar_t
>
(),
b
.
data_ptr
<
scalar_t
>
(),
ref
.
data_ptr
<
scalar_t
>
(),
act
,
grad
,
alpha
,
scale
,
loop_x
,
size_x
,
step_b
,
size_b
,
use_bias
,
use_ref
);
});
return
y
;
}
sr_model/Basicsr/basicsr/ops/upfirdn2d/__init__.py
0 → 100644
View file @
5efcc6ff
from
.upfirdn2d
import
upfirdn2d
__all__
=
[
'upfirdn2d'
]
sr_model/Basicsr/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp
0 → 100644
View file @
5efcc6ff
// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp
#include <torch/extension.h>
torch
::
Tensor
upfirdn2d_op
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
kernel
,
int
up_x
,
int
up_y
,
int
down_x
,
int
down_y
,
int
pad_x0
,
int
pad_x1
,
int
pad_y0
,
int
pad_y1
);
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch
::
Tensor
upfirdn2d
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
kernel
,
int
up_x
,
int
up_y
,
int
down_x
,
int
down_y
,
int
pad_x0
,
int
pad_x1
,
int
pad_y0
,
int
pad_y1
)
{
CHECK_CUDA
(
input
);
CHECK_CUDA
(
kernel
);
return
upfirdn2d_op
(
input
,
kernel
,
up_x
,
up_y
,
down_x
,
down_y
,
pad_x0
,
pad_x1
,
pad_y0
,
pad_y1
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"upfirdn2d"
,
&
upfirdn2d
,
"upfirdn2d (CUDA)"
);
}
sr_model/Basicsr/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu
0 → 100644
View file @
5efcc6ff
// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d_kernel.cu
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
#include <torch/types.h>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
static
__host__
__device__
__forceinline__
int
floor_div
(
int
a
,
int
b
)
{
int
c
=
a
/
b
;
if
(
c
*
b
>
a
)
{
c
--
;
}
return
c
;
}
struct
UpFirDn2DKernelParams
{
int
up_x
;
int
up_y
;
int
down_x
;
int
down_y
;
int
pad_x0
;
int
pad_x1
;
int
pad_y0
;
int
pad_y1
;
int
major_dim
;
int
in_h
;
int
in_w
;
int
minor_dim
;
int
kernel_h
;
int
kernel_w
;
int
out_h
;
int
out_w
;
int
loop_major
;
int
loop_x
;
};
template
<
typename
scalar_t
>
__global__
void
upfirdn2d_kernel_large
(
scalar_t
*
out
,
const
scalar_t
*
input
,
const
scalar_t
*
kernel
,
const
UpFirDn2DKernelParams
p
)
{
int
minor_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
out_y
=
minor_idx
/
p
.
minor_dim
;
minor_idx
-=
out_y
*
p
.
minor_dim
;
int
out_x_base
=
blockIdx
.
y
*
p
.
loop_x
*
blockDim
.
y
+
threadIdx
.
y
;
int
major_idx_base
=
blockIdx
.
z
*
p
.
loop_major
;
if
(
out_x_base
>=
p
.
out_w
||
out_y
>=
p
.
out_h
||
major_idx_base
>=
p
.
major_dim
)
{
return
;
}
int
mid_y
=
out_y
*
p
.
down_y
+
p
.
up_y
-
1
-
p
.
pad_y0
;
int
in_y
=
min
(
max
(
floor_div
(
mid_y
,
p
.
up_y
),
0
),
p
.
in_h
);
int
h
=
min
(
max
(
floor_div
(
mid_y
+
p
.
kernel_h
,
p
.
up_y
),
0
),
p
.
in_h
)
-
in_y
;
int
kernel_y
=
mid_y
+
p
.
kernel_h
-
(
in_y
+
1
)
*
p
.
up_y
;
for
(
int
loop_major
=
0
,
major_idx
=
major_idx_base
;
loop_major
<
p
.
loop_major
&&
major_idx
<
p
.
major_dim
;
loop_major
++
,
major_idx
++
)
{
for
(
int
loop_x
=
0
,
out_x
=
out_x_base
;
loop_x
<
p
.
loop_x
&&
out_x
<
p
.
out_w
;
loop_x
++
,
out_x
+=
blockDim
.
y
)
{
int
mid_x
=
out_x
*
p
.
down_x
+
p
.
up_x
-
1
-
p
.
pad_x0
;
int
in_x
=
min
(
max
(
floor_div
(
mid_x
,
p
.
up_x
),
0
),
p
.
in_w
);
int
w
=
min
(
max
(
floor_div
(
mid_x
+
p
.
kernel_w
,
p
.
up_x
),
0
),
p
.
in_w
)
-
in_x
;
int
kernel_x
=
mid_x
+
p
.
kernel_w
-
(
in_x
+
1
)
*
p
.
up_x
;
const
scalar_t
*
x_p
=
&
input
[((
major_idx
*
p
.
in_h
+
in_y
)
*
p
.
in_w
+
in_x
)
*
p
.
minor_dim
+
minor_idx
];
const
scalar_t
*
k_p
=
&
kernel
[
kernel_y
*
p
.
kernel_w
+
kernel_x
];
int
x_px
=
p
.
minor_dim
;
int
k_px
=
-
p
.
up_x
;
int
x_py
=
p
.
in_w
*
p
.
minor_dim
;
int
k_py
=
-
p
.
up_y
*
p
.
kernel_w
;
scalar_t
v
=
0.0
f
;
for
(
int
y
=
0
;
y
<
h
;
y
++
)
{
for
(
int
x
=
0
;
x
<
w
;
x
++
)
{
v
+=
static_cast
<
scalar_t
>
(
*
x_p
)
*
static_cast
<
scalar_t
>
(
*
k_p
);
x_p
+=
x_px
;
k_p
+=
k_px
;
}
x_p
+=
x_py
-
w
*
x_px
;
k_p
+=
k_py
-
w
*
k_px
;
}
out
[((
major_idx
*
p
.
out_h
+
out_y
)
*
p
.
out_w
+
out_x
)
*
p
.
minor_dim
+
minor_idx
]
=
v
;
}
}
}
template
<
typename
scalar_t
,
int
up_x
,
int
up_y
,
int
down_x
,
int
down_y
,
int
kernel_h
,
int
kernel_w
,
int
tile_out_h
,
int
tile_out_w
>
__global__
void
upfirdn2d_kernel
(
scalar_t
*
out
,
const
scalar_t
*
input
,
const
scalar_t
*
kernel
,
const
UpFirDn2DKernelParams
p
)
{
const
int
tile_in_h
=
((
tile_out_h
-
1
)
*
down_y
+
kernel_h
-
1
)
/
up_y
+
1
;
const
int
tile_in_w
=
((
tile_out_w
-
1
)
*
down_x
+
kernel_w
-
1
)
/
up_x
+
1
;
__shared__
volatile
float
sk
[
kernel_h
][
kernel_w
];
__shared__
volatile
float
sx
[
tile_in_h
][
tile_in_w
];
int
minor_idx
=
blockIdx
.
x
;
int
tile_out_y
=
minor_idx
/
p
.
minor_dim
;
minor_idx
-=
tile_out_y
*
p
.
minor_dim
;
tile_out_y
*=
tile_out_h
;
int
tile_out_x_base
=
blockIdx
.
y
*
p
.
loop_x
*
tile_out_w
;
int
major_idx_base
=
blockIdx
.
z
*
p
.
loop_major
;
if
(
tile_out_x_base
>=
p
.
out_w
|
tile_out_y
>=
p
.
out_h
|
major_idx_base
>=
p
.
major_dim
)
{
return
;
}
for
(
int
tap_idx
=
threadIdx
.
x
;
tap_idx
<
kernel_h
*
kernel_w
;
tap_idx
+=
blockDim
.
x
)
{
int
ky
=
tap_idx
/
kernel_w
;
int
kx
=
tap_idx
-
ky
*
kernel_w
;
scalar_t
v
=
0.0
;
if
(
kx
<
p
.
kernel_w
&
ky
<
p
.
kernel_h
)
{
v
=
kernel
[(
p
.
kernel_h
-
1
-
ky
)
*
p
.
kernel_w
+
(
p
.
kernel_w
-
1
-
kx
)];
}
sk
[
ky
][
kx
]
=
v
;
}
for
(
int
loop_major
=
0
,
major_idx
=
major_idx_base
;
loop_major
<
p
.
loop_major
&
major_idx
<
p
.
major_dim
;
loop_major
++
,
major_idx
++
)
{
for
(
int
loop_x
=
0
,
tile_out_x
=
tile_out_x_base
;
loop_x
<
p
.
loop_x
&
tile_out_x
<
p
.
out_w
;
loop_x
++
,
tile_out_x
+=
tile_out_w
)
{
int
tile_mid_x
=
tile_out_x
*
down_x
+
up_x
-
1
-
p
.
pad_x0
;
int
tile_mid_y
=
tile_out_y
*
down_y
+
up_y
-
1
-
p
.
pad_y0
;
int
tile_in_x
=
floor_div
(
tile_mid_x
,
up_x
);
int
tile_in_y
=
floor_div
(
tile_mid_y
,
up_y
);
__syncthreads
();
for
(
int
in_idx
=
threadIdx
.
x
;
in_idx
<
tile_in_h
*
tile_in_w
;
in_idx
+=
blockDim
.
x
)
{
int
rel_in_y
=
in_idx
/
tile_in_w
;
int
rel_in_x
=
in_idx
-
rel_in_y
*
tile_in_w
;
int
in_x
=
rel_in_x
+
tile_in_x
;
int
in_y
=
rel_in_y
+
tile_in_y
;
scalar_t
v
=
0.0
;
if
(
in_x
>=
0
&
in_y
>=
0
&
in_x
<
p
.
in_w
&
in_y
<
p
.
in_h
)
{
v
=
input
[((
major_idx
*
p
.
in_h
+
in_y
)
*
p
.
in_w
+
in_x
)
*
p
.
minor_dim
+
minor_idx
];
}
sx
[
rel_in_y
][
rel_in_x
]
=
v
;
}
__syncthreads
();
for
(
int
out_idx
=
threadIdx
.
x
;
out_idx
<
tile_out_h
*
tile_out_w
;
out_idx
+=
blockDim
.
x
)
{
int
rel_out_y
=
out_idx
/
tile_out_w
;
int
rel_out_x
=
out_idx
-
rel_out_y
*
tile_out_w
;
int
out_x
=
rel_out_x
+
tile_out_x
;
int
out_y
=
rel_out_y
+
tile_out_y
;
int
mid_x
=
tile_mid_x
+
rel_out_x
*
down_x
;
int
mid_y
=
tile_mid_y
+
rel_out_y
*
down_y
;
int
in_x
=
floor_div
(
mid_x
,
up_x
);
int
in_y
=
floor_div
(
mid_y
,
up_y
);
int
rel_in_x
=
in_x
-
tile_in_x
;
int
rel_in_y
=
in_y
-
tile_in_y
;
int
kernel_x
=
(
in_x
+
1
)
*
up_x
-
mid_x
-
1
;
int
kernel_y
=
(
in_y
+
1
)
*
up_y
-
mid_y
-
1
;
scalar_t
v
=
0.0
;
#pragma unroll
for
(
int
y
=
0
;
y
<
kernel_h
/
up_y
;
y
++
)
#pragma unroll
for
(
int
x
=
0
;
x
<
kernel_w
/
up_x
;
x
++
)
v
+=
sx
[
rel_in_y
+
y
][
rel_in_x
+
x
]
*
sk
[
kernel_y
+
y
*
up_y
][
kernel_x
+
x
*
up_x
];
if
(
out_x
<
p
.
out_w
&
out_y
<
p
.
out_h
)
{
out
[((
major_idx
*
p
.
out_h
+
out_y
)
*
p
.
out_w
+
out_x
)
*
p
.
minor_dim
+
minor_idx
]
=
v
;
}
}
}
}
}
torch
::
Tensor
upfirdn2d_op
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
kernel
,
int
up_x
,
int
up_y
,
int
down_x
,
int
down_y
,
int
pad_x0
,
int
pad_x1
,
int
pad_y0
,
int
pad_y1
)
{
int
curDevice
=
-
1
;
cudaGetDevice
(
&
curDevice
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
curDevice
);
UpFirDn2DKernelParams
p
;
auto
x
=
input
.
contiguous
();
auto
k
=
kernel
.
contiguous
();
p
.
major_dim
=
x
.
size
(
0
);
p
.
in_h
=
x
.
size
(
1
);
p
.
in_w
=
x
.
size
(
2
);
p
.
minor_dim
=
x
.
size
(
3
);
p
.
kernel_h
=
k
.
size
(
0
);
p
.
kernel_w
=
k
.
size
(
1
);
p
.
up_x
=
up_x
;
p
.
up_y
=
up_y
;
p
.
down_x
=
down_x
;
p
.
down_y
=
down_y
;
p
.
pad_x0
=
pad_x0
;
p
.
pad_x1
=
pad_x1
;
p
.
pad_y0
=
pad_y0
;
p
.
pad_y1
=
pad_y1
;
p
.
out_h
=
(
p
.
in_h
*
p
.
up_y
+
p
.
pad_y0
+
p
.
pad_y1
-
p
.
kernel_h
+
p
.
down_y
)
/
p
.
down_y
;
p
.
out_w
=
(
p
.
in_w
*
p
.
up_x
+
p
.
pad_x0
+
p
.
pad_x1
-
p
.
kernel_w
+
p
.
down_x
)
/
p
.
down_x
;
auto
out
=
at
::
empty
({
p
.
major_dim
,
p
.
out_h
,
p
.
out_w
,
p
.
minor_dim
},
x
.
options
());
int
mode
=
-
1
;
int
tile_out_h
=
-
1
;
int
tile_out_w
=
-
1
;
if
(
p
.
up_x
==
1
&&
p
.
up_y
==
1
&&
p
.
down_x
==
1
&&
p
.
down_y
==
1
&&
p
.
kernel_h
<=
4
&&
p
.
kernel_w
<=
4
)
{
mode
=
1
;
tile_out_h
=
16
;
tile_out_w
=
64
;
}
if
(
p
.
up_x
==
1
&&
p
.
up_y
==
1
&&
p
.
down_x
==
1
&&
p
.
down_y
==
1
&&
p
.
kernel_h
<=
3
&&
p
.
kernel_w
<=
3
)
{
mode
=
2
;
tile_out_h
=
16
;
tile_out_w
=
64
;
}
if
(
p
.
up_x
==
2
&&
p
.
up_y
==
2
&&
p
.
down_x
==
1
&&
p
.
down_y
==
1
&&
p
.
kernel_h
<=
4
&&
p
.
kernel_w
<=
4
)
{
mode
=
3
;
tile_out_h
=
16
;
tile_out_w
=
64
;
}
if
(
p
.
up_x
==
2
&&
p
.
up_y
==
2
&&
p
.
down_x
==
1
&&
p
.
down_y
==
1
&&
p
.
kernel_h
<=
2
&&
p
.
kernel_w
<=
2
)
{
mode
=
4
;
tile_out_h
=
16
;
tile_out_w
=
64
;
}
if
(
p
.
up_x
==
1
&&
p
.
up_y
==
1
&&
p
.
down_x
==
2
&&
p
.
down_y
==
2
&&
p
.
kernel_h
<=
4
&&
p
.
kernel_w
<=
4
)
{
mode
=
5
;
tile_out_h
=
8
;
tile_out_w
=
32
;
}
if
(
p
.
up_x
==
1
&&
p
.
up_y
==
1
&&
p
.
down_x
==
2
&&
p
.
down_y
==
2
&&
p
.
kernel_h
<=
2
&&
p
.
kernel_w
<=
2
)
{
mode
=
6
;
tile_out_h
=
8
;
tile_out_w
=
32
;
}
dim3
block_size
;
dim3
grid_size
;
if
(
tile_out_h
>
0
&&
tile_out_w
>
0
)
{
p
.
loop_major
=
(
p
.
major_dim
-
1
)
/
16384
+
1
;
p
.
loop_x
=
1
;
block_size
=
dim3
(
32
*
8
,
1
,
1
);
grid_size
=
dim3
(((
p
.
out_h
-
1
)
/
tile_out_h
+
1
)
*
p
.
minor_dim
,
(
p
.
out_w
-
1
)
/
(
p
.
loop_x
*
tile_out_w
)
+
1
,
(
p
.
major_dim
-
1
)
/
p
.
loop_major
+
1
);
}
else
{
p
.
loop_major
=
(
p
.
major_dim
-
1
)
/
16384
+
1
;
p
.
loop_x
=
4
;
block_size
=
dim3
(
4
,
32
,
1
);
grid_size
=
dim3
((
p
.
out_h
*
p
.
minor_dim
-
1
)
/
block_size
.
x
+
1
,
(
p
.
out_w
-
1
)
/
(
p
.
loop_x
*
block_size
.
y
)
+
1
,
(
p
.
major_dim
-
1
)
/
p
.
loop_major
+
1
);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
x
.
scalar_type
(),
"upfirdn2d_cuda"
,
[
&
]
{
switch
(
mode
)
{
case
1
:
upfirdn2d_kernel
<
scalar_t
,
1
,
1
,
1
,
1
,
4
,
4
,
16
,
64
>
<<<
grid_size
,
block_size
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
x
.
data_ptr
<
scalar_t
>
(),
k
.
data_ptr
<
scalar_t
>
(),
p
);
break
;
case
2
:
upfirdn2d_kernel
<
scalar_t
,
1
,
1
,
1
,
1
,
3
,
3
,
16
,
64
>
<<<
grid_size
,
block_size
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
x
.
data_ptr
<
scalar_t
>
(),
k
.
data_ptr
<
scalar_t
>
(),
p
);
break
;
case
3
:
upfirdn2d_kernel
<
scalar_t
,
2
,
2
,
1
,
1
,
4
,
4
,
16
,
64
>
<<<
grid_size
,
block_size
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
x
.
data_ptr
<
scalar_t
>
(),
k
.
data_ptr
<
scalar_t
>
(),
p
);
break
;
case
4
:
upfirdn2d_kernel
<
scalar_t
,
2
,
2
,
1
,
1
,
2
,
2
,
16
,
64
>
<<<
grid_size
,
block_size
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
x
.
data_ptr
<
scalar_t
>
(),
k
.
data_ptr
<
scalar_t
>
(),
p
);
break
;
case
5
:
upfirdn2d_kernel
<
scalar_t
,
1
,
1
,
2
,
2
,
4
,
4
,
8
,
32
>
<<<
grid_size
,
block_size
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
x
.
data_ptr
<
scalar_t
>
(),
k
.
data_ptr
<
scalar_t
>
(),
p
);
break
;
case
6
:
upfirdn2d_kernel
<
scalar_t
,
1
,
1
,
2
,
2
,
4
,
4
,
8
,
32
>
<<<
grid_size
,
block_size
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
x
.
data_ptr
<
scalar_t
>
(),
k
.
data_ptr
<
scalar_t
>
(),
p
);
break
;
default:
upfirdn2d_kernel_large
<
scalar_t
><<<
grid_size
,
block_size
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
x
.
data_ptr
<
scalar_t
>
(),
k
.
data_ptr
<
scalar_t
>
(),
p
);
}
});
return
out
;
}
sr_model/Basicsr/basicsr/ops/upfirdn2d/upfirdn2d.py
0 → 100644
View file @
5efcc6ff
# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501
import
os
import
torch
from
torch.autograd
import
Function
from
torch.nn
import
functional
as
F
BASICSR_JIT
=
os
.
getenv
(
'BASICSR_JIT'
)
if
BASICSR_JIT
==
'True'
:
from
torch.utils.cpp_extension
import
load
module_path
=
os
.
path
.
dirname
(
__file__
)
upfirdn2d_ext
=
load
(
'upfirdn2d'
,
sources
=
[
os
.
path
.
join
(
module_path
,
'src'
,
'upfirdn2d.cpp'
),
os
.
path
.
join
(
module_path
,
'src'
,
'upfirdn2d_kernel.cu'
),
],
)
else
:
try
:
from
.
import
upfirdn2d_ext
except
ImportError
:
pass
# avoid annoying print output
# print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n '
# '1. compile with BASICSR_EXT=True. or\n '
# '2. set BASICSR_JIT=True during running')
class
UpFirDn2dBackward
(
Function
):
@
staticmethod
def
forward
(
ctx
,
grad_output
,
kernel
,
grad_kernel
,
up
,
down
,
pad
,
g_pad
,
in_size
,
out_size
):
up_x
,
up_y
=
up
down_x
,
down_y
=
down
g_pad_x0
,
g_pad_x1
,
g_pad_y0
,
g_pad_y1
=
g_pad
grad_output
=
grad_output
.
reshape
(
-
1
,
out_size
[
0
],
out_size
[
1
],
1
)
grad_input
=
upfirdn2d_ext
.
upfirdn2d
(
grad_output
,
grad_kernel
,
down_x
,
down_y
,
up_x
,
up_y
,
g_pad_x0
,
g_pad_x1
,
g_pad_y0
,
g_pad_y1
,
)
grad_input
=
grad_input
.
view
(
in_size
[
0
],
in_size
[
1
],
in_size
[
2
],
in_size
[
3
])
ctx
.
save_for_backward
(
kernel
)
pad_x0
,
pad_x1
,
pad_y0
,
pad_y1
=
pad
ctx
.
up_x
=
up_x
ctx
.
up_y
=
up_y
ctx
.
down_x
=
down_x
ctx
.
down_y
=
down_y
ctx
.
pad_x0
=
pad_x0
ctx
.
pad_x1
=
pad_x1
ctx
.
pad_y0
=
pad_y0
ctx
.
pad_y1
=
pad_y1
ctx
.
in_size
=
in_size
ctx
.
out_size
=
out_size
return
grad_input
@
staticmethod
def
backward
(
ctx
,
gradgrad_input
):
kernel
,
=
ctx
.
saved_tensors
gradgrad_input
=
gradgrad_input
.
reshape
(
-
1
,
ctx
.
in_size
[
2
],
ctx
.
in_size
[
3
],
1
)
gradgrad_out
=
upfirdn2d_ext
.
upfirdn2d
(
gradgrad_input
,
kernel
,
ctx
.
up_x
,
ctx
.
up_y
,
ctx
.
down_x
,
ctx
.
down_y
,
ctx
.
pad_x0
,
ctx
.
pad_x1
,
ctx
.
pad_y0
,
ctx
.
pad_y1
,
)
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0],
# ctx.out_size[1], ctx.in_size[3])
gradgrad_out
=
gradgrad_out
.
view
(
ctx
.
in_size
[
0
],
ctx
.
in_size
[
1
],
ctx
.
out_size
[
0
],
ctx
.
out_size
[
1
])
return
gradgrad_out
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
UpFirDn2d
(
Function
):
@
staticmethod
def
forward
(
ctx
,
input
,
kernel
,
up
,
down
,
pad
):
up_x
,
up_y
=
up
down_x
,
down_y
=
down
pad_x0
,
pad_x1
,
pad_y0
,
pad_y1
=
pad
kernel_h
,
kernel_w
=
kernel
.
shape
_
,
channel
,
in_h
,
in_w
=
input
.
shape
ctx
.
in_size
=
input
.
shape
input
=
input
.
reshape
(
-
1
,
in_h
,
in_w
,
1
)
ctx
.
save_for_backward
(
kernel
,
torch
.
flip
(
kernel
,
[
0
,
1
]))
out_h
=
(
in_h
*
up_y
+
pad_y0
+
pad_y1
-
kernel_h
)
//
down_y
+
1
out_w
=
(
in_w
*
up_x
+
pad_x0
+
pad_x1
-
kernel_w
)
//
down_x
+
1
ctx
.
out_size
=
(
out_h
,
out_w
)
ctx
.
up
=
(
up_x
,
up_y
)
ctx
.
down
=
(
down_x
,
down_y
)
ctx
.
pad
=
(
pad_x0
,
pad_x1
,
pad_y0
,
pad_y1
)
g_pad_x0
=
kernel_w
-
pad_x0
-
1
g_pad_y0
=
kernel_h
-
pad_y0
-
1
g_pad_x1
=
in_w
*
up_x
-
out_w
*
down_x
+
pad_x0
-
up_x
+
1
g_pad_y1
=
in_h
*
up_y
-
out_h
*
down_y
+
pad_y0
-
up_y
+
1
ctx
.
g_pad
=
(
g_pad_x0
,
g_pad_x1
,
g_pad_y0
,
g_pad_y1
)
out
=
upfirdn2d_ext
.
upfirdn2d
(
input
,
kernel
,
up_x
,
up_y
,
down_x
,
down_y
,
pad_x0
,
pad_x1
,
pad_y0
,
pad_y1
)
# out = out.view(major, out_h, out_w, minor)
out
=
out
.
view
(
-
1
,
channel
,
out_h
,
out_w
)
return
out
@
staticmethod
def
backward
(
ctx
,
grad_output
):
kernel
,
grad_kernel
=
ctx
.
saved_tensors
grad_input
=
UpFirDn2dBackward
.
apply
(
grad_output
,
kernel
,
grad_kernel
,
ctx
.
up
,
ctx
.
down
,
ctx
.
pad
,
ctx
.
g_pad
,
ctx
.
in_size
,
ctx
.
out_size
,
)
return
grad_input
,
None
,
None
,
None
,
None
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
if
input
.
device
.
type
==
'cpu'
:
out
=
upfirdn2d_native
(
input
,
kernel
,
up
,
up
,
down
,
down
,
pad
[
0
],
pad
[
1
],
pad
[
0
],
pad
[
1
])
else
:
out
=
UpFirDn2d
.
apply
(
input
,
kernel
,
(
up
,
up
),
(
down
,
down
),
(
pad
[
0
],
pad
[
1
],
pad
[
0
],
pad
[
1
]))
return
out
def
upfirdn2d_native
(
input
,
kernel
,
up_x
,
up_y
,
down_x
,
down_y
,
pad_x0
,
pad_x1
,
pad_y0
,
pad_y1
):
_
,
channel
,
in_h
,
in_w
=
input
.
shape
input
=
input
.
reshape
(
-
1
,
in_h
,
in_w
,
1
)
_
,
in_h
,
in_w
,
minor
=
input
.
shape
kernel_h
,
kernel_w
=
kernel
.
shape
out
=
input
.
view
(
-
1
,
in_h
,
1
,
in_w
,
1
,
minor
)
out
=
F
.
pad
(
out
,
[
0
,
0
,
0
,
up_x
-
1
,
0
,
0
,
0
,
up_y
-
1
])
out
=
out
.
view
(
-
1
,
in_h
*
up_y
,
in_w
*
up_x
,
minor
)
out
=
F
.
pad
(
out
,
[
0
,
0
,
max
(
pad_x0
,
0
),
max
(
pad_x1
,
0
),
max
(
pad_y0
,
0
),
max
(
pad_y1
,
0
)])
out
=
out
[:,
max
(
-
pad_y0
,
0
):
out
.
shape
[
1
]
-
max
(
-
pad_y1
,
0
),
max
(
-
pad_x0
,
0
):
out
.
shape
[
2
]
-
max
(
-
pad_x1
,
0
),
:,
]
out
=
out
.
permute
(
0
,
3
,
1
,
2
)
out
=
out
.
reshape
([
-
1
,
1
,
in_h
*
up_y
+
pad_y0
+
pad_y1
,
in_w
*
up_x
+
pad_x0
+
pad_x1
])
w
=
torch
.
flip
(
kernel
,
[
0
,
1
]).
view
(
1
,
1
,
kernel_h
,
kernel_w
)
out
=
F
.
conv2d
(
out
,
w
)
out
=
out
.
reshape
(
-
1
,
minor
,
in_h
*
up_y
+
pad_y0
+
pad_y1
-
kernel_h
+
1
,
in_w
*
up_x
+
pad_x0
+
pad_x1
-
kernel_w
+
1
,
)
out
=
out
.
permute
(
0
,
2
,
3
,
1
)
out
=
out
[:,
::
down_y
,
::
down_x
,
:]
out_h
=
(
in_h
*
up_y
+
pad_y0
+
pad_y1
-
kernel_h
)
//
down_y
+
1
out_w
=
(
in_w
*
up_x
+
pad_x0
+
pad_x1
-
kernel_w
)
//
down_x
+
1
return
out
.
view
(
-
1
,
channel
,
out_h
,
out_w
)
sr_model/Basicsr/basicsr/test.py
0 → 100644
View file @
5efcc6ff
import
logging
import
torch
from
os
import
path
as
osp
from
basicsr.data
import
build_dataloader
,
build_dataset
from
basicsr.models
import
build_model
from
basicsr.utils
import
get_env_info
,
get_root_logger
,
get_time_str
,
make_exp_dirs
from
basicsr.utils.options
import
dict2str
,
parse_options
def
test_pipeline
(
root_path
):
# parse options, set distributed setting, set ramdom seed
opt
,
_
=
parse_options
(
root_path
,
is_train
=
False
)
torch
.
backends
.
cudnn
.
benchmark
=
True
# torch.backends.cudnn.deterministic = True
# mkdir and initialize loggers
make_exp_dirs
(
opt
)
log_file
=
osp
.
join
(
opt
[
'path'
][
'log'
],
f
"test_
{
opt
[
'name'
]
}
_
{
get_time_str
()
}
.log"
)
logger
=
get_root_logger
(
logger_name
=
'basicsr'
,
log_level
=
logging
.
INFO
,
log_file
=
log_file
)
logger
.
info
(
get_env_info
())
logger
.
info
(
dict2str
(
opt
))
# create test dataset and dataloader
test_loaders
=
[]
for
_
,
dataset_opt
in
sorted
(
opt
[
'datasets'
].
items
()):
test_set
=
build_dataset
(
dataset_opt
)
test_loader
=
build_dataloader
(
test_set
,
dataset_opt
,
num_gpu
=
opt
[
'num_gpu'
],
dist
=
opt
[
'dist'
],
sampler
=
None
,
seed
=
opt
[
'manual_seed'
])
logger
.
info
(
f
"Number of test images in
{
dataset_opt
[
'name'
]
}
:
{
len
(
test_set
)
}
"
)
test_loaders
.
append
(
test_loader
)
# create model
model
=
build_model
(
opt
)
for
test_loader
in
test_loaders
:
test_set_name
=
test_loader
.
dataset
.
opt
[
'name'
]
logger
.
info
(
f
'Testing
{
test_set_name
}
...'
)
model
.
validation
(
test_loader
,
current_iter
=
opt
[
'name'
],
tb_logger
=
None
,
save_img
=
opt
[
'val'
][
'save_img'
])
if
__name__
==
'__main__'
:
root_path
=
osp
.
abspath
(
osp
.
join
(
__file__
,
osp
.
pardir
,
osp
.
pardir
))
test_pipeline
(
root_path
)
sr_model/Basicsr/basicsr/train.py
0 → 100644
View file @
5efcc6ff
import
datetime
import
logging
import
math
import
time
import
torch
from
os
import
path
as
osp
from
basicsr.data
import
build_dataloader
,
build_dataset
from
basicsr.data.data_sampler
import
EnlargedSampler
from
basicsr.data.prefetch_dataloader
import
CPUPrefetcher
,
CUDAPrefetcher
from
basicsr.models
import
build_model
from
basicsr.utils
import
(
AvgTimer
,
MessageLogger
,
check_resume
,
get_env_info
,
get_root_logger
,
get_time_str
,
init_tb_logger
,
init_wandb_logger
,
make_exp_dirs
,
mkdir_and_rename
,
scandir
)
from
basicsr.utils.options
import
copy_opt_file
,
dict2str
,
parse_options
def
init_tb_loggers
(
opt
):
# initialize wandb logger before tensorboard logger to allow proper sync
if
(
opt
[
'logger'
].
get
(
'wandb'
)
is
not
None
)
and
(
opt
[
'logger'
][
'wandb'
].
get
(
'project'
)
is
not
None
)
and
(
'debug'
not
in
opt
[
'name'
]):
assert
opt
[
'logger'
].
get
(
'use_tb_logger'
)
is
True
,
(
'should turn on tensorboard when using wandb'
)
init_wandb_logger
(
opt
)
tb_logger
=
None
if
opt
[
'logger'
].
get
(
'use_tb_logger'
)
and
'debug'
not
in
opt
[
'name'
]:
tb_logger
=
init_tb_logger
(
log_dir
=
osp
.
join
(
opt
[
'root_path'
],
'tb_logger'
,
opt
[
'name'
]))
return
tb_logger
def
create_train_val_dataloader
(
opt
,
logger
):
# create train and val dataloaders
train_loader
,
val_loaders
=
None
,
[]
for
phase
,
dataset_opt
in
opt
[
'datasets'
].
items
():
if
phase
==
'train'
:
dataset_enlarge_ratio
=
dataset_opt
.
get
(
'dataset_enlarge_ratio'
,
1
)
train_set
=
build_dataset
(
dataset_opt
)
train_sampler
=
EnlargedSampler
(
train_set
,
opt
[
'world_size'
],
opt
[
'rank'
],
dataset_enlarge_ratio
)
train_loader
=
build_dataloader
(
train_set
,
dataset_opt
,
num_gpu
=
opt
[
'num_gpu'
],
dist
=
opt
[
'dist'
],
sampler
=
train_sampler
,
seed
=
opt
[
'manual_seed'
])
num_iter_per_epoch
=
math
.
ceil
(
len
(
train_set
)
*
dataset_enlarge_ratio
/
(
dataset_opt
[
'batch_size_per_gpu'
]
*
opt
[
'world_size'
]))
total_iters
=
int
(
opt
[
'train'
][
'total_iter'
])
total_epochs
=
math
.
ceil
(
total_iters
/
(
num_iter_per_epoch
))
logger
.
info
(
'Training statistics:'
f
'
\n\t
Number of train images:
{
len
(
train_set
)
}
'
f
'
\n\t
Dataset enlarge ratio:
{
dataset_enlarge_ratio
}
'
f
'
\n\t
Batch size per gpu:
{
dataset_opt
[
"batch_size_per_gpu"
]
}
'
f
'
\n\t
World size (gpu number):
{
opt
[
"world_size"
]
}
'
f
'
\n\t
Require iter number per epoch:
{
num_iter_per_epoch
}
'
f
'
\n\t
Total epochs:
{
total_epochs
}
; iters:
{
total_iters
}
.'
)
elif
phase
.
split
(
'_'
)[
0
]
==
'val'
:
val_set
=
build_dataset
(
dataset_opt
)
val_loader
=
build_dataloader
(
val_set
,
dataset_opt
,
num_gpu
=
opt
[
'num_gpu'
],
dist
=
opt
[
'dist'
],
sampler
=
None
,
seed
=
opt
[
'manual_seed'
])
logger
.
info
(
f
'Number of val images/folders in
{
dataset_opt
[
"name"
]
}
:
{
len
(
val_set
)
}
'
)
val_loaders
.
append
(
val_loader
)
else
:
raise
ValueError
(
f
'Dataset phase
{
phase
}
is not recognized.'
)
return
train_loader
,
train_sampler
,
val_loaders
,
total_epochs
,
total_iters
def
load_resume_state
(
opt
):
resume_state_path
=
None
if
opt
[
'auto_resume'
]:
state_path
=
osp
.
join
(
'experiments'
,
opt
[
'name'
],
'training_states'
)
if
osp
.
isdir
(
state_path
):
states
=
list
(
scandir
(
state_path
,
suffix
=
'state'
,
recursive
=
False
,
full_path
=
False
))
if
len
(
states
)
!=
0
:
states
=
[
float
(
v
.
split
(
'.state'
)[
0
])
for
v
in
states
]
resume_state_path
=
osp
.
join
(
state_path
,
f
'
{
max
(
states
):.
0
f
}
.state'
)
opt
[
'path'
][
'resume_state'
]
=
resume_state_path
else
:
if
opt
[
'path'
].
get
(
'resume_state'
):
resume_state_path
=
opt
[
'path'
][
'resume_state'
]
if
resume_state_path
is
None
:
resume_state
=
None
else
:
device_id
=
torch
.
cuda
.
current_device
()
resume_state
=
torch
.
load
(
resume_state_path
,
map_location
=
lambda
storage
,
loc
:
storage
.
cuda
(
device_id
))
check_resume
(
opt
,
resume_state
[
'iter'
])
return
resume_state
def
train_pipeline
(
root_path
):
# parse options, set distributed setting, set random seed
opt
,
args
=
parse_options
(
root_path
,
is_train
=
True
)
opt
[
'root_path'
]
=
root_path
# torch.backends.cudnn.benchmark = True
torch
.
backends
.
cudnn
.
deterministic
=
True
# load resume states if necessary
resume_state
=
load_resume_state
(
opt
)
# mkdir for experiments and logger
if
resume_state
is
None
:
make_exp_dirs
(
opt
)
if
opt
[
'logger'
].
get
(
'use_tb_logger'
)
and
'debug'
not
in
opt
[
'name'
]
and
opt
[
'rank'
]
==
0
:
mkdir_and_rename
(
osp
.
join
(
opt
[
'root_path'
],
'tb_logger'
,
opt
[
'name'
]))
# copy the yml file to the experiment root
copy_opt_file
(
args
.
opt
,
opt
[
'path'
][
'experiments_root'
])
# WARNING: should not use get_root_logger in the above codes, including the called functions
# Otherwise the logger will not be properly initialized
log_file
=
osp
.
join
(
opt
[
'path'
][
'log'
],
f
"train_
{
opt
[
'name'
]
}
_
{
get_time_str
()
}
.log"
)
logger
=
get_root_logger
(
logger_name
=
'basicsr'
,
log_level
=
logging
.
INFO
,
log_file
=
log_file
)
logger
.
info
(
get_env_info
())
logger
.
info
(
dict2str
(
opt
))
# initialize wandb and tb loggers
tb_logger
=
init_tb_loggers
(
opt
)
# create train and validation dataloaders
result
=
create_train_val_dataloader
(
opt
,
logger
)
train_loader
,
train_sampler
,
val_loaders
,
total_epochs
,
total_iters
=
result
# create model
model
=
build_model
(
opt
)
if
resume_state
:
# resume training
model
.
resume_training
(
resume_state
)
# handle optimizers and schedulers
logger
.
info
(
f
"Resuming training from epoch:
{
resume_state
[
'epoch'
]
}
, iter:
{
resume_state
[
'iter'
]
}
."
)
start_epoch
=
resume_state
[
'epoch'
]
current_iter
=
resume_state
[
'iter'
]
else
:
start_epoch
=
0
current_iter
=
0
# create message logger (formatted outputs)
msg_logger
=
MessageLogger
(
opt
,
current_iter
,
tb_logger
)
# dataloader prefetcher
prefetch_mode
=
opt
[
'datasets'
][
'train'
].
get
(
'prefetch_mode'
)
if
prefetch_mode
is
None
or
prefetch_mode
==
'cpu'
:
prefetcher
=
CPUPrefetcher
(
train_loader
)
elif
prefetch_mode
==
'cuda'
:
prefetcher
=
CUDAPrefetcher
(
train_loader
,
opt
)
logger
.
info
(
f
'Use
{
prefetch_mode
}
prefetch dataloader'
)
if
opt
[
'datasets'
][
'train'
].
get
(
'pin_memory'
)
is
not
True
:
raise
ValueError
(
'Please set pin_memory=True for CUDAPrefetcher.'
)
else
:
raise
ValueError
(
f
"Wrong prefetch_mode
{
prefetch_mode
}
. Supported ones are: None, 'cuda', 'cpu'."
)
if
opt
.
get
(
'val'
)
is
not
None
and
(
opt
[
'path'
].
get
(
'pretrain_network_g'
,
None
)
is
not
None
):
if
len
(
val_loaders
)
>
1
:
logger
.
warning
(
'Multiple validation datasets are *only* supported by SRModel.'
)
for
val_loader
in
val_loaders
:
model
.
validation
(
val_loader
,
current_iter
,
tb_logger
,
opt
[
'val'
][
'save_img'
])
# training
logger
.
info
(
f
'Start training from epoch:
{
start_epoch
}
, iter:
{
current_iter
}
'
)
data_timer
,
iter_timer
=
AvgTimer
(),
AvgTimer
()
start_time
=
time
.
time
()
for
epoch
in
range
(
start_epoch
,
total_epochs
+
1
):
train_sampler
.
set_epoch
(
epoch
)
prefetcher
.
reset
()
train_data
=
prefetcher
.
next
()
while
train_data
is
not
None
:
data_timer
.
record
()
current_iter
+=
1
if
current_iter
>
total_iters
:
break
# update learning rate
model
.
update_learning_rate
(
current_iter
,
warmup_iter
=
opt
[
'train'
].
get
(
'warmup_iter'
,
-
1
))
# training
model
.
feed_data
(
train_data
)
model
.
optimize_parameters
(
current_iter
)
iter_timer
.
record
()
if
current_iter
==
1
:
# reset start time in msg_logger for more accurate eta_time
# not work in resume mode
msg_logger
.
reset_start_time
()
# log
if
current_iter
%
opt
[
'logger'
][
'print_freq'
]
==
0
:
log_vars
=
{
'epoch'
:
epoch
,
'iter'
:
current_iter
}
log_vars
.
update
({
'lrs'
:
model
.
get_current_learning_rate
()})
log_vars
.
update
({
'time'
:
iter_timer
.
get_avg_time
(),
'data_time'
:
data_timer
.
get_avg_time
()})
log_vars
.
update
(
model
.
get_current_log
())
msg_logger
(
log_vars
)
if
'print_module_para'
in
opt
[
'logger'
]:
logger
=
get_root_logger
()
def
print_stat
(
a
):
return
f
"shape=
{
a
.
shape
}
, min=
{
a
.
min
():.
2
f
}
, median=
{
a
.
median
():.
2
f
}
, max=
{
a
.
max
():.
2
f
}
, var=
{
a
.
var
():.
2
f
}
,
{
a
.
flatten
()[
0
]
}
"
for
module_name
in
opt
[
'logger'
][
"print_module_para"
]:
if
hasattr
(
model
.
get_bare_model
(
model
.
net_g
),
module_name
):
# if 'module_name' in
module
=
getattr
(
model
.
get_bare_model
(
model
.
net_g
),
module_name
)
if
len
(
list
(
module
.
named_parameters
()))
>
0
and
len
(
list
(
module
.
named_parameters
())[
0
])
>
0
:
p
=
list
(
module
.
named_parameters
())[
0
][
1
]
# print_stat(p)
# logger = get_root_logger()
logger
.
info
(
f
"parameters of
{
module_name
}
"
+
print_stat
(
p
)
+
f
" require grad
{
p
.
requires_grad
}
"
)
# save models and training states
if
current_iter
%
opt
[
'logger'
][
'save_checkpoint_freq'
]
==
0
or
current_iter
==
1
:
logger
.
info
(
'Saving models and training states.'
)
model
.
save
(
epoch
,
current_iter
)
if
hasattr
(
model
,
"save_during_training"
):
model
.
save_during_training
(
current_iter
)
# validation
if
opt
.
get
(
'val'
)
is
not
None
and
((
current_iter
%
opt
[
'val'
][
'val_freq'
]
==
0
)
or
current_iter
==
1
):
if
len
(
val_loaders
)
>
1
:
logger
.
warning
(
'Multiple validation datasets are *only* supported by SRModel.'
)
for
val_loader
in
val_loaders
:
model
.
validation
(
val_loader
,
current_iter
,
tb_logger
,
opt
[
'val'
][
'save_img'
])
data_timer
.
start
()
iter_timer
.
start
()
train_data
=
prefetcher
.
next
()
# end of iter
# end of epoch
consumed_time
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
time
.
time
()
-
start_time
)))
logger
.
info
(
f
'End of training. Time consumed:
{
consumed_time
}
'
)
logger
.
info
(
'Save the latest model.'
)
model
.
save
(
epoch
=-
1
,
current_iter
=-
1
)
# -1 stands for the latest
if
opt
.
get
(
'val'
)
is
not
None
:
for
val_loader
in
val_loaders
:
model
.
validation
(
val_loader
,
current_iter
,
tb_logger
,
opt
[
'val'
][
'save_img'
])
if
tb_logger
:
tb_logger
.
close
()
if
__name__
==
'__main__'
:
root_path
=
osp
.
abspath
(
osp
.
join
(
__file__
,
osp
.
pardir
,
osp
.
pardir
))
train_pipeline
(
root_path
)
sr_model/Basicsr/basicsr/utils/__init__.py
0 → 100644
View file @
5efcc6ff
from
.color_util
import
bgr2ycbcr
,
rgb2ycbcr
,
rgb2ycbcr_pt
,
ycbcr2bgr
,
ycbcr2rgb
from
.diffjpeg
import
DiffJPEG
from
.file_client
import
FileClient
from
.img_process_util
import
USMSharp
,
usm_sharp
from
.img_util
import
crop_border
,
imfrombytes
,
img2tensor
,
imwrite
,
tensor2img
,
folder_to_concat_folder
,
folder_to_video
from
.logger
import
AvgTimer
,
MessageLogger
,
get_env_info
,
get_root_logger
,
init_tb_logger
,
init_wandb_logger
from
.misc
import
check_resume
,
get_time_str
,
make_exp_dirs
,
mkdir_and_rename
,
scandir
,
set_random_seed
,
sizeof_fmt
__all__
=
[
# color_util.py
'bgr2ycbcr'
,
'rgb2ycbcr'
,
'rgb2ycbcr_pt'
,
'ycbcr2bgr'
,
'ycbcr2rgb'
,
# file_client.py
'FileClient'
,
# img_util.py
'img2tensor'
,
'tensor2img'
,
'imfrombytes'
,
'imwrite'
,
'crop_border'
,
'folder_to_concat_folder'
,
'folder_to_video'
,
# logger.py
'MessageLogger'
,
'AvgTimer'
,
'init_tb_logger'
,
'init_wandb_logger'
,
'get_root_logger'
,
'get_env_info'
,
# misc.py
'set_random_seed'
,
'get_time_str'
,
'mkdir_and_rename'
,
'make_exp_dirs'
,
'scandir'
,
'check_resume'
,
'sizeof_fmt'
,
# diffjpeg
'DiffJPEG'
,
# img_process_util
'USMSharp'
,
'usm_sharp'
]
sr_model/Basicsr/basicsr/utils/color_util.py
0 → 100644
View file @
5efcc6ff
import
numpy
as
np
import
torch
def
rgb2ycbcr
(
img
,
y_only
=
False
):
"""Convert a RGB image to YCbCr image.
This function produces the same results as Matlab's `rgb2ycbcr` function.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
y_only (bool): Whether to only return Y channel. Default: False.
Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
"""
img_type
=
img
.
dtype
img
=
_convert_input_type_range
(
img
)
if
y_only
:
out_img
=
np
.
dot
(
img
,
[
65.481
,
128.553
,
24.966
])
+
16.0
else
:
out_img
=
np
.
matmul
(
img
,
[[
65.481
,
-
37.797
,
112.0
],
[
128.553
,
-
74.203
,
-
93.786
],
[
24.966
,
112.0
,
-
18.214
]])
+
[
16
,
128
,
128
]
out_img
=
_convert_output_type_range
(
out_img
,
img_type
)
return
out_img
def
bgr2ycbcr
(
img
,
y_only
=
False
):
"""Convert a BGR image to YCbCr image.
The bgr version of rgb2ycbcr.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
y_only (bool): Whether to only return Y channel. Default: False.
Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
"""
img_type
=
img
.
dtype
img
=
_convert_input_type_range
(
img
)
if
y_only
:
out_img
=
np
.
dot
(
img
,
[
24.966
,
128.553
,
65.481
])
+
16.0
else
:
out_img
=
np
.
matmul
(
img
,
[[
24.966
,
112.0
,
-
18.214
],
[
128.553
,
-
74.203
,
-
93.786
],
[
65.481
,
-
37.797
,
112.0
]])
+
[
16
,
128
,
128
]
out_img
=
_convert_output_type_range
(
out_img
,
img_type
)
return
out_img
def
ycbcr2rgb
(
img
):
"""Convert a YCbCr image to RGB image.
This function produces the same results as Matlab's ycbcr2rgb function.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
Returns:
ndarray: The converted RGB image. The output image has the same type
and range as input image.
"""
img_type
=
img
.
dtype
img
=
_convert_input_type_range
(
img
)
*
255
out_img
=
np
.
matmul
(
img
,
[[
0.00456621
,
0.00456621
,
0.00456621
],
[
0
,
-
0.00153632
,
0.00791071
],
[
0.00625893
,
-
0.00318811
,
0
]])
*
255.0
+
[
-
222.921
,
135.576
,
-
276.836
]
# noqa: E126
out_img
=
_convert_output_type_range
(
out_img
,
img_type
)
return
out_img
def
ycbcr2bgr
(
img
):
"""Convert a YCbCr image to BGR image.
The bgr version of ycbcr2rgb.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
Returns:
ndarray: The converted BGR image. The output image has the same type
and range as input image.
"""
img_type
=
img
.
dtype
img
=
_convert_input_type_range
(
img
)
*
255
out_img
=
np
.
matmul
(
img
,
[[
0.00456621
,
0.00456621
,
0.00456621
],
[
0.00791071
,
-
0.00153632
,
0
],
[
0
,
-
0.00318811
,
0.00625893
]])
*
255.0
+
[
-
276.836
,
135.576
,
-
222.921
]
# noqa: E126
out_img
=
_convert_output_type_range
(
out_img
,
img_type
)
return
out_img
def
_convert_input_type_range
(
img
):
"""Convert the type and range of the input image.
It converts the input image to np.float32 type and range of [0, 1].
It is mainly used for pre-processing the input image in colorspace
conversion functions such as rgb2ycbcr and ycbcr2rgb.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
Returns:
(ndarray): The converted image with type of np.float32 and range of
[0, 1].
"""
img_type
=
img
.
dtype
img
=
img
.
astype
(
np
.
float32
)
if
img_type
==
np
.
float32
:
pass
elif
img_type
==
np
.
uint8
:
img
/=
255.
else
:
raise
TypeError
(
f
'The img type should be np.float32 or np.uint8, but got
{
img_type
}
'
)
return
img
def
_convert_output_type_range
(
img
,
dst_type
):
"""Convert the type and range of the image according to dst_type.
It converts the image to desired type and range. If `dst_type` is np.uint8,
images will be converted to np.uint8 type with range [0, 255]. If
`dst_type` is np.float32, it converts the image to np.float32 type with
range [0, 1].
It is mainly used for post-processing images in colorspace conversion
functions such as rgb2ycbcr and ycbcr2rgb.
Args:
img (ndarray): The image to be converted with np.float32 type and
range [0, 255].
dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
converts the image to np.uint8 type with range [0, 255]. If
dst_type is np.float32, it converts the image to np.float32 type
with range [0, 1].
Returns:
(ndarray): The converted image with desired type and range.
"""
if
dst_type
not
in
(
np
.
uint8
,
np
.
float32
):
raise
TypeError
(
f
'The dst_type should be np.float32 or np.uint8, but got
{
dst_type
}
'
)
if
dst_type
==
np
.
uint8
:
img
=
img
.
round
()
else
:
img
/=
255.
return
img
.
astype
(
dst_type
)
def
rgb2ycbcr_pt
(
img
,
y_only
=
False
):
"""Convert RGB images to YCbCr images (PyTorch version).
It implements the ITU-R BT.601 conversion for standard-definition television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
Args:
img (Tensor): Images with shape (n, 3, h, w), the range [0, 1], float, RGB format.
y_only (bool): Whether to only return Y channel. Default: False.
Returns:
(Tensor): converted images with the shape (n, 3/1, h, w), the range [0, 1], float.
"""
if
y_only
:
weight
=
torch
.
tensor
([[
65.481
],
[
128.553
],
[
24.966
]]).
to
(
img
)
out_img
=
torch
.
matmul
(
img
.
permute
(
0
,
2
,
3
,
1
),
weight
).
permute
(
0
,
3
,
1
,
2
)
+
16.0
else
:
weight
=
torch
.
tensor
([[
65.481
,
-
37.797
,
112.0
],
[
128.553
,
-
74.203
,
-
93.786
],
[
24.966
,
112.0
,
-
18.214
]]).
to
(
img
)
bias
=
torch
.
tensor
([
16
,
128
,
128
]).
view
(
1
,
3
,
1
,
1
).
to
(
img
)
out_img
=
torch
.
matmul
(
img
.
permute
(
0
,
2
,
3
,
1
),
weight
).
permute
(
0
,
3
,
1
,
2
)
+
bias
out_img
=
out_img
/
255.
return
out_img
sr_model/Basicsr/basicsr/utils/diffjpeg.py
0 → 100644
View file @
5efcc6ff
"""
Modified from https://github.com/mlomnitz/DiffJPEG
For images not divisible by 8
https://dsp.stackexchange.com/questions/35339/jpeg-dct-padding/35343#35343
"""
import
itertools
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
torch.nn
import
functional
as
F
# ------------------------ utils ------------------------#
y_table
=
np
.
array
(
[[
16
,
11
,
10
,
16
,
24
,
40
,
51
,
61
],
[
12
,
12
,
14
,
19
,
26
,
58
,
60
,
55
],
[
14
,
13
,
16
,
24
,
40
,
57
,
69
,
56
],
[
14
,
17
,
22
,
29
,
51
,
87
,
80
,
62
],
[
18
,
22
,
37
,
56
,
68
,
109
,
103
,
77
],
[
24
,
35
,
55
,
64
,
81
,
104
,
113
,
92
],
[
49
,
64
,
78
,
87
,
103
,
121
,
120
,
101
],
[
72
,
92
,
95
,
98
,
112
,
100
,
103
,
99
]],
dtype
=
np
.
float32
).
T
y_table
=
nn
.
Parameter
(
torch
.
from_numpy
(
y_table
))
c_table
=
np
.
empty
((
8
,
8
),
dtype
=
np
.
float32
)
c_table
.
fill
(
99
)
c_table
[:
4
,
:
4
]
=
np
.
array
([[
17
,
18
,
24
,
47
],
[
18
,
21
,
26
,
66
],
[
24
,
26
,
56
,
99
],
[
47
,
66
,
99
,
99
]]).
T
c_table
=
nn
.
Parameter
(
torch
.
from_numpy
(
c_table
))
def
diff_round
(
x
):
""" Differentiable rounding function
"""
return
torch
.
round
(
x
)
+
(
x
-
torch
.
round
(
x
))
**
3
def
quality_to_factor
(
quality
):
""" Calculate factor corresponding to quality
Args:
quality(float): Quality for jpeg compression.
Returns:
float: Compression factor.
"""
if
quality
<
50
:
quality
=
5000.
/
quality
else
:
quality
=
200.
-
quality
*
2
return
quality
/
100.
# ------------------------ compression ------------------------#
class
RGB2YCbCrJpeg
(
nn
.
Module
):
""" Converts RGB image to YCbCr
"""
def
__init__
(
self
):
super
(
RGB2YCbCrJpeg
,
self
).
__init__
()
matrix
=
np
.
array
([[
0.299
,
0.587
,
0.114
],
[
-
0.168736
,
-
0.331264
,
0.5
],
[
0.5
,
-
0.418688
,
-
0.081312
]],
dtype
=
np
.
float32
).
T
self
.
shift
=
nn
.
Parameter
(
torch
.
tensor
([
0.
,
128.
,
128.
]))
self
.
matrix
=
nn
.
Parameter
(
torch
.
from_numpy
(
matrix
))
def
forward
(
self
,
image
):
"""
Args:
image(Tensor): batch x 3 x height x width
Returns:
Tensor: batch x height x width x 3
"""
image
=
image
.
permute
(
0
,
2
,
3
,
1
)
result
=
torch
.
tensordot
(
image
,
self
.
matrix
,
dims
=
1
)
+
self
.
shift
return
result
.
view
(
image
.
shape
)
class
ChromaSubsampling
(
nn
.
Module
):
""" Chroma subsampling on CbCr channels
"""
def
__init__
(
self
):
super
(
ChromaSubsampling
,
self
).
__init__
()
def
forward
(
self
,
image
):
"""
Args:
image(tensor): batch x height x width x 3
Returns:
y(tensor): batch x height x width
cb(tensor): batch x height/2 x width/2
cr(tensor): batch x height/2 x width/2
"""
image_2
=
image
.
permute
(
0
,
3
,
1
,
2
).
clone
()
cb
=
F
.
avg_pool2d
(
image_2
[:,
1
,
:,
:].
unsqueeze
(
1
),
kernel_size
=
2
,
stride
=
(
2
,
2
),
count_include_pad
=
False
)
cr
=
F
.
avg_pool2d
(
image_2
[:,
2
,
:,
:].
unsqueeze
(
1
),
kernel_size
=
2
,
stride
=
(
2
,
2
),
count_include_pad
=
False
)
cb
=
cb
.
permute
(
0
,
2
,
3
,
1
)
cr
=
cr
.
permute
(
0
,
2
,
3
,
1
)
return
image
[:,
:,
:,
0
],
cb
.
squeeze
(
3
),
cr
.
squeeze
(
3
)
class
BlockSplitting
(
nn
.
Module
):
""" Splitting image into patches
"""
def
__init__
(
self
):
super
(
BlockSplitting
,
self
).
__init__
()
self
.
k
=
8
def
forward
(
self
,
image
):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x h*w/64 x h x w
"""
height
,
_
=
image
.
shape
[
1
:
3
]
batch_size
=
image
.
shape
[
0
]
image_reshaped
=
image
.
view
(
batch_size
,
height
//
self
.
k
,
self
.
k
,
-
1
,
self
.
k
)
image_transposed
=
image_reshaped
.
permute
(
0
,
1
,
3
,
2
,
4
)
return
image_transposed
.
contiguous
().
view
(
batch_size
,
-
1
,
self
.
k
,
self
.
k
)
class
DCT8x8
(
nn
.
Module
):
""" Discrete Cosine Transformation
"""
def
__init__
(
self
):
super
(
DCT8x8
,
self
).
__init__
()
tensor
=
np
.
zeros
((
8
,
8
,
8
,
8
),
dtype
=
np
.
float32
)
for
x
,
y
,
u
,
v
in
itertools
.
product
(
range
(
8
),
repeat
=
4
):
tensor
[
x
,
y
,
u
,
v
]
=
np
.
cos
((
2
*
x
+
1
)
*
u
*
np
.
pi
/
16
)
*
np
.
cos
((
2
*
y
+
1
)
*
v
*
np
.
pi
/
16
)
alpha
=
np
.
array
([
1.
/
np
.
sqrt
(
2
)]
+
[
1
]
*
7
)
self
.
tensor
=
nn
.
Parameter
(
torch
.
from_numpy
(
tensor
).
float
())
self
.
scale
=
nn
.
Parameter
(
torch
.
from_numpy
(
np
.
outer
(
alpha
,
alpha
)
*
0.25
).
float
())
def
forward
(
self
,
image
):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x height x width
"""
image
=
image
-
128
result
=
self
.
scale
*
torch
.
tensordot
(
image
,
self
.
tensor
,
dims
=
2
)
result
.
view
(
image
.
shape
)
return
result
class
YQuantize
(
nn
.
Module
):
""" JPEG Quantization for Y channel
Args:
rounding(function): rounding function to use
"""
def
__init__
(
self
,
rounding
):
super
(
YQuantize
,
self
).
__init__
()
self
.
rounding
=
rounding
self
.
y_table
=
y_table
def
forward
(
self
,
image
,
factor
=
1
):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x height x width
"""
if
isinstance
(
factor
,
(
int
,
float
)):
image
=
image
.
float
()
/
(
self
.
y_table
*
factor
)
else
:
b
=
factor
.
size
(
0
)
table
=
self
.
y_table
.
expand
(
b
,
1
,
8
,
8
)
*
factor
.
view
(
b
,
1
,
1
,
1
)
image
=
image
.
float
()
/
table
image
=
self
.
rounding
(
image
)
return
image
class
CQuantize
(
nn
.
Module
):
""" JPEG Quantization for CbCr channels
Args:
rounding(function): rounding function to use
"""
def
__init__
(
self
,
rounding
):
super
(
CQuantize
,
self
).
__init__
()
self
.
rounding
=
rounding
self
.
c_table
=
c_table
def
forward
(
self
,
image
,
factor
=
1
):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x height x width
"""
if
isinstance
(
factor
,
(
int
,
float
)):
image
=
image
.
float
()
/
(
self
.
c_table
*
factor
)
else
:
b
=
factor
.
size
(
0
)
table
=
self
.
c_table
.
expand
(
b
,
1
,
8
,
8
)
*
factor
.
view
(
b
,
1
,
1
,
1
)
image
=
image
.
float
()
/
table
image
=
self
.
rounding
(
image
)
return
image
class
CompressJpeg
(
nn
.
Module
):
"""Full JPEG compression algorithm
Args:
rounding(function): rounding function to use
"""
def
__init__
(
self
,
rounding
=
torch
.
round
):
super
(
CompressJpeg
,
self
).
__init__
()
self
.
l1
=
nn
.
Sequential
(
RGB2YCbCrJpeg
(),
ChromaSubsampling
())
self
.
l2
=
nn
.
Sequential
(
BlockSplitting
(),
DCT8x8
())
self
.
c_quantize
=
CQuantize
(
rounding
=
rounding
)
self
.
y_quantize
=
YQuantize
(
rounding
=
rounding
)
def
forward
(
self
,
image
,
factor
=
1
):
"""
Args:
image(tensor): batch x 3 x height x width
Returns:
dict(tensor): Compressed tensor with batch x h*w/64 x 8 x 8.
"""
y
,
cb
,
cr
=
self
.
l1
(
image
*
255
)
components
=
{
'y'
:
y
,
'cb'
:
cb
,
'cr'
:
cr
}
for
k
in
components
.
keys
():
comp
=
self
.
l2
(
components
[
k
])
if
k
in
(
'cb'
,
'cr'
):
comp
=
self
.
c_quantize
(
comp
,
factor
=
factor
)
else
:
comp
=
self
.
y_quantize
(
comp
,
factor
=
factor
)
components
[
k
]
=
comp
return
components
[
'y'
],
components
[
'cb'
],
components
[
'cr'
]
# ------------------------ decompression ------------------------#
class
YDequantize
(
nn
.
Module
):
"""Dequantize Y channel
"""
def
__init__
(
self
):
super
(
YDequantize
,
self
).
__init__
()
self
.
y_table
=
y_table
def
forward
(
self
,
image
,
factor
=
1
):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x height x width
"""
if
isinstance
(
factor
,
(
int
,
float
)):
out
=
image
*
(
self
.
y_table
*
factor
)
else
:
b
=
factor
.
size
(
0
)
table
=
self
.
y_table
.
expand
(
b
,
1
,
8
,
8
)
*
factor
.
view
(
b
,
1
,
1
,
1
)
out
=
image
*
table
return
out
class
CDequantize
(
nn
.
Module
):
"""Dequantize CbCr channel
"""
def
__init__
(
self
):
super
(
CDequantize
,
self
).
__init__
()
self
.
c_table
=
c_table
def
forward
(
self
,
image
,
factor
=
1
):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x height x width
"""
if
isinstance
(
factor
,
(
int
,
float
)):
out
=
image
*
(
self
.
c_table
*
factor
)
else
:
b
=
factor
.
size
(
0
)
table
=
self
.
c_table
.
expand
(
b
,
1
,
8
,
8
)
*
factor
.
view
(
b
,
1
,
1
,
1
)
out
=
image
*
table
return
out
class
iDCT8x8
(
nn
.
Module
):
"""Inverse discrete Cosine Transformation
"""
def
__init__
(
self
):
super
(
iDCT8x8
,
self
).
__init__
()
alpha
=
np
.
array
([
1.
/
np
.
sqrt
(
2
)]
+
[
1
]
*
7
)
self
.
alpha
=
nn
.
Parameter
(
torch
.
from_numpy
(
np
.
outer
(
alpha
,
alpha
)).
float
())
tensor
=
np
.
zeros
((
8
,
8
,
8
,
8
),
dtype
=
np
.
float32
)
for
x
,
y
,
u
,
v
in
itertools
.
product
(
range
(
8
),
repeat
=
4
):
tensor
[
x
,
y
,
u
,
v
]
=
np
.
cos
((
2
*
u
+
1
)
*
x
*
np
.
pi
/
16
)
*
np
.
cos
((
2
*
v
+
1
)
*
y
*
np
.
pi
/
16
)
self
.
tensor
=
nn
.
Parameter
(
torch
.
from_numpy
(
tensor
).
float
())
def
forward
(
self
,
image
):
"""
Args:
image(tensor): batch x height x width
Returns:
Tensor: batch x height x width
"""
image
=
image
*
self
.
alpha
result
=
0.25
*
torch
.
tensordot
(
image
,
self
.
tensor
,
dims
=
2
)
+
128
result
.
view
(
image
.
shape
)
return
result
class
BlockMerging
(
nn
.
Module
):
"""Merge patches into image
"""
def
__init__
(
self
):
super
(
BlockMerging
,
self
).
__init__
()
def
forward
(
self
,
patches
,
height
,
width
):
"""
Args:
patches(tensor) batch x height*width/64, height x width
height(int)
width(int)
Returns:
Tensor: batch x height x width
"""
k
=
8
batch_size
=
patches
.
shape
[
0
]
image_reshaped
=
patches
.
view
(
batch_size
,
height
//
k
,
width
//
k
,
k
,
k
)
image_transposed
=
image_reshaped
.
permute
(
0
,
1
,
3
,
2
,
4
)
return
image_transposed
.
contiguous
().
view
(
batch_size
,
height
,
width
)
class
ChromaUpsampling
(
nn
.
Module
):
"""Upsample chroma layers
"""
def
__init__
(
self
):
super
(
ChromaUpsampling
,
self
).
__init__
()
def
forward
(
self
,
y
,
cb
,
cr
):
"""
Args:
y(tensor): y channel image
cb(tensor): cb channel
cr(tensor): cr channel
Returns:
Tensor: batch x height x width x 3
"""
def
repeat
(
x
,
k
=
2
):
height
,
width
=
x
.
shape
[
1
:
3
]
x
=
x
.
unsqueeze
(
-
1
)
x
=
x
.
repeat
(
1
,
1
,
k
,
k
)
x
=
x
.
view
(
-
1
,
height
*
k
,
width
*
k
)
return
x
cb
=
repeat
(
cb
)
cr
=
repeat
(
cr
)
return
torch
.
cat
([
y
.
unsqueeze
(
3
),
cb
.
unsqueeze
(
3
),
cr
.
unsqueeze
(
3
)],
dim
=
3
)
class
YCbCr2RGBJpeg
(
nn
.
Module
):
"""Converts YCbCr image to RGB JPEG
"""
def
__init__
(
self
):
super
(
YCbCr2RGBJpeg
,
self
).
__init__
()
matrix
=
np
.
array
([[
1.
,
0.
,
1.402
],
[
1
,
-
0.344136
,
-
0.714136
],
[
1
,
1.772
,
0
]],
dtype
=
np
.
float32
).
T
self
.
shift
=
nn
.
Parameter
(
torch
.
tensor
([
0
,
-
128.
,
-
128.
]))
self
.
matrix
=
nn
.
Parameter
(
torch
.
from_numpy
(
matrix
))
def
forward
(
self
,
image
):
"""
Args:
image(tensor): batch x height x width x 3
Returns:
Tensor: batch x 3 x height x width
"""
result
=
torch
.
tensordot
(
image
+
self
.
shift
,
self
.
matrix
,
dims
=
1
)
return
result
.
view
(
image
.
shape
).
permute
(
0
,
3
,
1
,
2
)
class
DeCompressJpeg
(
nn
.
Module
):
"""Full JPEG decompression algorithm
Args:
rounding(function): rounding function to use
"""
def
__init__
(
self
,
rounding
=
torch
.
round
):
super
(
DeCompressJpeg
,
self
).
__init__
()
self
.
c_dequantize
=
CDequantize
()
self
.
y_dequantize
=
YDequantize
()
self
.
idct
=
iDCT8x8
()
self
.
merging
=
BlockMerging
()
self
.
chroma
=
ChromaUpsampling
()
self
.
colors
=
YCbCr2RGBJpeg
()
def
forward
(
self
,
y
,
cb
,
cr
,
imgh
,
imgw
,
factor
=
1
):
"""
Args:
compressed(dict(tensor)): batch x h*w/64 x 8 x 8
imgh(int)
imgw(int)
factor(float)
Returns:
Tensor: batch x 3 x height x width
"""
components
=
{
'y'
:
y
,
'cb'
:
cb
,
'cr'
:
cr
}
for
k
in
components
.
keys
():
if
k
in
(
'cb'
,
'cr'
):
comp
=
self
.
c_dequantize
(
components
[
k
],
factor
=
factor
)
height
,
width
=
int
(
imgh
/
2
),
int
(
imgw
/
2
)
else
:
comp
=
self
.
y_dequantize
(
components
[
k
],
factor
=
factor
)
height
,
width
=
imgh
,
imgw
comp
=
self
.
idct
(
comp
)
components
[
k
]
=
self
.
merging
(
comp
,
height
,
width
)
#
image
=
self
.
chroma
(
components
[
'y'
],
components
[
'cb'
],
components
[
'cr'
])
image
=
self
.
colors
(
image
)
image
=
torch
.
min
(
255
*
torch
.
ones_like
(
image
),
torch
.
max
(
torch
.
zeros_like
(
image
),
image
))
return
image
/
255
# ------------------------ main DiffJPEG ------------------------ #
class
DiffJPEG
(
nn
.
Module
):
"""This JPEG algorithm result is slightly different from cv2.
DiffJPEG supports batch processing.
Args:
differentiable(bool): If True, uses custom differentiable rounding function, if False, uses standard torch.round
"""
def
__init__
(
self
,
differentiable
=
True
):
super
(
DiffJPEG
,
self
).
__init__
()
if
differentiable
:
rounding
=
diff_round
else
:
rounding
=
torch
.
round
self
.
compress
=
CompressJpeg
(
rounding
=
rounding
)
self
.
decompress
=
DeCompressJpeg
(
rounding
=
rounding
)
def
forward
(
self
,
x
,
quality
):
"""
Args:
x (Tensor): Input image, bchw, rgb, [0, 1]
quality(float): Quality factor for jpeg compression scheme.
"""
factor
=
quality
if
isinstance
(
factor
,
(
int
,
float
)):
factor
=
quality_to_factor
(
factor
)
else
:
for
i
in
range
(
factor
.
size
(
0
)):
factor
[
i
]
=
quality_to_factor
(
factor
[
i
])
h
,
w
=
x
.
size
()[
-
2
:]
h_pad
,
w_pad
=
0
,
0
# why should use 16
if
h
%
16
!=
0
:
h_pad
=
16
-
h
%
16
if
w
%
16
!=
0
:
w_pad
=
16
-
w
%
16
x
=
F
.
pad
(
x
,
(
0
,
w_pad
,
0
,
h_pad
),
mode
=
'constant'
,
value
=
0
)
y
,
cb
,
cr
=
self
.
compress
(
x
,
factor
=
factor
)
recovered
=
self
.
decompress
(
y
,
cb
,
cr
,
(
h
+
h_pad
),
(
w
+
w_pad
),
factor
=
factor
)
recovered
=
recovered
[:,
:,
0
:
h
,
0
:
w
]
return
recovered
if
__name__
==
'__main__'
:
import
cv2
from
basicsr.utils
import
img2tensor
,
tensor2img
img_gt
=
cv2
.
imread
(
'test.png'
)
/
255.
# -------------- cv2 -------------- #
encode_param
=
[
int
(
cv2
.
IMWRITE_JPEG_QUALITY
),
20
]
_
,
encimg
=
cv2
.
imencode
(
'.jpg'
,
img_gt
*
255.
,
encode_param
)
img_lq
=
np
.
float32
(
cv2
.
imdecode
(
encimg
,
1
))
cv2
.
imwrite
(
'cv2_JPEG_20.png'
,
img_lq
)
# -------------- DiffJPEG -------------- #
jpeger
=
DiffJPEG
(
differentiable
=
False
).
cuda
()
img_gt
=
img2tensor
(
img_gt
)
img_gt
=
torch
.
stack
([
img_gt
,
img_gt
]).
cuda
()
quality
=
img_gt
.
new_tensor
([
20
,
40
])
out
=
jpeger
(
img_gt
,
quality
=
quality
)
cv2
.
imwrite
(
'pt_JPEG_20.png'
,
tensor2img
(
out
[
0
]))
cv2
.
imwrite
(
'pt_JPEG_40.png'
,
tensor2img
(
out
[
1
]))
sr_model/Basicsr/basicsr/utils/dist_util.py
0 → 100644
View file @
5efcc6ff
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
import
functools
import
os
import
subprocess
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
def
init_dist
(
launcher
,
backend
=
'nccl'
,
**
kwargs
):
if
mp
.
get_start_method
(
allow_none
=
True
)
is
None
:
mp
.
set_start_method
(
'spawn'
)
if
launcher
==
'pytorch'
:
_init_dist_pytorch
(
backend
,
**
kwargs
)
elif
launcher
==
'slurm'
:
_init_dist_slurm
(
backend
,
**
kwargs
)
else
:
raise
ValueError
(
f
'Invalid launcher type:
{
launcher
}
'
)
def
_init_dist_pytorch
(
backend
,
**
kwargs
):
rank
=
int
(
os
.
environ
[
'RANK'
])
num_gpus
=
torch
.
cuda
.
device_count
()
torch
.
cuda
.
set_device
(
rank
%
num_gpus
)
dist
.
init_process_group
(
backend
=
backend
,
**
kwargs
)
def
_init_dist_slurm
(
backend
,
port
=
None
):
"""Initialize slurm distributed training environment.
If argument ``port`` is not specified, then the master port will be system
environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
environment variable, then a default port ``29500`` will be used.
Args:
backend (str): Backend of torch.distributed.
port (int, optional): Master port. Defaults to None.
"""
proc_id
=
int
(
os
.
environ
[
'SLURM_PROCID'
])
ntasks
=
int
(
os
.
environ
[
'SLURM_NTASKS'
])
node_list
=
os
.
environ
[
'SLURM_NODELIST'
]
num_gpus
=
torch
.
cuda
.
device_count
()
torch
.
cuda
.
set_device
(
proc_id
%
num_gpus
)
addr
=
subprocess
.
getoutput
(
f
'scontrol show hostname
{
node_list
}
| head -n1'
)
# specify master port
if
port
is
not
None
:
os
.
environ
[
'MASTER_PORT'
]
=
str
(
port
)
elif
'MASTER_PORT'
in
os
.
environ
:
pass
# use MASTER_PORT in the environment variable
else
:
# 29500 is torch.distributed default port
os
.
environ
[
'MASTER_PORT'
]
=
'29500'
os
.
environ
[
'MASTER_ADDR'
]
=
addr
os
.
environ
[
'WORLD_SIZE'
]
=
str
(
ntasks
)
os
.
environ
[
'LOCAL_RANK'
]
=
str
(
proc_id
%
num_gpus
)
os
.
environ
[
'RANK'
]
=
str
(
proc_id
)
dist
.
init_process_group
(
backend
=
backend
)
def
get_dist_info
():
if
dist
.
is_available
():
initialized
=
dist
.
is_initialized
()
else
:
initialized
=
False
if
initialized
:
rank
=
dist
.
get_rank
()
world_size
=
dist
.
get_world_size
()
else
:
rank
=
0
world_size
=
1
return
rank
,
world_size
def
master_only
(
func
):
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
rank
,
_
=
get_dist_info
()
if
rank
==
0
:
return
func
(
*
args
,
**
kwargs
)
return
wrapper
sr_model/Basicsr/basicsr/utils/download_util.py
0 → 100644
View file @
5efcc6ff
import
math
import
os
import
requests
from
torch.hub
import
download_url_to_file
,
get_dir
from
tqdm
import
tqdm
from
urllib.parse
import
urlparse
from
.misc
import
sizeof_fmt
def
download_file_from_google_drive
(
file_id
,
save_path
):
"""Download files from google drive.
Ref:
https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
Args:
file_id (str): File id.
save_path (str): Save path.
"""
session
=
requests
.
Session
()
URL
=
'https://docs.google.com/uc?export=download'
params
=
{
'id'
:
file_id
}
response
=
session
.
get
(
URL
,
params
=
params
,
stream
=
True
)
token
=
get_confirm_token
(
response
)
if
token
:
params
[
'confirm'
]
=
token
response
=
session
.
get
(
URL
,
params
=
params
,
stream
=
True
)
# get file size
response_file_size
=
session
.
get
(
URL
,
params
=
params
,
stream
=
True
,
headers
=
{
'Range'
:
'bytes=0-2'
})
if
'Content-Range'
in
response_file_size
.
headers
:
file_size
=
int
(
response_file_size
.
headers
[
'Content-Range'
].
split
(
'/'
)[
1
])
else
:
file_size
=
None
save_response_content
(
response
,
save_path
,
file_size
)
def
get_confirm_token
(
response
):
for
key
,
value
in
response
.
cookies
.
items
():
if
key
.
startswith
(
'download_warning'
):
return
value
return
None
def
save_response_content
(
response
,
destination
,
file_size
=
None
,
chunk_size
=
32768
):
if
file_size
is
not
None
:
pbar
=
tqdm
(
total
=
math
.
ceil
(
file_size
/
chunk_size
),
unit
=
'chunk'
)
readable_file_size
=
sizeof_fmt
(
file_size
)
else
:
pbar
=
None
with
open
(
destination
,
'wb'
)
as
f
:
downloaded_size
=
0
for
chunk
in
response
.
iter_content
(
chunk_size
):
downloaded_size
+=
chunk_size
if
pbar
is
not
None
:
pbar
.
update
(
1
)
pbar
.
set_description
(
f
'Download
{
sizeof_fmt
(
downloaded_size
)
}
/
{
readable_file_size
}
'
)
if
chunk
:
# filter out keep-alive new chunks
f
.
write
(
chunk
)
if
pbar
is
not
None
:
pbar
.
close
()
def
load_file_from_url
(
url
,
model_dir
=
None
,
progress
=
True
,
file_name
=
None
):
"""Load file form http url, will download models if necessary.
Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
Args:
url (str): URL to be downloaded.
model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
Default: None.
progress (bool): Whether to show the download progress. Default: True.
file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
Returns:
str: The path to the downloaded file.
"""
if
model_dir
is
None
:
# use the pytorch hub_dir
hub_dir
=
get_dir
()
model_dir
=
os
.
path
.
join
(
hub_dir
,
'checkpoints'
)
os
.
makedirs
(
model_dir
,
exist_ok
=
True
)
parts
=
urlparse
(
url
)
filename
=
os
.
path
.
basename
(
parts
.
path
)
if
file_name
is
not
None
:
filename
=
file_name
cached_file
=
os
.
path
.
abspath
(
os
.
path
.
join
(
model_dir
,
filename
))
if
not
os
.
path
.
exists
(
cached_file
):
print
(
f
'Downloading: "
{
url
}
" to
{
cached_file
}
\n
'
)
download_url_to_file
(
url
,
cached_file
,
hash_prefix
=
None
,
progress
=
progress
)
return
cached_file
Prev
1
…
4
5
6
7
8
9
10
11
12
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment