Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ControlNet_pytorch
Commits
e2696ece
Commit
e2696ece
authored
Nov 22, 2023
by
mashun1
Browse files
controlnet
parents
Pipeline
#643
canceled with stages
Changes
822
Pipelines
1
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3452 additions
and
0 deletions
+3452
-0
BasicSR/basicsr/archs/__init__.py
BasicSR/basicsr/archs/__init__.py
+24
-0
BasicSR/basicsr/archs/arch_util.py
BasicSR/basicsr/archs/arch_util.py
+313
-0
BasicSR/basicsr/archs/basicvsr_arch.py
BasicSR/basicsr/archs/basicvsr_arch.py
+336
-0
BasicSR/basicsr/archs/basicvsrpp_arch.py
BasicSR/basicsr/archs/basicvsrpp_arch.py
+417
-0
BasicSR/basicsr/archs/dfdnet_arch.py
BasicSR/basicsr/archs/dfdnet_arch.py
+169
-0
BasicSR/basicsr/archs/dfdnet_util.py
BasicSR/basicsr/archs/dfdnet_util.py
+162
-0
BasicSR/basicsr/archs/discriminator_arch.py
BasicSR/basicsr/archs/discriminator_arch.py
+150
-0
BasicSR/basicsr/archs/duf_arch.py
BasicSR/basicsr/archs/duf_arch.py
+276
-0
BasicSR/basicsr/archs/ecbsr_arch.py
BasicSR/basicsr/archs/ecbsr_arch.py
+275
-0
BasicSR/basicsr/archs/edsr_arch.py
BasicSR/basicsr/archs/edsr_arch.py
+61
-0
BasicSR/basicsr/archs/edvr_arch.py
BasicSR/basicsr/archs/edvr_arch.py
+382
-0
BasicSR/basicsr/archs/hifacegan_arch.py
BasicSR/basicsr/archs/hifacegan_arch.py
+260
-0
BasicSR/basicsr/archs/hifacegan_util.py
BasicSR/basicsr/archs/hifacegan_util.py
+255
-0
BasicSR/basicsr/archs/inception.py
BasicSR/basicsr/archs/inception.py
+307
-0
BasicSR/basicsr/archs/rcan_arch.py
BasicSR/basicsr/archs/rcan_arch.py
+0
-0
BasicSR/basicsr/archs/ridnet_arch.py
BasicSR/basicsr/archs/ridnet_arch.py
+0
-0
BasicSR/basicsr/archs/rrdbnet_arch.py
BasicSR/basicsr/archs/rrdbnet_arch.py
+0
-0
BasicSR/basicsr/archs/spynet_arch.py
BasicSR/basicsr/archs/spynet_arch.py
+0
-0
BasicSR/basicsr/archs/srresnet_arch.py
BasicSR/basicsr/archs/srresnet_arch.py
+65
-0
BasicSR/basicsr/archs/srvgg_arch.py
BasicSR/basicsr/archs/srvgg_arch.py
+0
-0
No files found.
BasicSR/basicsr/archs/__init__.py
0 → 100644
View file @
e2696ece
import
importlib
from
copy
import
deepcopy
from
os
import
path
as
osp
from
basicsr.utils
import
get_root_logger
,
scandir
from
basicsr.utils.registry
import
ARCH_REGISTRY
__all__
=
[
'build_network'
]
# automatically scan and import arch modules for registry
# scan all the files under the 'archs' folder and collect files ending with '_arch.py'
arch_folder
=
osp
.
dirname
(
osp
.
abspath
(
__file__
))
arch_filenames
=
[
osp
.
splitext
(
osp
.
basename
(
v
))[
0
]
for
v
in
scandir
(
arch_folder
)
if
v
.
endswith
(
'_arch.py'
)]
# import all the arch modules
_arch_modules
=
[
importlib
.
import_module
(
f
'basicsr.archs.
{
file_name
}
'
)
for
file_name
in
arch_filenames
]
def
build_network
(
opt
):
opt
=
deepcopy
(
opt
)
network_type
=
opt
.
pop
(
'type'
)
net
=
ARCH_REGISTRY
.
get
(
network_type
)(
**
opt
)
logger
=
get_root_logger
()
logger
.
info
(
f
'Network [
{
net
.
__class__
.
__name__
}
] is created.'
)
return
net
BasicSR/basicsr/archs/arch_util.py
0 → 100644
View file @
e2696ece
import
collections.abc
import
math
import
torch
import
torchvision
import
warnings
from
distutils.version
import
LooseVersion
from
itertools
import
repeat
from
torch
import
nn
as
nn
from
torch.nn
import
functional
as
F
from
torch.nn
import
init
as
init
from
torch.nn.modules.batchnorm
import
_BatchNorm
from
basicsr.ops.dcn
import
ModulatedDeformConvPack
,
modulated_deform_conv
from
basicsr.utils
import
get_root_logger
@
torch
.
no_grad
()
def
default_init_weights
(
module_list
,
scale
=
1
,
bias_fill
=
0
,
**
kwargs
):
"""Initialize network weights.
Args:
module_list (list[nn.Module] | nn.Module): Modules to be initialized.
scale (float): Scale initialized weights, especially for residual
blocks. Default: 1.
bias_fill (float): The value to fill bias. Default: 0
kwargs (dict): Other arguments for initialization function.
"""
if
not
isinstance
(
module_list
,
list
):
module_list
=
[
module_list
]
for
module
in
module_list
:
for
m
in
module
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
init
.
kaiming_normal_
(
m
.
weight
,
**
kwargs
)
m
.
weight
.
data
*=
scale
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
fill_
(
bias_fill
)
elif
isinstance
(
m
,
nn
.
Linear
):
init
.
kaiming_normal_
(
m
.
weight
,
**
kwargs
)
m
.
weight
.
data
*=
scale
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
fill_
(
bias_fill
)
elif
isinstance
(
m
,
_BatchNorm
):
init
.
constant_
(
m
.
weight
,
1
)
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
fill_
(
bias_fill
)
def
make_layer
(
basic_block
,
num_basic_block
,
**
kwarg
):
"""Make layers by stacking the same blocks.
Args:
basic_block (nn.module): nn.module class for basic block.
num_basic_block (int): number of blocks.
Returns:
nn.Sequential: Stacked blocks in nn.Sequential.
"""
layers
=
[]
for
_
in
range
(
num_basic_block
):
layers
.
append
(
basic_block
(
**
kwarg
))
return
nn
.
Sequential
(
*
layers
)
class
ResidualBlockNoBN
(
nn
.
Module
):
"""Residual block without BN.
Args:
num_feat (int): Channel number of intermediate features.
Default: 64.
res_scale (float): Residual scale. Default: 1.
pytorch_init (bool): If set to True, use pytorch default init,
otherwise, use default_init_weights. Default: False.
"""
def
__init__
(
self
,
num_feat
=
64
,
res_scale
=
1
,
pytorch_init
=
False
):
super
(
ResidualBlockNoBN
,
self
).
__init__
()
self
.
res_scale
=
res_scale
self
.
conv1
=
nn
.
Conv2d
(
num_feat
,
num_feat
,
3
,
1
,
1
,
bias
=
True
)
self
.
conv2
=
nn
.
Conv2d
(
num_feat
,
num_feat
,
3
,
1
,
1
,
bias
=
True
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
if
not
pytorch_init
:
default_init_weights
([
self
.
conv1
,
self
.
conv2
],
0.1
)
def
forward
(
self
,
x
):
identity
=
x
out
=
self
.
conv2
(
self
.
relu
(
self
.
conv1
(
x
)))
return
identity
+
out
*
self
.
res_scale
class
Upsample
(
nn
.
Sequential
):
"""Upsample module.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def
__init__
(
self
,
scale
,
num_feat
):
m
=
[]
if
(
scale
&
(
scale
-
1
))
==
0
:
# scale = 2^n
for
_
in
range
(
int
(
math
.
log
(
scale
,
2
))):
m
.
append
(
nn
.
Conv2d
(
num_feat
,
4
*
num_feat
,
3
,
1
,
1
))
m
.
append
(
nn
.
PixelShuffle
(
2
))
elif
scale
==
3
:
m
.
append
(
nn
.
Conv2d
(
num_feat
,
9
*
num_feat
,
3
,
1
,
1
))
m
.
append
(
nn
.
PixelShuffle
(
3
))
else
:
raise
ValueError
(
f
'scale
{
scale
}
is not supported. Supported scales: 2^n and 3.'
)
super
(
Upsample
,
self
).
__init__
(
*
m
)
def
flow_warp
(
x
,
flow
,
interp_mode
=
'bilinear'
,
padding_mode
=
'zeros'
,
align_corners
=
True
):
"""Warp an image or feature map with optical flow.
Args:
x (Tensor): Tensor with size (n, c, h, w).
flow (Tensor): Tensor with size (n, h, w, 2), normal value.
interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
padding_mode (str): 'zeros' or 'border' or 'reflection'.
Default: 'zeros'.
align_corners (bool): Before pytorch 1.3, the default value is
align_corners=True. After pytorch 1.3, the default value is
align_corners=False. Here, we use the True as default.
Returns:
Tensor: Warped image or feature map.
"""
assert
x
.
size
()[
-
2
:]
==
flow
.
size
()[
1
:
3
]
_
,
_
,
h
,
w
=
x
.
size
()
# create mesh grid
grid_y
,
grid_x
=
torch
.
meshgrid
(
torch
.
arange
(
0
,
h
).
type_as
(
x
),
torch
.
arange
(
0
,
w
).
type_as
(
x
))
grid
=
torch
.
stack
((
grid_x
,
grid_y
),
2
).
float
()
# W(x), H(y), 2
grid
.
requires_grad
=
False
vgrid
=
grid
+
flow
# scale grid to [-1,1]
vgrid_x
=
2.0
*
vgrid
[:,
:,
:,
0
]
/
max
(
w
-
1
,
1
)
-
1.0
vgrid_y
=
2.0
*
vgrid
[:,
:,
:,
1
]
/
max
(
h
-
1
,
1
)
-
1.0
vgrid_scaled
=
torch
.
stack
((
vgrid_x
,
vgrid_y
),
dim
=
3
)
output
=
F
.
grid_sample
(
x
,
vgrid_scaled
,
mode
=
interp_mode
,
padding_mode
=
padding_mode
,
align_corners
=
align_corners
)
# TODO, what if align_corners=False
return
output
def
resize_flow
(
flow
,
size_type
,
sizes
,
interp_mode
=
'bilinear'
,
align_corners
=
False
):
"""Resize a flow according to ratio or shape.
Args:
flow (Tensor): Precomputed flow. shape [N, 2, H, W].
size_type (str): 'ratio' or 'shape'.
sizes (list[int | float]): the ratio for resizing or the final output
shape.
1) The order of ratio should be [ratio_h, ratio_w]. For
downsampling, the ratio should be smaller than 1.0 (i.e., ratio
< 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
ratio > 1.0).
2) The order of output_size should be [out_h, out_w].
interp_mode (str): The mode of interpolation for resizing.
Default: 'bilinear'.
align_corners (bool): Whether align corners. Default: False.
Returns:
Tensor: Resized flow.
"""
_
,
_
,
flow_h
,
flow_w
=
flow
.
size
()
if
size_type
==
'ratio'
:
output_h
,
output_w
=
int
(
flow_h
*
sizes
[
0
]),
int
(
flow_w
*
sizes
[
1
])
elif
size_type
==
'shape'
:
output_h
,
output_w
=
sizes
[
0
],
sizes
[
1
]
else
:
raise
ValueError
(
f
'Size type should be ratio or shape, but got type
{
size_type
}
.'
)
input_flow
=
flow
.
clone
()
ratio_h
=
output_h
/
flow_h
ratio_w
=
output_w
/
flow_w
input_flow
[:,
0
,
:,
:]
*=
ratio_w
input_flow
[:,
1
,
:,
:]
*=
ratio_h
resized_flow
=
F
.
interpolate
(
input
=
input_flow
,
size
=
(
output_h
,
output_w
),
mode
=
interp_mode
,
align_corners
=
align_corners
)
return
resized_flow
# TODO: may write a cpp file
def
pixel_unshuffle
(
x
,
scale
):
""" Pixel unshuffle.
Args:
x (Tensor): Input feature with shape (b, c, hh, hw).
scale (int): Downsample ratio.
Returns:
Tensor: the pixel unshuffled feature.
"""
b
,
c
,
hh
,
hw
=
x
.
size
()
out_channel
=
c
*
(
scale
**
2
)
assert
hh
%
scale
==
0
and
hw
%
scale
==
0
h
=
hh
//
scale
w
=
hw
//
scale
x_view
=
x
.
view
(
b
,
c
,
h
,
scale
,
w
,
scale
)
return
x_view
.
permute
(
0
,
1
,
3
,
5
,
2
,
4
).
reshape
(
b
,
out_channel
,
h
,
w
)
class
DCNv2Pack
(
ModulatedDeformConvPack
):
"""Modulated deformable conv for deformable alignment.
Different from the official DCNv2Pack, which generates offsets and masks
from the preceding features, this DCNv2Pack takes another different
features to generate offsets and masks.
``Paper: Delving Deep into Deformable Alignment in Video Super-Resolution``
"""
def
forward
(
self
,
x
,
feat
):
out
=
self
.
conv_offset
(
feat
)
o1
,
o2
,
mask
=
torch
.
chunk
(
out
,
3
,
dim
=
1
)
offset
=
torch
.
cat
((
o1
,
o2
),
dim
=
1
)
mask
=
torch
.
sigmoid
(
mask
)
offset_absmean
=
torch
.
mean
(
torch
.
abs
(
offset
))
if
offset_absmean
>
50
:
logger
=
get_root_logger
()
logger
.
warning
(
f
'Offset abs mean is
{
offset_absmean
}
, larger than 50.'
)
if
LooseVersion
(
torchvision
.
__version__
)
>=
LooseVersion
(
'0.9.0'
):
return
torchvision
.
ops
.
deform_conv2d
(
x
,
offset
,
self
.
weight
,
self
.
bias
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
mask
)
else
:
return
modulated_deform_conv
(
x
,
offset
,
mask
,
self
.
weight
,
self
.
bias
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
groups
,
self
.
deformable_groups
)
def
_no_grad_trunc_normal_
(
tensor
,
mean
,
std
,
a
,
b
):
# From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def
norm_cdf
(
x
):
# Computes standard normal cumulative distribution function
return
(
1.
+
math
.
erf
(
x
/
math
.
sqrt
(
2.
)))
/
2.
if
(
mean
<
a
-
2
*
std
)
or
(
mean
>
b
+
2
*
std
):
warnings
.
warn
(
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
'The distribution of values may be incorrect.'
,
stacklevel
=
2
)
with
torch
.
no_grad
():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
low
=
norm_cdf
((
a
-
mean
)
/
std
)
up
=
norm_cdf
((
b
-
mean
)
/
std
)
# Uniformly fill tensor with values from [low, up], then translate to
# [2l-1, 2u-1].
tensor
.
uniform_
(
2
*
low
-
1
,
2
*
up
-
1
)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor
.
erfinv_
()
# Transform to proper mean, std
tensor
.
mul_
(
std
*
math
.
sqrt
(
2.
))
tensor
.
add_
(
mean
)
# Clamp to ensure it's in the proper range
tensor
.
clamp_
(
min
=
a
,
max
=
b
)
return
tensor
def
trunc_normal_
(
tensor
,
mean
=
0.
,
std
=
1.
,
a
=-
2.
,
b
=
2.
):
r
"""Fills the input Tensor with values drawn from a truncated
normal distribution.
From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return
_no_grad_trunc_normal_
(
tensor
,
mean
,
std
,
a
,
b
)
# From PyTorch
def
_ntuple
(
n
):
def
parse
(
x
):
if
isinstance
(
x
,
collections
.
abc
.
Iterable
):
return
x
return
tuple
(
repeat
(
x
,
n
))
return
parse
to_1tuple
=
_ntuple
(
1
)
to_2tuple
=
_ntuple
(
2
)
to_3tuple
=
_ntuple
(
3
)
to_4tuple
=
_ntuple
(
4
)
to_ntuple
=
_ntuple
BasicSR/basicsr/archs/basicvsr_arch.py
0 → 100644
View file @
e2696ece
import
torch
from
torch
import
nn
as
nn
from
torch.nn
import
functional
as
F
from
basicsr.utils.registry
import
ARCH_REGISTRY
from
.arch_util
import
ResidualBlockNoBN
,
flow_warp
,
make_layer
from
.edvr_arch
import
PCDAlignment
,
TSAFusion
from
.spynet_arch
import
SpyNet
@
ARCH_REGISTRY
.
register
()
class
BasicVSR
(
nn
.
Module
):
"""A recurrent network for video SR. Now only x4 is supported.
Args:
num_feat (int): Number of channels. Default: 64.
num_block (int): Number of residual blocks for each branch. Default: 15
spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
"""
def
__init__
(
self
,
num_feat
=
64
,
num_block
=
15
,
spynet_path
=
None
):
super
().
__init__
()
self
.
num_feat
=
num_feat
# alignment
self
.
spynet
=
SpyNet
(
spynet_path
)
# propagation
self
.
backward_trunk
=
ConvResidualBlocks
(
num_feat
+
3
,
num_feat
,
num_block
)
self
.
forward_trunk
=
ConvResidualBlocks
(
num_feat
+
3
,
num_feat
,
num_block
)
# reconstruction
self
.
fusion
=
nn
.
Conv2d
(
num_feat
*
2
,
num_feat
,
1
,
1
,
0
,
bias
=
True
)
self
.
upconv1
=
nn
.
Conv2d
(
num_feat
,
num_feat
*
4
,
3
,
1
,
1
,
bias
=
True
)
self
.
upconv2
=
nn
.
Conv2d
(
num_feat
,
64
*
4
,
3
,
1
,
1
,
bias
=
True
)
self
.
conv_hr
=
nn
.
Conv2d
(
64
,
64
,
3
,
1
,
1
)
self
.
conv_last
=
nn
.
Conv2d
(
64
,
3
,
3
,
1
,
1
)
self
.
pixel_shuffle
=
nn
.
PixelShuffle
(
2
)
# activation functions
self
.
lrelu
=
nn
.
LeakyReLU
(
negative_slope
=
0.1
,
inplace
=
True
)
def
get_flow
(
self
,
x
):
b
,
n
,
c
,
h
,
w
=
x
.
size
()
x_1
=
x
[:,
:
-
1
,
:,
:,
:].
reshape
(
-
1
,
c
,
h
,
w
)
x_2
=
x
[:,
1
:,
:,
:,
:].
reshape
(
-
1
,
c
,
h
,
w
)
flows_backward
=
self
.
spynet
(
x_1
,
x_2
).
view
(
b
,
n
-
1
,
2
,
h
,
w
)
flows_forward
=
self
.
spynet
(
x_2
,
x_1
).
view
(
b
,
n
-
1
,
2
,
h
,
w
)
return
flows_forward
,
flows_backward
def
forward
(
self
,
x
):
"""Forward function of BasicVSR.
Args:
x: Input frames with shape (b, n, c, h, w). n is the temporal dimension / number of frames.
"""
flows_forward
,
flows_backward
=
self
.
get_flow
(
x
)
b
,
n
,
_
,
h
,
w
=
x
.
size
()
# backward branch
out_l
=
[]
feat_prop
=
x
.
new_zeros
(
b
,
self
.
num_feat
,
h
,
w
)
for
i
in
range
(
n
-
1
,
-
1
,
-
1
):
x_i
=
x
[:,
i
,
:,
:,
:]
if
i
<
n
-
1
:
flow
=
flows_backward
[:,
i
,
:,
:,
:]
feat_prop
=
flow_warp
(
feat_prop
,
flow
.
permute
(
0
,
2
,
3
,
1
))
feat_prop
=
torch
.
cat
([
x_i
,
feat_prop
],
dim
=
1
)
feat_prop
=
self
.
backward_trunk
(
feat_prop
)
out_l
.
insert
(
0
,
feat_prop
)
# forward branch
feat_prop
=
torch
.
zeros_like
(
feat_prop
)
for
i
in
range
(
0
,
n
):
x_i
=
x
[:,
i
,
:,
:,
:]
if
i
>
0
:
flow
=
flows_forward
[:,
i
-
1
,
:,
:,
:]
feat_prop
=
flow_warp
(
feat_prop
,
flow
.
permute
(
0
,
2
,
3
,
1
))
feat_prop
=
torch
.
cat
([
x_i
,
feat_prop
],
dim
=
1
)
feat_prop
=
self
.
forward_trunk
(
feat_prop
)
# upsample
out
=
torch
.
cat
([
out_l
[
i
],
feat_prop
],
dim
=
1
)
out
=
self
.
lrelu
(
self
.
fusion
(
out
))
out
=
self
.
lrelu
(
self
.
pixel_shuffle
(
self
.
upconv1
(
out
)))
out
=
self
.
lrelu
(
self
.
pixel_shuffle
(
self
.
upconv2
(
out
)))
out
=
self
.
lrelu
(
self
.
conv_hr
(
out
))
out
=
self
.
conv_last
(
out
)
base
=
F
.
interpolate
(
x_i
,
scale_factor
=
4
,
mode
=
'bilinear'
,
align_corners
=
False
)
out
+=
base
out_l
[
i
]
=
out
return
torch
.
stack
(
out_l
,
dim
=
1
)
class
ConvResidualBlocks
(
nn
.
Module
):
"""Conv and residual block used in BasicVSR.
Args:
num_in_ch (int): Number of input channels. Default: 3.
num_out_ch (int): Number of output channels. Default: 64.
num_block (int): Number of residual blocks. Default: 15.
"""
def
__init__
(
self
,
num_in_ch
=
3
,
num_out_ch
=
64
,
num_block
=
15
):
super
().
__init__
()
self
.
main
=
nn
.
Sequential
(
nn
.
Conv2d
(
num_in_ch
,
num_out_ch
,
3
,
1
,
1
,
bias
=
True
),
nn
.
LeakyReLU
(
negative_slope
=
0.1
,
inplace
=
True
),
make_layer
(
ResidualBlockNoBN
,
num_block
,
num_feat
=
num_out_ch
))
def
forward
(
self
,
fea
):
return
self
.
main
(
fea
)
@
ARCH_REGISTRY
.
register
()
class
IconVSR
(
nn
.
Module
):
"""IconVSR, proposed also in the BasicVSR paper.
Args:
num_feat (int): Number of channels. Default: 64.
num_block (int): Number of residual blocks for each branch. Default: 15.
keyframe_stride (int): Keyframe stride. Default: 5.
temporal_padding (int): Temporal padding. Default: 2.
spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
edvr_path (str): Path to the pretrained EDVR model. Default: None.
"""
def
__init__
(
self
,
num_feat
=
64
,
num_block
=
15
,
keyframe_stride
=
5
,
temporal_padding
=
2
,
spynet_path
=
None
,
edvr_path
=
None
):
super
().
__init__
()
self
.
num_feat
=
num_feat
self
.
temporal_padding
=
temporal_padding
self
.
keyframe_stride
=
keyframe_stride
# keyframe_branch
self
.
edvr
=
EDVRFeatureExtractor
(
temporal_padding
*
2
+
1
,
num_feat
,
edvr_path
)
# alignment
self
.
spynet
=
SpyNet
(
spynet_path
)
# propagation
self
.
backward_fusion
=
nn
.
Conv2d
(
2
*
num_feat
,
num_feat
,
3
,
1
,
1
,
bias
=
True
)
self
.
backward_trunk
=
ConvResidualBlocks
(
num_feat
+
3
,
num_feat
,
num_block
)
self
.
forward_fusion
=
nn
.
Conv2d
(
2
*
num_feat
,
num_feat
,
3
,
1
,
1
,
bias
=
True
)
self
.
forward_trunk
=
ConvResidualBlocks
(
2
*
num_feat
+
3
,
num_feat
,
num_block
)
# reconstruction
self
.
upconv1
=
nn
.
Conv2d
(
num_feat
,
num_feat
*
4
,
3
,
1
,
1
,
bias
=
True
)
self
.
upconv2
=
nn
.
Conv2d
(
num_feat
,
64
*
4
,
3
,
1
,
1
,
bias
=
True
)
self
.
conv_hr
=
nn
.
Conv2d
(
64
,
64
,
3
,
1
,
1
)
self
.
conv_last
=
nn
.
Conv2d
(
64
,
3
,
3
,
1
,
1
)
self
.
pixel_shuffle
=
nn
.
PixelShuffle
(
2
)
# activation functions
self
.
lrelu
=
nn
.
LeakyReLU
(
negative_slope
=
0.1
,
inplace
=
True
)
def
pad_spatial
(
self
,
x
):
"""Apply padding spatially.
Since the PCD module in EDVR requires that the resolution is a multiple
of 4, we apply padding to the input LR images if their resolution is
not divisible by 4.
Args:
x (Tensor): Input LR sequence with shape (n, t, c, h, w).
Returns:
Tensor: Padded LR sequence with shape (n, t, c, h_pad, w_pad).
"""
n
,
t
,
c
,
h
,
w
=
x
.
size
()
pad_h
=
(
4
-
h
%
4
)
%
4
pad_w
=
(
4
-
w
%
4
)
%
4
# padding
x
=
x
.
view
(
-
1
,
c
,
h
,
w
)
x
=
F
.
pad
(
x
,
[
0
,
pad_w
,
0
,
pad_h
],
mode
=
'reflect'
)
return
x
.
view
(
n
,
t
,
c
,
h
+
pad_h
,
w
+
pad_w
)
def
get_flow
(
self
,
x
):
b
,
n
,
c
,
h
,
w
=
x
.
size
()
x_1
=
x
[:,
:
-
1
,
:,
:,
:].
reshape
(
-
1
,
c
,
h
,
w
)
x_2
=
x
[:,
1
:,
:,
:,
:].
reshape
(
-
1
,
c
,
h
,
w
)
flows_backward
=
self
.
spynet
(
x_1
,
x_2
).
view
(
b
,
n
-
1
,
2
,
h
,
w
)
flows_forward
=
self
.
spynet
(
x_2
,
x_1
).
view
(
b
,
n
-
1
,
2
,
h
,
w
)
return
flows_forward
,
flows_backward
def
get_keyframe_feature
(
self
,
x
,
keyframe_idx
):
if
self
.
temporal_padding
==
2
:
x
=
[
x
[:,
[
4
,
3
]],
x
,
x
[:,
[
-
4
,
-
5
]]]
elif
self
.
temporal_padding
==
3
:
x
=
[
x
[:,
[
6
,
5
,
4
]],
x
,
x
[:,
[
-
5
,
-
6
,
-
7
]]]
x
=
torch
.
cat
(
x
,
dim
=
1
)
num_frames
=
2
*
self
.
temporal_padding
+
1
feats_keyframe
=
{}
for
i
in
keyframe_idx
:
feats_keyframe
[
i
]
=
self
.
edvr
(
x
[:,
i
:
i
+
num_frames
].
contiguous
())
return
feats_keyframe
def
forward
(
self
,
x
):
b
,
n
,
_
,
h_input
,
w_input
=
x
.
size
()
x
=
self
.
pad_spatial
(
x
)
h
,
w
=
x
.
shape
[
3
:]
keyframe_idx
=
list
(
range
(
0
,
n
,
self
.
keyframe_stride
))
if
keyframe_idx
[
-
1
]
!=
n
-
1
:
keyframe_idx
.
append
(
n
-
1
)
# last frame is a keyframe
# compute flow and keyframe features
flows_forward
,
flows_backward
=
self
.
get_flow
(
x
)
feats_keyframe
=
self
.
get_keyframe_feature
(
x
,
keyframe_idx
)
# backward branch
out_l
=
[]
feat_prop
=
x
.
new_zeros
(
b
,
self
.
num_feat
,
h
,
w
)
for
i
in
range
(
n
-
1
,
-
1
,
-
1
):
x_i
=
x
[:,
i
,
:,
:,
:]
if
i
<
n
-
1
:
flow
=
flows_backward
[:,
i
,
:,
:,
:]
feat_prop
=
flow_warp
(
feat_prop
,
flow
.
permute
(
0
,
2
,
3
,
1
))
if
i
in
keyframe_idx
:
feat_prop
=
torch
.
cat
([
feat_prop
,
feats_keyframe
[
i
]],
dim
=
1
)
feat_prop
=
self
.
backward_fusion
(
feat_prop
)
feat_prop
=
torch
.
cat
([
x_i
,
feat_prop
],
dim
=
1
)
feat_prop
=
self
.
backward_trunk
(
feat_prop
)
out_l
.
insert
(
0
,
feat_prop
)
# forward branch
feat_prop
=
torch
.
zeros_like
(
feat_prop
)
for
i
in
range
(
0
,
n
):
x_i
=
x
[:,
i
,
:,
:,
:]
if
i
>
0
:
flow
=
flows_forward
[:,
i
-
1
,
:,
:,
:]
feat_prop
=
flow_warp
(
feat_prop
,
flow
.
permute
(
0
,
2
,
3
,
1
))
if
i
in
keyframe_idx
:
feat_prop
=
torch
.
cat
([
feat_prop
,
feats_keyframe
[
i
]],
dim
=
1
)
feat_prop
=
self
.
forward_fusion
(
feat_prop
)
feat_prop
=
torch
.
cat
([
x_i
,
out_l
[
i
],
feat_prop
],
dim
=
1
)
feat_prop
=
self
.
forward_trunk
(
feat_prop
)
# upsample
out
=
self
.
lrelu
(
self
.
pixel_shuffle
(
self
.
upconv1
(
feat_prop
)))
out
=
self
.
lrelu
(
self
.
pixel_shuffle
(
self
.
upconv2
(
out
)))
out
=
self
.
lrelu
(
self
.
conv_hr
(
out
))
out
=
self
.
conv_last
(
out
)
base
=
F
.
interpolate
(
x_i
,
scale_factor
=
4
,
mode
=
'bilinear'
,
align_corners
=
False
)
out
+=
base
out_l
[
i
]
=
out
return
torch
.
stack
(
out_l
,
dim
=
1
)[...,
:
4
*
h_input
,
:
4
*
w_input
]
class
EDVRFeatureExtractor
(
nn
.
Module
):
"""EDVR feature extractor used in IconVSR.
Args:
num_input_frame (int): Number of input frames.
num_feat (int): Number of feature channels
load_path (str): Path to the pretrained weights of EDVR. Default: None.
"""
def
__init__
(
self
,
num_input_frame
,
num_feat
,
load_path
):
super
(
EDVRFeatureExtractor
,
self
).
__init__
()
self
.
center_frame_idx
=
num_input_frame
//
2
# extract pyramid features
self
.
conv_first
=
nn
.
Conv2d
(
3
,
num_feat
,
3
,
1
,
1
)
self
.
feature_extraction
=
make_layer
(
ResidualBlockNoBN
,
5
,
num_feat
=
num_feat
)
self
.
conv_l2_1
=
nn
.
Conv2d
(
num_feat
,
num_feat
,
3
,
2
,
1
)
self
.
conv_l2_2
=
nn
.
Conv2d
(
num_feat
,
num_feat
,
3
,
1
,
1
)
self
.
conv_l3_1
=
nn
.
Conv2d
(
num_feat
,
num_feat
,
3
,
2
,
1
)
self
.
conv_l3_2
=
nn
.
Conv2d
(
num_feat
,
num_feat
,
3
,
1
,
1
)
# pcd and tsa module
self
.
pcd_align
=
PCDAlignment
(
num_feat
=
num_feat
,
deformable_groups
=
8
)
self
.
fusion
=
TSAFusion
(
num_feat
=
num_feat
,
num_frame
=
num_input_frame
,
center_frame_idx
=
self
.
center_frame_idx
)
# activation function
self
.
lrelu
=
nn
.
LeakyReLU
(
negative_slope
=
0.1
,
inplace
=
True
)
if
load_path
:
self
.
load_state_dict
(
torch
.
load
(
load_path
,
map_location
=
lambda
storage
,
loc
:
storage
)[
'params'
])
def
forward
(
self
,
x
):
b
,
n
,
c
,
h
,
w
=
x
.
size
()
# extract features for each frame
# L1
feat_l1
=
self
.
lrelu
(
self
.
conv_first
(
x
.
view
(
-
1
,
c
,
h
,
w
)))
feat_l1
=
self
.
feature_extraction
(
feat_l1
)
# L2
feat_l2
=
self
.
lrelu
(
self
.
conv_l2_1
(
feat_l1
))
feat_l2
=
self
.
lrelu
(
self
.
conv_l2_2
(
feat_l2
))
# L3
feat_l3
=
self
.
lrelu
(
self
.
conv_l3_1
(
feat_l2
))
feat_l3
=
self
.
lrelu
(
self
.
conv_l3_2
(
feat_l3
))
feat_l1
=
feat_l1
.
view
(
b
,
n
,
-
1
,
h
,
w
)
feat_l2
=
feat_l2
.
view
(
b
,
n
,
-
1
,
h
//
2
,
w
//
2
)
feat_l3
=
feat_l3
.
view
(
b
,
n
,
-
1
,
h
//
4
,
w
//
4
)
# PCD alignment
ref_feat_l
=
[
# reference feature list
feat_l1
[:,
self
.
center_frame_idx
,
:,
:,
:].
clone
(),
feat_l2
[:,
self
.
center_frame_idx
,
:,
:,
:].
clone
(),
feat_l3
[:,
self
.
center_frame_idx
,
:,
:,
:].
clone
()
]
aligned_feat
=
[]
for
i
in
range
(
n
):
nbr_feat_l
=
[
# neighboring feature list
feat_l1
[:,
i
,
:,
:,
:].
clone
(),
feat_l2
[:,
i
,
:,
:,
:].
clone
(),
feat_l3
[:,
i
,
:,
:,
:].
clone
()
]
aligned_feat
.
append
(
self
.
pcd_align
(
nbr_feat_l
,
ref_feat_l
))
aligned_feat
=
torch
.
stack
(
aligned_feat
,
dim
=
1
)
# (b, t, c, h, w)
# TSA fusion
return
self
.
fusion
(
aligned_feat
)
BasicSR/basicsr/archs/basicvsrpp_arch.py
0 → 100644
View file @
e2696ece
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torchvision
import
warnings
from
basicsr.archs.arch_util
import
flow_warp
from
basicsr.archs.basicvsr_arch
import
ConvResidualBlocks
from
basicsr.archs.spynet_arch
import
SpyNet
from
basicsr.ops.dcn
import
ModulatedDeformConvPack
from
basicsr.utils.registry
import
ARCH_REGISTRY
@
ARCH_REGISTRY
.
register
()
class
BasicVSRPlusPlus
(
nn
.
Module
):
"""BasicVSR++ network structure.
Support either x4 upsampling or same size output. Since DCN is used in this
model, it can only be used with CUDA enabled. If CUDA is not enabled,
feature alignment will be skipped. Besides, we adopt the official DCN
implementation and the version of torch need to be higher than 1.9.
``Paper: BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment``
Args:
mid_channels (int, optional): Channel number of the intermediate
features. Default: 64.
num_blocks (int, optional): The number of residual blocks in each
propagation branch. Default: 7.
max_residue_magnitude (int): The maximum magnitude of the offset
residue (Eq. 6 in paper). Default: 10.
is_low_res_input (bool, optional): Whether the input is low-resolution
or not. If False, the output resolution is equal to the input
resolution. Default: True.
spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
cpu_cache_length (int, optional): When the length of sequence is larger
than this value, the intermediate features are sent to CPU. This
saves GPU memory, but slows down the inference speed. You can
increase this number if you have a GPU with large memory.
Default: 100.
"""
def
__init__
(
self
,
mid_channels
=
64
,
num_blocks
=
7
,
max_residue_magnitude
=
10
,
is_low_res_input
=
True
,
spynet_path
=
None
,
cpu_cache_length
=
100
):
super
().
__init__
()
self
.
mid_channels
=
mid_channels
self
.
is_low_res_input
=
is_low_res_input
self
.
cpu_cache_length
=
cpu_cache_length
# optical flow
self
.
spynet
=
SpyNet
(
spynet_path
)
# feature extraction module
if
is_low_res_input
:
self
.
feat_extract
=
ConvResidualBlocks
(
3
,
mid_channels
,
5
)
else
:
self
.
feat_extract
=
nn
.
Sequential
(
nn
.
Conv2d
(
3
,
mid_channels
,
3
,
2
,
1
),
nn
.
LeakyReLU
(
negative_slope
=
0.1
,
inplace
=
True
),
nn
.
Conv2d
(
mid_channels
,
mid_channels
,
3
,
2
,
1
),
nn
.
LeakyReLU
(
negative_slope
=
0.1
,
inplace
=
True
),
ConvResidualBlocks
(
mid_channels
,
mid_channels
,
5
))
# propagation branches
self
.
deform_align
=
nn
.
ModuleDict
()
self
.
backbone
=
nn
.
ModuleDict
()
modules
=
[
'backward_1'
,
'forward_1'
,
'backward_2'
,
'forward_2'
]
for
i
,
module
in
enumerate
(
modules
):
if
torch
.
cuda
.
is_available
():
self
.
deform_align
[
module
]
=
SecondOrderDeformableAlignment
(
2
*
mid_channels
,
mid_channels
,
3
,
padding
=
1
,
deformable_groups
=
16
,
max_residue_magnitude
=
max_residue_magnitude
)
self
.
backbone
[
module
]
=
ConvResidualBlocks
((
2
+
i
)
*
mid_channels
,
mid_channels
,
num_blocks
)
# upsampling module
self
.
reconstruction
=
ConvResidualBlocks
(
5
*
mid_channels
,
mid_channels
,
5
)
self
.
upconv1
=
nn
.
Conv2d
(
mid_channels
,
mid_channels
*
4
,
3
,
1
,
1
,
bias
=
True
)
self
.
upconv2
=
nn
.
Conv2d
(
mid_channels
,
64
*
4
,
3
,
1
,
1
,
bias
=
True
)
self
.
pixel_shuffle
=
nn
.
PixelShuffle
(
2
)
self
.
conv_hr
=
nn
.
Conv2d
(
64
,
64
,
3
,
1
,
1
)
self
.
conv_last
=
nn
.
Conv2d
(
64
,
3
,
3
,
1
,
1
)
self
.
img_upsample
=
nn
.
Upsample
(
scale_factor
=
4
,
mode
=
'bilinear'
,
align_corners
=
False
)
# activation function
self
.
lrelu
=
nn
.
LeakyReLU
(
negative_slope
=
0.1
,
inplace
=
True
)
# check if the sequence is augmented by flipping
self
.
is_mirror_extended
=
False
if
len
(
self
.
deform_align
)
>
0
:
self
.
is_with_alignment
=
True
else
:
self
.
is_with_alignment
=
False
warnings
.
warn
(
'Deformable alignment module is not added. '
'Probably your CUDA is not configured correctly. DCN can only '
'be used with CUDA enabled. Alignment is skipped now.'
)
def
check_if_mirror_extended
(
self
,
lqs
):
"""Check whether the input is a mirror-extended sequence.
If mirror-extended, the i-th (i=0, ..., t-1) frame is equal to the (t-1-i)-th frame.
Args:
lqs (tensor): Input low quality (LQ) sequence with shape (n, t, c, h, w).
"""
if
lqs
.
size
(
1
)
%
2
==
0
:
lqs_1
,
lqs_2
=
torch
.
chunk
(
lqs
,
2
,
dim
=
1
)
if
torch
.
norm
(
lqs_1
-
lqs_2
.
flip
(
1
))
==
0
:
self
.
is_mirror_extended
=
True
def
compute_flow
(
self
,
lqs
):
"""Compute optical flow using SPyNet for feature alignment.
Note that if the input is an mirror-extended sequence, 'flows_forward'
is not needed, since it is equal to 'flows_backward.flip(1)'.
Args:
lqs (tensor): Input low quality (LQ) sequence with
shape (n, t, c, h, w).
Return:
tuple(Tensor): Optical flow. 'flows_forward' corresponds to the flows used for forward-time propagation
\
(current to previous). 'flows_backward' corresponds to the flows used for backward-time
\
propagation (current to next).
"""
n
,
t
,
c
,
h
,
w
=
lqs
.
size
()
lqs_1
=
lqs
[:,
:
-
1
,
:,
:,
:].
reshape
(
-
1
,
c
,
h
,
w
)
lqs_2
=
lqs
[:,
1
:,
:,
:,
:].
reshape
(
-
1
,
c
,
h
,
w
)
flows_backward
=
self
.
spynet
(
lqs_1
,
lqs_2
).
view
(
n
,
t
-
1
,
2
,
h
,
w
)
if
self
.
is_mirror_extended
:
# flows_forward = flows_backward.flip(1)
flows_forward
=
flows_backward
.
flip
(
1
)
else
:
flows_forward
=
self
.
spynet
(
lqs_2
,
lqs_1
).
view
(
n
,
t
-
1
,
2
,
h
,
w
)
if
self
.
cpu_cache
:
flows_backward
=
flows_backward
.
cpu
()
flows_forward
=
flows_forward
.
cpu
()
return
flows_forward
,
flows_backward
def
propagate
(
self
,
feats
,
flows
,
module_name
):
"""Propagate the latent features throughout the sequence.
Args:
feats dict(list[tensor]): Features from previous branches. Each
component is a list of tensors with shape (n, c, h, w).
flows (tensor): Optical flows with shape (n, t - 1, 2, h, w).
module_name (str): The name of the propgation branches. Can either
be 'backward_1', 'forward_1', 'backward_2', 'forward_2'.
Return:
dict(list[tensor]): A dictionary containing all the propagated
\
features. Each key in the dictionary corresponds to a
\
propagation branch, which is represented by a list of tensors.
"""
n
,
t
,
_
,
h
,
w
=
flows
.
size
()
frame_idx
=
range
(
0
,
t
+
1
)
flow_idx
=
range
(
-
1
,
t
)
mapping_idx
=
list
(
range
(
0
,
len
(
feats
[
'spatial'
])))
mapping_idx
+=
mapping_idx
[::
-
1
]
if
'backward'
in
module_name
:
frame_idx
=
frame_idx
[::
-
1
]
flow_idx
=
frame_idx
feat_prop
=
flows
.
new_zeros
(
n
,
self
.
mid_channels
,
h
,
w
)
for
i
,
idx
in
enumerate
(
frame_idx
):
feat_current
=
feats
[
'spatial'
][
mapping_idx
[
idx
]]
if
self
.
cpu_cache
:
feat_current
=
feat_current
.
cuda
()
feat_prop
=
feat_prop
.
cuda
()
# second-order deformable alignment
if
i
>
0
and
self
.
is_with_alignment
:
flow_n1
=
flows
[:,
flow_idx
[
i
],
:,
:,
:]
if
self
.
cpu_cache
:
flow_n1
=
flow_n1
.
cuda
()
cond_n1
=
flow_warp
(
feat_prop
,
flow_n1
.
permute
(
0
,
2
,
3
,
1
))
# initialize second-order features
feat_n2
=
torch
.
zeros_like
(
feat_prop
)
flow_n2
=
torch
.
zeros_like
(
flow_n1
)
cond_n2
=
torch
.
zeros_like
(
cond_n1
)
if
i
>
1
:
# second-order features
feat_n2
=
feats
[
module_name
][
-
2
]
if
self
.
cpu_cache
:
feat_n2
=
feat_n2
.
cuda
()
flow_n2
=
flows
[:,
flow_idx
[
i
-
1
],
:,
:,
:]
if
self
.
cpu_cache
:
flow_n2
=
flow_n2
.
cuda
()
flow_n2
=
flow_n1
+
flow_warp
(
flow_n2
,
flow_n1
.
permute
(
0
,
2
,
3
,
1
))
cond_n2
=
flow_warp
(
feat_n2
,
flow_n2
.
permute
(
0
,
2
,
3
,
1
))
# flow-guided deformable convolution
cond
=
torch
.
cat
([
cond_n1
,
feat_current
,
cond_n2
],
dim
=
1
)
feat_prop
=
torch
.
cat
([
feat_prop
,
feat_n2
],
dim
=
1
)
feat_prop
=
self
.
deform_align
[
module_name
](
feat_prop
,
cond
,
flow_n1
,
flow_n2
)
# concatenate and residual blocks
feat
=
[
feat_current
]
+
[
feats
[
k
][
idx
]
for
k
in
feats
if
k
not
in
[
'spatial'
,
module_name
]]
+
[
feat_prop
]
if
self
.
cpu_cache
:
feat
=
[
f
.
cuda
()
for
f
in
feat
]
feat
=
torch
.
cat
(
feat
,
dim
=
1
)
feat_prop
=
feat_prop
+
self
.
backbone
[
module_name
](
feat
)
feats
[
module_name
].
append
(
feat_prop
)
if
self
.
cpu_cache
:
feats
[
module_name
][
-
1
]
=
feats
[
module_name
][
-
1
].
cpu
()
torch
.
cuda
.
empty_cache
()
if
'backward'
in
module_name
:
feats
[
module_name
]
=
feats
[
module_name
][::
-
1
]
return
feats
def
upsample
(
self
,
lqs
,
feats
):
"""Compute the output image given the features.
Args:
lqs (tensor): Input low quality (LQ) sequence with
shape (n, t, c, h, w).
feats (dict): The features from the propagation branches.
Returns:
Tensor: Output HR sequence with shape (n, t, c, 4h, 4w).
"""
outputs
=
[]
num_outputs
=
len
(
feats
[
'spatial'
])
mapping_idx
=
list
(
range
(
0
,
num_outputs
))
mapping_idx
+=
mapping_idx
[::
-
1
]
for
i
in
range
(
0
,
lqs
.
size
(
1
)):
hr
=
[
feats
[
k
].
pop
(
0
)
for
k
in
feats
if
k
!=
'spatial'
]
hr
.
insert
(
0
,
feats
[
'spatial'
][
mapping_idx
[
i
]])
hr
=
torch
.
cat
(
hr
,
dim
=
1
)
if
self
.
cpu_cache
:
hr
=
hr
.
cuda
()
hr
=
self
.
reconstruction
(
hr
)
hr
=
self
.
lrelu
(
self
.
pixel_shuffle
(
self
.
upconv1
(
hr
)))
hr
=
self
.
lrelu
(
self
.
pixel_shuffle
(
self
.
upconv2
(
hr
)))
hr
=
self
.
lrelu
(
self
.
conv_hr
(
hr
))
hr
=
self
.
conv_last
(
hr
)
if
self
.
is_low_res_input
:
hr
+=
self
.
img_upsample
(
lqs
[:,
i
,
:,
:,
:])
else
:
hr
+=
lqs
[:,
i
,
:,
:,
:]
if
self
.
cpu_cache
:
hr
=
hr
.
cpu
()
torch
.
cuda
.
empty_cache
()
outputs
.
append
(
hr
)
return
torch
.
stack
(
outputs
,
dim
=
1
)
def
forward
(
self
,
lqs
):
"""Forward function for BasicVSR++.
Args:
lqs (tensor): Input low quality (LQ) sequence with
shape (n, t, c, h, w).
Returns:
Tensor: Output HR sequence with shape (n, t, c, 4h, 4w).
"""
n
,
t
,
c
,
h
,
w
=
lqs
.
size
()
# whether to cache the features in CPU
self
.
cpu_cache
=
True
if
t
>
self
.
cpu_cache_length
else
False
if
self
.
is_low_res_input
:
lqs_downsample
=
lqs
.
clone
()
else
:
lqs_downsample
=
F
.
interpolate
(
lqs
.
view
(
-
1
,
c
,
h
,
w
),
scale_factor
=
0.25
,
mode
=
'bicubic'
).
view
(
n
,
t
,
c
,
h
//
4
,
w
//
4
)
# check whether the input is an extended sequence
self
.
check_if_mirror_extended
(
lqs
)
feats
=
{}
# compute spatial features
if
self
.
cpu_cache
:
feats
[
'spatial'
]
=
[]
for
i
in
range
(
0
,
t
):
feat
=
self
.
feat_extract
(
lqs
[:,
i
,
:,
:,
:]).
cpu
()
feats
[
'spatial'
].
append
(
feat
)
torch
.
cuda
.
empty_cache
()
else
:
feats_
=
self
.
feat_extract
(
lqs
.
view
(
-
1
,
c
,
h
,
w
))
h
,
w
=
feats_
.
shape
[
2
:]
feats_
=
feats_
.
view
(
n
,
t
,
-
1
,
h
,
w
)
feats
[
'spatial'
]
=
[
feats_
[:,
i
,
:,
:,
:]
for
i
in
range
(
0
,
t
)]
# compute optical flow using the low-res inputs
assert
lqs_downsample
.
size
(
3
)
>=
64
and
lqs_downsample
.
size
(
4
)
>=
64
,
(
'The height and width of low-res inputs must be at least 64, '
f
'but got
{
h
}
and
{
w
}
.'
)
flows_forward
,
flows_backward
=
self
.
compute_flow
(
lqs_downsample
)
# feature propgation
for
iter_
in
[
1
,
2
]:
for
direction
in
[
'backward'
,
'forward'
]:
module
=
f
'
{
direction
}
_
{
iter_
}
'
feats
[
module
]
=
[]
if
direction
==
'backward'
:
flows
=
flows_backward
elif
flows_forward
is
not
None
:
flows
=
flows_forward
else
:
flows
=
flows_backward
.
flip
(
1
)
feats
=
self
.
propagate
(
feats
,
flows
,
module
)
if
self
.
cpu_cache
:
del
flows
torch
.
cuda
.
empty_cache
()
return
self
.
upsample
(
lqs
,
feats
)
class
SecondOrderDeformableAlignment
(
ModulatedDeformConvPack
):
"""Second-order deformable alignment module.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int or tuple[int]): Same as nn.Conv2d.
padding (int or tuple[int]): Same as nn.Conv2d.
dilation (int or tuple[int]): Same as nn.Conv2d.
groups (int): Same as nn.Conv2d.
bias (bool or str): If specified as `auto`, it will be decided by the
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
False.
max_residue_magnitude (int): The maximum magnitude of the offset
residue (Eq. 6 in paper). Default: 10.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
max_residue_magnitude
=
kwargs
.
pop
(
'max_residue_magnitude'
,
10
)
super
(
SecondOrderDeformableAlignment
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
conv_offset
=
nn
.
Sequential
(
nn
.
Conv2d
(
3
*
self
.
out_channels
+
4
,
self
.
out_channels
,
3
,
1
,
1
),
nn
.
LeakyReLU
(
negative_slope
=
0.1
,
inplace
=
True
),
nn
.
Conv2d
(
self
.
out_channels
,
self
.
out_channels
,
3
,
1
,
1
),
nn
.
LeakyReLU
(
negative_slope
=
0.1
,
inplace
=
True
),
nn
.
Conv2d
(
self
.
out_channels
,
self
.
out_channels
,
3
,
1
,
1
),
nn
.
LeakyReLU
(
negative_slope
=
0.1
,
inplace
=
True
),
nn
.
Conv2d
(
self
.
out_channels
,
27
*
self
.
deformable_groups
,
3
,
1
,
1
),
)
self
.
init_offset
()
def
init_offset
(
self
):
def
_constant_init
(
module
,
val
,
bias
=
0
):
if
hasattr
(
module
,
'weight'
)
and
module
.
weight
is
not
None
:
nn
.
init
.
constant_
(
module
.
weight
,
val
)
if
hasattr
(
module
,
'bias'
)
and
module
.
bias
is
not
None
:
nn
.
init
.
constant_
(
module
.
bias
,
bias
)
_constant_init
(
self
.
conv_offset
[
-
1
],
val
=
0
,
bias
=
0
)
def
forward
(
self
,
x
,
extra_feat
,
flow_1
,
flow_2
):
extra_feat
=
torch
.
cat
([
extra_feat
,
flow_1
,
flow_2
],
dim
=
1
)
out
=
self
.
conv_offset
(
extra_feat
)
o1
,
o2
,
mask
=
torch
.
chunk
(
out
,
3
,
dim
=
1
)
# offset
offset
=
self
.
max_residue_magnitude
*
torch
.
tanh
(
torch
.
cat
((
o1
,
o2
),
dim
=
1
))
offset_1
,
offset_2
=
torch
.
chunk
(
offset
,
2
,
dim
=
1
)
offset_1
=
offset_1
+
flow_1
.
flip
(
1
).
repeat
(
1
,
offset_1
.
size
(
1
)
//
2
,
1
,
1
)
offset_2
=
offset_2
+
flow_2
.
flip
(
1
).
repeat
(
1
,
offset_2
.
size
(
1
)
//
2
,
1
,
1
)
offset
=
torch
.
cat
([
offset_1
,
offset_2
],
dim
=
1
)
# mask
mask
=
torch
.
sigmoid
(
mask
)
return
torchvision
.
ops
.
deform_conv2d
(
x
,
offset
,
self
.
weight
,
self
.
bias
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
mask
)
# if __name__ == '__main__':
# spynet_path = 'experiments/pretrained_models/flownet/spynet_sintel_final-3d2a1287.pth'
# model = BasicVSRPlusPlus(spynet_path=spynet_path).cuda()
# input = torch.rand(1, 2, 3, 64, 64).cuda()
# output = model(input)
# print('===================')
# print(output.shape)
BasicSR/basicsr/archs/dfdnet_arch.py
0 → 100644
View file @
e2696ece
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.nn.utils.spectral_norm
import
spectral_norm
from
basicsr.utils.registry
import
ARCH_REGISTRY
from
.dfdnet_util
import
AttentionBlock
,
Blur
,
MSDilationBlock
,
UpResBlock
,
adaptive_instance_normalization
from
.vgg_arch
import
VGGFeatureExtractor
class
SFTUpBlock
(
nn
.
Module
):
"""Spatial feature transform (SFT) with upsampling block.
Args:
in_channel (int): Number of input channels.
out_channel (int): Number of output channels.
kernel_size (int): Kernel size in convolutions. Default: 3.
padding (int): Padding in convolutions. Default: 1.
"""
def
__init__
(
self
,
in_channel
,
out_channel
,
kernel_size
=
3
,
padding
=
1
):
super
(
SFTUpBlock
,
self
).
__init__
()
self
.
conv1
=
nn
.
Sequential
(
Blur
(
in_channel
),
spectral_norm
(
nn
.
Conv2d
(
in_channel
,
out_channel
,
kernel_size
,
padding
=
padding
)),
nn
.
LeakyReLU
(
0.04
,
True
),
# The official codes use two LeakyReLU here, so 0.04 for equivalent
)
self
.
convup
=
nn
.
Sequential
(
nn
.
Upsample
(
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
False
),
spectral_norm
(
nn
.
Conv2d
(
out_channel
,
out_channel
,
kernel_size
,
padding
=
padding
)),
nn
.
LeakyReLU
(
0.2
,
True
),
)
# for SFT scale and shift
self
.
scale_block
=
nn
.
Sequential
(
spectral_norm
(
nn
.
Conv2d
(
in_channel
,
out_channel
,
3
,
1
,
1
)),
nn
.
LeakyReLU
(
0.2
,
True
),
spectral_norm
(
nn
.
Conv2d
(
out_channel
,
out_channel
,
3
,
1
,
1
)))
self
.
shift_block
=
nn
.
Sequential
(
spectral_norm
(
nn
.
Conv2d
(
in_channel
,
out_channel
,
3
,
1
,
1
)),
nn
.
LeakyReLU
(
0.2
,
True
),
spectral_norm
(
nn
.
Conv2d
(
out_channel
,
out_channel
,
3
,
1
,
1
)),
nn
.
Sigmoid
())
# The official codes use sigmoid for shift block, do not know why
def
forward
(
self
,
x
,
updated_feat
):
out
=
self
.
conv1
(
x
)
# SFT
scale
=
self
.
scale_block
(
updated_feat
)
shift
=
self
.
shift_block
(
updated_feat
)
out
=
out
*
scale
+
shift
# upsample
out
=
self
.
convup
(
out
)
return
out
@
ARCH_REGISTRY
.
register
()
class
DFDNet
(
nn
.
Module
):
"""DFDNet: Deep Face Dictionary Network.
It only processes faces with 512x512 size.
Args:
num_feat (int): Number of feature channels.
dict_path (str): Path to the facial component dictionary.
"""
def
__init__
(
self
,
num_feat
,
dict_path
):
super
().
__init__
()
self
.
parts
=
[
'left_eye'
,
'right_eye'
,
'nose'
,
'mouth'
]
# part_sizes: [80, 80, 50, 110]
channel_sizes
=
[
128
,
256
,
512
,
512
]
self
.
feature_sizes
=
np
.
array
([
256
,
128
,
64
,
32
])
self
.
vgg_layers
=
[
'relu2_2'
,
'relu3_4'
,
'relu4_4'
,
'conv5_4'
]
self
.
flag_dict_device
=
False
# dict
self
.
dict
=
torch
.
load
(
dict_path
)
# vgg face extractor
self
.
vgg_extractor
=
VGGFeatureExtractor
(
layer_name_list
=
self
.
vgg_layers
,
vgg_type
=
'vgg19'
,
use_input_norm
=
True
,
range_norm
=
True
,
requires_grad
=
False
)
# attention block for fusing dictionary features and input features
self
.
attn_blocks
=
nn
.
ModuleDict
()
for
idx
,
feat_size
in
enumerate
(
self
.
feature_sizes
):
for
name
in
self
.
parts
:
self
.
attn_blocks
[
f
'
{
name
}
_
{
feat_size
}
'
]
=
AttentionBlock
(
channel_sizes
[
idx
])
# multi scale dilation block
self
.
multi_scale_dilation
=
MSDilationBlock
(
num_feat
*
8
,
dilation
=
[
4
,
3
,
2
,
1
])
# upsampling and reconstruction
self
.
upsample0
=
SFTUpBlock
(
num_feat
*
8
,
num_feat
*
8
)
self
.
upsample1
=
SFTUpBlock
(
num_feat
*
8
,
num_feat
*
4
)
self
.
upsample2
=
SFTUpBlock
(
num_feat
*
4
,
num_feat
*
2
)
self
.
upsample3
=
SFTUpBlock
(
num_feat
*
2
,
num_feat
)
self
.
upsample4
=
nn
.
Sequential
(
spectral_norm
(
nn
.
Conv2d
(
num_feat
,
num_feat
,
3
,
1
,
1
)),
nn
.
LeakyReLU
(
0.2
,
True
),
UpResBlock
(
num_feat
),
UpResBlock
(
num_feat
),
nn
.
Conv2d
(
num_feat
,
3
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
),
nn
.
Tanh
())
def
swap_feat
(
self
,
vgg_feat
,
updated_feat
,
dict_feat
,
location
,
part_name
,
f_size
):
"""swap the features from the dictionary."""
# get the original vgg features
part_feat
=
vgg_feat
[:,
:,
location
[
1
]:
location
[
3
],
location
[
0
]:
location
[
2
]].
clone
()
# resize original vgg features
part_resize_feat
=
F
.
interpolate
(
part_feat
,
dict_feat
.
size
()[
2
:
4
],
mode
=
'bilinear'
,
align_corners
=
False
)
# use adaptive instance normalization to adjust color and illuminations
dict_feat
=
adaptive_instance_normalization
(
dict_feat
,
part_resize_feat
)
# get similarity scores
similarity_score
=
F
.
conv2d
(
part_resize_feat
,
dict_feat
)
similarity_score
=
F
.
softmax
(
similarity_score
.
view
(
-
1
),
dim
=
0
)
# select the most similar features in the dict (after norm)
select_idx
=
torch
.
argmax
(
similarity_score
)
swap_feat
=
F
.
interpolate
(
dict_feat
[
select_idx
:
select_idx
+
1
],
part_feat
.
size
()[
2
:
4
])
# attention
attn
=
self
.
attn_blocks
[
f
'
{
part_name
}
_'
+
str
(
f_size
)](
swap_feat
-
part_feat
)
attn_feat
=
attn
*
swap_feat
# update features
updated_feat
[:,
:,
location
[
1
]:
location
[
3
],
location
[
0
]:
location
[
2
]]
=
attn_feat
+
part_feat
return
updated_feat
def
put_dict_to_device
(
self
,
x
):
if
self
.
flag_dict_device
is
False
:
for
k
,
v
in
self
.
dict
.
items
():
for
kk
,
vv
in
v
.
items
():
self
.
dict
[
k
][
kk
]
=
vv
.
to
(
x
)
self
.
flag_dict_device
=
True
def
forward
(
self
,
x
,
part_locations
):
"""
Now only support testing with batch size = 0.
Args:
x (Tensor): Input faces with shape (b, c, 512, 512).
part_locations (list[Tensor]): Part locations.
"""
self
.
put_dict_to_device
(
x
)
# extract vggface features
vgg_features
=
self
.
vgg_extractor
(
x
)
# update vggface features using the dictionary for each part
updated_vgg_features
=
[]
batch
=
0
# only supports testing with batch size = 0
for
vgg_layer
,
f_size
in
zip
(
self
.
vgg_layers
,
self
.
feature_sizes
):
dict_features
=
self
.
dict
[
f
'
{
f_size
}
'
]
vgg_feat
=
vgg_features
[
vgg_layer
]
updated_feat
=
vgg_feat
.
clone
()
# swap features from dictionary
for
part_idx
,
part_name
in
enumerate
(
self
.
parts
):
location
=
(
part_locations
[
part_idx
][
batch
]
//
(
512
/
f_size
)).
int
()
updated_feat
=
self
.
swap_feat
(
vgg_feat
,
updated_feat
,
dict_features
[
part_name
],
location
,
part_name
,
f_size
)
updated_vgg_features
.
append
(
updated_feat
)
vgg_feat_dilation
=
self
.
multi_scale_dilation
(
vgg_features
[
'conv5_4'
])
# use updated vgg features to modulate the upsampled features with
# SFT (Spatial Feature Transform) scaling and shifting manner.
upsampled_feat
=
self
.
upsample0
(
vgg_feat_dilation
,
updated_vgg_features
[
3
])
upsampled_feat
=
self
.
upsample1
(
upsampled_feat
,
updated_vgg_features
[
2
])
upsampled_feat
=
self
.
upsample2
(
upsampled_feat
,
updated_vgg_features
[
1
])
upsampled_feat
=
self
.
upsample3
(
upsampled_feat
,
updated_vgg_features
[
0
])
out
=
self
.
upsample4
(
upsampled_feat
)
return
out
BasicSR/basicsr/archs/dfdnet_util.py
0 → 100644
View file @
e2696ece
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.autograd
import
Function
from
torch.nn.utils.spectral_norm
import
spectral_norm
class
BlurFunctionBackward
(
Function
):
@
staticmethod
def
forward
(
ctx
,
grad_output
,
kernel
,
kernel_flip
):
ctx
.
save_for_backward
(
kernel
,
kernel_flip
)
grad_input
=
F
.
conv2d
(
grad_output
,
kernel_flip
,
padding
=
1
,
groups
=
grad_output
.
shape
[
1
])
return
grad_input
@
staticmethod
def
backward
(
ctx
,
gradgrad_output
):
kernel
,
_
=
ctx
.
saved_tensors
grad_input
=
F
.
conv2d
(
gradgrad_output
,
kernel
,
padding
=
1
,
groups
=
gradgrad_output
.
shape
[
1
])
return
grad_input
,
None
,
None
class
BlurFunction
(
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
kernel
,
kernel_flip
):
ctx
.
save_for_backward
(
kernel
,
kernel_flip
)
output
=
F
.
conv2d
(
x
,
kernel
,
padding
=
1
,
groups
=
x
.
shape
[
1
])
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
kernel
,
kernel_flip
=
ctx
.
saved_tensors
grad_input
=
BlurFunctionBackward
.
apply
(
grad_output
,
kernel
,
kernel_flip
)
return
grad_input
,
None
,
None
blur
=
BlurFunction
.
apply
class
Blur
(
nn
.
Module
):
def
__init__
(
self
,
channel
):
super
().
__init__
()
kernel
=
torch
.
tensor
([[
1
,
2
,
1
],
[
2
,
4
,
2
],
[
1
,
2
,
1
]],
dtype
=
torch
.
float32
)
kernel
=
kernel
.
view
(
1
,
1
,
3
,
3
)
kernel
=
kernel
/
kernel
.
sum
()
kernel_flip
=
torch
.
flip
(
kernel
,
[
2
,
3
])
self
.
kernel
=
kernel
.
repeat
(
channel
,
1
,
1
,
1
)
self
.
kernel_flip
=
kernel_flip
.
repeat
(
channel
,
1
,
1
,
1
)
def
forward
(
self
,
x
):
return
blur
(
x
,
self
.
kernel
.
type_as
(
x
),
self
.
kernel_flip
.
type_as
(
x
))
def
calc_mean_std
(
feat
,
eps
=
1e-5
):
"""Calculate mean and std for adaptive_instance_normalization.
Args:
feat (Tensor): 4D tensor.
eps (float): A small value added to the variance to avoid
divide-by-zero. Default: 1e-5.
"""
size
=
feat
.
size
()
assert
len
(
size
)
==
4
,
'The input feature should be 4D tensor.'
n
,
c
=
size
[:
2
]
feat_var
=
feat
.
view
(
n
,
c
,
-
1
).
var
(
dim
=
2
)
+
eps
feat_std
=
feat_var
.
sqrt
().
view
(
n
,
c
,
1
,
1
)
feat_mean
=
feat
.
view
(
n
,
c
,
-
1
).
mean
(
dim
=
2
).
view
(
n
,
c
,
1
,
1
)
return
feat_mean
,
feat_std
def
adaptive_instance_normalization
(
content_feat
,
style_feat
):
"""Adaptive instance normalization.
Adjust the reference features to have the similar color and illuminations
as those in the degradate features.
Args:
content_feat (Tensor): The reference feature.
style_feat (Tensor): The degradate features.
"""
size
=
content_feat
.
size
()
style_mean
,
style_std
=
calc_mean_std
(
style_feat
)
content_mean
,
content_std
=
calc_mean_std
(
content_feat
)
normalized_feat
=
(
content_feat
-
content_mean
.
expand
(
size
))
/
content_std
.
expand
(
size
)
return
normalized_feat
*
style_std
.
expand
(
size
)
+
style_mean
.
expand
(
size
)
def
AttentionBlock
(
in_channel
):
return
nn
.
Sequential
(
spectral_norm
(
nn
.
Conv2d
(
in_channel
,
in_channel
,
3
,
1
,
1
)),
nn
.
LeakyReLU
(
0.2
,
True
),
spectral_norm
(
nn
.
Conv2d
(
in_channel
,
in_channel
,
3
,
1
,
1
)))
def
conv_block
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
dilation
=
1
,
bias
=
True
):
"""Conv block used in MSDilationBlock."""
return
nn
.
Sequential
(
spectral_norm
(
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
dilation
=
dilation
,
padding
=
((
kernel_size
-
1
)
//
2
)
*
dilation
,
bias
=
bias
)),
nn
.
LeakyReLU
(
0.2
),
spectral_norm
(
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
dilation
=
dilation
,
padding
=
((
kernel_size
-
1
)
//
2
)
*
dilation
,
bias
=
bias
)),
)
class
MSDilationBlock
(
nn
.
Module
):
"""Multi-scale dilation block."""
def
__init__
(
self
,
in_channels
,
kernel_size
=
3
,
dilation
=
(
1
,
1
,
1
,
1
),
bias
=
True
):
super
(
MSDilationBlock
,
self
).
__init__
()
self
.
conv_blocks
=
nn
.
ModuleList
()
for
i
in
range
(
4
):
self
.
conv_blocks
.
append
(
conv_block
(
in_channels
,
in_channels
,
kernel_size
,
dilation
=
dilation
[
i
],
bias
=
bias
))
self
.
conv_fusion
=
spectral_norm
(
nn
.
Conv2d
(
in_channels
*
4
,
in_channels
,
kernel_size
=
kernel_size
,
stride
=
1
,
padding
=
(
kernel_size
-
1
)
//
2
,
bias
=
bias
))
def
forward
(
self
,
x
):
out
=
[]
for
i
in
range
(
4
):
out
.
append
(
self
.
conv_blocks
[
i
](
x
))
out
=
torch
.
cat
(
out
,
1
)
out
=
self
.
conv_fusion
(
out
)
+
x
return
out
class
UpResBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channel
):
super
(
UpResBlock
,
self
).
__init__
()
self
.
body
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_channel
,
in_channel
,
3
,
1
,
1
),
nn
.
LeakyReLU
(
0.2
,
True
),
nn
.
Conv2d
(
in_channel
,
in_channel
,
3
,
1
,
1
),
)
def
forward
(
self
,
x
):
out
=
x
+
self
.
body
(
x
)
return
out
BasicSR/basicsr/archs/discriminator_arch.py
0 → 100644
View file @
e2696ece
from
torch
import
nn
as
nn
from
torch.nn
import
functional
as
F
from
torch.nn.utils
import
spectral_norm
from
basicsr.utils.registry
import
ARCH_REGISTRY
@
ARCH_REGISTRY
.
register
()
class
VGGStyleDiscriminator
(
nn
.
Module
):
"""VGG style discriminator with input size 128 x 128 or 256 x 256.
It is used to train SRGAN, ESRGAN, and VideoGAN.
Args:
num_in_ch (int): Channel number of inputs. Default: 3.
num_feat (int): Channel number of base intermediate features.Default: 64.
"""
def
__init__
(
self
,
num_in_ch
,
num_feat
,
input_size
=
128
):
super
(
VGGStyleDiscriminator
,
self
).
__init__
()
self
.
input_size
=
input_size
assert
self
.
input_size
==
128
or
self
.
input_size
==
256
,
(
f
'input size must be 128 or 256, but received
{
input_size
}
'
)
self
.
conv0_0
=
nn
.
Conv2d
(
num_in_ch
,
num_feat
,
3
,
1
,
1
,
bias
=
True
)
self
.
conv0_1
=
nn
.
Conv2d
(
num_feat
,
num_feat
,
4
,
2
,
1
,
bias
=
False
)
self
.
bn0_1
=
nn
.
BatchNorm2d
(
num_feat
,
affine
=
True
)
self
.
conv1_0
=
nn
.
Conv2d
(
num_feat
,
num_feat
*
2
,
3
,
1
,
1
,
bias
=
False
)
self
.
bn1_0
=
nn
.
BatchNorm2d
(
num_feat
*
2
,
affine
=
True
)
self
.
conv1_1
=
nn
.
Conv2d
(
num_feat
*
2
,
num_feat
*
2
,
4
,
2
,
1
,
bias
=
False
)
self
.
bn1_1
=
nn
.
BatchNorm2d
(
num_feat
*
2
,
affine
=
True
)
self
.
conv2_0
=
nn
.
Conv2d
(
num_feat
*
2
,
num_feat
*
4
,
3
,
1
,
1
,
bias
=
False
)
self
.
bn2_0
=
nn
.
BatchNorm2d
(
num_feat
*
4
,
affine
=
True
)
self
.
conv2_1
=
nn
.
Conv2d
(
num_feat
*
4
,
num_feat
*
4
,
4
,
2
,
1
,
bias
=
False
)
self
.
bn2_1
=
nn
.
BatchNorm2d
(
num_feat
*
4
,
affine
=
True
)
self
.
conv3_0
=
nn
.
Conv2d
(
num_feat
*
4
,
num_feat
*
8
,
3
,
1
,
1
,
bias
=
False
)
self
.
bn3_0
=
nn
.
BatchNorm2d
(
num_feat
*
8
,
affine
=
True
)
self
.
conv3_1
=
nn
.
Conv2d
(
num_feat
*
8
,
num_feat
*
8
,
4
,
2
,
1
,
bias
=
False
)
self
.
bn3_1
=
nn
.
BatchNorm2d
(
num_feat
*
8
,
affine
=
True
)
self
.
conv4_0
=
nn
.
Conv2d
(
num_feat
*
8
,
num_feat
*
8
,
3
,
1
,
1
,
bias
=
False
)
self
.
bn4_0
=
nn
.
BatchNorm2d
(
num_feat
*
8
,
affine
=
True
)
self
.
conv4_1
=
nn
.
Conv2d
(
num_feat
*
8
,
num_feat
*
8
,
4
,
2
,
1
,
bias
=
False
)
self
.
bn4_1
=
nn
.
BatchNorm2d
(
num_feat
*
8
,
affine
=
True
)
if
self
.
input_size
==
256
:
self
.
conv5_0
=
nn
.
Conv2d
(
num_feat
*
8
,
num_feat
*
8
,
3
,
1
,
1
,
bias
=
False
)
self
.
bn5_0
=
nn
.
BatchNorm2d
(
num_feat
*
8
,
affine
=
True
)
self
.
conv5_1
=
nn
.
Conv2d
(
num_feat
*
8
,
num_feat
*
8
,
4
,
2
,
1
,
bias
=
False
)
self
.
bn5_1
=
nn
.
BatchNorm2d
(
num_feat
*
8
,
affine
=
True
)
self
.
linear1
=
nn
.
Linear
(
num_feat
*
8
*
4
*
4
,
100
)
self
.
linear2
=
nn
.
Linear
(
100
,
1
)
# activation function
self
.
lrelu
=
nn
.
LeakyReLU
(
negative_slope
=
0.2
,
inplace
=
True
)
def
forward
(
self
,
x
):
assert
x
.
size
(
2
)
==
self
.
input_size
,
(
f
'Input size must be identical to input_size, but received
{
x
.
size
()
}
.'
)
feat
=
self
.
lrelu
(
self
.
conv0_0
(
x
))
feat
=
self
.
lrelu
(
self
.
bn0_1
(
self
.
conv0_1
(
feat
)))
# output spatial size: /2
feat
=
self
.
lrelu
(
self
.
bn1_0
(
self
.
conv1_0
(
feat
)))
feat
=
self
.
lrelu
(
self
.
bn1_1
(
self
.
conv1_1
(
feat
)))
# output spatial size: /4
feat
=
self
.
lrelu
(
self
.
bn2_0
(
self
.
conv2_0
(
feat
)))
feat
=
self
.
lrelu
(
self
.
bn2_1
(
self
.
conv2_1
(
feat
)))
# output spatial size: /8
feat
=
self
.
lrelu
(
self
.
bn3_0
(
self
.
conv3_0
(
feat
)))
feat
=
self
.
lrelu
(
self
.
bn3_1
(
self
.
conv3_1
(
feat
)))
# output spatial size: /16
feat
=
self
.
lrelu
(
self
.
bn4_0
(
self
.
conv4_0
(
feat
)))
feat
=
self
.
lrelu
(
self
.
bn4_1
(
self
.
conv4_1
(
feat
)))
# output spatial size: /32
if
self
.
input_size
==
256
:
feat
=
self
.
lrelu
(
self
.
bn5_0
(
self
.
conv5_0
(
feat
)))
feat
=
self
.
lrelu
(
self
.
bn5_1
(
self
.
conv5_1
(
feat
)))
# output spatial size: / 64
# spatial size: (4, 4)
feat
=
feat
.
view
(
feat
.
size
(
0
),
-
1
)
feat
=
self
.
lrelu
(
self
.
linear1
(
feat
))
out
=
self
.
linear2
(
feat
)
return
out
@
ARCH_REGISTRY
.
register
(
suffix
=
'basicsr'
)
class
UNetDiscriminatorSN
(
nn
.
Module
):
"""Defines a U-Net discriminator with spectral normalization (SN)
It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
Arg:
num_in_ch (int): Channel number of inputs. Default: 3.
num_feat (int): Channel number of base intermediate features. Default: 64.
skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
"""
def
__init__
(
self
,
num_in_ch
,
num_feat
=
64
,
skip_connection
=
True
):
super
(
UNetDiscriminatorSN
,
self
).
__init__
()
self
.
skip_connection
=
skip_connection
norm
=
spectral_norm
# the first convolution
self
.
conv0
=
nn
.
Conv2d
(
num_in_ch
,
num_feat
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
# downsample
self
.
conv1
=
norm
(
nn
.
Conv2d
(
num_feat
,
num_feat
*
2
,
4
,
2
,
1
,
bias
=
False
))
self
.
conv2
=
norm
(
nn
.
Conv2d
(
num_feat
*
2
,
num_feat
*
4
,
4
,
2
,
1
,
bias
=
False
))
self
.
conv3
=
norm
(
nn
.
Conv2d
(
num_feat
*
4
,
num_feat
*
8
,
4
,
2
,
1
,
bias
=
False
))
# upsample
self
.
conv4
=
norm
(
nn
.
Conv2d
(
num_feat
*
8
,
num_feat
*
4
,
3
,
1
,
1
,
bias
=
False
))
self
.
conv5
=
norm
(
nn
.
Conv2d
(
num_feat
*
4
,
num_feat
*
2
,
3
,
1
,
1
,
bias
=
False
))
self
.
conv6
=
norm
(
nn
.
Conv2d
(
num_feat
*
2
,
num_feat
,
3
,
1
,
1
,
bias
=
False
))
# extra convolutions
self
.
conv7
=
norm
(
nn
.
Conv2d
(
num_feat
,
num_feat
,
3
,
1
,
1
,
bias
=
False
))
self
.
conv8
=
norm
(
nn
.
Conv2d
(
num_feat
,
num_feat
,
3
,
1
,
1
,
bias
=
False
))
self
.
conv9
=
nn
.
Conv2d
(
num_feat
,
1
,
3
,
1
,
1
)
def
forward
(
self
,
x
):
# downsample
x0
=
F
.
leaky_relu
(
self
.
conv0
(
x
),
negative_slope
=
0.2
,
inplace
=
True
)
x1
=
F
.
leaky_relu
(
self
.
conv1
(
x0
),
negative_slope
=
0.2
,
inplace
=
True
)
x2
=
F
.
leaky_relu
(
self
.
conv2
(
x1
),
negative_slope
=
0.2
,
inplace
=
True
)
x3
=
F
.
leaky_relu
(
self
.
conv3
(
x2
),
negative_slope
=
0.2
,
inplace
=
True
)
# upsample
x3
=
F
.
interpolate
(
x3
,
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
False
)
x4
=
F
.
leaky_relu
(
self
.
conv4
(
x3
),
negative_slope
=
0.2
,
inplace
=
True
)
if
self
.
skip_connection
:
x4
=
x4
+
x2
x4
=
F
.
interpolate
(
x4
,
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
False
)
x5
=
F
.
leaky_relu
(
self
.
conv5
(
x4
),
negative_slope
=
0.2
,
inplace
=
True
)
if
self
.
skip_connection
:
x5
=
x5
+
x1
x5
=
F
.
interpolate
(
x5
,
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
False
)
x6
=
F
.
leaky_relu
(
self
.
conv6
(
x5
),
negative_slope
=
0.2
,
inplace
=
True
)
if
self
.
skip_connection
:
x6
=
x6
+
x0
# extra convolutions
out
=
F
.
leaky_relu
(
self
.
conv7
(
x6
),
negative_slope
=
0.2
,
inplace
=
True
)
out
=
F
.
leaky_relu
(
self
.
conv8
(
out
),
negative_slope
=
0.2
,
inplace
=
True
)
out
=
self
.
conv9
(
out
)
return
out
BasicSR/basicsr/archs/duf_arch.py
0 → 100644
View file @
e2696ece
import
numpy
as
np
import
torch
from
torch
import
nn
as
nn
from
torch.nn
import
functional
as
F
from
basicsr.utils.registry
import
ARCH_REGISTRY
class
DenseBlocksTemporalReduce
(
nn
.
Module
):
"""A concatenation of 3 dense blocks with reduction in temporal dimension.
Note that the output temporal dimension is 6 fewer the input temporal dimension, since there are 3 blocks.
Args:
num_feat (int): Number of channels in the blocks. Default: 64.
num_grow_ch (int): Growing factor of the dense blocks. Default: 32
adapt_official_weights (bool): Whether to adapt the weights translated from the official implementation.
Set to false if you want to train from scratch. Default: False.
"""
def
__init__
(
self
,
num_feat
=
64
,
num_grow_ch
=
32
,
adapt_official_weights
=
False
):
super
(
DenseBlocksTemporalReduce
,
self
).
__init__
()
if
adapt_official_weights
:
eps
=
1e-3
momentum
=
1e-3
else
:
# pytorch default values
eps
=
1e-05
momentum
=
0.1
self
.
temporal_reduce1
=
nn
.
Sequential
(
nn
.
BatchNorm3d
(
num_feat
,
eps
=
eps
,
momentum
=
momentum
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Conv3d
(
num_feat
,
num_feat
,
(
1
,
1
,
1
),
stride
=
(
1
,
1
,
1
),
padding
=
(
0
,
0
,
0
),
bias
=
True
),
nn
.
BatchNorm3d
(
num_feat
,
eps
=
eps
,
momentum
=
momentum
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Conv3d
(
num_feat
,
num_grow_ch
,
(
3
,
3
,
3
),
stride
=
(
1
,
1
,
1
),
padding
=
(
0
,
1
,
1
),
bias
=
True
))
self
.
temporal_reduce2
=
nn
.
Sequential
(
nn
.
BatchNorm3d
(
num_feat
+
num_grow_ch
,
eps
=
eps
,
momentum
=
momentum
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Conv3d
(
num_feat
+
num_grow_ch
,
num_feat
+
num_grow_ch
,
(
1
,
1
,
1
),
stride
=
(
1
,
1
,
1
),
padding
=
(
0
,
0
,
0
),
bias
=
True
),
nn
.
BatchNorm3d
(
num_feat
+
num_grow_ch
,
eps
=
eps
,
momentum
=
momentum
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Conv3d
(
num_feat
+
num_grow_ch
,
num_grow_ch
,
(
3
,
3
,
3
),
stride
=
(
1
,
1
,
1
),
padding
=
(
0
,
1
,
1
),
bias
=
True
))
self
.
temporal_reduce3
=
nn
.
Sequential
(
nn
.
BatchNorm3d
(
num_feat
+
2
*
num_grow_ch
,
eps
=
eps
,
momentum
=
momentum
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Conv3d
(
num_feat
+
2
*
num_grow_ch
,
num_feat
+
2
*
num_grow_ch
,
(
1
,
1
,
1
),
stride
=
(
1
,
1
,
1
),
padding
=
(
0
,
0
,
0
),
bias
=
True
),
nn
.
BatchNorm3d
(
num_feat
+
2
*
num_grow_ch
,
eps
=
eps
,
momentum
=
momentum
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Conv3d
(
num_feat
+
2
*
num_grow_ch
,
num_grow_ch
,
(
3
,
3
,
3
),
stride
=
(
1
,
1
,
1
),
padding
=
(
0
,
1
,
1
),
bias
=
True
))
def
forward
(
self
,
x
):
"""
Args:
x (Tensor): Input tensor with shape (b, num_feat, t, h, w).
Returns:
Tensor: Output with shape (b, num_feat + num_grow_ch * 3, 1, h, w).
"""
x1
=
self
.
temporal_reduce1
(
x
)
x1
=
torch
.
cat
((
x
[:,
:,
1
:
-
1
,
:,
:],
x1
),
1
)
x2
=
self
.
temporal_reduce2
(
x1
)
x2
=
torch
.
cat
((
x1
[:,
:,
1
:
-
1
,
:,
:],
x2
),
1
)
x3
=
self
.
temporal_reduce3
(
x2
)
x3
=
torch
.
cat
((
x2
[:,
:,
1
:
-
1
,
:,
:],
x3
),
1
)
return
x3
class
DenseBlocks
(
nn
.
Module
):
""" A concatenation of N dense blocks.
Args:
num_feat (int): Number of channels in the blocks. Default: 64.
num_grow_ch (int): Growing factor of the dense blocks. Default: 32.
num_block (int): Number of dense blocks. The values are:
DUF-S (16 layers): 3
DUF-M (18 layers): 9
DUF-L (52 layers): 21
adapt_official_weights (bool): Whether to adapt the weights translated from the official implementation.
Set to false if you want to train from scratch. Default: False.
"""
def
__init__
(
self
,
num_block
,
num_feat
=
64
,
num_grow_ch
=
16
,
adapt_official_weights
=
False
):
super
(
DenseBlocks
,
self
).
__init__
()
if
adapt_official_weights
:
eps
=
1e-3
momentum
=
1e-3
else
:
# pytorch default values
eps
=
1e-05
momentum
=
0.1
self
.
dense_blocks
=
nn
.
ModuleList
()
for
i
in
range
(
0
,
num_block
):
self
.
dense_blocks
.
append
(
nn
.
Sequential
(
nn
.
BatchNorm3d
(
num_feat
+
i
*
num_grow_ch
,
eps
=
eps
,
momentum
=
momentum
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Conv3d
(
num_feat
+
i
*
num_grow_ch
,
num_feat
+
i
*
num_grow_ch
,
(
1
,
1
,
1
),
stride
=
(
1
,
1
,
1
),
padding
=
(
0
,
0
,
0
),
bias
=
True
),
nn
.
BatchNorm3d
(
num_feat
+
i
*
num_grow_ch
,
eps
=
eps
,
momentum
=
momentum
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Conv3d
(
num_feat
+
i
*
num_grow_ch
,
num_grow_ch
,
(
3
,
3
,
3
),
stride
=
(
1
,
1
,
1
),
padding
=
(
1
,
1
,
1
),
bias
=
True
)))
def
forward
(
self
,
x
):
"""
Args:
x (Tensor): Input tensor with shape (b, num_feat, t, h, w).
Returns:
Tensor: Output with shape (b, num_feat + num_block * num_grow_ch, t, h, w).
"""
for
i
in
range
(
0
,
len
(
self
.
dense_blocks
)):
y
=
self
.
dense_blocks
[
i
](
x
)
x
=
torch
.
cat
((
x
,
y
),
1
)
return
x
class
DynamicUpsamplingFilter
(
nn
.
Module
):
"""Dynamic upsampling filter used in DUF.
Reference: https://github.com/yhjo09/VSR-DUF
It only supports input with 3 channels. And it applies the same filters to 3 channels.
Args:
filter_size (tuple): Filter size of generated filters. The shape is (kh, kw). Default: (5, 5).
"""
def
__init__
(
self
,
filter_size
=
(
5
,
5
)):
super
(
DynamicUpsamplingFilter
,
self
).
__init__
()
if
not
isinstance
(
filter_size
,
tuple
):
raise
TypeError
(
f
'The type of filter_size must be tuple, but got type
{
filter_size
}
'
)
if
len
(
filter_size
)
!=
2
:
raise
ValueError
(
f
'The length of filter size must be 2, but got
{
len
(
filter_size
)
}
.'
)
# generate a local expansion filter, similar to im2col
self
.
filter_size
=
filter_size
filter_prod
=
np
.
prod
(
filter_size
)
expansion_filter
=
torch
.
eye
(
int
(
filter_prod
)).
view
(
filter_prod
,
1
,
*
filter_size
)
# (kh*kw, 1, kh, kw)
self
.
expansion_filter
=
expansion_filter
.
repeat
(
3
,
1
,
1
,
1
)
# repeat for all the 3 channels
def
forward
(
self
,
x
,
filters
):
"""Forward function for DynamicUpsamplingFilter.
Args:
x (Tensor): Input image with 3 channels. The shape is (n, 3, h, w).
filters (Tensor): Generated dynamic filters. The shape is (n, filter_prod, upsampling_square, h, w).
filter_prod: prod of filter kernel size, e.g., 1*5*5=25.
upsampling_square: similar to pixel shuffle, upsampling_square = upsampling * upsampling.
e.g., for x 4 upsampling, upsampling_square= 4*4 = 16
Returns:
Tensor: Filtered image with shape (n, 3*upsampling_square, h, w)
"""
n
,
filter_prod
,
upsampling_square
,
h
,
w
=
filters
.
size
()
kh
,
kw
=
self
.
filter_size
expanded_input
=
F
.
conv2d
(
x
,
self
.
expansion_filter
.
to
(
x
),
padding
=
(
kh
//
2
,
kw
//
2
),
groups
=
3
)
# (n, 3*filter_prod, h, w)
expanded_input
=
expanded_input
.
view
(
n
,
3
,
filter_prod
,
h
,
w
).
permute
(
0
,
3
,
4
,
1
,
2
)
# (n, h, w, 3, filter_prod)
filters
=
filters
.
permute
(
0
,
3
,
4
,
1
,
2
)
# (n, h, w, filter_prod, upsampling_square]
out
=
torch
.
matmul
(
expanded_input
,
filters
)
# (n, h, w, 3, upsampling_square)
return
out
.
permute
(
0
,
3
,
4
,
1
,
2
).
view
(
n
,
3
*
upsampling_square
,
h
,
w
)
@
ARCH_REGISTRY
.
register
()
class
DUF
(
nn
.
Module
):
"""Network architecture for DUF
``Paper: Deep Video Super-Resolution Network Using Dynamic Upsampling Filters Without Explicit Motion Compensation``
Reference: https://github.com/yhjo09/VSR-DUF
For all the models below, 'adapt_official_weights' is only necessary when
loading the weights converted from the official TensorFlow weights.
Please set it to False if you are training the model from scratch.
There are three models with different model size: DUF16Layers, DUF28Layers,
and DUF52Layers. This class is the base class for these models.
Args:
scale (int): The upsampling factor. Default: 4.
num_layer (int): The number of layers. Default: 52.
adapt_official_weights_weights (bool): Whether to adapt the weights
translated from the official implementation. Set to false if you
want to train from scratch. Default: False.
"""
def
__init__
(
self
,
scale
=
4
,
num_layer
=
52
,
adapt_official_weights
=
False
):
super
(
DUF
,
self
).
__init__
()
self
.
scale
=
scale
if
adapt_official_weights
:
eps
=
1e-3
momentum
=
1e-3
else
:
# pytorch default values
eps
=
1e-05
momentum
=
0.1
self
.
conv3d1
=
nn
.
Conv3d
(
3
,
64
,
(
1
,
3
,
3
),
stride
=
(
1
,
1
,
1
),
padding
=
(
0
,
1
,
1
),
bias
=
True
)
self
.
dynamic_filter
=
DynamicUpsamplingFilter
((
5
,
5
))
if
num_layer
==
16
:
num_block
=
3
num_grow_ch
=
32
elif
num_layer
==
28
:
num_block
=
9
num_grow_ch
=
16
elif
num_layer
==
52
:
num_block
=
21
num_grow_ch
=
16
else
:
raise
ValueError
(
f
'Only supported (16, 28, 52) layers, but got
{
num_layer
}
.'
)
self
.
dense_block1
=
DenseBlocks
(
num_block
=
num_block
,
num_feat
=
64
,
num_grow_ch
=
num_grow_ch
,
adapt_official_weights
=
adapt_official_weights
)
# T = 7
self
.
dense_block2
=
DenseBlocksTemporalReduce
(
64
+
num_grow_ch
*
num_block
,
num_grow_ch
,
adapt_official_weights
=
adapt_official_weights
)
# T = 1
channels
=
64
+
num_grow_ch
*
num_block
+
num_grow_ch
*
3
self
.
bn3d2
=
nn
.
BatchNorm3d
(
channels
,
eps
=
eps
,
momentum
=
momentum
)
self
.
conv3d2
=
nn
.
Conv3d
(
channels
,
256
,
(
1
,
3
,
3
),
stride
=
(
1
,
1
,
1
),
padding
=
(
0
,
1
,
1
),
bias
=
True
)
self
.
conv3d_r1
=
nn
.
Conv3d
(
256
,
256
,
(
1
,
1
,
1
),
stride
=
(
1
,
1
,
1
),
padding
=
(
0
,
0
,
0
),
bias
=
True
)
self
.
conv3d_r2
=
nn
.
Conv3d
(
256
,
3
*
(
scale
**
2
),
(
1
,
1
,
1
),
stride
=
(
1
,
1
,
1
),
padding
=
(
0
,
0
,
0
),
bias
=
True
)
self
.
conv3d_f1
=
nn
.
Conv3d
(
256
,
512
,
(
1
,
1
,
1
),
stride
=
(
1
,
1
,
1
),
padding
=
(
0
,
0
,
0
),
bias
=
True
)
self
.
conv3d_f2
=
nn
.
Conv3d
(
512
,
1
*
5
*
5
*
(
scale
**
2
),
(
1
,
1
,
1
),
stride
=
(
1
,
1
,
1
),
padding
=
(
0
,
0
,
0
),
bias
=
True
)
def
forward
(
self
,
x
):
"""
Args:
x (Tensor): Input with shape (b, 7, c, h, w)
Returns:
Tensor: Output with shape (b, c, h * scale, w * scale)
"""
num_batches
,
num_imgs
,
_
,
h
,
w
=
x
.
size
()
x
=
x
.
permute
(
0
,
2
,
1
,
3
,
4
)
# (b, c, 7, h, w) for Conv3D
x_center
=
x
[:,
:,
num_imgs
//
2
,
:,
:]
x
=
self
.
conv3d1
(
x
)
x
=
self
.
dense_block1
(
x
)
x
=
self
.
dense_block2
(
x
)
x
=
F
.
relu
(
self
.
bn3d2
(
x
),
inplace
=
True
)
x
=
F
.
relu
(
self
.
conv3d2
(
x
),
inplace
=
True
)
# residual image
res
=
self
.
conv3d_r2
(
F
.
relu
(
self
.
conv3d_r1
(
x
),
inplace
=
True
))
# filter
filter_
=
self
.
conv3d_f2
(
F
.
relu
(
self
.
conv3d_f1
(
x
),
inplace
=
True
))
filter_
=
F
.
softmax
(
filter_
.
view
(
num_batches
,
25
,
self
.
scale
**
2
,
h
,
w
),
dim
=
1
)
# dynamic filter
out
=
self
.
dynamic_filter
(
x_center
,
filter_
)
out
+=
res
.
squeeze_
(
2
)
out
=
F
.
pixel_shuffle
(
out
,
self
.
scale
)
return
out
BasicSR/basicsr/archs/ecbsr_arch.py
0 → 100644
View file @
e2696ece
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
basicsr.utils.registry
import
ARCH_REGISTRY
class
SeqConv3x3
(
nn
.
Module
):
"""The re-parameterizable block used in the ECBSR architecture.
``Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices``
Reference: https://github.com/xindongzhang/ECBSR
Args:
seq_type (str): Sequence type, option: conv1x1-conv3x3 | conv1x1-sobelx | conv1x1-sobely | conv1x1-laplacian.
in_channels (int): Channel number of input.
out_channels (int): Channel number of output.
depth_multiplier (int): Width multiplier in the expand-and-squeeze conv. Default: 1.
"""
def
__init__
(
self
,
seq_type
,
in_channels
,
out_channels
,
depth_multiplier
=
1
):
super
(
SeqConv3x3
,
self
).
__init__
()
self
.
seq_type
=
seq_type
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
if
self
.
seq_type
==
'conv1x1-conv3x3'
:
self
.
mid_planes
=
int
(
out_channels
*
depth_multiplier
)
conv0
=
torch
.
nn
.
Conv2d
(
self
.
in_channels
,
self
.
mid_planes
,
kernel_size
=
1
,
padding
=
0
)
self
.
k0
=
conv0
.
weight
self
.
b0
=
conv0
.
bias
conv1
=
torch
.
nn
.
Conv2d
(
self
.
mid_planes
,
self
.
out_channels
,
kernel_size
=
3
)
self
.
k1
=
conv1
.
weight
self
.
b1
=
conv1
.
bias
elif
self
.
seq_type
==
'conv1x1-sobelx'
:
conv0
=
torch
.
nn
.
Conv2d
(
self
.
in_channels
,
self
.
out_channels
,
kernel_size
=
1
,
padding
=
0
)
self
.
k0
=
conv0
.
weight
self
.
b0
=
conv0
.
bias
# init scale and bias
scale
=
torch
.
randn
(
size
=
(
self
.
out_channels
,
1
,
1
,
1
))
*
1e-3
self
.
scale
=
nn
.
Parameter
(
scale
)
bias
=
torch
.
randn
(
self
.
out_channels
)
*
1e-3
bias
=
torch
.
reshape
(
bias
,
(
self
.
out_channels
,
))
self
.
bias
=
nn
.
Parameter
(
bias
)
# init mask
self
.
mask
=
torch
.
zeros
((
self
.
out_channels
,
1
,
3
,
3
),
dtype
=
torch
.
float32
)
for
i
in
range
(
self
.
out_channels
):
self
.
mask
[
i
,
0
,
0
,
0
]
=
1.0
self
.
mask
[
i
,
0
,
1
,
0
]
=
2.0
self
.
mask
[
i
,
0
,
2
,
0
]
=
1.0
self
.
mask
[
i
,
0
,
0
,
2
]
=
-
1.0
self
.
mask
[
i
,
0
,
1
,
2
]
=
-
2.0
self
.
mask
[
i
,
0
,
2
,
2
]
=
-
1.0
self
.
mask
=
nn
.
Parameter
(
data
=
self
.
mask
,
requires_grad
=
False
)
elif
self
.
seq_type
==
'conv1x1-sobely'
:
conv0
=
torch
.
nn
.
Conv2d
(
self
.
in_channels
,
self
.
out_channels
,
kernel_size
=
1
,
padding
=
0
)
self
.
k0
=
conv0
.
weight
self
.
b0
=
conv0
.
bias
# init scale and bias
scale
=
torch
.
randn
(
size
=
(
self
.
out_channels
,
1
,
1
,
1
))
*
1e-3
self
.
scale
=
nn
.
Parameter
(
torch
.
FloatTensor
(
scale
))
bias
=
torch
.
randn
(
self
.
out_channels
)
*
1e-3
bias
=
torch
.
reshape
(
bias
,
(
self
.
out_channels
,
))
self
.
bias
=
nn
.
Parameter
(
torch
.
FloatTensor
(
bias
))
# init mask
self
.
mask
=
torch
.
zeros
((
self
.
out_channels
,
1
,
3
,
3
),
dtype
=
torch
.
float32
)
for
i
in
range
(
self
.
out_channels
):
self
.
mask
[
i
,
0
,
0
,
0
]
=
1.0
self
.
mask
[
i
,
0
,
0
,
1
]
=
2.0
self
.
mask
[
i
,
0
,
0
,
2
]
=
1.0
self
.
mask
[
i
,
0
,
2
,
0
]
=
-
1.0
self
.
mask
[
i
,
0
,
2
,
1
]
=
-
2.0
self
.
mask
[
i
,
0
,
2
,
2
]
=
-
1.0
self
.
mask
=
nn
.
Parameter
(
data
=
self
.
mask
,
requires_grad
=
False
)
elif
self
.
seq_type
==
'conv1x1-laplacian'
:
conv0
=
torch
.
nn
.
Conv2d
(
self
.
in_channels
,
self
.
out_channels
,
kernel_size
=
1
,
padding
=
0
)
self
.
k0
=
conv0
.
weight
self
.
b0
=
conv0
.
bias
# init scale and bias
scale
=
torch
.
randn
(
size
=
(
self
.
out_channels
,
1
,
1
,
1
))
*
1e-3
self
.
scale
=
nn
.
Parameter
(
torch
.
FloatTensor
(
scale
))
bias
=
torch
.
randn
(
self
.
out_channels
)
*
1e-3
bias
=
torch
.
reshape
(
bias
,
(
self
.
out_channels
,
))
self
.
bias
=
nn
.
Parameter
(
torch
.
FloatTensor
(
bias
))
# init mask
self
.
mask
=
torch
.
zeros
((
self
.
out_channels
,
1
,
3
,
3
),
dtype
=
torch
.
float32
)
for
i
in
range
(
self
.
out_channels
):
self
.
mask
[
i
,
0
,
0
,
1
]
=
1.0
self
.
mask
[
i
,
0
,
1
,
0
]
=
1.0
self
.
mask
[
i
,
0
,
1
,
2
]
=
1.0
self
.
mask
[
i
,
0
,
2
,
1
]
=
1.0
self
.
mask
[
i
,
0
,
1
,
1
]
=
-
4.0
self
.
mask
=
nn
.
Parameter
(
data
=
self
.
mask
,
requires_grad
=
False
)
else
:
raise
ValueError
(
'The type of seqconv is not supported!'
)
def
forward
(
self
,
x
):
if
self
.
seq_type
==
'conv1x1-conv3x3'
:
# conv-1x1
y0
=
F
.
conv2d
(
input
=
x
,
weight
=
self
.
k0
,
bias
=
self
.
b0
,
stride
=
1
)
# explicitly padding with bias
y0
=
F
.
pad
(
y0
,
(
1
,
1
,
1
,
1
),
'constant'
,
0
)
b0_pad
=
self
.
b0
.
view
(
1
,
-
1
,
1
,
1
)
y0
[:,
:,
0
:
1
,
:]
=
b0_pad
y0
[:,
:,
-
1
:,
:]
=
b0_pad
y0
[:,
:,
:,
0
:
1
]
=
b0_pad
y0
[:,
:,
:,
-
1
:]
=
b0_pad
# conv-3x3
y1
=
F
.
conv2d
(
input
=
y0
,
weight
=
self
.
k1
,
bias
=
self
.
b1
,
stride
=
1
)
else
:
y0
=
F
.
conv2d
(
input
=
x
,
weight
=
self
.
k0
,
bias
=
self
.
b0
,
stride
=
1
)
# explicitly padding with bias
y0
=
F
.
pad
(
y0
,
(
1
,
1
,
1
,
1
),
'constant'
,
0
)
b0_pad
=
self
.
b0
.
view
(
1
,
-
1
,
1
,
1
)
y0
[:,
:,
0
:
1
,
:]
=
b0_pad
y0
[:,
:,
-
1
:,
:]
=
b0_pad
y0
[:,
:,
:,
0
:
1
]
=
b0_pad
y0
[:,
:,
:,
-
1
:]
=
b0_pad
# conv-3x3
y1
=
F
.
conv2d
(
input
=
y0
,
weight
=
self
.
scale
*
self
.
mask
,
bias
=
self
.
bias
,
stride
=
1
,
groups
=
self
.
out_channels
)
return
y1
def
rep_params
(
self
):
device
=
self
.
k0
.
get_device
()
if
device
<
0
:
device
=
None
if
self
.
seq_type
==
'conv1x1-conv3x3'
:
# re-param conv kernel
rep_weight
=
F
.
conv2d
(
input
=
self
.
k1
,
weight
=
self
.
k0
.
permute
(
1
,
0
,
2
,
3
))
# re-param conv bias
rep_bias
=
torch
.
ones
(
1
,
self
.
mid_planes
,
3
,
3
,
device
=
device
)
*
self
.
b0
.
view
(
1
,
-
1
,
1
,
1
)
rep_bias
=
F
.
conv2d
(
input
=
rep_bias
,
weight
=
self
.
k1
).
view
(
-
1
,
)
+
self
.
b1
else
:
tmp
=
self
.
scale
*
self
.
mask
k1
=
torch
.
zeros
((
self
.
out_channels
,
self
.
out_channels
,
3
,
3
),
device
=
device
)
for
i
in
range
(
self
.
out_channels
):
k1
[
i
,
i
,
:,
:]
=
tmp
[
i
,
0
,
:,
:]
b1
=
self
.
bias
# re-param conv kernel
rep_weight
=
F
.
conv2d
(
input
=
k1
,
weight
=
self
.
k0
.
permute
(
1
,
0
,
2
,
3
))
# re-param conv bias
rep_bias
=
torch
.
ones
(
1
,
self
.
out_channels
,
3
,
3
,
device
=
device
)
*
self
.
b0
.
view
(
1
,
-
1
,
1
,
1
)
rep_bias
=
F
.
conv2d
(
input
=
rep_bias
,
weight
=
k1
).
view
(
-
1
,
)
+
b1
return
rep_weight
,
rep_bias
class
ECB
(
nn
.
Module
):
"""The ECB block used in the ECBSR architecture.
Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
Ref git repo: https://github.com/xindongzhang/ECBSR
Args:
in_channels (int): Channel number of input.
out_channels (int): Channel number of output.
depth_multiplier (int): Width multiplier in the expand-and-squeeze conv. Default: 1.
act_type (str): Activation type. Option: prelu | relu | rrelu | softplus | linear. Default: prelu.
with_idt (bool): Whether to use identity connection. Default: False.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
depth_multiplier
,
act_type
=
'prelu'
,
with_idt
=
False
):
super
(
ECB
,
self
).
__init__
()
self
.
depth_multiplier
=
depth_multiplier
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
act_type
=
act_type
if
with_idt
and
(
self
.
in_channels
==
self
.
out_channels
):
self
.
with_idt
=
True
else
:
self
.
with_idt
=
False
self
.
conv3x3
=
torch
.
nn
.
Conv2d
(
self
.
in_channels
,
self
.
out_channels
,
kernel_size
=
3
,
padding
=
1
)
self
.
conv1x1_3x3
=
SeqConv3x3
(
'conv1x1-conv3x3'
,
self
.
in_channels
,
self
.
out_channels
,
self
.
depth_multiplier
)
self
.
conv1x1_sbx
=
SeqConv3x3
(
'conv1x1-sobelx'
,
self
.
in_channels
,
self
.
out_channels
)
self
.
conv1x1_sby
=
SeqConv3x3
(
'conv1x1-sobely'
,
self
.
in_channels
,
self
.
out_channels
)
self
.
conv1x1_lpl
=
SeqConv3x3
(
'conv1x1-laplacian'
,
self
.
in_channels
,
self
.
out_channels
)
if
self
.
act_type
==
'prelu'
:
self
.
act
=
nn
.
PReLU
(
num_parameters
=
self
.
out_channels
)
elif
self
.
act_type
==
'relu'
:
self
.
act
=
nn
.
ReLU
(
inplace
=
True
)
elif
self
.
act_type
==
'rrelu'
:
self
.
act
=
nn
.
RReLU
(
lower
=-
0.05
,
upper
=
0.05
)
elif
self
.
act_type
==
'softplus'
:
self
.
act
=
nn
.
Softplus
()
elif
self
.
act_type
==
'linear'
:
pass
else
:
raise
ValueError
(
'The type of activation if not support!'
)
def
forward
(
self
,
x
):
if
self
.
training
:
y
=
self
.
conv3x3
(
x
)
+
self
.
conv1x1_3x3
(
x
)
+
self
.
conv1x1_sbx
(
x
)
+
self
.
conv1x1_sby
(
x
)
+
self
.
conv1x1_lpl
(
x
)
if
self
.
with_idt
:
y
+=
x
else
:
rep_weight
,
rep_bias
=
self
.
rep_params
()
y
=
F
.
conv2d
(
input
=
x
,
weight
=
rep_weight
,
bias
=
rep_bias
,
stride
=
1
,
padding
=
1
)
if
self
.
act_type
!=
'linear'
:
y
=
self
.
act
(
y
)
return
y
def
rep_params
(
self
):
weight0
,
bias0
=
self
.
conv3x3
.
weight
,
self
.
conv3x3
.
bias
weight1
,
bias1
=
self
.
conv1x1_3x3
.
rep_params
()
weight2
,
bias2
=
self
.
conv1x1_sbx
.
rep_params
()
weight3
,
bias3
=
self
.
conv1x1_sby
.
rep_params
()
weight4
,
bias4
=
self
.
conv1x1_lpl
.
rep_params
()
rep_weight
,
rep_bias
=
(
weight0
+
weight1
+
weight2
+
weight3
+
weight4
),
(
bias0
+
bias1
+
bias2
+
bias3
+
bias4
)
if
self
.
with_idt
:
device
=
rep_weight
.
get_device
()
if
device
<
0
:
device
=
None
weight_idt
=
torch
.
zeros
(
self
.
out_channels
,
self
.
out_channels
,
3
,
3
,
device
=
device
)
for
i
in
range
(
self
.
out_channels
):
weight_idt
[
i
,
i
,
1
,
1
]
=
1.0
bias_idt
=
0.0
rep_weight
,
rep_bias
=
rep_weight
+
weight_idt
,
rep_bias
+
bias_idt
return
rep_weight
,
rep_bias
@
ARCH_REGISTRY
.
register
()
class
ECBSR
(
nn
.
Module
):
"""ECBSR architecture.
Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
Ref git repo: https://github.com/xindongzhang/ECBSR
Args:
num_in_ch (int): Channel number of inputs.
num_out_ch (int): Channel number of outputs.
num_block (int): Block number in the trunk network.
num_channel (int): Channel number.
with_idt (bool): Whether use identity in convolution layers.
act_type (str): Activation type.
scale (int): Upsampling factor.
"""
def
__init__
(
self
,
num_in_ch
,
num_out_ch
,
num_block
,
num_channel
,
with_idt
,
act_type
,
scale
):
super
(
ECBSR
,
self
).
__init__
()
self
.
num_in_ch
=
num_in_ch
self
.
scale
=
scale
backbone
=
[]
backbone
+=
[
ECB
(
num_in_ch
,
num_channel
,
depth_multiplier
=
2.0
,
act_type
=
act_type
,
with_idt
=
with_idt
)]
for
_
in
range
(
num_block
):
backbone
+=
[
ECB
(
num_channel
,
num_channel
,
depth_multiplier
=
2.0
,
act_type
=
act_type
,
with_idt
=
with_idt
)]
backbone
+=
[
ECB
(
num_channel
,
num_out_ch
*
scale
*
scale
,
depth_multiplier
=
2.0
,
act_type
=
'linear'
,
with_idt
=
with_idt
)
]
self
.
backbone
=
nn
.
Sequential
(
*
backbone
)
self
.
upsampler
=
nn
.
PixelShuffle
(
scale
)
def
forward
(
self
,
x
):
if
self
.
num_in_ch
>
1
:
shortcut
=
torch
.
repeat_interleave
(
x
,
self
.
scale
*
self
.
scale
,
dim
=
1
)
else
:
shortcut
=
x
# will repeat the input in the channel dimension (repeat scale * scale times)
y
=
self
.
backbone
(
x
)
+
shortcut
y
=
self
.
upsampler
(
y
)
return
y
BasicSR/basicsr/archs/edsr_arch.py
0 → 100644
View file @
e2696ece
import
torch
from
torch
import
nn
as
nn
from
basicsr.archs.arch_util
import
ResidualBlockNoBN
,
Upsample
,
make_layer
from
basicsr.utils.registry
import
ARCH_REGISTRY
@
ARCH_REGISTRY
.
register
()
class
EDSR
(
nn
.
Module
):
"""EDSR network structure.
Paper: Enhanced Deep Residual Networks for Single Image Super-Resolution.
Ref git repo: https://github.com/thstkdgus35/EDSR-PyTorch
Args:
num_in_ch (int): Channel number of inputs.
num_out_ch (int): Channel number of outputs.
num_feat (int): Channel number of intermediate features.
Default: 64.
num_block (int): Block number in the trunk network. Default: 16.
upscale (int): Upsampling factor. Support 2^n and 3.
Default: 4.
res_scale (float): Used to scale the residual in residual block.
Default: 1.
img_range (float): Image range. Default: 255.
rgb_mean (tuple[float]): Image mean in RGB orders.
Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
"""
def
__init__
(
self
,
num_in_ch
,
num_out_ch
,
num_feat
=
64
,
num_block
=
16
,
upscale
=
4
,
res_scale
=
1
,
img_range
=
255.
,
rgb_mean
=
(
0.4488
,
0.4371
,
0.4040
)):
super
(
EDSR
,
self
).
__init__
()
self
.
img_range
=
img_range
self
.
mean
=
torch
.
Tensor
(
rgb_mean
).
view
(
1
,
3
,
1
,
1
)
self
.
conv_first
=
nn
.
Conv2d
(
num_in_ch
,
num_feat
,
3
,
1
,
1
)
self
.
body
=
make_layer
(
ResidualBlockNoBN
,
num_block
,
num_feat
=
num_feat
,
res_scale
=
res_scale
,
pytorch_init
=
True
)
self
.
conv_after_body
=
nn
.
Conv2d
(
num_feat
,
num_feat
,
3
,
1
,
1
)
self
.
upsample
=
Upsample
(
upscale
,
num_feat
)
self
.
conv_last
=
nn
.
Conv2d
(
num_feat
,
num_out_ch
,
3
,
1
,
1
)
def
forward
(
self
,
x
):
self
.
mean
=
self
.
mean
.
type_as
(
x
)
x
=
(
x
-
self
.
mean
)
*
self
.
img_range
x
=
self
.
conv_first
(
x
)
res
=
self
.
conv_after_body
(
self
.
body
(
x
))
res
+=
x
x
=
self
.
conv_last
(
self
.
upsample
(
res
))
x
=
x
/
self
.
img_range
+
self
.
mean
return
x
BasicSR/basicsr/archs/edvr_arch.py
0 → 100644
View file @
e2696ece
import
torch
from
torch
import
nn
as
nn
from
torch.nn
import
functional
as
F
from
basicsr.utils.registry
import
ARCH_REGISTRY
from
.arch_util
import
DCNv2Pack
,
ResidualBlockNoBN
,
make_layer
class
PCDAlignment
(
nn
.
Module
):
"""Alignment module using Pyramid, Cascading and Deformable convolution
(PCD). It is used in EDVR.
``Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks``
Args:
num_feat (int): Channel number of middle features. Default: 64.
deformable_groups (int): Deformable groups. Defaults: 8.
"""
def
__init__
(
self
,
num_feat
=
64
,
deformable_groups
=
8
):
super
(
PCDAlignment
,
self
).
__init__
()
# Pyramid has three levels:
# L3: level 3, 1/4 spatial size
# L2: level 2, 1/2 spatial size
# L1: level 1, original spatial size
self
.
offset_conv1
=
nn
.
ModuleDict
()
self
.
offset_conv2
=
nn
.
ModuleDict
()
self
.
offset_conv3
=
nn
.
ModuleDict
()
self
.
dcn_pack
=
nn
.
ModuleDict
()
self
.
feat_conv
=
nn
.
ModuleDict
()
# Pyramids
for
i
in
range
(
3
,
0
,
-
1
):
level
=
f
'l
{
i
}
'
self
.
offset_conv1
[
level
]
=
nn
.
Conv2d
(
num_feat
*
2
,
num_feat
,
3
,
1
,
1
)
if
i
==
3
:
self
.
offset_conv2
[
level
]
=
nn
.
Conv2d
(
num_feat
,
num_feat
,
3
,
1
,
1
)
else
:
self
.
offset_conv2
[
level
]
=
nn
.
Conv2d
(
num_feat
*
2
,
num_feat
,
3
,
1
,
1
)
self
.
offset_conv3
[
level
]
=
nn
.
Conv2d
(
num_feat
,
num_feat
,
3
,
1
,
1
)
self
.
dcn_pack
[
level
]
=
DCNv2Pack
(
num_feat
,
num_feat
,
3
,
padding
=
1
,
deformable_groups
=
deformable_groups
)
if
i
<
3
:
self
.
feat_conv
[
level
]
=
nn
.
Conv2d
(
num_feat
*
2
,
num_feat
,
3
,
1
,
1
)
# Cascading dcn
self
.
cas_offset_conv1
=
nn
.
Conv2d
(
num_feat
*
2
,
num_feat
,
3
,
1
,
1
)
self
.
cas_offset_conv2
=
nn
.
Conv2d
(
num_feat
,
num_feat
,
3
,
1
,
1
)
self
.
cas_dcnpack
=
DCNv2Pack
(
num_feat
,
num_feat
,
3
,
padding
=
1
,
deformable_groups
=
deformable_groups
)
self
.
upsample
=
nn
.
Upsample
(
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
False
)
self
.
lrelu
=
nn
.
LeakyReLU
(
negative_slope
=
0.1
,
inplace
=
True
)
def
forward
(
self
,
nbr_feat_l
,
ref_feat_l
):
"""Align neighboring frame features to the reference frame features.
Args:
nbr_feat_l (list[Tensor]): Neighboring feature list. It
contains three pyramid levels (L1, L2, L3),
each with shape (b, c, h, w).
ref_feat_l (list[Tensor]): Reference feature list. It
contains three pyramid levels (L1, L2, L3),
each with shape (b, c, h, w).
Returns:
Tensor: Aligned features.
"""
# Pyramids
upsampled_offset
,
upsampled_feat
=
None
,
None
for
i
in
range
(
3
,
0
,
-
1
):
level
=
f
'l
{
i
}
'
offset
=
torch
.
cat
([
nbr_feat_l
[
i
-
1
],
ref_feat_l
[
i
-
1
]],
dim
=
1
)
offset
=
self
.
lrelu
(
self
.
offset_conv1
[
level
](
offset
))
if
i
==
3
:
offset
=
self
.
lrelu
(
self
.
offset_conv2
[
level
](
offset
))
else
:
offset
=
self
.
lrelu
(
self
.
offset_conv2
[
level
](
torch
.
cat
([
offset
,
upsampled_offset
],
dim
=
1
)))
offset
=
self
.
lrelu
(
self
.
offset_conv3
[
level
](
offset
))
feat
=
self
.
dcn_pack
[
level
](
nbr_feat_l
[
i
-
1
],
offset
)
if
i
<
3
:
feat
=
self
.
feat_conv
[
level
](
torch
.
cat
([
feat
,
upsampled_feat
],
dim
=
1
))
if
i
>
1
:
feat
=
self
.
lrelu
(
feat
)
if
i
>
1
:
# upsample offset and features
# x2: when we upsample the offset, we should also enlarge
# the magnitude.
upsampled_offset
=
self
.
upsample
(
offset
)
*
2
upsampled_feat
=
self
.
upsample
(
feat
)
# Cascading
offset
=
torch
.
cat
([
feat
,
ref_feat_l
[
0
]],
dim
=
1
)
offset
=
self
.
lrelu
(
self
.
cas_offset_conv2
(
self
.
lrelu
(
self
.
cas_offset_conv1
(
offset
))))
feat
=
self
.
lrelu
(
self
.
cas_dcnpack
(
feat
,
offset
))
return
feat
class
TSAFusion
(
nn
.
Module
):
"""Temporal Spatial Attention (TSA) fusion module.
Temporal: Calculate the correlation between center frame and
neighboring frames;
Spatial: It has 3 pyramid levels, the attention is similar to SFT.
(SFT: Recovering realistic texture in image super-resolution by deep
spatial feature transform.)
Args:
num_feat (int): Channel number of middle features. Default: 64.
num_frame (int): Number of frames. Default: 5.
center_frame_idx (int): The index of center frame. Default: 2.
"""
def
__init__
(
self
,
num_feat
=
64
,
num_frame
=
5
,
center_frame_idx
=
2
):
super
(
TSAFusion
,
self
).
__init__
()
self
.
center_frame_idx
=
center_frame_idx
# temporal attention (before fusion conv)
self
.
temporal_attn1
=
nn
.
Conv2d
(
num_feat
,
num_feat
,
3
,
1
,
1
)
self
.
temporal_attn2
=
nn
.
Conv2d
(
num_feat
,
num_feat
,
3
,
1
,
1
)
self
.
feat_fusion
=
nn
.
Conv2d
(
num_frame
*
num_feat
,
num_feat
,
1
,
1
)
# spatial attention (after fusion conv)
self
.
max_pool
=
nn
.
MaxPool2d
(
3
,
stride
=
2
,
padding
=
1
)
self
.
avg_pool
=
nn
.
AvgPool2d
(
3
,
stride
=
2
,
padding
=
1
)
self
.
spatial_attn1
=
nn
.
Conv2d
(
num_frame
*
num_feat
,
num_feat
,
1
)
self
.
spatial_attn2
=
nn
.
Conv2d
(
num_feat
*
2
,
num_feat
,
1
)
self
.
spatial_attn3
=
nn
.
Conv2d
(
num_feat
,
num_feat
,
3
,
1
,
1
)
self
.
spatial_attn4
=
nn
.
Conv2d
(
num_feat
,
num_feat
,
1
)
self
.
spatial_attn5
=
nn
.
Conv2d
(
num_feat
,
num_feat
,
3
,
1
,
1
)
self
.
spatial_attn_l1
=
nn
.
Conv2d
(
num_feat
,
num_feat
,
1
)
self
.
spatial_attn_l2
=
nn
.
Conv2d
(
num_feat
*
2
,
num_feat
,
3
,
1
,
1
)
self
.
spatial_attn_l3
=
nn
.
Conv2d
(
num_feat
,
num_feat
,
3
,
1
,
1
)
self
.
spatial_attn_add1
=
nn
.
Conv2d
(
num_feat
,
num_feat
,
1
)
self
.
spatial_attn_add2
=
nn
.
Conv2d
(
num_feat
,
num_feat
,
1
)
self
.
lrelu
=
nn
.
LeakyReLU
(
negative_slope
=
0.1
,
inplace
=
True
)
self
.
upsample
=
nn
.
Upsample
(
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
False
)
def
forward
(
self
,
aligned_feat
):
"""
Args:
aligned_feat (Tensor): Aligned features with shape (b, t, c, h, w).
Returns:
Tensor: Features after TSA with the shape (b, c, h, w).
"""
b
,
t
,
c
,
h
,
w
=
aligned_feat
.
size
()
# temporal attention
embedding_ref
=
self
.
temporal_attn1
(
aligned_feat
[:,
self
.
center_frame_idx
,
:,
:,
:].
clone
())
embedding
=
self
.
temporal_attn2
(
aligned_feat
.
view
(
-
1
,
c
,
h
,
w
))
embedding
=
embedding
.
view
(
b
,
t
,
-
1
,
h
,
w
)
# (b, t, c, h, w)
corr_l
=
[]
# correlation list
for
i
in
range
(
t
):
emb_neighbor
=
embedding
[:,
i
,
:,
:,
:]
corr
=
torch
.
sum
(
emb_neighbor
*
embedding_ref
,
1
)
# (b, h, w)
corr_l
.
append
(
corr
.
unsqueeze
(
1
))
# (b, 1, h, w)
corr_prob
=
torch
.
sigmoid
(
torch
.
cat
(
corr_l
,
dim
=
1
))
# (b, t, h, w)
corr_prob
=
corr_prob
.
unsqueeze
(
2
).
expand
(
b
,
t
,
c
,
h
,
w
)
corr_prob
=
corr_prob
.
contiguous
().
view
(
b
,
-
1
,
h
,
w
)
# (b, t*c, h, w)
aligned_feat
=
aligned_feat
.
view
(
b
,
-
1
,
h
,
w
)
*
corr_prob
# fusion
feat
=
self
.
lrelu
(
self
.
feat_fusion
(
aligned_feat
))
# spatial attention
attn
=
self
.
lrelu
(
self
.
spatial_attn1
(
aligned_feat
))
attn_max
=
self
.
max_pool
(
attn
)
attn_avg
=
self
.
avg_pool
(
attn
)
attn
=
self
.
lrelu
(
self
.
spatial_attn2
(
torch
.
cat
([
attn_max
,
attn_avg
],
dim
=
1
)))
# pyramid levels
attn_level
=
self
.
lrelu
(
self
.
spatial_attn_l1
(
attn
))
attn_max
=
self
.
max_pool
(
attn_level
)
attn_avg
=
self
.
avg_pool
(
attn_level
)
attn_level
=
self
.
lrelu
(
self
.
spatial_attn_l2
(
torch
.
cat
([
attn_max
,
attn_avg
],
dim
=
1
)))
attn_level
=
self
.
lrelu
(
self
.
spatial_attn_l3
(
attn_level
))
attn_level
=
self
.
upsample
(
attn_level
)
attn
=
self
.
lrelu
(
self
.
spatial_attn3
(
attn
))
+
attn_level
attn
=
self
.
lrelu
(
self
.
spatial_attn4
(
attn
))
attn
=
self
.
upsample
(
attn
)
attn
=
self
.
spatial_attn5
(
attn
)
attn_add
=
self
.
spatial_attn_add2
(
self
.
lrelu
(
self
.
spatial_attn_add1
(
attn
)))
attn
=
torch
.
sigmoid
(
attn
)
# after initialization, * 2 makes (attn * 2) to be close to 1.
feat
=
feat
*
attn
*
2
+
attn_add
return
feat
class
PredeblurModule
(
nn
.
Module
):
"""Pre-dublur module.
Args:
num_in_ch (int): Channel number of input image. Default: 3.
num_feat (int): Channel number of intermediate features. Default: 64.
hr_in (bool): Whether the input has high resolution. Default: False.
"""
def
__init__
(
self
,
num_in_ch
=
3
,
num_feat
=
64
,
hr_in
=
False
):
super
(
PredeblurModule
,
self
).
__init__
()
self
.
hr_in
=
hr_in
self
.
conv_first
=
nn
.
Conv2d
(
num_in_ch
,
num_feat
,
3
,
1
,
1
)
if
self
.
hr_in
:
# downsample x4 by stride conv
self
.
stride_conv_hr1
=
nn
.
Conv2d
(
num_feat
,
num_feat
,
3
,
2
,
1
)
self
.
stride_conv_hr2
=
nn
.
Conv2d
(
num_feat
,
num_feat
,
3
,
2
,
1
)
# generate feature pyramid
self
.
stride_conv_l2
=
nn
.
Conv2d
(
num_feat
,
num_feat
,
3
,
2
,
1
)
self
.
stride_conv_l3
=
nn
.
Conv2d
(
num_feat
,
num_feat
,
3
,
2
,
1
)
self
.
resblock_l3
=
ResidualBlockNoBN
(
num_feat
=
num_feat
)
self
.
resblock_l2_1
=
ResidualBlockNoBN
(
num_feat
=
num_feat
)
self
.
resblock_l2_2
=
ResidualBlockNoBN
(
num_feat
=
num_feat
)
self
.
resblock_l1
=
nn
.
ModuleList
([
ResidualBlockNoBN
(
num_feat
=
num_feat
)
for
i
in
range
(
5
)])
self
.
upsample
=
nn
.
Upsample
(
scale_factor
=
2
,
mode
=
'bilinear'
,
align_corners
=
False
)
self
.
lrelu
=
nn
.
LeakyReLU
(
negative_slope
=
0.1
,
inplace
=
True
)
def
forward
(
self
,
x
):
feat_l1
=
self
.
lrelu
(
self
.
conv_first
(
x
))
if
self
.
hr_in
:
feat_l1
=
self
.
lrelu
(
self
.
stride_conv_hr1
(
feat_l1
))
feat_l1
=
self
.
lrelu
(
self
.
stride_conv_hr2
(
feat_l1
))
# generate feature pyramid
feat_l2
=
self
.
lrelu
(
self
.
stride_conv_l2
(
feat_l1
))
feat_l3
=
self
.
lrelu
(
self
.
stride_conv_l3
(
feat_l2
))
feat_l3
=
self
.
upsample
(
self
.
resblock_l3
(
feat_l3
))
feat_l2
=
self
.
resblock_l2_1
(
feat_l2
)
+
feat_l3
feat_l2
=
self
.
upsample
(
self
.
resblock_l2_2
(
feat_l2
))
for
i
in
range
(
2
):
feat_l1
=
self
.
resblock_l1
[
i
](
feat_l1
)
feat_l1
=
feat_l1
+
feat_l2
for
i
in
range
(
2
,
5
):
feat_l1
=
self
.
resblock_l1
[
i
](
feat_l1
)
return
feat_l1
@
ARCH_REGISTRY
.
register
()
class
EDVR
(
nn
.
Module
):
"""EDVR network structure for video super-resolution.
Now only support X4 upsampling factor.
``Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks``
Args:
num_in_ch (int): Channel number of input image. Default: 3.
num_out_ch (int): Channel number of output image. Default: 3.
num_feat (int): Channel number of intermediate features. Default: 64.
num_frame (int): Number of input frames. Default: 5.
deformable_groups (int): Deformable groups. Defaults: 8.
num_extract_block (int): Number of blocks for feature extraction.
Default: 5.
num_reconstruct_block (int): Number of blocks for reconstruction.
Default: 10.
center_frame_idx (int): The index of center frame. Frame counting from
0. Default: Middle of input frames.
hr_in (bool): Whether the input has high resolution. Default: False.
with_predeblur (bool): Whether has predeblur module.
Default: False.
with_tsa (bool): Whether has TSA module. Default: True.
"""
def
__init__
(
self
,
num_in_ch
=
3
,
num_out_ch
=
3
,
num_feat
=
64
,
num_frame
=
5
,
deformable_groups
=
8
,
num_extract_block
=
5
,
num_reconstruct_block
=
10
,
center_frame_idx
=
None
,
hr_in
=
False
,
with_predeblur
=
False
,
with_tsa
=
True
):
super
(
EDVR
,
self
).
__init__
()
if
center_frame_idx
is
None
:
self
.
center_frame_idx
=
num_frame
//
2
else
:
self
.
center_frame_idx
=
center_frame_idx
self
.
hr_in
=
hr_in
self
.
with_predeblur
=
with_predeblur
self
.
with_tsa
=
with_tsa
# extract features for each frame
if
self
.
with_predeblur
:
self
.
predeblur
=
PredeblurModule
(
num_feat
=
num_feat
,
hr_in
=
self
.
hr_in
)
self
.
conv_1x1
=
nn
.
Conv2d
(
num_feat
,
num_feat
,
1
,
1
)
else
:
self
.
conv_first
=
nn
.
Conv2d
(
num_in_ch
,
num_feat
,
3
,
1
,
1
)
# extract pyramid features
self
.
feature_extraction
=
make_layer
(
ResidualBlockNoBN
,
num_extract_block
,
num_feat
=
num_feat
)
self
.
conv_l2_1
=
nn
.
Conv2d
(
num_feat
,
num_feat
,
3
,
2
,
1
)
self
.
conv_l2_2
=
nn
.
Conv2d
(
num_feat
,
num_feat
,
3
,
1
,
1
)
self
.
conv_l3_1
=
nn
.
Conv2d
(
num_feat
,
num_feat
,
3
,
2
,
1
)
self
.
conv_l3_2
=
nn
.
Conv2d
(
num_feat
,
num_feat
,
3
,
1
,
1
)
# pcd and tsa module
self
.
pcd_align
=
PCDAlignment
(
num_feat
=
num_feat
,
deformable_groups
=
deformable_groups
)
if
self
.
with_tsa
:
self
.
fusion
=
TSAFusion
(
num_feat
=
num_feat
,
num_frame
=
num_frame
,
center_frame_idx
=
self
.
center_frame_idx
)
else
:
self
.
fusion
=
nn
.
Conv2d
(
num_frame
*
num_feat
,
num_feat
,
1
,
1
)
# reconstruction
self
.
reconstruction
=
make_layer
(
ResidualBlockNoBN
,
num_reconstruct_block
,
num_feat
=
num_feat
)
# upsample
self
.
upconv1
=
nn
.
Conv2d
(
num_feat
,
num_feat
*
4
,
3
,
1
,
1
)
self
.
upconv2
=
nn
.
Conv2d
(
num_feat
,
64
*
4
,
3
,
1
,
1
)
self
.
pixel_shuffle
=
nn
.
PixelShuffle
(
2
)
self
.
conv_hr
=
nn
.
Conv2d
(
64
,
64
,
3
,
1
,
1
)
self
.
conv_last
=
nn
.
Conv2d
(
64
,
3
,
3
,
1
,
1
)
# activation function
self
.
lrelu
=
nn
.
LeakyReLU
(
negative_slope
=
0.1
,
inplace
=
True
)
def
forward
(
self
,
x
):
b
,
t
,
c
,
h
,
w
=
x
.
size
()
if
self
.
hr_in
:
assert
h
%
16
==
0
and
w
%
16
==
0
,
(
'The height and width must be multiple of 16.'
)
else
:
assert
h
%
4
==
0
and
w
%
4
==
0
,
(
'The height and width must be multiple of 4.'
)
x_center
=
x
[:,
self
.
center_frame_idx
,
:,
:,
:].
contiguous
()
# extract features for each frame
# L1
if
self
.
with_predeblur
:
feat_l1
=
self
.
conv_1x1
(
self
.
predeblur
(
x
.
view
(
-
1
,
c
,
h
,
w
)))
if
self
.
hr_in
:
h
,
w
=
h
//
4
,
w
//
4
else
:
feat_l1
=
self
.
lrelu
(
self
.
conv_first
(
x
.
view
(
-
1
,
c
,
h
,
w
)))
feat_l1
=
self
.
feature_extraction
(
feat_l1
)
# L2
feat_l2
=
self
.
lrelu
(
self
.
conv_l2_1
(
feat_l1
))
feat_l2
=
self
.
lrelu
(
self
.
conv_l2_2
(
feat_l2
))
# L3
feat_l3
=
self
.
lrelu
(
self
.
conv_l3_1
(
feat_l2
))
feat_l3
=
self
.
lrelu
(
self
.
conv_l3_2
(
feat_l3
))
feat_l1
=
feat_l1
.
view
(
b
,
t
,
-
1
,
h
,
w
)
feat_l2
=
feat_l2
.
view
(
b
,
t
,
-
1
,
h
//
2
,
w
//
2
)
feat_l3
=
feat_l3
.
view
(
b
,
t
,
-
1
,
h
//
4
,
w
//
4
)
# PCD alignment
ref_feat_l
=
[
# reference feature list
feat_l1
[:,
self
.
center_frame_idx
,
:,
:,
:].
clone
(),
feat_l2
[:,
self
.
center_frame_idx
,
:,
:,
:].
clone
(),
feat_l3
[:,
self
.
center_frame_idx
,
:,
:,
:].
clone
()
]
aligned_feat
=
[]
for
i
in
range
(
t
):
nbr_feat_l
=
[
# neighboring feature list
feat_l1
[:,
i
,
:,
:,
:].
clone
(),
feat_l2
[:,
i
,
:,
:,
:].
clone
(),
feat_l3
[:,
i
,
:,
:,
:].
clone
()
]
aligned_feat
.
append
(
self
.
pcd_align
(
nbr_feat_l
,
ref_feat_l
))
aligned_feat
=
torch
.
stack
(
aligned_feat
,
dim
=
1
)
# (b, t, c, h, w)
if
not
self
.
with_tsa
:
aligned_feat
=
aligned_feat
.
view
(
b
,
-
1
,
h
,
w
)
feat
=
self
.
fusion
(
aligned_feat
)
out
=
self
.
reconstruction
(
feat
)
out
=
self
.
lrelu
(
self
.
pixel_shuffle
(
self
.
upconv1
(
out
)))
out
=
self
.
lrelu
(
self
.
pixel_shuffle
(
self
.
upconv2
(
out
)))
out
=
self
.
lrelu
(
self
.
conv_hr
(
out
))
out
=
self
.
conv_last
(
out
)
if
self
.
hr_in
:
base
=
x_center
else
:
base
=
F
.
interpolate
(
x_center
,
scale_factor
=
4
,
mode
=
'bilinear'
,
align_corners
=
False
)
out
+=
base
return
out
BasicSR/basicsr/archs/hifacegan_arch.py
0 → 100644
View file @
e2696ece
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
basicsr.utils.registry
import
ARCH_REGISTRY
from
.hifacegan_util
import
BaseNetwork
,
LIPEncoder
,
SPADEResnetBlock
,
get_nonspade_norm_layer
class
SPADEGenerator
(
BaseNetwork
):
"""Generator with SPADEResBlock"""
def
__init__
(
self
,
num_in_ch
=
3
,
num_feat
=
64
,
use_vae
=
False
,
z_dim
=
256
,
crop_size
=
512
,
norm_g
=
'spectralspadesyncbatch3x3'
,
is_train
=
True
,
init_train_phase
=
3
):
# progressive training disabled
super
().
__init__
()
self
.
nf
=
num_feat
self
.
input_nc
=
num_in_ch
self
.
is_train
=
is_train
self
.
train_phase
=
init_train_phase
self
.
scale_ratio
=
5
# hardcoded now
self
.
sw
=
crop_size
//
(
2
**
self
.
scale_ratio
)
self
.
sh
=
self
.
sw
# 20210519: By default use square image, aspect_ratio = 1.0
if
use_vae
:
# In case of VAE, we will sample from random z vector
self
.
fc
=
nn
.
Linear
(
z_dim
,
16
*
self
.
nf
*
self
.
sw
*
self
.
sh
)
else
:
# Otherwise, we make the network deterministic by starting with
# downsampled segmentation map instead of random z
self
.
fc
=
nn
.
Conv2d
(
num_in_ch
,
16
*
self
.
nf
,
3
,
padding
=
1
)
self
.
head_0
=
SPADEResnetBlock
(
16
*
self
.
nf
,
16
*
self
.
nf
,
norm_g
)
self
.
g_middle_0
=
SPADEResnetBlock
(
16
*
self
.
nf
,
16
*
self
.
nf
,
norm_g
)
self
.
g_middle_1
=
SPADEResnetBlock
(
16
*
self
.
nf
,
16
*
self
.
nf
,
norm_g
)
self
.
ups
=
nn
.
ModuleList
([
SPADEResnetBlock
(
16
*
self
.
nf
,
8
*
self
.
nf
,
norm_g
),
SPADEResnetBlock
(
8
*
self
.
nf
,
4
*
self
.
nf
,
norm_g
),
SPADEResnetBlock
(
4
*
self
.
nf
,
2
*
self
.
nf
,
norm_g
),
SPADEResnetBlock
(
2
*
self
.
nf
,
1
*
self
.
nf
,
norm_g
)
])
self
.
to_rgbs
=
nn
.
ModuleList
([
nn
.
Conv2d
(
8
*
self
.
nf
,
3
,
3
,
padding
=
1
),
nn
.
Conv2d
(
4
*
self
.
nf
,
3
,
3
,
padding
=
1
),
nn
.
Conv2d
(
2
*
self
.
nf
,
3
,
3
,
padding
=
1
),
nn
.
Conv2d
(
1
*
self
.
nf
,
3
,
3
,
padding
=
1
)
])
self
.
up
=
nn
.
Upsample
(
scale_factor
=
2
)
def
encode
(
self
,
input_tensor
):
"""
Encode input_tensor into feature maps, can be overridden in derived classes
Default: nearest downsampling of 2**5 = 32 times
"""
h
,
w
=
input_tensor
.
size
()[
-
2
:]
sh
,
sw
=
h
//
2
**
self
.
scale_ratio
,
w
//
2
**
self
.
scale_ratio
x
=
F
.
interpolate
(
input_tensor
,
size
=
(
sh
,
sw
))
return
self
.
fc
(
x
)
def
forward
(
self
,
x
):
# In oroginal SPADE, seg means a segmentation map, but here we use x instead.
seg
=
x
x
=
self
.
encode
(
x
)
x
=
self
.
head_0
(
x
,
seg
)
x
=
self
.
up
(
x
)
x
=
self
.
g_middle_0
(
x
,
seg
)
x
=
self
.
g_middle_1
(
x
,
seg
)
if
self
.
is_train
:
phase
=
self
.
train_phase
+
1
else
:
phase
=
len
(
self
.
to_rgbs
)
for
i
in
range
(
phase
):
x
=
self
.
up
(
x
)
x
=
self
.
ups
[
i
](
x
,
seg
)
x
=
self
.
to_rgbs
[
phase
-
1
](
F
.
leaky_relu
(
x
,
2e-1
))
x
=
torch
.
tanh
(
x
)
return
x
def
mixed_guidance_forward
(
self
,
input_x
,
seg
=
None
,
n
=
0
,
mode
=
'progressive'
):
"""
A helper class for subspace visualization. Input and seg are different images.
For the first n levels (including encoder) we use input, for the rest we use seg.
If mode = 'progressive', the output's like: AAABBB
If mode = 'one_plug', the output's like: AAABAA
If mode = 'one_ablate', the output's like: BBBABB
"""
if
seg
is
None
:
return
self
.
forward
(
input_x
)
if
self
.
is_train
:
phase
=
self
.
train_phase
+
1
else
:
phase
=
len
(
self
.
to_rgbs
)
if
mode
==
'progressive'
:
n
=
max
(
min
(
n
,
4
+
phase
),
0
)
guide_list
=
[
input_x
]
*
n
+
[
seg
]
*
(
4
+
phase
-
n
)
elif
mode
==
'one_plug'
:
n
=
max
(
min
(
n
,
4
+
phase
-
1
),
0
)
guide_list
=
[
seg
]
*
(
4
+
phase
)
guide_list
[
n
]
=
input_x
elif
mode
==
'one_ablate'
:
if
n
>
3
+
phase
:
return
self
.
forward
(
input_x
)
guide_list
=
[
input_x
]
*
(
4
+
phase
)
guide_list
[
n
]
=
seg
x
=
self
.
encode
(
guide_list
[
0
])
x
=
self
.
head_0
(
x
,
guide_list
[
1
])
x
=
self
.
up
(
x
)
x
=
self
.
g_middle_0
(
x
,
guide_list
[
2
])
x
=
self
.
g_middle_1
(
x
,
guide_list
[
3
])
for
i
in
range
(
phase
):
x
=
self
.
up
(
x
)
x
=
self
.
ups
[
i
](
x
,
guide_list
[
4
+
i
])
x
=
self
.
to_rgbs
[
phase
-
1
](
F
.
leaky_relu
(
x
,
2e-1
))
x
=
torch
.
tanh
(
x
)
return
x
@
ARCH_REGISTRY
.
register
()
class
HiFaceGAN
(
SPADEGenerator
):
"""
HiFaceGAN: SPADEGenerator with a learnable feature encoder
Current encoder design: LIPEncoder
"""
def
__init__
(
self
,
num_in_ch
=
3
,
num_feat
=
64
,
use_vae
=
False
,
z_dim
=
256
,
crop_size
=
512
,
norm_g
=
'spectralspadesyncbatch3x3'
,
is_train
=
True
,
init_train_phase
=
3
):
super
().
__init__
(
num_in_ch
,
num_feat
,
use_vae
,
z_dim
,
crop_size
,
norm_g
,
is_train
,
init_train_phase
)
self
.
lip_encoder
=
LIPEncoder
(
num_in_ch
,
num_feat
,
self
.
sw
,
self
.
sh
,
self
.
scale_ratio
)
def
encode
(
self
,
input_tensor
):
return
self
.
lip_encoder
(
input_tensor
)
@
ARCH_REGISTRY
.
register
()
class
HiFaceGANDiscriminator
(
BaseNetwork
):
"""
Inspired by pix2pixHD multiscale discriminator.
Args:
num_in_ch (int): Channel number of inputs. Default: 3.
num_out_ch (int): Channel number of outputs. Default: 3.
conditional_d (bool): Whether use conditional discriminator.
Default: True.
num_d (int): Number of Multiscale discriminators. Default: 3.
n_layers_d (int): Number of downsample layers in each D. Default: 4.
num_feat (int): Channel number of base intermediate features.
Default: 64.
norm_d (str): String to determine normalization layers in D.
Choices: [spectral][instance/batch/syncbatch]
Default: 'spectralinstance'.
keep_features (bool): Keep intermediate features for matching loss, etc.
Default: True.
"""
def
__init__
(
self
,
num_in_ch
=
3
,
num_out_ch
=
3
,
conditional_d
=
True
,
num_d
=
2
,
n_layers_d
=
4
,
num_feat
=
64
,
norm_d
=
'spectralinstance'
,
keep_features
=
True
):
super
().
__init__
()
self
.
num_d
=
num_d
input_nc
=
num_in_ch
if
conditional_d
:
input_nc
+=
num_out_ch
for
i
in
range
(
num_d
):
subnet_d
=
NLayerDiscriminator
(
input_nc
,
n_layers_d
,
num_feat
,
norm_d
,
keep_features
)
self
.
add_module
(
f
'discriminator_
{
i
}
'
,
subnet_d
)
def
downsample
(
self
,
x
):
return
F
.
avg_pool2d
(
x
,
kernel_size
=
3
,
stride
=
2
,
padding
=
[
1
,
1
],
count_include_pad
=
False
)
# Returns list of lists of discriminator outputs.
# The final result is of size opt.num_d x opt.n_layers_D
def
forward
(
self
,
x
):
result
=
[]
for
_
,
_net_d
in
self
.
named_children
():
out
=
_net_d
(
x
)
result
.
append
(
out
)
x
=
self
.
downsample
(
x
)
return
result
class
NLayerDiscriminator
(
BaseNetwork
):
"""Defines the PatchGAN discriminator with the specified arguments."""
def
__init__
(
self
,
input_nc
,
n_layers_d
,
num_feat
,
norm_d
,
keep_features
):
super
().
__init__
()
kw
=
4
padw
=
int
(
np
.
ceil
((
kw
-
1.0
)
/
2
))
nf
=
num_feat
self
.
keep_features
=
keep_features
norm_layer
=
get_nonspade_norm_layer
(
norm_d
)
sequence
=
[[
nn
.
Conv2d
(
input_nc
,
nf
,
kernel_size
=
kw
,
stride
=
2
,
padding
=
padw
),
nn
.
LeakyReLU
(
0.2
,
False
)]]
for
n
in
range
(
1
,
n_layers_d
):
nf_prev
=
nf
nf
=
min
(
nf
*
2
,
512
)
stride
=
1
if
n
==
n_layers_d
-
1
else
2
sequence
+=
[[
norm_layer
(
nn
.
Conv2d
(
nf_prev
,
nf
,
kernel_size
=
kw
,
stride
=
stride
,
padding
=
padw
)),
nn
.
LeakyReLU
(
0.2
,
False
)
]]
sequence
+=
[[
nn
.
Conv2d
(
nf
,
1
,
kernel_size
=
kw
,
stride
=
1
,
padding
=
padw
)]]
# We divide the layers into groups to extract intermediate layer outputs
for
n
in
range
(
len
(
sequence
)):
self
.
add_module
(
'model'
+
str
(
n
),
nn
.
Sequential
(
*
sequence
[
n
]))
def
forward
(
self
,
x
):
results
=
[
x
]
for
submodel
in
self
.
children
():
intermediate_output
=
submodel
(
results
[
-
1
])
results
.
append
(
intermediate_output
)
if
self
.
keep_features
:
return
results
[
1
:]
else
:
return
results
[
-
1
]
BasicSR/basicsr/archs/hifacegan_util.py
0 → 100644
View file @
e2696ece
import
re
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.nn
import
init
# Warning: spectral norm could be buggy
# under eval mode and multi-GPU inference
# A workaround is sticking to single-GPU inference and train mode
from
torch.nn.utils
import
spectral_norm
class
SPADE
(
nn
.
Module
):
def
__init__
(
self
,
config_text
,
norm_nc
,
label_nc
):
super
().
__init__
()
assert
config_text
.
startswith
(
'spade'
)
parsed
=
re
.
search
(
'spade(
\\
D+)(
\\
d)x
\\
d'
,
config_text
)
param_free_norm_type
=
str
(
parsed
.
group
(
1
))
ks
=
int
(
parsed
.
group
(
2
))
if
param_free_norm_type
==
'instance'
:
self
.
param_free_norm
=
nn
.
InstanceNorm2d
(
norm_nc
)
elif
param_free_norm_type
==
'syncbatch'
:
print
(
'SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead'
)
self
.
param_free_norm
=
nn
.
InstanceNorm2d
(
norm_nc
)
elif
param_free_norm_type
==
'batch'
:
self
.
param_free_norm
=
nn
.
BatchNorm2d
(
norm_nc
,
affine
=
False
)
else
:
raise
ValueError
(
f
'
{
param_free_norm_type
}
is not a recognized param-free norm type in SPADE'
)
# The dimension of the intermediate embedding space. Yes, hardcoded.
nhidden
=
128
if
norm_nc
>
128
else
norm_nc
pw
=
ks
//
2
self
.
mlp_shared
=
nn
.
Sequential
(
nn
.
Conv2d
(
label_nc
,
nhidden
,
kernel_size
=
ks
,
padding
=
pw
),
nn
.
ReLU
())
self
.
mlp_gamma
=
nn
.
Conv2d
(
nhidden
,
norm_nc
,
kernel_size
=
ks
,
padding
=
pw
,
bias
=
False
)
self
.
mlp_beta
=
nn
.
Conv2d
(
nhidden
,
norm_nc
,
kernel_size
=
ks
,
padding
=
pw
,
bias
=
False
)
def
forward
(
self
,
x
,
segmap
):
# Part 1. generate parameter-free normalized activations
normalized
=
self
.
param_free_norm
(
x
)
# Part 2. produce scaling and bias conditioned on semantic map
segmap
=
F
.
interpolate
(
segmap
,
size
=
x
.
size
()[
2
:],
mode
=
'nearest'
)
actv
=
self
.
mlp_shared
(
segmap
)
gamma
=
self
.
mlp_gamma
(
actv
)
beta
=
self
.
mlp_beta
(
actv
)
# apply scale and bias
out
=
normalized
*
gamma
+
beta
return
out
class
SPADEResnetBlock
(
nn
.
Module
):
"""
ResNet block that uses SPADE. It differs from the ResNet block of pix2pixHD in that
it takes in the segmentation map as input, learns the skip connection if necessary,
and applies normalization first and then convolution.
This architecture seemed like a standard architecture for unconditional or
class-conditional GAN architecture using residual block.
The code was inspired from https://github.com/LMescheder/GAN_stability.
"""
def
__init__
(
self
,
fin
,
fout
,
norm_g
=
'spectralspadesyncbatch3x3'
,
semantic_nc
=
3
):
super
().
__init__
()
# Attributes
self
.
learned_shortcut
=
(
fin
!=
fout
)
fmiddle
=
min
(
fin
,
fout
)
# create conv layers
self
.
conv_0
=
nn
.
Conv2d
(
fin
,
fmiddle
,
kernel_size
=
3
,
padding
=
1
)
self
.
conv_1
=
nn
.
Conv2d
(
fmiddle
,
fout
,
kernel_size
=
3
,
padding
=
1
)
if
self
.
learned_shortcut
:
self
.
conv_s
=
nn
.
Conv2d
(
fin
,
fout
,
kernel_size
=
1
,
bias
=
False
)
# apply spectral norm if specified
if
'spectral'
in
norm_g
:
self
.
conv_0
=
spectral_norm
(
self
.
conv_0
)
self
.
conv_1
=
spectral_norm
(
self
.
conv_1
)
if
self
.
learned_shortcut
:
self
.
conv_s
=
spectral_norm
(
self
.
conv_s
)
# define normalization layers
spade_config_str
=
norm_g
.
replace
(
'spectral'
,
''
)
self
.
norm_0
=
SPADE
(
spade_config_str
,
fin
,
semantic_nc
)
self
.
norm_1
=
SPADE
(
spade_config_str
,
fmiddle
,
semantic_nc
)
if
self
.
learned_shortcut
:
self
.
norm_s
=
SPADE
(
spade_config_str
,
fin
,
semantic_nc
)
# note the resnet block with SPADE also takes in |seg|,
# the semantic segmentation map as input
def
forward
(
self
,
x
,
seg
):
x_s
=
self
.
shortcut
(
x
,
seg
)
dx
=
self
.
conv_0
(
self
.
act
(
self
.
norm_0
(
x
,
seg
)))
dx
=
self
.
conv_1
(
self
.
act
(
self
.
norm_1
(
dx
,
seg
)))
out
=
x_s
+
dx
return
out
def
shortcut
(
self
,
x
,
seg
):
if
self
.
learned_shortcut
:
x_s
=
self
.
conv_s
(
self
.
norm_s
(
x
,
seg
))
else
:
x_s
=
x
return
x_s
def
act
(
self
,
x
):
return
F
.
leaky_relu
(
x
,
2e-1
)
class
BaseNetwork
(
nn
.
Module
):
""" A basis for hifacegan archs with custom initialization """
def
init_weights
(
self
,
init_type
=
'normal'
,
gain
=
0.02
):
def
init_func
(
m
):
classname
=
m
.
__class__
.
__name__
if
classname
.
find
(
'BatchNorm2d'
)
!=
-
1
:
if
hasattr
(
m
,
'weight'
)
and
m
.
weight
is
not
None
:
init
.
normal_
(
m
.
weight
.
data
,
1.0
,
gain
)
if
hasattr
(
m
,
'bias'
)
and
m
.
bias
is
not
None
:
init
.
constant_
(
m
.
bias
.
data
,
0.0
)
elif
hasattr
(
m
,
'weight'
)
and
(
classname
.
find
(
'Conv'
)
!=
-
1
or
classname
.
find
(
'Linear'
)
!=
-
1
):
if
init_type
==
'normal'
:
init
.
normal_
(
m
.
weight
.
data
,
0.0
,
gain
)
elif
init_type
==
'xavier'
:
init
.
xavier_normal_
(
m
.
weight
.
data
,
gain
=
gain
)
elif
init_type
==
'xavier_uniform'
:
init
.
xavier_uniform_
(
m
.
weight
.
data
,
gain
=
1.0
)
elif
init_type
==
'kaiming'
:
init
.
kaiming_normal_
(
m
.
weight
.
data
,
a
=
0
,
mode
=
'fan_in'
)
elif
init_type
==
'orthogonal'
:
init
.
orthogonal_
(
m
.
weight
.
data
,
gain
=
gain
)
elif
init_type
==
'none'
:
# uses pytorch's default init method
m
.
reset_parameters
()
else
:
raise
NotImplementedError
(
f
'initialization method [
{
init_type
}
] is not implemented'
)
if
hasattr
(
m
,
'bias'
)
and
m
.
bias
is
not
None
:
init
.
constant_
(
m
.
bias
.
data
,
0.0
)
self
.
apply
(
init_func
)
# propagate to children
for
m
in
self
.
children
():
if
hasattr
(
m
,
'init_weights'
):
m
.
init_weights
(
init_type
,
gain
)
def
forward
(
self
,
x
):
pass
def
lip2d
(
x
,
logit
,
kernel
=
3
,
stride
=
2
,
padding
=
1
):
weight
=
logit
.
exp
()
return
F
.
avg_pool2d
(
x
*
weight
,
kernel
,
stride
,
padding
)
/
F
.
avg_pool2d
(
weight
,
kernel
,
stride
,
padding
)
class
SoftGate
(
nn
.
Module
):
COEFF
=
12.0
def
forward
(
self
,
x
):
return
torch
.
sigmoid
(
x
).
mul
(
self
.
COEFF
)
class
SimplifiedLIP
(
nn
.
Module
):
def
__init__
(
self
,
channels
):
super
(
SimplifiedLIP
,
self
).
__init__
()
self
.
logit
=
nn
.
Sequential
(
nn
.
Conv2d
(
channels
,
channels
,
3
,
padding
=
1
,
bias
=
False
),
nn
.
InstanceNorm2d
(
channels
,
affine
=
True
),
SoftGate
())
def
init_layer
(
self
):
self
.
logit
[
0
].
weight
.
data
.
fill_
(
0.0
)
def
forward
(
self
,
x
):
frac
=
lip2d
(
x
,
self
.
logit
(
x
))
return
frac
class
LIPEncoder
(
BaseNetwork
):
"""Local Importance-based Pooling (Ziteng Gao et.al.,ICCV 2019)"""
def
__init__
(
self
,
input_nc
,
ngf
,
sw
,
sh
,
n_2xdown
,
norm_layer
=
nn
.
InstanceNorm2d
):
super
().
__init__
()
self
.
sw
=
sw
self
.
sh
=
sh
self
.
max_ratio
=
16
# 20200310: Several Convolution (stride 1) + LIP blocks, 4 fold
kw
=
3
pw
=
(
kw
-
1
)
//
2
model
=
[
nn
.
Conv2d
(
input_nc
,
ngf
,
kw
,
stride
=
1
,
padding
=
pw
,
bias
=
False
),
norm_layer
(
ngf
),
nn
.
ReLU
(),
]
cur_ratio
=
1
for
i
in
range
(
n_2xdown
):
next_ratio
=
min
(
cur_ratio
*
2
,
self
.
max_ratio
)
model
+=
[
SimplifiedLIP
(
ngf
*
cur_ratio
),
nn
.
Conv2d
(
ngf
*
cur_ratio
,
ngf
*
next_ratio
,
kw
,
stride
=
1
,
padding
=
pw
),
norm_layer
(
ngf
*
next_ratio
),
]
cur_ratio
=
next_ratio
if
i
<
n_2xdown
-
1
:
model
+=
[
nn
.
ReLU
(
inplace
=
True
)]
self
.
model
=
nn
.
Sequential
(
*
model
)
def
forward
(
self
,
x
):
return
self
.
model
(
x
)
def
get_nonspade_norm_layer
(
norm_type
=
'instance'
):
# helper function to get # output channels of the previous layer
def
get_out_channel
(
layer
):
if
hasattr
(
layer
,
'out_channels'
):
return
getattr
(
layer
,
'out_channels'
)
return
layer
.
weight
.
size
(
0
)
# this function will be returned
def
add_norm_layer
(
layer
):
nonlocal
norm_type
if
norm_type
.
startswith
(
'spectral'
):
layer
=
spectral_norm
(
layer
)
subnorm_type
=
norm_type
[
len
(
'spectral'
):]
if
subnorm_type
==
'none'
or
len
(
subnorm_type
)
==
0
:
return
layer
# remove bias in the previous layer, which is meaningless
# since it has no effect after normalization
if
getattr
(
layer
,
'bias'
,
None
)
is
not
None
:
delattr
(
layer
,
'bias'
)
layer
.
register_parameter
(
'bias'
,
None
)
if
subnorm_type
==
'batch'
:
norm_layer
=
nn
.
BatchNorm2d
(
get_out_channel
(
layer
),
affine
=
True
)
elif
subnorm_type
==
'sync_batch'
:
print
(
'SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead'
)
# norm_layer = SynchronizedBatchNorm2d(
# get_out_channel(layer), affine=True)
norm_layer
=
nn
.
InstanceNorm2d
(
get_out_channel
(
layer
),
affine
=
False
)
elif
subnorm_type
==
'instance'
:
norm_layer
=
nn
.
InstanceNorm2d
(
get_out_channel
(
layer
),
affine
=
False
)
else
:
raise
ValueError
(
f
'normalization layer
{
subnorm_type
}
is not recognized'
)
return
nn
.
Sequential
(
layer
,
norm_layer
)
print
(
'This is a legacy from nvlabs/SPADE, and will be removed in future versions.'
)
return
add_norm_layer
BasicSR/basicsr/archs/inception.py
0 → 100644
View file @
e2696ece
# Modified from https://github.com/mseitzer/pytorch-fid/blob/master/pytorch_fid/inception.py # noqa: E501
# For FID metric
import
os
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.utils.model_zoo
import
load_url
from
torchvision
import
models
# Inception weights ported to Pytorch from
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
FID_WEIGHTS_URL
=
'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
# noqa: E501
LOCAL_FID_WEIGHTS
=
'experiments/pretrained_models/pt_inception-2015-12-05-6726825d.pth'
# noqa: E501
class
InceptionV3
(
nn
.
Module
):
"""Pretrained InceptionV3 network returning feature maps"""
# Index of default block of inception to return,
# corresponds to output of final average pooling
DEFAULT_BLOCK_INDEX
=
3
# Maps feature dimensionality to their output blocks indices
BLOCK_INDEX_BY_DIM
=
{
64
:
0
,
# First max pooling features
192
:
1
,
# Second max pooling features
768
:
2
,
# Pre-aux classifier features
2048
:
3
# Final average pooling features
}
def
__init__
(
self
,
output_blocks
=
(
DEFAULT_BLOCK_INDEX
),
resize_input
=
True
,
normalize_input
=
True
,
requires_grad
=
False
,
use_fid_inception
=
True
):
"""Build pretrained InceptionV3.
Args:
output_blocks (list[int]): Indices of blocks to return features of.
Possible values are:
- 0: corresponds to output of first max pooling
- 1: corresponds to output of second max pooling
- 2: corresponds to output which is fed to aux classifier
- 3: corresponds to output of final average pooling
resize_input (bool): If true, bilinearly resizes input to width and
height 299 before feeding input to model. As the network
without fully connected layers is fully convolutional, it
should be able to handle inputs of arbitrary size, so resizing
might not be strictly needed. Default: True.
normalize_input (bool): If true, scales the input from range (0, 1)
to the range the pretrained Inception network expects,
namely (-1, 1). Default: True.
requires_grad (bool): If true, parameters of the model require
gradients. Possibly useful for finetuning the network.
Default: False.
use_fid_inception (bool): If true, uses the pretrained Inception
model used in Tensorflow's FID implementation.
If false, uses the pretrained Inception model available in
torchvision. The FID Inception model has different weights
and a slightly different structure from torchvision's
Inception model. If you want to compute FID scores, you are
strongly advised to set this parameter to true to get
comparable results. Default: True.
"""
super
(
InceptionV3
,
self
).
__init__
()
self
.
resize_input
=
resize_input
self
.
normalize_input
=
normalize_input
self
.
output_blocks
=
sorted
(
output_blocks
)
self
.
last_needed_block
=
max
(
output_blocks
)
assert
self
.
last_needed_block
<=
3
,
(
'Last possible output block index is 3'
)
self
.
blocks
=
nn
.
ModuleList
()
if
use_fid_inception
:
inception
=
fid_inception_v3
()
else
:
try
:
inception
=
models
.
inception_v3
(
pretrained
=
True
,
init_weights
=
False
)
except
TypeError
:
# pytorch < 1.5 does not have init_weights for inception_v3
inception
=
models
.
inception_v3
(
pretrained
=
True
)
# Block 0: input to maxpool1
block0
=
[
inception
.
Conv2d_1a_3x3
,
inception
.
Conv2d_2a_3x3
,
inception
.
Conv2d_2b_3x3
,
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
)
]
self
.
blocks
.
append
(
nn
.
Sequential
(
*
block0
))
# Block 1: maxpool1 to maxpool2
if
self
.
last_needed_block
>=
1
:
block1
=
[
inception
.
Conv2d_3b_1x1
,
inception
.
Conv2d_4a_3x3
,
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
)]
self
.
blocks
.
append
(
nn
.
Sequential
(
*
block1
))
# Block 2: maxpool2 to aux classifier
if
self
.
last_needed_block
>=
2
:
block2
=
[
inception
.
Mixed_5b
,
inception
.
Mixed_5c
,
inception
.
Mixed_5d
,
inception
.
Mixed_6a
,
inception
.
Mixed_6b
,
inception
.
Mixed_6c
,
inception
.
Mixed_6d
,
inception
.
Mixed_6e
,
]
self
.
blocks
.
append
(
nn
.
Sequential
(
*
block2
))
# Block 3: aux classifier to final avgpool
if
self
.
last_needed_block
>=
3
:
block3
=
[
inception
.
Mixed_7a
,
inception
.
Mixed_7b
,
inception
.
Mixed_7c
,
nn
.
AdaptiveAvgPool2d
(
output_size
=
(
1
,
1
))
]
self
.
blocks
.
append
(
nn
.
Sequential
(
*
block3
))
for
param
in
self
.
parameters
():
param
.
requires_grad
=
requires_grad
def
forward
(
self
,
x
):
"""Get Inception feature maps.
Args:
x (Tensor): Input tensor of shape (b, 3, h, w).
Values are expected to be in range (-1, 1). You can also input
(0, 1) with setting normalize_input = True.
Returns:
list[Tensor]: Corresponding to the selected output block, sorted
ascending by index.
"""
output
=
[]
if
self
.
resize_input
:
x
=
F
.
interpolate
(
x
,
size
=
(
299
,
299
),
mode
=
'bilinear'
,
align_corners
=
False
)
if
self
.
normalize_input
:
x
=
2
*
x
-
1
# Scale from range (0, 1) to range (-1, 1)
for
idx
,
block
in
enumerate
(
self
.
blocks
):
x
=
block
(
x
)
if
idx
in
self
.
output_blocks
:
output
.
append
(
x
)
if
idx
==
self
.
last_needed_block
:
break
return
output
def
fid_inception_v3
():
"""Build pretrained Inception model for FID computation.
The Inception model for FID computation uses a different set of weights
and has a slightly different structure than torchvision's Inception.
This method first constructs torchvision's Inception and then patches the
necessary parts that are different in the FID Inception model.
"""
try
:
inception
=
models
.
inception_v3
(
num_classes
=
1008
,
aux_logits
=
False
,
pretrained
=
False
,
init_weights
=
False
)
except
TypeError
:
# pytorch < 1.5 does not have init_weights for inception_v3
inception
=
models
.
inception_v3
(
num_classes
=
1008
,
aux_logits
=
False
,
pretrained
=
False
)
inception
.
Mixed_5b
=
FIDInceptionA
(
192
,
pool_features
=
32
)
inception
.
Mixed_5c
=
FIDInceptionA
(
256
,
pool_features
=
64
)
inception
.
Mixed_5d
=
FIDInceptionA
(
288
,
pool_features
=
64
)
inception
.
Mixed_6b
=
FIDInceptionC
(
768
,
channels_7x7
=
128
)
inception
.
Mixed_6c
=
FIDInceptionC
(
768
,
channels_7x7
=
160
)
inception
.
Mixed_6d
=
FIDInceptionC
(
768
,
channels_7x7
=
160
)
inception
.
Mixed_6e
=
FIDInceptionC
(
768
,
channels_7x7
=
192
)
inception
.
Mixed_7b
=
FIDInceptionE_1
(
1280
)
inception
.
Mixed_7c
=
FIDInceptionE_2
(
2048
)
if
os
.
path
.
exists
(
LOCAL_FID_WEIGHTS
):
state_dict
=
torch
.
load
(
LOCAL_FID_WEIGHTS
,
map_location
=
lambda
storage
,
loc
:
storage
)
else
:
state_dict
=
load_url
(
FID_WEIGHTS_URL
,
progress
=
True
)
inception
.
load_state_dict
(
state_dict
)
return
inception
class
FIDInceptionA
(
models
.
inception
.
InceptionA
):
"""InceptionA block patched for FID computation"""
def
__init__
(
self
,
in_channels
,
pool_features
):
super
(
FIDInceptionA
,
self
).
__init__
(
in_channels
,
pool_features
)
def
forward
(
self
,
x
):
branch1x1
=
self
.
branch1x1
(
x
)
branch5x5
=
self
.
branch5x5_1
(
x
)
branch5x5
=
self
.
branch5x5_2
(
branch5x5
)
branch3x3dbl
=
self
.
branch3x3dbl_1
(
x
)
branch3x3dbl
=
self
.
branch3x3dbl_2
(
branch3x3dbl
)
branch3x3dbl
=
self
.
branch3x3dbl_3
(
branch3x3dbl
)
# Patch: Tensorflow's average pool does not use the padded zero's in
# its average calculation
branch_pool
=
F
.
avg_pool2d
(
x
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
count_include_pad
=
False
)
branch_pool
=
self
.
branch_pool
(
branch_pool
)
outputs
=
[
branch1x1
,
branch5x5
,
branch3x3dbl
,
branch_pool
]
return
torch
.
cat
(
outputs
,
1
)
class
FIDInceptionC
(
models
.
inception
.
InceptionC
):
"""InceptionC block patched for FID computation"""
def
__init__
(
self
,
in_channels
,
channels_7x7
):
super
(
FIDInceptionC
,
self
).
__init__
(
in_channels
,
channels_7x7
)
def
forward
(
self
,
x
):
branch1x1
=
self
.
branch1x1
(
x
)
branch7x7
=
self
.
branch7x7_1
(
x
)
branch7x7
=
self
.
branch7x7_2
(
branch7x7
)
branch7x7
=
self
.
branch7x7_3
(
branch7x7
)
branch7x7dbl
=
self
.
branch7x7dbl_1
(
x
)
branch7x7dbl
=
self
.
branch7x7dbl_2
(
branch7x7dbl
)
branch7x7dbl
=
self
.
branch7x7dbl_3
(
branch7x7dbl
)
branch7x7dbl
=
self
.
branch7x7dbl_4
(
branch7x7dbl
)
branch7x7dbl
=
self
.
branch7x7dbl_5
(
branch7x7dbl
)
# Patch: Tensorflow's average pool does not use the padded zero's in
# its average calculation
branch_pool
=
F
.
avg_pool2d
(
x
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
count_include_pad
=
False
)
branch_pool
=
self
.
branch_pool
(
branch_pool
)
outputs
=
[
branch1x1
,
branch7x7
,
branch7x7dbl
,
branch_pool
]
return
torch
.
cat
(
outputs
,
1
)
class
FIDInceptionE_1
(
models
.
inception
.
InceptionE
):
"""First InceptionE block patched for FID computation"""
def
__init__
(
self
,
in_channels
):
super
(
FIDInceptionE_1
,
self
).
__init__
(
in_channels
)
def
forward
(
self
,
x
):
branch1x1
=
self
.
branch1x1
(
x
)
branch3x3
=
self
.
branch3x3_1
(
x
)
branch3x3
=
[
self
.
branch3x3_2a
(
branch3x3
),
self
.
branch3x3_2b
(
branch3x3
),
]
branch3x3
=
torch
.
cat
(
branch3x3
,
1
)
branch3x3dbl
=
self
.
branch3x3dbl_1
(
x
)
branch3x3dbl
=
self
.
branch3x3dbl_2
(
branch3x3dbl
)
branch3x3dbl
=
[
self
.
branch3x3dbl_3a
(
branch3x3dbl
),
self
.
branch3x3dbl_3b
(
branch3x3dbl
),
]
branch3x3dbl
=
torch
.
cat
(
branch3x3dbl
,
1
)
# Patch: Tensorflow's average pool does not use the padded zero's in
# its average calculation
branch_pool
=
F
.
avg_pool2d
(
x
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
count_include_pad
=
False
)
branch_pool
=
self
.
branch_pool
(
branch_pool
)
outputs
=
[
branch1x1
,
branch3x3
,
branch3x3dbl
,
branch_pool
]
return
torch
.
cat
(
outputs
,
1
)
class
FIDInceptionE_2
(
models
.
inception
.
InceptionE
):
"""Second InceptionE block patched for FID computation"""
def
__init__
(
self
,
in_channels
):
super
(
FIDInceptionE_2
,
self
).
__init__
(
in_channels
)
def
forward
(
self
,
x
):
branch1x1
=
self
.
branch1x1
(
x
)
branch3x3
=
self
.
branch3x3_1
(
x
)
branch3x3
=
[
self
.
branch3x3_2a
(
branch3x3
),
self
.
branch3x3_2b
(
branch3x3
),
]
branch3x3
=
torch
.
cat
(
branch3x3
,
1
)
branch3x3dbl
=
self
.
branch3x3dbl_1
(
x
)
branch3x3dbl
=
self
.
branch3x3dbl_2
(
branch3x3dbl
)
branch3x3dbl
=
[
self
.
branch3x3dbl_3a
(
branch3x3dbl
),
self
.
branch3x3dbl_3b
(
branch3x3dbl
),
]
branch3x3dbl
=
torch
.
cat
(
branch3x3dbl
,
1
)
# Patch: The FID Inception model uses max pooling instead of average
# pooling. This is likely an error in this specific Inception
# implementation, as other Inception models use average pooling here
# (which matches the description in the paper).
branch_pool
=
F
.
max_pool2d
(
x
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
branch_pool
=
self
.
branch_pool
(
branch_pool
)
outputs
=
[
branch1x1
,
branch3x3
,
branch3x3dbl
,
branch_pool
]
return
torch
.
cat
(
outputs
,
1
)
BasicSR/basicsr/archs/rcan_arch.py
0 → 100644
View file @
e2696ece
This diff is collapsed.
Click to expand it.
BasicSR/basicsr/archs/ridnet_arch.py
0 → 100644
View file @
e2696ece
This diff is collapsed.
Click to expand it.
BasicSR/basicsr/archs/rrdbnet_arch.py
0 → 100644
View file @
e2696ece
This diff is collapsed.
Click to expand it.
BasicSR/basicsr/archs/spynet_arch.py
0 → 100644
View file @
e2696ece
This diff is collapsed.
Click to expand it.
BasicSR/basicsr/archs/srresnet_arch.py
0 → 100644
View file @
e2696ece
from
torch
import
nn
as
nn
from
torch.nn
import
functional
as
F
from
basicsr.utils.registry
import
ARCH_REGISTRY
from
.arch_util
import
ResidualBlockNoBN
,
default_init_weights
,
make_layer
@
ARCH_REGISTRY
.
register
()
class
MSRResNet
(
nn
.
Module
):
"""Modified SRResNet.
A compacted version modified from SRResNet in
"Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network"
It uses residual blocks without BN, similar to EDSR.
Currently, it supports x2, x3 and x4 upsampling scale factor.
Args:
num_in_ch (int): Channel number of inputs. Default: 3.
num_out_ch (int): Channel number of outputs. Default: 3.
num_feat (int): Channel number of intermediate features. Default: 64.
num_block (int): Block number in the body network. Default: 16.
upscale (int): Upsampling factor. Support x2, x3 and x4. Default: 4.
"""
def
__init__
(
self
,
num_in_ch
=
3
,
num_out_ch
=
3
,
num_feat
=
64
,
num_block
=
16
,
upscale
=
4
):
super
(
MSRResNet
,
self
).
__init__
()
self
.
upscale
=
upscale
self
.
conv_first
=
nn
.
Conv2d
(
num_in_ch
,
num_feat
,
3
,
1
,
1
)
self
.
body
=
make_layer
(
ResidualBlockNoBN
,
num_block
,
num_feat
=
num_feat
)
# upsampling
if
self
.
upscale
in
[
2
,
3
]:
self
.
upconv1
=
nn
.
Conv2d
(
num_feat
,
num_feat
*
self
.
upscale
*
self
.
upscale
,
3
,
1
,
1
)
self
.
pixel_shuffle
=
nn
.
PixelShuffle
(
self
.
upscale
)
elif
self
.
upscale
==
4
:
self
.
upconv1
=
nn
.
Conv2d
(
num_feat
,
num_feat
*
4
,
3
,
1
,
1
)
self
.
upconv2
=
nn
.
Conv2d
(
num_feat
,
num_feat
*
4
,
3
,
1
,
1
)
self
.
pixel_shuffle
=
nn
.
PixelShuffle
(
2
)
self
.
conv_hr
=
nn
.
Conv2d
(
num_feat
,
num_feat
,
3
,
1
,
1
)
self
.
conv_last
=
nn
.
Conv2d
(
num_feat
,
num_out_ch
,
3
,
1
,
1
)
# activation function
self
.
lrelu
=
nn
.
LeakyReLU
(
negative_slope
=
0.1
,
inplace
=
True
)
# initialization
default_init_weights
([
self
.
conv_first
,
self
.
upconv1
,
self
.
conv_hr
,
self
.
conv_last
],
0.1
)
if
self
.
upscale
==
4
:
default_init_weights
(
self
.
upconv2
,
0.1
)
def
forward
(
self
,
x
):
feat
=
self
.
lrelu
(
self
.
conv_first
(
x
))
out
=
self
.
body
(
feat
)
if
self
.
upscale
==
4
:
out
=
self
.
lrelu
(
self
.
pixel_shuffle
(
self
.
upconv1
(
out
)))
out
=
self
.
lrelu
(
self
.
pixel_shuffle
(
self
.
upconv2
(
out
)))
elif
self
.
upscale
in
[
2
,
3
]:
out
=
self
.
lrelu
(
self
.
pixel_shuffle
(
self
.
upconv1
(
out
)))
out
=
self
.
conv_last
(
self
.
lrelu
(
self
.
conv_hr
(
out
)))
base
=
F
.
interpolate
(
x
,
scale_factor
=
self
.
upscale
,
mode
=
'bilinear'
,
align_corners
=
False
)
out
+=
base
return
out
BasicSR/basicsr/archs/srvgg_arch.py
0 → 100644
View file @
e2696ece
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
6
…
42
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment