Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ControlNet_pytorch
Commits
e2696ece
Commit
e2696ece
authored
Nov 22, 2023
by
mashun1
Browse files
controlnet
parents
Pipeline
#643
canceled with stages
Changes
263
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2996 additions
and
0 deletions
+2996
-0
BasicSR/basicsr/metrics/psnr_ssim.py
BasicSR/basicsr/metrics/psnr_ssim.py
+231
-0
BasicSR/basicsr/metrics/test_metrics/test_psnr_ssim.py
BasicSR/basicsr/metrics/test_metrics/test_psnr_ssim.py
+52
-0
BasicSR/basicsr/models/__init__.py
BasicSR/basicsr/models/__init__.py
+29
-0
BasicSR/basicsr/models/base_model.py
BasicSR/basicsr/models/base_model.py
+392
-0
BasicSR/basicsr/models/edvr_model.py
BasicSR/basicsr/models/edvr_model.py
+62
-0
BasicSR/basicsr/models/esrgan_model.py
BasicSR/basicsr/models/esrgan_model.py
+83
-0
BasicSR/basicsr/models/hifacegan_model.py
BasicSR/basicsr/models/hifacegan_model.py
+288
-0
BasicSR/basicsr/models/lr_scheduler.py
BasicSR/basicsr/models/lr_scheduler.py
+96
-0
BasicSR/basicsr/models/realesrgan_model.py
BasicSR/basicsr/models/realesrgan_model.py
+267
-0
BasicSR/basicsr/models/realesrnet_model.py
BasicSR/basicsr/models/realesrnet_model.py
+189
-0
BasicSR/basicsr/models/sr_model.py
BasicSR/basicsr/models/sr_model.py
+279
-0
BasicSR/basicsr/models/srgan_model.py
BasicSR/basicsr/models/srgan_model.py
+149
-0
BasicSR/basicsr/models/stylegan2_model.py
BasicSR/basicsr/models/stylegan2_model.py
+283
-0
BasicSR/basicsr/models/swinir_model.py
BasicSR/basicsr/models/swinir_model.py
+33
-0
BasicSR/basicsr/models/video_base_model.py
BasicSR/basicsr/models/video_base_model.py
+160
-0
BasicSR/basicsr/models/video_gan_model.py
BasicSR/basicsr/models/video_gan_model.py
+19
-0
BasicSR/basicsr/models/video_recurrent_gan_model.py
BasicSR/basicsr/models/video_recurrent_gan_model.py
+180
-0
BasicSR/basicsr/models/video_recurrent_model.py
BasicSR/basicsr/models/video_recurrent_model.py
+197
-0
BasicSR/basicsr/ops/__init__.py
BasicSR/basicsr/ops/__init__.py
+0
-0
BasicSR/basicsr/ops/dcn/__init__.py
BasicSR/basicsr/ops/dcn/__init__.py
+7
-0
No files found.
Too many changes to show.
To preserve performance only
263 of 263+
files are displayed.
Plain diff
Email patch
BasicSR/basicsr/metrics/psnr_ssim.py
0 → 100644
View file @
e2696ece
import
cv2
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
basicsr.metrics.metric_util
import
reorder_image
,
to_y_channel
from
basicsr.utils.color_util
import
rgb2ycbcr_pt
from
basicsr.utils.registry
import
METRIC_REGISTRY
@
METRIC_REGISTRY
.
register
()
def
calculate_psnr
(
img
,
img2
,
crop_border
,
input_order
=
'HWC'
,
test_y_channel
=
False
,
**
kwargs
):
"""Calculate PSNR (Peak Signal-to-Noise Ratio).
Reference: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
Args:
img (ndarray): Images with range [0, 255].
img2 (ndarray): Images with range [0, 255].
crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
input_order (str): Whether the input order is 'HWC' or 'CHW'. Default: 'HWC'.
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
Returns:
float: PSNR result.
"""
assert
img
.
shape
==
img2
.
shape
,
(
f
'Image shapes are different:
{
img
.
shape
}
,
{
img2
.
shape
}
.'
)
if
input_order
not
in
[
'HWC'
,
'CHW'
]:
raise
ValueError
(
f
'Wrong input_order
{
input_order
}
. Supported input_orders are "HWC" and "CHW"'
)
img
=
reorder_image
(
img
,
input_order
=
input_order
)
img2
=
reorder_image
(
img2
,
input_order
=
input_order
)
if
crop_border
!=
0
:
img
=
img
[
crop_border
:
-
crop_border
,
crop_border
:
-
crop_border
,
...]
img2
=
img2
[
crop_border
:
-
crop_border
,
crop_border
:
-
crop_border
,
...]
if
test_y_channel
:
img
=
to_y_channel
(
img
)
img2
=
to_y_channel
(
img2
)
img
=
img
.
astype
(
np
.
float64
)
img2
=
img2
.
astype
(
np
.
float64
)
mse
=
np
.
mean
((
img
-
img2
)
**
2
)
if
mse
==
0
:
return
float
(
'inf'
)
return
10.
*
np
.
log10
(
255.
*
255.
/
mse
)
@
METRIC_REGISTRY
.
register
()
def
calculate_psnr_pt
(
img
,
img2
,
crop_border
,
test_y_channel
=
False
,
**
kwargs
):
"""Calculate PSNR (Peak Signal-to-Noise Ratio) (PyTorch version).
Reference: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
Args:
img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
Returns:
float: PSNR result.
"""
assert
img
.
shape
==
img2
.
shape
,
(
f
'Image shapes are different:
{
img
.
shape
}
,
{
img2
.
shape
}
.'
)
if
crop_border
!=
0
:
img
=
img
[:,
:,
crop_border
:
-
crop_border
,
crop_border
:
-
crop_border
]
img2
=
img2
[:,
:,
crop_border
:
-
crop_border
,
crop_border
:
-
crop_border
]
if
test_y_channel
:
img
=
rgb2ycbcr_pt
(
img
,
y_only
=
True
)
img2
=
rgb2ycbcr_pt
(
img2
,
y_only
=
True
)
img
=
img
.
to
(
torch
.
float64
)
img2
=
img2
.
to
(
torch
.
float64
)
mse
=
torch
.
mean
((
img
-
img2
)
**
2
,
dim
=
[
1
,
2
,
3
])
return
10.
*
torch
.
log10
(
1.
/
(
mse
+
1e-8
))
@
METRIC_REGISTRY
.
register
()
def
calculate_ssim
(
img
,
img2
,
crop_border
,
input_order
=
'HWC'
,
test_y_channel
=
False
,
**
kwargs
):
"""Calculate SSIM (structural similarity).
``Paper: Image quality assessment: From error visibility to structural similarity``
The results are the same as that of the official released MATLAB code in
https://ece.uwaterloo.ca/~z70wang/research/ssim/.
For three-channel images, SSIM is calculated for each channel and then
averaged.
Args:
img (ndarray): Images with range [0, 255].
img2 (ndarray): Images with range [0, 255].
crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
input_order (str): Whether the input order is 'HWC' or 'CHW'.
Default: 'HWC'.
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
Returns:
float: SSIM result.
"""
assert
img
.
shape
==
img2
.
shape
,
(
f
'Image shapes are different:
{
img
.
shape
}
,
{
img2
.
shape
}
.'
)
if
input_order
not
in
[
'HWC'
,
'CHW'
]:
raise
ValueError
(
f
'Wrong input_order
{
input_order
}
. Supported input_orders are "HWC" and "CHW"'
)
img
=
reorder_image
(
img
,
input_order
=
input_order
)
img2
=
reorder_image
(
img2
,
input_order
=
input_order
)
if
crop_border
!=
0
:
img
=
img
[
crop_border
:
-
crop_border
,
crop_border
:
-
crop_border
,
...]
img2
=
img2
[
crop_border
:
-
crop_border
,
crop_border
:
-
crop_border
,
...]
if
test_y_channel
:
img
=
to_y_channel
(
img
)
img2
=
to_y_channel
(
img2
)
img
=
img
.
astype
(
np
.
float64
)
img2
=
img2
.
astype
(
np
.
float64
)
ssims
=
[]
for
i
in
range
(
img
.
shape
[
2
]):
ssims
.
append
(
_ssim
(
img
[...,
i
],
img2
[...,
i
]))
return
np
.
array
(
ssims
).
mean
()
@
METRIC_REGISTRY
.
register
()
def
calculate_ssim_pt
(
img
,
img2
,
crop_border
,
test_y_channel
=
False
,
**
kwargs
):
"""Calculate SSIM (structural similarity) (PyTorch version).
``Paper: Image quality assessment: From error visibility to structural similarity``
The results are the same as that of the official released MATLAB code in
https://ece.uwaterloo.ca/~z70wang/research/ssim/.
For three-channel images, SSIM is calculated for each channel and then
averaged.
Args:
img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
Returns:
float: SSIM result.
"""
assert
img
.
shape
==
img2
.
shape
,
(
f
'Image shapes are different:
{
img
.
shape
}
,
{
img2
.
shape
}
.'
)
if
crop_border
!=
0
:
img
=
img
[:,
:,
crop_border
:
-
crop_border
,
crop_border
:
-
crop_border
]
img2
=
img2
[:,
:,
crop_border
:
-
crop_border
,
crop_border
:
-
crop_border
]
if
test_y_channel
:
img
=
rgb2ycbcr_pt
(
img
,
y_only
=
True
)
img2
=
rgb2ycbcr_pt
(
img2
,
y_only
=
True
)
img
=
img
.
to
(
torch
.
float64
)
img2
=
img2
.
to
(
torch
.
float64
)
ssim
=
_ssim_pth
(
img
*
255.
,
img2
*
255.
)
return
ssim
def
_ssim
(
img
,
img2
):
"""Calculate SSIM (structural similarity) for one channel images.
It is called by func:`calculate_ssim`.
Args:
img (ndarray): Images with range [0, 255] with order 'HWC'.
img2 (ndarray): Images with range [0, 255] with order 'HWC'.
Returns:
float: SSIM result.
"""
c1
=
(
0.01
*
255
)
**
2
c2
=
(
0.03
*
255
)
**
2
kernel
=
cv2
.
getGaussianKernel
(
11
,
1.5
)
window
=
np
.
outer
(
kernel
,
kernel
.
transpose
())
mu1
=
cv2
.
filter2D
(
img
,
-
1
,
window
)[
5
:
-
5
,
5
:
-
5
]
# valid mode for window size 11
mu2
=
cv2
.
filter2D
(
img2
,
-
1
,
window
)[
5
:
-
5
,
5
:
-
5
]
mu1_sq
=
mu1
**
2
mu2_sq
=
mu2
**
2
mu1_mu2
=
mu1
*
mu2
sigma1_sq
=
cv2
.
filter2D
(
img
**
2
,
-
1
,
window
)[
5
:
-
5
,
5
:
-
5
]
-
mu1_sq
sigma2_sq
=
cv2
.
filter2D
(
img2
**
2
,
-
1
,
window
)[
5
:
-
5
,
5
:
-
5
]
-
mu2_sq
sigma12
=
cv2
.
filter2D
(
img
*
img2
,
-
1
,
window
)[
5
:
-
5
,
5
:
-
5
]
-
mu1_mu2
ssim_map
=
((
2
*
mu1_mu2
+
c1
)
*
(
2
*
sigma12
+
c2
))
/
((
mu1_sq
+
mu2_sq
+
c1
)
*
(
sigma1_sq
+
sigma2_sq
+
c2
))
return
ssim_map
.
mean
()
def
_ssim_pth
(
img
,
img2
):
"""Calculate SSIM (structural similarity) (PyTorch version).
It is called by func:`calculate_ssim_pt`.
Args:
img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
Returns:
float: SSIM result.
"""
c1
=
(
0.01
*
255
)
**
2
c2
=
(
0.03
*
255
)
**
2
kernel
=
cv2
.
getGaussianKernel
(
11
,
1.5
)
window
=
np
.
outer
(
kernel
,
kernel
.
transpose
())
window
=
torch
.
from_numpy
(
window
).
view
(
1
,
1
,
11
,
11
).
expand
(
img
.
size
(
1
),
1
,
11
,
11
).
to
(
img
.
dtype
).
to
(
img
.
device
)
mu1
=
F
.
conv2d
(
img
,
window
,
stride
=
1
,
padding
=
0
,
groups
=
img
.
shape
[
1
])
# valid mode
mu2
=
F
.
conv2d
(
img2
,
window
,
stride
=
1
,
padding
=
0
,
groups
=
img2
.
shape
[
1
])
# valid mode
mu1_sq
=
mu1
.
pow
(
2
)
mu2_sq
=
mu2
.
pow
(
2
)
mu1_mu2
=
mu1
*
mu2
sigma1_sq
=
F
.
conv2d
(
img
*
img
,
window
,
stride
=
1
,
padding
=
0
,
groups
=
img
.
shape
[
1
])
-
mu1_sq
sigma2_sq
=
F
.
conv2d
(
img2
*
img2
,
window
,
stride
=
1
,
padding
=
0
,
groups
=
img
.
shape
[
1
])
-
mu2_sq
sigma12
=
F
.
conv2d
(
img
*
img2
,
window
,
stride
=
1
,
padding
=
0
,
groups
=
img
.
shape
[
1
])
-
mu1_mu2
cs_map
=
(
2
*
sigma12
+
c2
)
/
(
sigma1_sq
+
sigma2_sq
+
c2
)
ssim_map
=
((
2
*
mu1_mu2
+
c1
)
/
(
mu1_sq
+
mu2_sq
+
c1
))
*
cs_map
return
ssim_map
.
mean
([
1
,
2
,
3
])
BasicSR/basicsr/metrics/test_metrics/test_psnr_ssim.py
0 → 100644
View file @
e2696ece
import
cv2
import
torch
from
basicsr.metrics
import
calculate_psnr
,
calculate_ssim
from
basicsr.metrics.psnr_ssim
import
calculate_psnr_pt
,
calculate_ssim_pt
from
basicsr.utils
import
img2tensor
def
test
(
img_path
,
img_path2
,
crop_border
,
test_y_channel
=
False
):
img
=
cv2
.
imread
(
img_path
,
cv2
.
IMREAD_UNCHANGED
)
img2
=
cv2
.
imread
(
img_path2
,
cv2
.
IMREAD_UNCHANGED
)
# --------------------- Numpy ---------------------
psnr
=
calculate_psnr
(
img
,
img2
,
crop_border
=
crop_border
,
input_order
=
'HWC'
,
test_y_channel
=
test_y_channel
)
ssim
=
calculate_ssim
(
img
,
img2
,
crop_border
=
crop_border
,
input_order
=
'HWC'
,
test_y_channel
=
test_y_channel
)
print
(
f
'
\t
Numpy
\t
PSNR:
{
psnr
:.
6
f
}
dB,
\t
SSIM:
{
ssim
:.
6
f
}
'
)
# --------------------- PyTorch (CPU) ---------------------
img
=
img2tensor
(
img
/
255.
,
bgr2rgb
=
True
,
float32
=
True
).
unsqueeze_
(
0
)
img2
=
img2tensor
(
img2
/
255.
,
bgr2rgb
=
True
,
float32
=
True
).
unsqueeze_
(
0
)
psnr_pth
=
calculate_psnr_pt
(
img
,
img2
,
crop_border
=
crop_border
,
test_y_channel
=
test_y_channel
)
ssim_pth
=
calculate_ssim_pt
(
img
,
img2
,
crop_border
=
crop_border
,
test_y_channel
=
test_y_channel
)
print
(
f
'
\t
Tensor (CPU)
\t
PSNR:
{
psnr_pth
[
0
]:.
6
f
}
dB,
\t
SSIM:
{
ssim_pth
[
0
]:.
6
f
}
'
)
# --------------------- PyTorch (GPU) ---------------------
img
=
img
.
cuda
()
img2
=
img2
.
cuda
()
psnr_pth
=
calculate_psnr_pt
(
img
,
img2
,
crop_border
=
crop_border
,
test_y_channel
=
test_y_channel
)
ssim_pth
=
calculate_ssim_pt
(
img
,
img2
,
crop_border
=
crop_border
,
test_y_channel
=
test_y_channel
)
print
(
f
'
\t
Tensor (GPU)
\t
PSNR:
{
psnr_pth
[
0
]:.
6
f
}
dB,
\t
SSIM:
{
ssim_pth
[
0
]:.
6
f
}
'
)
psnr_pth
=
calculate_psnr_pt
(
torch
.
repeat_interleave
(
img
,
2
,
dim
=
0
),
torch
.
repeat_interleave
(
img2
,
2
,
dim
=
0
),
crop_border
=
crop_border
,
test_y_channel
=
test_y_channel
)
ssim_pth
=
calculate_ssim_pt
(
torch
.
repeat_interleave
(
img
,
2
,
dim
=
0
),
torch
.
repeat_interleave
(
img2
,
2
,
dim
=
0
),
crop_border
=
crop_border
,
test_y_channel
=
test_y_channel
)
print
(
f
'
\t
Tensor (GPU batch)
\t
PSNR:
{
psnr_pth
[
0
]:.
6
f
}
,
{
psnr_pth
[
1
]:.
6
f
}
dB,'
f
'
\t
SSIM:
{
ssim_pth
[
0
]:.
6
f
}
,
{
ssim_pth
[
1
]:.
6
f
}
'
)
if
__name__
==
'__main__'
:
test
(
'tests/data/bic/baboon.png'
,
'tests/data/gt/baboon.png'
,
crop_border
=
4
,
test_y_channel
=
False
)
test
(
'tests/data/bic/baboon.png'
,
'tests/data/gt/baboon.png'
,
crop_border
=
4
,
test_y_channel
=
True
)
test
(
'tests/data/bic/comic.png'
,
'tests/data/gt/comic.png'
,
crop_border
=
4
,
test_y_channel
=
False
)
test
(
'tests/data/bic/comic.png'
,
'tests/data/gt/comic.png'
,
crop_border
=
4
,
test_y_channel
=
True
)
BasicSR/basicsr/models/__init__.py
0 → 100644
View file @
e2696ece
import
importlib
from
copy
import
deepcopy
from
os
import
path
as
osp
from
basicsr.utils
import
get_root_logger
,
scandir
from
basicsr.utils.registry
import
MODEL_REGISTRY
__all__
=
[
'build_model'
]
# automatically scan and import model modules for registry
# scan all the files under the 'models' folder and collect files ending with '_model.py'
model_folder
=
osp
.
dirname
(
osp
.
abspath
(
__file__
))
model_filenames
=
[
osp
.
splitext
(
osp
.
basename
(
v
))[
0
]
for
v
in
scandir
(
model_folder
)
if
v
.
endswith
(
'_model.py'
)]
# import all the model modules
_model_modules
=
[
importlib
.
import_module
(
f
'basicsr.models.
{
file_name
}
'
)
for
file_name
in
model_filenames
]
def
build_model
(
opt
):
"""Build model from options.
Args:
opt (dict): Configuration. It must contain:
model_type (str): Model type.
"""
opt
=
deepcopy
(
opt
)
model
=
MODEL_REGISTRY
.
get
(
opt
[
'model_type'
])(
opt
)
logger
=
get_root_logger
()
logger
.
info
(
f
'Model [
{
model
.
__class__
.
__name__
}
] is created.'
)
return
model
BasicSR/basicsr/models/base_model.py
0 → 100644
View file @
e2696ece
import
os
import
time
import
torch
from
collections
import
OrderedDict
from
copy
import
deepcopy
from
torch.nn.parallel
import
DataParallel
,
DistributedDataParallel
from
basicsr.models
import
lr_scheduler
as
lr_scheduler
from
basicsr.utils
import
get_root_logger
from
basicsr.utils.dist_util
import
master_only
class
BaseModel
():
"""Base model."""
def
__init__
(
self
,
opt
):
self
.
opt
=
opt
self
.
device
=
torch
.
device
(
'cuda'
if
opt
[
'num_gpu'
]
!=
0
else
'cpu'
)
self
.
is_train
=
opt
[
'is_train'
]
self
.
schedulers
=
[]
self
.
optimizers
=
[]
def
feed_data
(
self
,
data
):
pass
def
optimize_parameters
(
self
):
pass
def
get_current_visuals
(
self
):
pass
def
save
(
self
,
epoch
,
current_iter
):
"""Save networks and training state."""
pass
def
validation
(
self
,
dataloader
,
current_iter
,
tb_logger
,
save_img
=
False
):
"""Validation function.
Args:
dataloader (torch.utils.data.DataLoader): Validation dataloader.
current_iter (int): Current iteration.
tb_logger (tensorboard logger): Tensorboard logger.
save_img (bool): Whether to save images. Default: False.
"""
if
self
.
opt
[
'dist'
]:
self
.
dist_validation
(
dataloader
,
current_iter
,
tb_logger
,
save_img
)
else
:
self
.
nondist_validation
(
dataloader
,
current_iter
,
tb_logger
,
save_img
)
def
_initialize_best_metric_results
(
self
,
dataset_name
):
"""Initialize the best metric results dict for recording the best metric value and iteration."""
if
hasattr
(
self
,
'best_metric_results'
)
and
dataset_name
in
self
.
best_metric_results
:
return
elif
not
hasattr
(
self
,
'best_metric_results'
):
self
.
best_metric_results
=
dict
()
# add a dataset record
record
=
dict
()
for
metric
,
content
in
self
.
opt
[
'val'
][
'metrics'
].
items
():
better
=
content
.
get
(
'better'
,
'higher'
)
init_val
=
float
(
'-inf'
)
if
better
==
'higher'
else
float
(
'inf'
)
record
[
metric
]
=
dict
(
better
=
better
,
val
=
init_val
,
iter
=-
1
)
self
.
best_metric_results
[
dataset_name
]
=
record
def
_update_best_metric_result
(
self
,
dataset_name
,
metric
,
val
,
current_iter
):
if
self
.
best_metric_results
[
dataset_name
][
metric
][
'better'
]
==
'higher'
:
if
val
>=
self
.
best_metric_results
[
dataset_name
][
metric
][
'val'
]:
self
.
best_metric_results
[
dataset_name
][
metric
][
'val'
]
=
val
self
.
best_metric_results
[
dataset_name
][
metric
][
'iter'
]
=
current_iter
else
:
if
val
<=
self
.
best_metric_results
[
dataset_name
][
metric
][
'val'
]:
self
.
best_metric_results
[
dataset_name
][
metric
][
'val'
]
=
val
self
.
best_metric_results
[
dataset_name
][
metric
][
'iter'
]
=
current_iter
def
model_ema
(
self
,
decay
=
0.999
):
net_g
=
self
.
get_bare_model
(
self
.
net_g
)
net_g_params
=
dict
(
net_g
.
named_parameters
())
net_g_ema_params
=
dict
(
self
.
net_g_ema
.
named_parameters
())
for
k
in
net_g_ema_params
.
keys
():
net_g_ema_params
[
k
].
data
.
mul_
(
decay
).
add_
(
net_g_params
[
k
].
data
,
alpha
=
1
-
decay
)
def
get_current_log
(
self
):
return
self
.
log_dict
def
model_to_device
(
self
,
net
):
"""Model to device. It also warps models with DistributedDataParallel
or DataParallel.
Args:
net (nn.Module)
"""
net
=
net
.
to
(
self
.
device
)
if
self
.
opt
[
'dist'
]:
find_unused_parameters
=
self
.
opt
.
get
(
'find_unused_parameters'
,
False
)
net
=
DistributedDataParallel
(
net
,
device_ids
=
[
torch
.
cuda
.
current_device
()],
find_unused_parameters
=
find_unused_parameters
)
elif
self
.
opt
[
'num_gpu'
]
>
1
:
net
=
DataParallel
(
net
)
return
net
def
get_optimizer
(
self
,
optim_type
,
params
,
lr
,
**
kwargs
):
if
optim_type
==
'Adam'
:
optimizer
=
torch
.
optim
.
Adam
(
params
,
lr
,
**
kwargs
)
elif
optim_type
==
'AdamW'
:
optimizer
=
torch
.
optim
.
AdamW
(
params
,
lr
,
**
kwargs
)
elif
optim_type
==
'Adamax'
:
optimizer
=
torch
.
optim
.
Adamax
(
params
,
lr
,
**
kwargs
)
elif
optim_type
==
'SGD'
:
optimizer
=
torch
.
optim
.
SGD
(
params
,
lr
,
**
kwargs
)
elif
optim_type
==
'ASGD'
:
optimizer
=
torch
.
optim
.
ASGD
(
params
,
lr
,
**
kwargs
)
elif
optim_type
==
'RMSprop'
:
optimizer
=
torch
.
optim
.
RMSprop
(
params
,
lr
,
**
kwargs
)
elif
optim_type
==
'Rprop'
:
optimizer
=
torch
.
optim
.
Rprop
(
params
,
lr
,
**
kwargs
)
else
:
raise
NotImplementedError
(
f
'optimizer
{
optim_type
}
is not supported yet.'
)
return
optimizer
def
setup_schedulers
(
self
):
"""Set up schedulers."""
train_opt
=
self
.
opt
[
'train'
]
scheduler_type
=
train_opt
[
'scheduler'
].
pop
(
'type'
)
if
scheduler_type
in
[
'MultiStepLR'
,
'MultiStepRestartLR'
]:
for
optimizer
in
self
.
optimizers
:
self
.
schedulers
.
append
(
lr_scheduler
.
MultiStepRestartLR
(
optimizer
,
**
train_opt
[
'scheduler'
]))
elif
scheduler_type
==
'CosineAnnealingRestartLR'
:
for
optimizer
in
self
.
optimizers
:
self
.
schedulers
.
append
(
lr_scheduler
.
CosineAnnealingRestartLR
(
optimizer
,
**
train_opt
[
'scheduler'
]))
else
:
raise
NotImplementedError
(
f
'Scheduler
{
scheduler_type
}
is not implemented yet.'
)
def
get_bare_model
(
self
,
net
):
"""Get bare model, especially under wrapping with
DistributedDataParallel or DataParallel.
"""
if
isinstance
(
net
,
(
DataParallel
,
DistributedDataParallel
)):
net
=
net
.
module
return
net
@
master_only
def
print_network
(
self
,
net
):
"""Print the str and parameter number of a network.
Args:
net (nn.Module)
"""
if
isinstance
(
net
,
(
DataParallel
,
DistributedDataParallel
)):
net_cls_str
=
f
'
{
net
.
__class__
.
__name__
}
-
{
net
.
module
.
__class__
.
__name__
}
'
else
:
net_cls_str
=
f
'
{
net
.
__class__
.
__name__
}
'
net
=
self
.
get_bare_model
(
net
)
net_str
=
str
(
net
)
net_params
=
sum
(
map
(
lambda
x
:
x
.
numel
(),
net
.
parameters
()))
logger
=
get_root_logger
()
logger
.
info
(
f
'Network:
{
net_cls_str
}
, with parameters:
{
net_params
:,
d
}
'
)
logger
.
info
(
net_str
)
def
_set_lr
(
self
,
lr_groups_l
):
"""Set learning rate for warm-up.
Args:
lr_groups_l (list): List for lr_groups, each for an optimizer.
"""
for
optimizer
,
lr_groups
in
zip
(
self
.
optimizers
,
lr_groups_l
):
for
param_group
,
lr
in
zip
(
optimizer
.
param_groups
,
lr_groups
):
param_group
[
'lr'
]
=
lr
def
_get_init_lr
(
self
):
"""Get the initial lr, which is set by the scheduler.
"""
init_lr_groups_l
=
[]
for
optimizer
in
self
.
optimizers
:
init_lr_groups_l
.
append
([
v
[
'initial_lr'
]
for
v
in
optimizer
.
param_groups
])
return
init_lr_groups_l
def
update_learning_rate
(
self
,
current_iter
,
warmup_iter
=-
1
):
"""Update learning rate.
Args:
current_iter (int): Current iteration.
warmup_iter (int): Warm-up iter numbers. -1 for no warm-up.
Default: -1.
"""
if
current_iter
>
1
:
for
scheduler
in
self
.
schedulers
:
scheduler
.
step
()
# set up warm-up learning rate
if
current_iter
<
warmup_iter
:
# get initial lr for each group
init_lr_g_l
=
self
.
_get_init_lr
()
# modify warming-up learning rates
# currently only support linearly warm up
warm_up_lr_l
=
[]
for
init_lr_g
in
init_lr_g_l
:
warm_up_lr_l
.
append
([
v
/
warmup_iter
*
current_iter
for
v
in
init_lr_g
])
# set learning rate
self
.
_set_lr
(
warm_up_lr_l
)
def
get_current_learning_rate
(
self
):
return
[
param_group
[
'lr'
]
for
param_group
in
self
.
optimizers
[
0
].
param_groups
]
@
master_only
def
save_network
(
self
,
net
,
net_label
,
current_iter
,
param_key
=
'params'
):
"""Save networks.
Args:
net (nn.Module | list[nn.Module]): Network(s) to be saved.
net_label (str): Network label.
current_iter (int): Current iter number.
param_key (str | list[str]): The parameter key(s) to save network.
Default: 'params'.
"""
if
current_iter
==
-
1
:
current_iter
=
'latest'
save_filename
=
f
'
{
net_label
}
_
{
current_iter
}
.pth'
save_path
=
os
.
path
.
join
(
self
.
opt
[
'path'
][
'models'
],
save_filename
)
net
=
net
if
isinstance
(
net
,
list
)
else
[
net
]
param_key
=
param_key
if
isinstance
(
param_key
,
list
)
else
[
param_key
]
assert
len
(
net
)
==
len
(
param_key
),
'The lengths of net and param_key should be the same.'
save_dict
=
{}
for
net_
,
param_key_
in
zip
(
net
,
param_key
):
net_
=
self
.
get_bare_model
(
net_
)
state_dict
=
net_
.
state_dict
()
for
key
,
param
in
state_dict
.
items
():
if
key
.
startswith
(
'module.'
):
# remove unnecessary 'module.'
key
=
key
[
7
:]
state_dict
[
key
]
=
param
.
cpu
()
save_dict
[
param_key_
]
=
state_dict
# avoid occasional writing errors
retry
=
3
while
retry
>
0
:
try
:
torch
.
save
(
save_dict
,
save_path
)
except
Exception
as
e
:
logger
=
get_root_logger
()
logger
.
warning
(
f
'Save model error:
{
e
}
, remaining retry times:
{
retry
-
1
}
'
)
time
.
sleep
(
1
)
else
:
break
finally
:
retry
-=
1
if
retry
==
0
:
logger
.
warning
(
f
'Still cannot save
{
save_path
}
. Just ignore it.'
)
# raise IOError(f'Cannot save {save_path}.')
def
_print_different_keys_loading
(
self
,
crt_net
,
load_net
,
strict
=
True
):
"""Print keys with different name or different size when loading models.
1. Print keys with different names.
2. If strict=False, print the same key but with different tensor size.
It also ignore these keys with different sizes (not load).
Args:
crt_net (torch model): Current network.
load_net (dict): Loaded network.
strict (bool): Whether strictly loaded. Default: True.
"""
crt_net
=
self
.
get_bare_model
(
crt_net
)
crt_net
=
crt_net
.
state_dict
()
crt_net_keys
=
set
(
crt_net
.
keys
())
load_net_keys
=
set
(
load_net
.
keys
())
logger
=
get_root_logger
()
if
crt_net_keys
!=
load_net_keys
:
logger
.
warning
(
'Current net - loaded net:'
)
for
v
in
sorted
(
list
(
crt_net_keys
-
load_net_keys
)):
logger
.
warning
(
f
'
{
v
}
'
)
logger
.
warning
(
'Loaded net - current net:'
)
for
v
in
sorted
(
list
(
load_net_keys
-
crt_net_keys
)):
logger
.
warning
(
f
'
{
v
}
'
)
# check the size for the same keys
if
not
strict
:
common_keys
=
crt_net_keys
&
load_net_keys
for
k
in
common_keys
:
if
crt_net
[
k
].
size
()
!=
load_net
[
k
].
size
():
logger
.
warning
(
f
'Size different, ignore [
{
k
}
]: crt_net: '
f
'
{
crt_net
[
k
].
shape
}
; load_net:
{
load_net
[
k
].
shape
}
'
)
load_net
[
k
+
'.ignore'
]
=
load_net
.
pop
(
k
)
def
load_network
(
self
,
net
,
load_path
,
strict
=
True
,
param_key
=
'params'
):
"""Load network.
Args:
load_path (str): The path of networks to be loaded.
net (nn.Module): Network.
strict (bool): Whether strictly loaded.
param_key (str): The parameter key of loaded network. If set to
None, use the root 'path'.
Default: 'params'.
"""
logger
=
get_root_logger
()
net
=
self
.
get_bare_model
(
net
)
load_net
=
torch
.
load
(
load_path
,
map_location
=
lambda
storage
,
loc
:
storage
)
if
param_key
is
not
None
:
if
param_key
not
in
load_net
and
'params'
in
load_net
:
param_key
=
'params'
logger
.
info
(
'Loading: params_ema does not exist, use params.'
)
load_net
=
load_net
[
param_key
]
logger
.
info
(
f
'Loading
{
net
.
__class__
.
__name__
}
model from
{
load_path
}
, with param key: [
{
param_key
}
].'
)
# remove unnecessary 'module.'
for
k
,
v
in
deepcopy
(
load_net
).
items
():
if
k
.
startswith
(
'module.'
):
load_net
[
k
[
7
:]]
=
v
load_net
.
pop
(
k
)
self
.
_print_different_keys_loading
(
net
,
load_net
,
strict
)
net
.
load_state_dict
(
load_net
,
strict
=
strict
)
@
master_only
def
save_training_state
(
self
,
epoch
,
current_iter
):
"""Save training states during training, which will be used for
resuming.
Args:
epoch (int): Current epoch.
current_iter (int): Current iteration.
"""
if
current_iter
!=
-
1
:
state
=
{
'epoch'
:
epoch
,
'iter'
:
current_iter
,
'optimizers'
:
[],
'schedulers'
:
[]}
for
o
in
self
.
optimizers
:
state
[
'optimizers'
].
append
(
o
.
state_dict
())
for
s
in
self
.
schedulers
:
state
[
'schedulers'
].
append
(
s
.
state_dict
())
save_filename
=
f
'
{
current_iter
}
.state'
save_path
=
os
.
path
.
join
(
self
.
opt
[
'path'
][
'training_states'
],
save_filename
)
# avoid occasional writing errors
retry
=
3
while
retry
>
0
:
try
:
torch
.
save
(
state
,
save_path
)
except
Exception
as
e
:
logger
=
get_root_logger
()
logger
.
warning
(
f
'Save training state error:
{
e
}
, remaining retry times:
{
retry
-
1
}
'
)
time
.
sleep
(
1
)
else
:
break
finally
:
retry
-=
1
if
retry
==
0
:
logger
.
warning
(
f
'Still cannot save
{
save_path
}
. Just ignore it.'
)
# raise IOError(f'Cannot save {save_path}.')
def
resume_training
(
self
,
resume_state
):
"""Reload the optimizers and schedulers for resumed training.
Args:
resume_state (dict): Resume state.
"""
resume_optimizers
=
resume_state
[
'optimizers'
]
resume_schedulers
=
resume_state
[
'schedulers'
]
assert
len
(
resume_optimizers
)
==
len
(
self
.
optimizers
),
'Wrong lengths of optimizers'
assert
len
(
resume_schedulers
)
==
len
(
self
.
schedulers
),
'Wrong lengths of schedulers'
for
i
,
o
in
enumerate
(
resume_optimizers
):
self
.
optimizers
[
i
].
load_state_dict
(
o
)
for
i
,
s
in
enumerate
(
resume_schedulers
):
self
.
schedulers
[
i
].
load_state_dict
(
s
)
def
reduce_loss_dict
(
self
,
loss_dict
):
"""reduce loss dict.
In distributed training, it averages the losses among different GPUs .
Args:
loss_dict (OrderedDict): Loss dict.
"""
with
torch
.
no_grad
():
if
self
.
opt
[
'dist'
]:
keys
=
[]
losses
=
[]
for
name
,
value
in
loss_dict
.
items
():
keys
.
append
(
name
)
losses
.
append
(
value
)
losses
=
torch
.
stack
(
losses
,
0
)
torch
.
distributed
.
reduce
(
losses
,
dst
=
0
)
if
self
.
opt
[
'rank'
]
==
0
:
losses
/=
self
.
opt
[
'world_size'
]
loss_dict
=
{
key
:
loss
for
key
,
loss
in
zip
(
keys
,
losses
)}
log_dict
=
OrderedDict
()
for
name
,
value
in
loss_dict
.
items
():
log_dict
[
name
]
=
value
.
mean
().
item
()
return
log_dict
BasicSR/basicsr/models/edvr_model.py
0 → 100644
View file @
e2696ece
from
basicsr.utils
import
get_root_logger
from
basicsr.utils.registry
import
MODEL_REGISTRY
from
.video_base_model
import
VideoBaseModel
@
MODEL_REGISTRY
.
register
()
class
EDVRModel
(
VideoBaseModel
):
"""EDVR Model.
Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks. # noqa: E501
"""
def
__init__
(
self
,
opt
):
super
(
EDVRModel
,
self
).
__init__
(
opt
)
if
self
.
is_train
:
self
.
train_tsa_iter
=
opt
[
'train'
].
get
(
'tsa_iter'
)
def
setup_optimizers
(
self
):
train_opt
=
self
.
opt
[
'train'
]
dcn_lr_mul
=
train_opt
.
get
(
'dcn_lr_mul'
,
1
)
logger
=
get_root_logger
()
logger
.
info
(
f
'Multiple the learning rate for dcn with
{
dcn_lr_mul
}
.'
)
if
dcn_lr_mul
==
1
:
optim_params
=
self
.
net_g
.
parameters
()
else
:
# separate dcn params and normal params for different lr
normal_params
=
[]
dcn_params
=
[]
for
name
,
param
in
self
.
net_g
.
named_parameters
():
if
'dcn'
in
name
:
dcn_params
.
append
(
param
)
else
:
normal_params
.
append
(
param
)
optim_params
=
[
{
# add normal params first
'params'
:
normal_params
,
'lr'
:
train_opt
[
'optim_g'
][
'lr'
]
},
{
'params'
:
dcn_params
,
'lr'
:
train_opt
[
'optim_g'
][
'lr'
]
*
dcn_lr_mul
},
]
optim_type
=
train_opt
[
'optim_g'
].
pop
(
'type'
)
self
.
optimizer_g
=
self
.
get_optimizer
(
optim_type
,
optim_params
,
**
train_opt
[
'optim_g'
])
self
.
optimizers
.
append
(
self
.
optimizer_g
)
def
optimize_parameters
(
self
,
current_iter
):
if
self
.
train_tsa_iter
:
if
current_iter
==
1
:
logger
=
get_root_logger
()
logger
.
info
(
f
'Only train TSA module for
{
self
.
train_tsa_iter
}
iters.'
)
for
name
,
param
in
self
.
net_g
.
named_parameters
():
if
'fusion'
not
in
name
:
param
.
requires_grad
=
False
elif
current_iter
==
self
.
train_tsa_iter
:
logger
=
get_root_logger
()
logger
.
warning
(
'Train all the parameters.'
)
for
param
in
self
.
net_g
.
parameters
():
param
.
requires_grad
=
True
super
(
EDVRModel
,
self
).
optimize_parameters
(
current_iter
)
BasicSR/basicsr/models/esrgan_model.py
0 → 100644
View file @
e2696ece
import
torch
from
collections
import
OrderedDict
from
basicsr.utils.registry
import
MODEL_REGISTRY
from
.srgan_model
import
SRGANModel
@
MODEL_REGISTRY
.
register
()
class
ESRGANModel
(
SRGANModel
):
"""ESRGAN model for single image super-resolution."""
def
optimize_parameters
(
self
,
current_iter
):
# optimize net_g
for
p
in
self
.
net_d
.
parameters
():
p
.
requires_grad
=
False
self
.
optimizer_g
.
zero_grad
()
self
.
output
=
self
.
net_g
(
self
.
lq
)
l_g_total
=
0
loss_dict
=
OrderedDict
()
if
(
current_iter
%
self
.
net_d_iters
==
0
and
current_iter
>
self
.
net_d_init_iters
):
# pixel loss
if
self
.
cri_pix
:
l_g_pix
=
self
.
cri_pix
(
self
.
output
,
self
.
gt
)
l_g_total
+=
l_g_pix
loss_dict
[
'l_g_pix'
]
=
l_g_pix
# perceptual loss
if
self
.
cri_perceptual
:
l_g_percep
,
l_g_style
=
self
.
cri_perceptual
(
self
.
output
,
self
.
gt
)
if
l_g_percep
is
not
None
:
l_g_total
+=
l_g_percep
loss_dict
[
'l_g_percep'
]
=
l_g_percep
if
l_g_style
is
not
None
:
l_g_total
+=
l_g_style
loss_dict
[
'l_g_style'
]
=
l_g_style
# gan loss (relativistic gan)
real_d_pred
=
self
.
net_d
(
self
.
gt
).
detach
()
fake_g_pred
=
self
.
net_d
(
self
.
output
)
l_g_real
=
self
.
cri_gan
(
real_d_pred
-
torch
.
mean
(
fake_g_pred
),
False
,
is_disc
=
False
)
l_g_fake
=
self
.
cri_gan
(
fake_g_pred
-
torch
.
mean
(
real_d_pred
),
True
,
is_disc
=
False
)
l_g_gan
=
(
l_g_real
+
l_g_fake
)
/
2
l_g_total
+=
l_g_gan
loss_dict
[
'l_g_gan'
]
=
l_g_gan
l_g_total
.
backward
()
self
.
optimizer_g
.
step
()
# optimize net_d
for
p
in
self
.
net_d
.
parameters
():
p
.
requires_grad
=
True
self
.
optimizer_d
.
zero_grad
()
# gan loss (relativistic gan)
# In order to avoid the error in distributed training:
# "Error detected in CudnnBatchNormBackward: RuntimeError: one of
# the variables needed for gradient computation has been modified by
# an inplace operation",
# we separate the backwards for real and fake, and also detach the
# tensor for calculating mean.
# real
fake_d_pred
=
self
.
net_d
(
self
.
output
).
detach
()
real_d_pred
=
self
.
net_d
(
self
.
gt
)
l_d_real
=
self
.
cri_gan
(
real_d_pred
-
torch
.
mean
(
fake_d_pred
),
True
,
is_disc
=
True
)
*
0.5
l_d_real
.
backward
()
# fake
fake_d_pred
=
self
.
net_d
(
self
.
output
.
detach
())
l_d_fake
=
self
.
cri_gan
(
fake_d_pred
-
torch
.
mean
(
real_d_pred
.
detach
()),
False
,
is_disc
=
True
)
*
0.5
l_d_fake
.
backward
()
self
.
optimizer_d
.
step
()
loss_dict
[
'l_d_real'
]
=
l_d_real
loss_dict
[
'l_d_fake'
]
=
l_d_fake
loss_dict
[
'out_d_real'
]
=
torch
.
mean
(
real_d_pred
.
detach
())
loss_dict
[
'out_d_fake'
]
=
torch
.
mean
(
fake_d_pred
.
detach
())
self
.
log_dict
=
self
.
reduce_loss_dict
(
loss_dict
)
if
self
.
ema_decay
>
0
:
self
.
model_ema
(
decay
=
self
.
ema_decay
)
BasicSR/basicsr/models/hifacegan_model.py
0 → 100644
View file @
e2696ece
import
torch
from
collections
import
OrderedDict
from
os
import
path
as
osp
from
tqdm
import
tqdm
from
basicsr.archs
import
build_network
from
basicsr.losses
import
build_loss
from
basicsr.metrics
import
calculate_metric
from
basicsr.utils
import
imwrite
,
tensor2img
from
basicsr.utils.registry
import
MODEL_REGISTRY
from
.sr_model
import
SRModel
@
MODEL_REGISTRY
.
register
()
class
HiFaceGANModel
(
SRModel
):
"""HiFaceGAN model for generic-purpose face restoration.
No prior modeling required, works for any degradations.
Currently doesn't support EMA for inference.
"""
def
init_training_settings
(
self
):
train_opt
=
self
.
opt
[
'train'
]
self
.
ema_decay
=
train_opt
.
get
(
'ema_decay'
,
0
)
if
self
.
ema_decay
>
0
:
raise
(
NotImplementedError
(
'HiFaceGAN does not support EMA now. Pass'
))
self
.
net_g
.
train
()
self
.
net_d
=
build_network
(
self
.
opt
[
'network_d'
])
self
.
net_d
=
self
.
model_to_device
(
self
.
net_d
)
self
.
print_network
(
self
.
net_d
)
# define losses
# HiFaceGAN does not use pixel loss by default
if
train_opt
.
get
(
'pixel_opt'
):
self
.
cri_pix
=
build_loss
(
train_opt
[
'pixel_opt'
]).
to
(
self
.
device
)
else
:
self
.
cri_pix
=
None
if
train_opt
.
get
(
'perceptual_opt'
):
self
.
cri_perceptual
=
build_loss
(
train_opt
[
'perceptual_opt'
]).
to
(
self
.
device
)
else
:
self
.
cri_perceptual
=
None
if
train_opt
.
get
(
'feature_matching_opt'
):
self
.
cri_feat
=
build_loss
(
train_opt
[
'feature_matching_opt'
]).
to
(
self
.
device
)
else
:
self
.
cri_feat
=
None
if
self
.
cri_pix
is
None
and
self
.
cri_perceptual
is
None
:
raise
ValueError
(
'Both pixel and perceptual losses are None.'
)
if
train_opt
.
get
(
'gan_opt'
):
self
.
cri_gan
=
build_loss
(
train_opt
[
'gan_opt'
]).
to
(
self
.
device
)
self
.
net_d_iters
=
train_opt
.
get
(
'net_d_iters'
,
1
)
self
.
net_d_init_iters
=
train_opt
.
get
(
'net_d_init_iters'
,
0
)
# set up optimizers and schedulers
self
.
setup_optimizers
()
self
.
setup_schedulers
()
def
setup_optimizers
(
self
):
train_opt
=
self
.
opt
[
'train'
]
# optimizer g
optim_type
=
train_opt
[
'optim_g'
].
pop
(
'type'
)
self
.
optimizer_g
=
self
.
get_optimizer
(
optim_type
,
self
.
net_g
.
parameters
(),
**
train_opt
[
'optim_g'
])
self
.
optimizers
.
append
(
self
.
optimizer_g
)
# optimizer d
optim_type
=
train_opt
[
'optim_d'
].
pop
(
'type'
)
self
.
optimizer_d
=
self
.
get_optimizer
(
optim_type
,
self
.
net_d
.
parameters
(),
**
train_opt
[
'optim_d'
])
self
.
optimizers
.
append
(
self
.
optimizer_d
)
def
discriminate
(
self
,
input_lq
,
output
,
ground_truth
):
"""
This is a conditional (on the input) discriminator
In Batch Normalization, the fake and real images are
recommended to be in the same batch to avoid disparate
statistics in fake and real images.
So both fake and real images are fed to D all at once.
"""
h
,
w
=
output
.
shape
[
-
2
:]
if
output
.
shape
[
-
2
:]
!=
input_lq
.
shape
[
-
2
:]:
lq
=
torch
.
nn
.
functional
.
interpolate
(
input_lq
,
(
h
,
w
))
real
=
torch
.
nn
.
functional
.
interpolate
(
ground_truth
,
(
h
,
w
))
fake_concat
=
torch
.
cat
([
lq
,
output
],
dim
=
1
)
real_concat
=
torch
.
cat
([
lq
,
real
],
dim
=
1
)
else
:
fake_concat
=
torch
.
cat
([
input_lq
,
output
],
dim
=
1
)
real_concat
=
torch
.
cat
([
input_lq
,
ground_truth
],
dim
=
1
)
fake_and_real
=
torch
.
cat
([
fake_concat
,
real_concat
],
dim
=
0
)
discriminator_out
=
self
.
net_d
(
fake_and_real
)
pred_fake
,
pred_real
=
self
.
_divide_pred
(
discriminator_out
)
return
pred_fake
,
pred_real
@
staticmethod
def
_divide_pred
(
pred
):
"""
Take the prediction of fake and real images from the combined batch.
The prediction contains the intermediate outputs of multiscale GAN,
so it's usually a list
"""
if
type
(
pred
)
==
list
:
fake
=
[]
real
=
[]
for
p
in
pred
:
fake
.
append
([
tensor
[:
tensor
.
size
(
0
)
//
2
]
for
tensor
in
p
])
real
.
append
([
tensor
[
tensor
.
size
(
0
)
//
2
:]
for
tensor
in
p
])
else
:
fake
=
pred
[:
pred
.
size
(
0
)
//
2
]
real
=
pred
[
pred
.
size
(
0
)
//
2
:]
return
fake
,
real
def
optimize_parameters
(
self
,
current_iter
):
# optimize net_g
for
p
in
self
.
net_d
.
parameters
():
p
.
requires_grad
=
False
self
.
optimizer_g
.
zero_grad
()
self
.
output
=
self
.
net_g
(
self
.
lq
)
l_g_total
=
0
loss_dict
=
OrderedDict
()
if
(
current_iter
%
self
.
net_d_iters
==
0
and
current_iter
>
self
.
net_d_init_iters
):
# pixel loss
if
self
.
cri_pix
:
l_g_pix
=
self
.
cri_pix
(
self
.
output
,
self
.
gt
)
l_g_total
+=
l_g_pix
loss_dict
[
'l_g_pix'
]
=
l_g_pix
# perceptual loss
if
self
.
cri_perceptual
:
l_g_percep
,
l_g_style
=
self
.
cri_perceptual
(
self
.
output
,
self
.
gt
)
if
l_g_percep
is
not
None
:
l_g_total
+=
l_g_percep
loss_dict
[
'l_g_percep'
]
=
l_g_percep
if
l_g_style
is
not
None
:
l_g_total
+=
l_g_style
loss_dict
[
'l_g_style'
]
=
l_g_style
# Requires real prediction for feature matching loss
pred_fake
,
pred_real
=
self
.
discriminate
(
self
.
lq
,
self
.
output
,
self
.
gt
)
l_g_gan
=
self
.
cri_gan
(
pred_fake
,
True
,
is_disc
=
False
)
l_g_total
+=
l_g_gan
loss_dict
[
'l_g_gan'
]
=
l_g_gan
# feature matching loss
if
self
.
cri_feat
:
l_g_feat
=
self
.
cri_feat
(
pred_fake
,
pred_real
)
l_g_total
+=
l_g_feat
loss_dict
[
'l_g_feat'
]
=
l_g_feat
l_g_total
.
backward
()
self
.
optimizer_g
.
step
()
# optimize net_d
for
p
in
self
.
net_d
.
parameters
():
p
.
requires_grad
=
True
self
.
optimizer_d
.
zero_grad
()
# TODO: Benchmark test between HiFaceGAN and SRGAN implementation:
# SRGAN use the same fake output for discriminator update
# while HiFaceGAN regenerate a new output using updated net_g
# This should not make too much difference though. Stick to SRGAN now.
# -------------------------------------------------------------------
# ---------- Below are original HiFaceGAN code snippet --------------
# -------------------------------------------------------------------
# with torch.no_grad():
# fake_image = self.net_g(self.lq)
# fake_image = fake_image.detach()
# fake_image.requires_grad_()
# pred_fake, pred_real = self.discriminate(self.lq, fake_image, self.gt)
# real
pred_fake
,
pred_real
=
self
.
discriminate
(
self
.
lq
,
self
.
output
.
detach
(),
self
.
gt
)
l_d_real
=
self
.
cri_gan
(
pred_real
,
True
,
is_disc
=
True
)
loss_dict
[
'l_d_real'
]
=
l_d_real
# fake
l_d_fake
=
self
.
cri_gan
(
pred_fake
,
False
,
is_disc
=
True
)
loss_dict
[
'l_d_fake'
]
=
l_d_fake
l_d_total
=
(
l_d_real
+
l_d_fake
)
/
2
l_d_total
.
backward
()
self
.
optimizer_d
.
step
()
self
.
log_dict
=
self
.
reduce_loss_dict
(
loss_dict
)
if
self
.
ema_decay
>
0
:
print
(
'HiFaceGAN does not support EMA now. pass'
)
def
validation
(
self
,
dataloader
,
current_iter
,
tb_logger
,
save_img
=
False
):
"""
Warning: HiFaceGAN requires train() mode even for validation
For more info, see https://github.com/Lotayou/Face-Renovation/issues/31
Args:
dataloader (torch.utils.data.DataLoader): Validation dataloader.
current_iter (int): Current iteration.
tb_logger (tensorboard logger): Tensorboard logger.
save_img (bool): Whether to save images. Default: False.
"""
if
self
.
opt
[
'network_g'
][
'type'
]
in
(
'HiFaceGAN'
,
'SPADEGenerator'
):
self
.
net_g
.
train
()
if
self
.
opt
[
'dist'
]:
self
.
dist_validation
(
dataloader
,
current_iter
,
tb_logger
,
save_img
)
else
:
print
(
'In HiFaceGANModel: The new metrics package is under development.'
+
'Using super method now (Only PSNR & SSIM are supported)'
)
super
().
nondist_validation
(
dataloader
,
current_iter
,
tb_logger
,
save_img
)
def
nondist_validation
(
self
,
dataloader
,
current_iter
,
tb_logger
,
save_img
):
"""
TODO: Validation using updated metric system
The metrics are now evaluated after all images have been tested
This allows batch processing, and also allows evaluation of
distributional metrics, such as:
@ Frechet Inception Distance: FID
@ Maximum Mean Discrepancy: MMD
Warning:
Need careful batch management for different inference settings.
"""
dataset_name
=
dataloader
.
dataset
.
opt
[
'name'
]
with_metrics
=
self
.
opt
[
'val'
].
get
(
'metrics'
)
is
not
None
if
with_metrics
:
self
.
metric_results
=
dict
()
# {metric: 0 for metric in self.opt['val']['metrics'].keys()}
sr_tensors
=
[]
gt_tensors
=
[]
pbar
=
tqdm
(
total
=
len
(
dataloader
),
unit
=
'image'
)
for
val_data
in
dataloader
:
img_name
=
osp
.
splitext
(
osp
.
basename
(
val_data
[
'lq_path'
][
0
]))[
0
]
self
.
feed_data
(
val_data
)
self
.
test
()
visuals
=
self
.
get_current_visuals
()
# detached cpu tensor, non-squeeze
sr_tensors
.
append
(
visuals
[
'result'
])
if
'gt'
in
visuals
:
gt_tensors
.
append
(
visuals
[
'gt'
])
del
self
.
gt
# tentative for out of GPU memory
del
self
.
lq
del
self
.
output
torch
.
cuda
.
empty_cache
()
if
save_img
:
if
self
.
opt
[
'is_train'
]:
save_img_path
=
osp
.
join
(
self
.
opt
[
'path'
][
'visualization'
],
img_name
,
f
'
{
img_name
}
_
{
current_iter
}
.png'
)
else
:
if
self
.
opt
[
'val'
][
'suffix'
]:
save_img_path
=
osp
.
join
(
self
.
opt
[
'path'
][
'visualization'
],
dataset_name
,
f
'
{
img_name
}
_
{
self
.
opt
[
"val"
][
"suffix"
]
}
.png'
)
else
:
save_img_path
=
osp
.
join
(
self
.
opt
[
'path'
][
'visualization'
],
dataset_name
,
f
'
{
img_name
}
_
{
self
.
opt
[
"name"
]
}
.png'
)
imwrite
(
tensor2img
(
visuals
[
'result'
]),
save_img_path
)
pbar
.
update
(
1
)
pbar
.
set_description
(
f
'Test
{
img_name
}
'
)
pbar
.
close
()
if
with_metrics
:
sr_pack
=
torch
.
cat
(
sr_tensors
,
dim
=
0
)
gt_pack
=
torch
.
cat
(
gt_tensors
,
dim
=
0
)
# calculate metrics
for
name
,
opt_
in
self
.
opt
[
'val'
][
'metrics'
].
items
():
# The new metric caller automatically returns mean value
# FIXME: ERROR: calculate_metric only supports two arguments. Now the codes cannot be successfully run
self
.
metric_results
[
name
]
=
calculate_metric
(
dict
(
sr_pack
=
sr_pack
,
gt_pack
=
gt_pack
),
opt_
)
self
.
_log_validation_metric_values
(
current_iter
,
dataset_name
,
tb_logger
)
def
save
(
self
,
epoch
,
current_iter
):
if
hasattr
(
self
,
'net_g_ema'
):
print
(
'HiFaceGAN does not support EMA now. Fallback to normal mode.'
)
self
.
save_network
(
self
.
net_g
,
'net_g'
,
current_iter
)
self
.
save_network
(
self
.
net_d
,
'net_d'
,
current_iter
)
self
.
save_training_state
(
epoch
,
current_iter
)
BasicSR/basicsr/models/lr_scheduler.py
0 → 100644
View file @
e2696ece
import
math
from
collections
import
Counter
from
torch.optim.lr_scheduler
import
_LRScheduler
class
MultiStepRestartLR
(
_LRScheduler
):
""" MultiStep with restarts learning rate scheme.
Args:
optimizer (torch.nn.optimizer): Torch optimizer.
milestones (list): Iterations that will decrease learning rate.
gamma (float): Decrease ratio. Default: 0.1.
restarts (list): Restart iterations. Default: [0].
restart_weights (list): Restart weights at each restart iteration.
Default: [1].
last_epoch (int): Used in _LRScheduler. Default: -1.
"""
def
__init__
(
self
,
optimizer
,
milestones
,
gamma
=
0.1
,
restarts
=
(
0
,
),
restart_weights
=
(
1
,
),
last_epoch
=-
1
):
self
.
milestones
=
Counter
(
milestones
)
self
.
gamma
=
gamma
self
.
restarts
=
restarts
self
.
restart_weights
=
restart_weights
assert
len
(
self
.
restarts
)
==
len
(
self
.
restart_weights
),
'restarts and their weights do not match.'
super
(
MultiStepRestartLR
,
self
).
__init__
(
optimizer
,
last_epoch
)
def
get_lr
(
self
):
if
self
.
last_epoch
in
self
.
restarts
:
weight
=
self
.
restart_weights
[
self
.
restarts
.
index
(
self
.
last_epoch
)]
return
[
group
[
'initial_lr'
]
*
weight
for
group
in
self
.
optimizer
.
param_groups
]
if
self
.
last_epoch
not
in
self
.
milestones
:
return
[
group
[
'lr'
]
for
group
in
self
.
optimizer
.
param_groups
]
return
[
group
[
'lr'
]
*
self
.
gamma
**
self
.
milestones
[
self
.
last_epoch
]
for
group
in
self
.
optimizer
.
param_groups
]
def
get_position_from_periods
(
iteration
,
cumulative_period
):
"""Get the position from a period list.
It will return the index of the right-closest number in the period list.
For example, the cumulative_period = [100, 200, 300, 400],
if iteration == 50, return 0;
if iteration == 210, return 2;
if iteration == 300, return 2.
Args:
iteration (int): Current iteration.
cumulative_period (list[int]): Cumulative period list.
Returns:
int: The position of the right-closest number in the period list.
"""
for
i
,
period
in
enumerate
(
cumulative_period
):
if
iteration
<=
period
:
return
i
class
CosineAnnealingRestartLR
(
_LRScheduler
):
""" Cosine annealing with restarts learning rate scheme.
An example of config:
periods = [10, 10, 10, 10]
restart_weights = [1, 0.5, 0.5, 0.5]
eta_min=1e-7
It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
scheduler will restart with the weights in restart_weights.
Args:
optimizer (torch.nn.optimizer): Torch optimizer.
periods (list): Period for each cosine anneling cycle.
restart_weights (list): Restart weights at each restart iteration.
Default: [1].
eta_min (float): The minimum lr. Default: 0.
last_epoch (int): Used in _LRScheduler. Default: -1.
"""
def
__init__
(
self
,
optimizer
,
periods
,
restart_weights
=
(
1
,
),
eta_min
=
0
,
last_epoch
=-
1
):
self
.
periods
=
periods
self
.
restart_weights
=
restart_weights
self
.
eta_min
=
eta_min
assert
(
len
(
self
.
periods
)
==
len
(
self
.
restart_weights
)),
'periods and restart_weights should have the same length.'
self
.
cumulative_period
=
[
sum
(
self
.
periods
[
0
:
i
+
1
])
for
i
in
range
(
0
,
len
(
self
.
periods
))]
super
(
CosineAnnealingRestartLR
,
self
).
__init__
(
optimizer
,
last_epoch
)
def
get_lr
(
self
):
idx
=
get_position_from_periods
(
self
.
last_epoch
,
self
.
cumulative_period
)
current_weight
=
self
.
restart_weights
[
idx
]
nearest_restart
=
0
if
idx
==
0
else
self
.
cumulative_period
[
idx
-
1
]
current_period
=
self
.
periods
[
idx
]
return
[
self
.
eta_min
+
current_weight
*
0.5
*
(
base_lr
-
self
.
eta_min
)
*
(
1
+
math
.
cos
(
math
.
pi
*
((
self
.
last_epoch
-
nearest_restart
)
/
current_period
)))
for
base_lr
in
self
.
base_lrs
]
BasicSR/basicsr/models/realesrgan_model.py
0 → 100644
View file @
e2696ece
import
numpy
as
np
import
random
import
torch
from
collections
import
OrderedDict
from
torch.nn
import
functional
as
F
from
basicsr.data.degradations
import
random_add_gaussian_noise_pt
,
random_add_poisson_noise_pt
from
basicsr.data.transforms
import
paired_random_crop
from
basicsr.losses.loss_util
import
get_refined_artifact_map
from
basicsr.models.srgan_model
import
SRGANModel
from
basicsr.utils
import
DiffJPEG
,
USMSharp
from
basicsr.utils.img_process_util
import
filter2D
from
basicsr.utils.registry
import
MODEL_REGISTRY
@
MODEL_REGISTRY
.
register
(
suffix
=
'basicsr'
)
class
RealESRGANModel
(
SRGANModel
):
"""RealESRGAN Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
It mainly performs:
1. randomly synthesize LQ images in GPU tensors
2. optimize the networks with GAN training.
"""
def
__init__
(
self
,
opt
):
super
(
RealESRGANModel
,
self
).
__init__
(
opt
)
self
.
jpeger
=
DiffJPEG
(
differentiable
=
False
).
cuda
()
# simulate JPEG compression artifacts
self
.
usm_sharpener
=
USMSharp
().
cuda
()
# do usm sharpening
self
.
queue_size
=
opt
.
get
(
'queue_size'
,
180
)
@
torch
.
no_grad
()
def
_dequeue_and_enqueue
(
self
):
"""It is the training pair pool for increasing the diversity in a batch.
Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
batch could not have different resize scaling factors. Therefore, we employ this training pair pool
to increase the degradation diversity in a batch.
"""
# initialize
b
,
c
,
h
,
w
=
self
.
lq
.
size
()
if
not
hasattr
(
self
,
'queue_lr'
):
assert
self
.
queue_size
%
b
==
0
,
f
'queue size
{
self
.
queue_size
}
should be divisible by batch size
{
b
}
'
self
.
queue_lr
=
torch
.
zeros
(
self
.
queue_size
,
c
,
h
,
w
).
cuda
()
_
,
c
,
h
,
w
=
self
.
gt
.
size
()
self
.
queue_gt
=
torch
.
zeros
(
self
.
queue_size
,
c
,
h
,
w
).
cuda
()
self
.
queue_ptr
=
0
if
self
.
queue_ptr
==
self
.
queue_size
:
# the pool is full
# do dequeue and enqueue
# shuffle
idx
=
torch
.
randperm
(
self
.
queue_size
)
self
.
queue_lr
=
self
.
queue_lr
[
idx
]
self
.
queue_gt
=
self
.
queue_gt
[
idx
]
# get first b samples
lq_dequeue
=
self
.
queue_lr
[
0
:
b
,
:,
:,
:].
clone
()
gt_dequeue
=
self
.
queue_gt
[
0
:
b
,
:,
:,
:].
clone
()
# update the queue
self
.
queue_lr
[
0
:
b
,
:,
:,
:]
=
self
.
lq
.
clone
()
self
.
queue_gt
[
0
:
b
,
:,
:,
:]
=
self
.
gt
.
clone
()
self
.
lq
=
lq_dequeue
self
.
gt
=
gt_dequeue
else
:
# only do enqueue
self
.
queue_lr
[
self
.
queue_ptr
:
self
.
queue_ptr
+
b
,
:,
:,
:]
=
self
.
lq
.
clone
()
self
.
queue_gt
[
self
.
queue_ptr
:
self
.
queue_ptr
+
b
,
:,
:,
:]
=
self
.
gt
.
clone
()
self
.
queue_ptr
=
self
.
queue_ptr
+
b
@
torch
.
no_grad
()
def
feed_data
(
self
,
data
):
"""Accept data from dataloader, and then add two-order degradations to obtain LQ images.
"""
if
self
.
is_train
and
self
.
opt
.
get
(
'high_order_degradation'
,
True
):
# training data synthesis
self
.
gt
=
data
[
'gt'
].
to
(
self
.
device
)
self
.
gt_usm
=
self
.
usm_sharpener
(
self
.
gt
)
self
.
kernel1
=
data
[
'kernel1'
].
to
(
self
.
device
)
self
.
kernel2
=
data
[
'kernel2'
].
to
(
self
.
device
)
self
.
sinc_kernel
=
data
[
'sinc_kernel'
].
to
(
self
.
device
)
ori_h
,
ori_w
=
self
.
gt
.
size
()[
2
:
4
]
# ----------------------- The first degradation process ----------------------- #
# blur
out
=
filter2D
(
self
.
gt_usm
,
self
.
kernel1
)
# random resize
updown_type
=
random
.
choices
([
'up'
,
'down'
,
'keep'
],
self
.
opt
[
'resize_prob'
])[
0
]
if
updown_type
==
'up'
:
scale
=
np
.
random
.
uniform
(
1
,
self
.
opt
[
'resize_range'
][
1
])
elif
updown_type
==
'down'
:
scale
=
np
.
random
.
uniform
(
self
.
opt
[
'resize_range'
][
0
],
1
)
else
:
scale
=
1
mode
=
random
.
choice
([
'area'
,
'bilinear'
,
'bicubic'
])
out
=
F
.
interpolate
(
out
,
scale_factor
=
scale
,
mode
=
mode
)
# add noise
gray_noise_prob
=
self
.
opt
[
'gray_noise_prob'
]
if
np
.
random
.
uniform
()
<
self
.
opt
[
'gaussian_noise_prob'
]:
out
=
random_add_gaussian_noise_pt
(
out
,
sigma_range
=
self
.
opt
[
'noise_range'
],
clip
=
True
,
rounds
=
False
,
gray_prob
=
gray_noise_prob
)
else
:
out
=
random_add_poisson_noise_pt
(
out
,
scale_range
=
self
.
opt
[
'poisson_scale_range'
],
gray_prob
=
gray_noise_prob
,
clip
=
True
,
rounds
=
False
)
# JPEG compression
jpeg_p
=
out
.
new_zeros
(
out
.
size
(
0
)).
uniform_
(
*
self
.
opt
[
'jpeg_range'
])
out
=
torch
.
clamp
(
out
,
0
,
1
)
# clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
out
=
self
.
jpeger
(
out
,
quality
=
jpeg_p
)
# ----------------------- The second degradation process ----------------------- #
# blur
if
np
.
random
.
uniform
()
<
self
.
opt
[
'second_blur_prob'
]:
out
=
filter2D
(
out
,
self
.
kernel2
)
# random resize
updown_type
=
random
.
choices
([
'up'
,
'down'
,
'keep'
],
self
.
opt
[
'resize_prob2'
])[
0
]
if
updown_type
==
'up'
:
scale
=
np
.
random
.
uniform
(
1
,
self
.
opt
[
'resize_range2'
][
1
])
elif
updown_type
==
'down'
:
scale
=
np
.
random
.
uniform
(
self
.
opt
[
'resize_range2'
][
0
],
1
)
else
:
scale
=
1
mode
=
random
.
choice
([
'area'
,
'bilinear'
,
'bicubic'
])
out
=
F
.
interpolate
(
out
,
size
=
(
int
(
ori_h
/
self
.
opt
[
'scale'
]
*
scale
),
int
(
ori_w
/
self
.
opt
[
'scale'
]
*
scale
)),
mode
=
mode
)
# add noise
gray_noise_prob
=
self
.
opt
[
'gray_noise_prob2'
]
if
np
.
random
.
uniform
()
<
self
.
opt
[
'gaussian_noise_prob2'
]:
out
=
random_add_gaussian_noise_pt
(
out
,
sigma_range
=
self
.
opt
[
'noise_range2'
],
clip
=
True
,
rounds
=
False
,
gray_prob
=
gray_noise_prob
)
else
:
out
=
random_add_poisson_noise_pt
(
out
,
scale_range
=
self
.
opt
[
'poisson_scale_range2'
],
gray_prob
=
gray_noise_prob
,
clip
=
True
,
rounds
=
False
)
# JPEG compression + the final sinc filter
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
# as one operation.
# We consider two orders:
# 1. [resize back + sinc filter] + JPEG compression
# 2. JPEG compression + [resize back + sinc filter]
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
if
np
.
random
.
uniform
()
<
0.5
:
# resize back + the final sinc filter
mode
=
random
.
choice
([
'area'
,
'bilinear'
,
'bicubic'
])
out
=
F
.
interpolate
(
out
,
size
=
(
ori_h
//
self
.
opt
[
'scale'
],
ori_w
//
self
.
opt
[
'scale'
]),
mode
=
mode
)
out
=
filter2D
(
out
,
self
.
sinc_kernel
)
# JPEG compression
jpeg_p
=
out
.
new_zeros
(
out
.
size
(
0
)).
uniform_
(
*
self
.
opt
[
'jpeg_range2'
])
out
=
torch
.
clamp
(
out
,
0
,
1
)
out
=
self
.
jpeger
(
out
,
quality
=
jpeg_p
)
else
:
# JPEG compression
jpeg_p
=
out
.
new_zeros
(
out
.
size
(
0
)).
uniform_
(
*
self
.
opt
[
'jpeg_range2'
])
out
=
torch
.
clamp
(
out
,
0
,
1
)
out
=
self
.
jpeger
(
out
,
quality
=
jpeg_p
)
# resize back + the final sinc filter
mode
=
random
.
choice
([
'area'
,
'bilinear'
,
'bicubic'
])
out
=
F
.
interpolate
(
out
,
size
=
(
ori_h
//
self
.
opt
[
'scale'
],
ori_w
//
self
.
opt
[
'scale'
]),
mode
=
mode
)
out
=
filter2D
(
out
,
self
.
sinc_kernel
)
# clamp and round
self
.
lq
=
torch
.
clamp
((
out
*
255.0
).
round
(),
0
,
255
)
/
255.
# random crop
gt_size
=
self
.
opt
[
'gt_size'
]
(
self
.
gt
,
self
.
gt_usm
),
self
.
lq
=
paired_random_crop
([
self
.
gt
,
self
.
gt_usm
],
self
.
lq
,
gt_size
,
self
.
opt
[
'scale'
])
# training pair pool
self
.
_dequeue_and_enqueue
()
# sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
self
.
gt_usm
=
self
.
usm_sharpener
(
self
.
gt
)
self
.
lq
=
self
.
lq
.
contiguous
()
# for the warning: grad and param do not obey the gradient layout contract
else
:
# for paired training or validation
self
.
lq
=
data
[
'lq'
].
to
(
self
.
device
)
if
'gt'
in
data
:
self
.
gt
=
data
[
'gt'
].
to
(
self
.
device
)
self
.
gt_usm
=
self
.
usm_sharpener
(
self
.
gt
)
def
nondist_validation
(
self
,
dataloader
,
current_iter
,
tb_logger
,
save_img
):
# do not use the synthetic process during validation
self
.
is_train
=
False
super
(
RealESRGANModel
,
self
).
nondist_validation
(
dataloader
,
current_iter
,
tb_logger
,
save_img
)
self
.
is_train
=
True
def
optimize_parameters
(
self
,
current_iter
):
# usm sharpening
l1_gt
=
self
.
gt_usm
percep_gt
=
self
.
gt_usm
gan_gt
=
self
.
gt_usm
if
self
.
opt
[
'l1_gt_usm'
]
is
False
:
l1_gt
=
self
.
gt
if
self
.
opt
[
'percep_gt_usm'
]
is
False
:
percep_gt
=
self
.
gt
if
self
.
opt
[
'gan_gt_usm'
]
is
False
:
gan_gt
=
self
.
gt
# optimize net_g
for
p
in
self
.
net_d
.
parameters
():
p
.
requires_grad
=
False
self
.
optimizer_g
.
zero_grad
()
self
.
output
=
self
.
net_g
(
self
.
lq
)
if
self
.
cri_ldl
:
self
.
output_ema
=
self
.
net_g_ema
(
self
.
lq
)
l_g_total
=
0
loss_dict
=
OrderedDict
()
if
(
current_iter
%
self
.
net_d_iters
==
0
and
current_iter
>
self
.
net_d_init_iters
):
# pixel loss
if
self
.
cri_pix
:
l_g_pix
=
self
.
cri_pix
(
self
.
output
,
l1_gt
)
l_g_total
+=
l_g_pix
loss_dict
[
'l_g_pix'
]
=
l_g_pix
if
self
.
cri_ldl
:
pixel_weight
=
get_refined_artifact_map
(
self
.
gt
,
self
.
output
,
self
.
output_ema
,
7
)
l_g_ldl
=
self
.
cri_ldl
(
torch
.
mul
(
pixel_weight
,
self
.
output
),
torch
.
mul
(
pixel_weight
,
self
.
gt
))
l_g_total
+=
l_g_ldl
loss_dict
[
'l_g_ldl'
]
=
l_g_ldl
# perceptual loss
if
self
.
cri_perceptual
:
l_g_percep
,
l_g_style
=
self
.
cri_perceptual
(
self
.
output
,
percep_gt
)
if
l_g_percep
is
not
None
:
l_g_total
+=
l_g_percep
loss_dict
[
'l_g_percep'
]
=
l_g_percep
if
l_g_style
is
not
None
:
l_g_total
+=
l_g_style
loss_dict
[
'l_g_style'
]
=
l_g_style
# gan loss
fake_g_pred
=
self
.
net_d
(
self
.
output
)
l_g_gan
=
self
.
cri_gan
(
fake_g_pred
,
True
,
is_disc
=
False
)
l_g_total
+=
l_g_gan
loss_dict
[
'l_g_gan'
]
=
l_g_gan
l_g_total
.
backward
()
self
.
optimizer_g
.
step
()
# optimize net_d
for
p
in
self
.
net_d
.
parameters
():
p
.
requires_grad
=
True
self
.
optimizer_d
.
zero_grad
()
# real
real_d_pred
=
self
.
net_d
(
gan_gt
)
l_d_real
=
self
.
cri_gan
(
real_d_pred
,
True
,
is_disc
=
True
)
loss_dict
[
'l_d_real'
]
=
l_d_real
loss_dict
[
'out_d_real'
]
=
torch
.
mean
(
real_d_pred
.
detach
())
l_d_real
.
backward
()
# fake
fake_d_pred
=
self
.
net_d
(
self
.
output
.
detach
().
clone
())
# clone for pt1.9
l_d_fake
=
self
.
cri_gan
(
fake_d_pred
,
False
,
is_disc
=
True
)
loss_dict
[
'l_d_fake'
]
=
l_d_fake
loss_dict
[
'out_d_fake'
]
=
torch
.
mean
(
fake_d_pred
.
detach
())
l_d_fake
.
backward
()
self
.
optimizer_d
.
step
()
if
self
.
ema_decay
>
0
:
self
.
model_ema
(
decay
=
self
.
ema_decay
)
self
.
log_dict
=
self
.
reduce_loss_dict
(
loss_dict
)
BasicSR/basicsr/models/realesrnet_model.py
0 → 100644
View file @
e2696ece
import
numpy
as
np
import
random
import
torch
from
torch.nn
import
functional
as
F
from
basicsr.data.degradations
import
random_add_gaussian_noise_pt
,
random_add_poisson_noise_pt
from
basicsr.data.transforms
import
paired_random_crop
from
basicsr.models.sr_model
import
SRModel
from
basicsr.utils
import
DiffJPEG
,
USMSharp
from
basicsr.utils.img_process_util
import
filter2D
from
basicsr.utils.registry
import
MODEL_REGISTRY
@
MODEL_REGISTRY
.
register
(
suffix
=
'basicsr'
)
class
RealESRNetModel
(
SRModel
):
"""RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
It is trained without GAN losses.
It mainly performs:
1. randomly synthesize LQ images in GPU tensors
2. optimize the networks with GAN training.
"""
def
__init__
(
self
,
opt
):
super
(
RealESRNetModel
,
self
).
__init__
(
opt
)
self
.
jpeger
=
DiffJPEG
(
differentiable
=
False
).
cuda
()
# simulate JPEG compression artifacts
self
.
usm_sharpener
=
USMSharp
().
cuda
()
# do usm sharpening
self
.
queue_size
=
opt
.
get
(
'queue_size'
,
180
)
@
torch
.
no_grad
()
def
_dequeue_and_enqueue
(
self
):
"""It is the training pair pool for increasing the diversity in a batch.
Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
batch could not have different resize scaling factors. Therefore, we employ this training pair pool
to increase the degradation diversity in a batch.
"""
# initialize
b
,
c
,
h
,
w
=
self
.
lq
.
size
()
if
not
hasattr
(
self
,
'queue_lr'
):
assert
self
.
queue_size
%
b
==
0
,
f
'queue size
{
self
.
queue_size
}
should be divisible by batch size
{
b
}
'
self
.
queue_lr
=
torch
.
zeros
(
self
.
queue_size
,
c
,
h
,
w
).
cuda
()
_
,
c
,
h
,
w
=
self
.
gt
.
size
()
self
.
queue_gt
=
torch
.
zeros
(
self
.
queue_size
,
c
,
h
,
w
).
cuda
()
self
.
queue_ptr
=
0
if
self
.
queue_ptr
==
self
.
queue_size
:
# the pool is full
# do dequeue and enqueue
# shuffle
idx
=
torch
.
randperm
(
self
.
queue_size
)
self
.
queue_lr
=
self
.
queue_lr
[
idx
]
self
.
queue_gt
=
self
.
queue_gt
[
idx
]
# get first b samples
lq_dequeue
=
self
.
queue_lr
[
0
:
b
,
:,
:,
:].
clone
()
gt_dequeue
=
self
.
queue_gt
[
0
:
b
,
:,
:,
:].
clone
()
# update the queue
self
.
queue_lr
[
0
:
b
,
:,
:,
:]
=
self
.
lq
.
clone
()
self
.
queue_gt
[
0
:
b
,
:,
:,
:]
=
self
.
gt
.
clone
()
self
.
lq
=
lq_dequeue
self
.
gt
=
gt_dequeue
else
:
# only do enqueue
self
.
queue_lr
[
self
.
queue_ptr
:
self
.
queue_ptr
+
b
,
:,
:,
:]
=
self
.
lq
.
clone
()
self
.
queue_gt
[
self
.
queue_ptr
:
self
.
queue_ptr
+
b
,
:,
:,
:]
=
self
.
gt
.
clone
()
self
.
queue_ptr
=
self
.
queue_ptr
+
b
@
torch
.
no_grad
()
def
feed_data
(
self
,
data
):
"""Accept data from dataloader, and then add two-order degradations to obtain LQ images.
"""
if
self
.
is_train
and
self
.
opt
.
get
(
'high_order_degradation'
,
True
):
# training data synthesis
self
.
gt
=
data
[
'gt'
].
to
(
self
.
device
)
# USM sharpen the GT images
if
self
.
opt
[
'gt_usm'
]
is
True
:
self
.
gt
=
self
.
usm_sharpener
(
self
.
gt
)
self
.
kernel1
=
data
[
'kernel1'
].
to
(
self
.
device
)
self
.
kernel2
=
data
[
'kernel2'
].
to
(
self
.
device
)
self
.
sinc_kernel
=
data
[
'sinc_kernel'
].
to
(
self
.
device
)
ori_h
,
ori_w
=
self
.
gt
.
size
()[
2
:
4
]
# ----------------------- The first degradation process ----------------------- #
# blur
out
=
filter2D
(
self
.
gt
,
self
.
kernel1
)
# random resize
updown_type
=
random
.
choices
([
'up'
,
'down'
,
'keep'
],
self
.
opt
[
'resize_prob'
])[
0
]
if
updown_type
==
'up'
:
scale
=
np
.
random
.
uniform
(
1
,
self
.
opt
[
'resize_range'
][
1
])
elif
updown_type
==
'down'
:
scale
=
np
.
random
.
uniform
(
self
.
opt
[
'resize_range'
][
0
],
1
)
else
:
scale
=
1
mode
=
random
.
choice
([
'area'
,
'bilinear'
,
'bicubic'
])
out
=
F
.
interpolate
(
out
,
scale_factor
=
scale
,
mode
=
mode
)
# add noise
gray_noise_prob
=
self
.
opt
[
'gray_noise_prob'
]
if
np
.
random
.
uniform
()
<
self
.
opt
[
'gaussian_noise_prob'
]:
out
=
random_add_gaussian_noise_pt
(
out
,
sigma_range
=
self
.
opt
[
'noise_range'
],
clip
=
True
,
rounds
=
False
,
gray_prob
=
gray_noise_prob
)
else
:
out
=
random_add_poisson_noise_pt
(
out
,
scale_range
=
self
.
opt
[
'poisson_scale_range'
],
gray_prob
=
gray_noise_prob
,
clip
=
True
,
rounds
=
False
)
# JPEG compression
jpeg_p
=
out
.
new_zeros
(
out
.
size
(
0
)).
uniform_
(
*
self
.
opt
[
'jpeg_range'
])
out
=
torch
.
clamp
(
out
,
0
,
1
)
# clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
out
=
self
.
jpeger
(
out
,
quality
=
jpeg_p
)
# ----------------------- The second degradation process ----------------------- #
# blur
if
np
.
random
.
uniform
()
<
self
.
opt
[
'second_blur_prob'
]:
out
=
filter2D
(
out
,
self
.
kernel2
)
# random resize
updown_type
=
random
.
choices
([
'up'
,
'down'
,
'keep'
],
self
.
opt
[
'resize_prob2'
])[
0
]
if
updown_type
==
'up'
:
scale
=
np
.
random
.
uniform
(
1
,
self
.
opt
[
'resize_range2'
][
1
])
elif
updown_type
==
'down'
:
scale
=
np
.
random
.
uniform
(
self
.
opt
[
'resize_range2'
][
0
],
1
)
else
:
scale
=
1
mode
=
random
.
choice
([
'area'
,
'bilinear'
,
'bicubic'
])
out
=
F
.
interpolate
(
out
,
size
=
(
int
(
ori_h
/
self
.
opt
[
'scale'
]
*
scale
),
int
(
ori_w
/
self
.
opt
[
'scale'
]
*
scale
)),
mode
=
mode
)
# add noise
gray_noise_prob
=
self
.
opt
[
'gray_noise_prob2'
]
if
np
.
random
.
uniform
()
<
self
.
opt
[
'gaussian_noise_prob2'
]:
out
=
random_add_gaussian_noise_pt
(
out
,
sigma_range
=
self
.
opt
[
'noise_range2'
],
clip
=
True
,
rounds
=
False
,
gray_prob
=
gray_noise_prob
)
else
:
out
=
random_add_poisson_noise_pt
(
out
,
scale_range
=
self
.
opt
[
'poisson_scale_range2'
],
gray_prob
=
gray_noise_prob
,
clip
=
True
,
rounds
=
False
)
# JPEG compression + the final sinc filter
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
# as one operation.
# We consider two orders:
# 1. [resize back + sinc filter] + JPEG compression
# 2. JPEG compression + [resize back + sinc filter]
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
if
np
.
random
.
uniform
()
<
0.5
:
# resize back + the final sinc filter
mode
=
random
.
choice
([
'area'
,
'bilinear'
,
'bicubic'
])
out
=
F
.
interpolate
(
out
,
size
=
(
ori_h
//
self
.
opt
[
'scale'
],
ori_w
//
self
.
opt
[
'scale'
]),
mode
=
mode
)
out
=
filter2D
(
out
,
self
.
sinc_kernel
)
# JPEG compression
jpeg_p
=
out
.
new_zeros
(
out
.
size
(
0
)).
uniform_
(
*
self
.
opt
[
'jpeg_range2'
])
out
=
torch
.
clamp
(
out
,
0
,
1
)
out
=
self
.
jpeger
(
out
,
quality
=
jpeg_p
)
else
:
# JPEG compression
jpeg_p
=
out
.
new_zeros
(
out
.
size
(
0
)).
uniform_
(
*
self
.
opt
[
'jpeg_range2'
])
out
=
torch
.
clamp
(
out
,
0
,
1
)
out
=
self
.
jpeger
(
out
,
quality
=
jpeg_p
)
# resize back + the final sinc filter
mode
=
random
.
choice
([
'area'
,
'bilinear'
,
'bicubic'
])
out
=
F
.
interpolate
(
out
,
size
=
(
ori_h
//
self
.
opt
[
'scale'
],
ori_w
//
self
.
opt
[
'scale'
]),
mode
=
mode
)
out
=
filter2D
(
out
,
self
.
sinc_kernel
)
# clamp and round
self
.
lq
=
torch
.
clamp
((
out
*
255.0
).
round
(),
0
,
255
)
/
255.
# random crop
gt_size
=
self
.
opt
[
'gt_size'
]
self
.
gt
,
self
.
lq
=
paired_random_crop
(
self
.
gt
,
self
.
lq
,
gt_size
,
self
.
opt
[
'scale'
])
# training pair pool
self
.
_dequeue_and_enqueue
()
self
.
lq
=
self
.
lq
.
contiguous
()
# for the warning: grad and param do not obey the gradient layout contract
else
:
# for paired training or validation
self
.
lq
=
data
[
'lq'
].
to
(
self
.
device
)
if
'gt'
in
data
:
self
.
gt
=
data
[
'gt'
].
to
(
self
.
device
)
self
.
gt_usm
=
self
.
usm_sharpener
(
self
.
gt
)
def
nondist_validation
(
self
,
dataloader
,
current_iter
,
tb_logger
,
save_img
):
# do not use the synthetic process during validation
self
.
is_train
=
False
super
(
RealESRNetModel
,
self
).
nondist_validation
(
dataloader
,
current_iter
,
tb_logger
,
save_img
)
self
.
is_train
=
True
BasicSR/basicsr/models/sr_model.py
0 → 100644
View file @
e2696ece
import
torch
from
collections
import
OrderedDict
from
os
import
path
as
osp
from
tqdm
import
tqdm
from
basicsr.archs
import
build_network
from
basicsr.losses
import
build_loss
from
basicsr.metrics
import
calculate_metric
from
basicsr.utils
import
get_root_logger
,
imwrite
,
tensor2img
from
basicsr.utils.registry
import
MODEL_REGISTRY
from
.base_model
import
BaseModel
@
MODEL_REGISTRY
.
register
()
class
SRModel
(
BaseModel
):
"""Base SR model for single image super-resolution."""
def
__init__
(
self
,
opt
):
super
(
SRModel
,
self
).
__init__
(
opt
)
# define network
self
.
net_g
=
build_network
(
opt
[
'network_g'
])
self
.
net_g
=
self
.
model_to_device
(
self
.
net_g
)
self
.
print_network
(
self
.
net_g
)
# load pretrained models
load_path
=
self
.
opt
[
'path'
].
get
(
'pretrain_network_g'
,
None
)
if
load_path
is
not
None
:
param_key
=
self
.
opt
[
'path'
].
get
(
'param_key_g'
,
'params'
)
self
.
load_network
(
self
.
net_g
,
load_path
,
self
.
opt
[
'path'
].
get
(
'strict_load_g'
,
True
),
param_key
)
if
self
.
is_train
:
self
.
init_training_settings
()
def
init_training_settings
(
self
):
self
.
net_g
.
train
()
train_opt
=
self
.
opt
[
'train'
]
self
.
ema_decay
=
train_opt
.
get
(
'ema_decay'
,
0
)
if
self
.
ema_decay
>
0
:
logger
=
get_root_logger
()
logger
.
info
(
f
'Use Exponential Moving Average with decay:
{
self
.
ema_decay
}
'
)
# define network net_g with Exponential Moving Average (EMA)
# net_g_ema is used only for testing on one GPU and saving
# There is no need to wrap with DistributedDataParallel
self
.
net_g_ema
=
build_network
(
self
.
opt
[
'network_g'
]).
to
(
self
.
device
)
# load pretrained model
load_path
=
self
.
opt
[
'path'
].
get
(
'pretrain_network_g'
,
None
)
if
load_path
is
not
None
:
self
.
load_network
(
self
.
net_g_ema
,
load_path
,
self
.
opt
[
'path'
].
get
(
'strict_load_g'
,
True
),
'params_ema'
)
else
:
self
.
model_ema
(
0
)
# copy net_g weight
self
.
net_g_ema
.
eval
()
# define losses
if
train_opt
.
get
(
'pixel_opt'
):
self
.
cri_pix
=
build_loss
(
train_opt
[
'pixel_opt'
]).
to
(
self
.
device
)
else
:
self
.
cri_pix
=
None
if
train_opt
.
get
(
'perceptual_opt'
):
self
.
cri_perceptual
=
build_loss
(
train_opt
[
'perceptual_opt'
]).
to
(
self
.
device
)
else
:
self
.
cri_perceptual
=
None
if
self
.
cri_pix
is
None
and
self
.
cri_perceptual
is
None
:
raise
ValueError
(
'Both pixel and perceptual losses are None.'
)
# set up optimizers and schedulers
self
.
setup_optimizers
()
self
.
setup_schedulers
()
def
setup_optimizers
(
self
):
train_opt
=
self
.
opt
[
'train'
]
optim_params
=
[]
for
k
,
v
in
self
.
net_g
.
named_parameters
():
if
v
.
requires_grad
:
optim_params
.
append
(
v
)
else
:
logger
=
get_root_logger
()
logger
.
warning
(
f
'Params
{
k
}
will not be optimized.'
)
optim_type
=
train_opt
[
'optim_g'
].
pop
(
'type'
)
self
.
optimizer_g
=
self
.
get_optimizer
(
optim_type
,
optim_params
,
**
train_opt
[
'optim_g'
])
self
.
optimizers
.
append
(
self
.
optimizer_g
)
def
feed_data
(
self
,
data
):
self
.
lq
=
data
[
'lq'
].
to
(
self
.
device
)
if
'gt'
in
data
:
self
.
gt
=
data
[
'gt'
].
to
(
self
.
device
)
def
optimize_parameters
(
self
,
current_iter
):
self
.
optimizer_g
.
zero_grad
()
self
.
output
=
self
.
net_g
(
self
.
lq
)
l_total
=
0
loss_dict
=
OrderedDict
()
# pixel loss
if
self
.
cri_pix
:
l_pix
=
self
.
cri_pix
(
self
.
output
,
self
.
gt
)
l_total
+=
l_pix
loss_dict
[
'l_pix'
]
=
l_pix
# perceptual loss
if
self
.
cri_perceptual
:
l_percep
,
l_style
=
self
.
cri_perceptual
(
self
.
output
,
self
.
gt
)
if
l_percep
is
not
None
:
l_total
+=
l_percep
loss_dict
[
'l_percep'
]
=
l_percep
if
l_style
is
not
None
:
l_total
+=
l_style
loss_dict
[
'l_style'
]
=
l_style
l_total
.
backward
()
self
.
optimizer_g
.
step
()
self
.
log_dict
=
self
.
reduce_loss_dict
(
loss_dict
)
if
self
.
ema_decay
>
0
:
self
.
model_ema
(
decay
=
self
.
ema_decay
)
def
test
(
self
):
if
hasattr
(
self
,
'net_g_ema'
):
self
.
net_g_ema
.
eval
()
with
torch
.
no_grad
():
self
.
output
=
self
.
net_g_ema
(
self
.
lq
)
else
:
self
.
net_g
.
eval
()
with
torch
.
no_grad
():
self
.
output
=
self
.
net_g
(
self
.
lq
)
self
.
net_g
.
train
()
def
test_selfensemble
(
self
):
# TODO: to be tested
# 8 augmentations
# modified from https://github.com/thstkdgus35/EDSR-PyTorch
def
_transform
(
v
,
op
):
# if self.precision != 'single': v = v.float()
v2np
=
v
.
data
.
cpu
().
numpy
()
if
op
==
'v'
:
tfnp
=
v2np
[:,
:,
:,
::
-
1
].
copy
()
elif
op
==
'h'
:
tfnp
=
v2np
[:,
:,
::
-
1
,
:].
copy
()
elif
op
==
't'
:
tfnp
=
v2np
.
transpose
((
0
,
1
,
3
,
2
)).
copy
()
ret
=
torch
.
Tensor
(
tfnp
).
to
(
self
.
device
)
# if self.precision == 'half': ret = ret.half()
return
ret
# prepare augmented data
lq_list
=
[
self
.
lq
]
for
tf
in
'v'
,
'h'
,
't'
:
lq_list
.
extend
([
_transform
(
t
,
tf
)
for
t
in
lq_list
])
# inference
if
hasattr
(
self
,
'net_g_ema'
):
self
.
net_g_ema
.
eval
()
with
torch
.
no_grad
():
out_list
=
[
self
.
net_g_ema
(
aug
)
for
aug
in
lq_list
]
else
:
self
.
net_g
.
eval
()
with
torch
.
no_grad
():
out_list
=
[
self
.
net_g_ema
(
aug
)
for
aug
in
lq_list
]
self
.
net_g
.
train
()
# merge results
for
i
in
range
(
len
(
out_list
)):
if
i
>
3
:
out_list
[
i
]
=
_transform
(
out_list
[
i
],
't'
)
if
i
%
4
>
1
:
out_list
[
i
]
=
_transform
(
out_list
[
i
],
'h'
)
if
(
i
%
4
)
%
2
==
1
:
out_list
[
i
]
=
_transform
(
out_list
[
i
],
'v'
)
output
=
torch
.
cat
(
out_list
,
dim
=
0
)
self
.
output
=
output
.
mean
(
dim
=
0
,
keepdim
=
True
)
def
dist_validation
(
self
,
dataloader
,
current_iter
,
tb_logger
,
save_img
):
if
self
.
opt
[
'rank'
]
==
0
:
self
.
nondist_validation
(
dataloader
,
current_iter
,
tb_logger
,
save_img
)
def
nondist_validation
(
self
,
dataloader
,
current_iter
,
tb_logger
,
save_img
):
dataset_name
=
dataloader
.
dataset
.
opt
[
'name'
]
with_metrics
=
self
.
opt
[
'val'
].
get
(
'metrics'
)
is
not
None
use_pbar
=
self
.
opt
[
'val'
].
get
(
'pbar'
,
False
)
if
with_metrics
:
if
not
hasattr
(
self
,
'metric_results'
):
# only execute in the first run
self
.
metric_results
=
{
metric
:
0
for
metric
in
self
.
opt
[
'val'
][
'metrics'
].
keys
()}
# initialize the best metric results for each dataset_name (supporting multiple validation datasets)
self
.
_initialize_best_metric_results
(
dataset_name
)
# zero self.metric_results
if
with_metrics
:
self
.
metric_results
=
{
metric
:
0
for
metric
in
self
.
metric_results
}
metric_data
=
dict
()
if
use_pbar
:
pbar
=
tqdm
(
total
=
len
(
dataloader
),
unit
=
'image'
)
for
idx
,
val_data
in
enumerate
(
dataloader
):
img_name
=
osp
.
splitext
(
osp
.
basename
(
val_data
[
'lq_path'
][
0
]))[
0
]
self
.
feed_data
(
val_data
)
self
.
test
()
visuals
=
self
.
get_current_visuals
()
sr_img
=
tensor2img
([
visuals
[
'result'
]])
metric_data
[
'img'
]
=
sr_img
if
'gt'
in
visuals
:
gt_img
=
tensor2img
([
visuals
[
'gt'
]])
metric_data
[
'img2'
]
=
gt_img
del
self
.
gt
# tentative for out of GPU memory
del
self
.
lq
del
self
.
output
torch
.
cuda
.
empty_cache
()
if
save_img
:
if
self
.
opt
[
'is_train'
]:
save_img_path
=
osp
.
join
(
self
.
opt
[
'path'
][
'visualization'
],
img_name
,
f
'
{
img_name
}
_
{
current_iter
}
.png'
)
else
:
if
self
.
opt
[
'val'
][
'suffix'
]:
save_img_path
=
osp
.
join
(
self
.
opt
[
'path'
][
'visualization'
],
dataset_name
,
f
'
{
img_name
}
_
{
self
.
opt
[
"val"
][
"suffix"
]
}
.png'
)
else
:
save_img_path
=
osp
.
join
(
self
.
opt
[
'path'
][
'visualization'
],
dataset_name
,
f
'
{
img_name
}
_
{
self
.
opt
[
"name"
]
}
.png'
)
imwrite
(
sr_img
,
save_img_path
)
if
with_metrics
:
# calculate metrics
for
name
,
opt_
in
self
.
opt
[
'val'
][
'metrics'
].
items
():
self
.
metric_results
[
name
]
+=
calculate_metric
(
metric_data
,
opt_
)
if
use_pbar
:
pbar
.
update
(
1
)
pbar
.
set_description
(
f
'Test
{
img_name
}
'
)
if
use_pbar
:
pbar
.
close
()
if
with_metrics
:
for
metric
in
self
.
metric_results
.
keys
():
self
.
metric_results
[
metric
]
/=
(
idx
+
1
)
# update the best metric result
self
.
_update_best_metric_result
(
dataset_name
,
metric
,
self
.
metric_results
[
metric
],
current_iter
)
self
.
_log_validation_metric_values
(
current_iter
,
dataset_name
,
tb_logger
)
def
_log_validation_metric_values
(
self
,
current_iter
,
dataset_name
,
tb_logger
):
log_str
=
f
'Validation
{
dataset_name
}
\n
'
for
metric
,
value
in
self
.
metric_results
.
items
():
log_str
+=
f
'
\t
#
{
metric
}
:
{
value
:.
4
f
}
'
if
hasattr
(
self
,
'best_metric_results'
):
log_str
+=
(
f
'
\t
Best:
{
self
.
best_metric_results
[
dataset_name
][
metric
][
"val"
]:.
4
f
}
@ '
f
'
{
self
.
best_metric_results
[
dataset_name
][
metric
][
"iter"
]
}
iter'
)
log_str
+=
'
\n
'
logger
=
get_root_logger
()
logger
.
info
(
log_str
)
if
tb_logger
:
for
metric
,
value
in
self
.
metric_results
.
items
():
tb_logger
.
add_scalar
(
f
'metrics/
{
dataset_name
}
/
{
metric
}
'
,
value
,
current_iter
)
def
get_current_visuals
(
self
):
out_dict
=
OrderedDict
()
out_dict
[
'lq'
]
=
self
.
lq
.
detach
().
cpu
()
out_dict
[
'result'
]
=
self
.
output
.
detach
().
cpu
()
if
hasattr
(
self
,
'gt'
):
out_dict
[
'gt'
]
=
self
.
gt
.
detach
().
cpu
()
return
out_dict
def
save
(
self
,
epoch
,
current_iter
):
if
hasattr
(
self
,
'net_g_ema'
):
self
.
save_network
([
self
.
net_g
,
self
.
net_g_ema
],
'net_g'
,
current_iter
,
param_key
=
[
'params'
,
'params_ema'
])
else
:
self
.
save_network
(
self
.
net_g
,
'net_g'
,
current_iter
)
self
.
save_training_state
(
epoch
,
current_iter
)
BasicSR/basicsr/models/srgan_model.py
0 → 100644
View file @
e2696ece
import
torch
from
collections
import
OrderedDict
from
basicsr.archs
import
build_network
from
basicsr.losses
import
build_loss
from
basicsr.utils
import
get_root_logger
from
basicsr.utils.registry
import
MODEL_REGISTRY
from
.sr_model
import
SRModel
@
MODEL_REGISTRY
.
register
()
class
SRGANModel
(
SRModel
):
"""SRGAN model for single image super-resolution."""
def
init_training_settings
(
self
):
train_opt
=
self
.
opt
[
'train'
]
self
.
ema_decay
=
train_opt
.
get
(
'ema_decay'
,
0
)
if
self
.
ema_decay
>
0
:
logger
=
get_root_logger
()
logger
.
info
(
f
'Use Exponential Moving Average with decay:
{
self
.
ema_decay
}
'
)
# define network net_g with Exponential Moving Average (EMA)
# net_g_ema is used only for testing on one GPU and saving
# There is no need to wrap with DistributedDataParallel
self
.
net_g_ema
=
build_network
(
self
.
opt
[
'network_g'
]).
to
(
self
.
device
)
# load pretrained model
load_path
=
self
.
opt
[
'path'
].
get
(
'pretrain_network_g'
,
None
)
if
load_path
is
not
None
:
self
.
load_network
(
self
.
net_g_ema
,
load_path
,
self
.
opt
[
'path'
].
get
(
'strict_load_g'
,
True
),
'params_ema'
)
else
:
self
.
model_ema
(
0
)
# copy net_g weight
self
.
net_g_ema
.
eval
()
# define network net_d
self
.
net_d
=
build_network
(
self
.
opt
[
'network_d'
])
self
.
net_d
=
self
.
model_to_device
(
self
.
net_d
)
self
.
print_network
(
self
.
net_d
)
# load pretrained models
load_path
=
self
.
opt
[
'path'
].
get
(
'pretrain_network_d'
,
None
)
if
load_path
is
not
None
:
param_key
=
self
.
opt
[
'path'
].
get
(
'param_key_d'
,
'params'
)
self
.
load_network
(
self
.
net_d
,
load_path
,
self
.
opt
[
'path'
].
get
(
'strict_load_d'
,
True
),
param_key
)
self
.
net_g
.
train
()
self
.
net_d
.
train
()
# define losses
if
train_opt
.
get
(
'pixel_opt'
):
self
.
cri_pix
=
build_loss
(
train_opt
[
'pixel_opt'
]).
to
(
self
.
device
)
else
:
self
.
cri_pix
=
None
if
train_opt
.
get
(
'ldl_opt'
):
self
.
cri_ldl
=
build_loss
(
train_opt
[
'ldl_opt'
]).
to
(
self
.
device
)
else
:
self
.
cri_ldl
=
None
if
train_opt
.
get
(
'perceptual_opt'
):
self
.
cri_perceptual
=
build_loss
(
train_opt
[
'perceptual_opt'
]).
to
(
self
.
device
)
else
:
self
.
cri_perceptual
=
None
if
train_opt
.
get
(
'gan_opt'
):
self
.
cri_gan
=
build_loss
(
train_opt
[
'gan_opt'
]).
to
(
self
.
device
)
self
.
net_d_iters
=
train_opt
.
get
(
'net_d_iters'
,
1
)
self
.
net_d_init_iters
=
train_opt
.
get
(
'net_d_init_iters'
,
0
)
# set up optimizers and schedulers
self
.
setup_optimizers
()
self
.
setup_schedulers
()
def
setup_optimizers
(
self
):
train_opt
=
self
.
opt
[
'train'
]
# optimizer g
optim_type
=
train_opt
[
'optim_g'
].
pop
(
'type'
)
self
.
optimizer_g
=
self
.
get_optimizer
(
optim_type
,
self
.
net_g
.
parameters
(),
**
train_opt
[
'optim_g'
])
self
.
optimizers
.
append
(
self
.
optimizer_g
)
# optimizer d
optim_type
=
train_opt
[
'optim_d'
].
pop
(
'type'
)
self
.
optimizer_d
=
self
.
get_optimizer
(
optim_type
,
self
.
net_d
.
parameters
(),
**
train_opt
[
'optim_d'
])
self
.
optimizers
.
append
(
self
.
optimizer_d
)
def
optimize_parameters
(
self
,
current_iter
):
# optimize net_g
for
p
in
self
.
net_d
.
parameters
():
p
.
requires_grad
=
False
self
.
optimizer_g
.
zero_grad
()
self
.
output
=
self
.
net_g
(
self
.
lq
)
l_g_total
=
0
loss_dict
=
OrderedDict
()
if
(
current_iter
%
self
.
net_d_iters
==
0
and
current_iter
>
self
.
net_d_init_iters
):
# pixel loss
if
self
.
cri_pix
:
l_g_pix
=
self
.
cri_pix
(
self
.
output
,
self
.
gt
)
l_g_total
+=
l_g_pix
loss_dict
[
'l_g_pix'
]
=
l_g_pix
# perceptual loss
if
self
.
cri_perceptual
:
l_g_percep
,
l_g_style
=
self
.
cri_perceptual
(
self
.
output
,
self
.
gt
)
if
l_g_percep
is
not
None
:
l_g_total
+=
l_g_percep
loss_dict
[
'l_g_percep'
]
=
l_g_percep
if
l_g_style
is
not
None
:
l_g_total
+=
l_g_style
loss_dict
[
'l_g_style'
]
=
l_g_style
# gan loss
fake_g_pred
=
self
.
net_d
(
self
.
output
)
l_g_gan
=
self
.
cri_gan
(
fake_g_pred
,
True
,
is_disc
=
False
)
l_g_total
+=
l_g_gan
loss_dict
[
'l_g_gan'
]
=
l_g_gan
l_g_total
.
backward
()
self
.
optimizer_g
.
step
()
# optimize net_d
for
p
in
self
.
net_d
.
parameters
():
p
.
requires_grad
=
True
self
.
optimizer_d
.
zero_grad
()
# real
real_d_pred
=
self
.
net_d
(
self
.
gt
)
l_d_real
=
self
.
cri_gan
(
real_d_pred
,
True
,
is_disc
=
True
)
loss_dict
[
'l_d_real'
]
=
l_d_real
loss_dict
[
'out_d_real'
]
=
torch
.
mean
(
real_d_pred
.
detach
())
l_d_real
.
backward
()
# fake
fake_d_pred
=
self
.
net_d
(
self
.
output
.
detach
())
l_d_fake
=
self
.
cri_gan
(
fake_d_pred
,
False
,
is_disc
=
True
)
loss_dict
[
'l_d_fake'
]
=
l_d_fake
loss_dict
[
'out_d_fake'
]
=
torch
.
mean
(
fake_d_pred
.
detach
())
l_d_fake
.
backward
()
self
.
optimizer_d
.
step
()
self
.
log_dict
=
self
.
reduce_loss_dict
(
loss_dict
)
if
self
.
ema_decay
>
0
:
self
.
model_ema
(
decay
=
self
.
ema_decay
)
def
save
(
self
,
epoch
,
current_iter
):
if
hasattr
(
self
,
'net_g_ema'
):
self
.
save_network
([
self
.
net_g
,
self
.
net_g_ema
],
'net_g'
,
current_iter
,
param_key
=
[
'params'
,
'params_ema'
])
else
:
self
.
save_network
(
self
.
net_g
,
'net_g'
,
current_iter
)
self
.
save_network
(
self
.
net_d
,
'net_d'
,
current_iter
)
self
.
save_training_state
(
epoch
,
current_iter
)
BasicSR/basicsr/models/stylegan2_model.py
0 → 100644
View file @
e2696ece
import
cv2
import
math
import
numpy
as
np
import
random
import
torch
from
collections
import
OrderedDict
from
os
import
path
as
osp
from
basicsr.archs
import
build_network
from
basicsr.losses
import
build_loss
from
basicsr.losses.gan_loss
import
g_path_regularize
,
r1_penalty
from
basicsr.utils
import
imwrite
,
tensor2img
from
basicsr.utils.registry
import
MODEL_REGISTRY
from
.base_model
import
BaseModel
@
MODEL_REGISTRY
.
register
()
class
StyleGAN2Model
(
BaseModel
):
"""StyleGAN2 model."""
def
__init__
(
self
,
opt
):
super
(
StyleGAN2Model
,
self
).
__init__
(
opt
)
# define network net_g
self
.
net_g
=
build_network
(
opt
[
'network_g'
])
self
.
net_g
=
self
.
model_to_device
(
self
.
net_g
)
self
.
print_network
(
self
.
net_g
)
# load pretrained model
load_path
=
self
.
opt
[
'path'
].
get
(
'pretrain_network_g'
,
None
)
if
load_path
is
not
None
:
param_key
=
self
.
opt
[
'path'
].
get
(
'param_key_g'
,
'params'
)
self
.
load_network
(
self
.
net_g
,
load_path
,
self
.
opt
[
'path'
].
get
(
'strict_load_g'
,
True
),
param_key
)
# latent dimension: self.num_style_feat
self
.
num_style_feat
=
opt
[
'network_g'
][
'num_style_feat'
]
num_val_samples
=
self
.
opt
[
'val'
].
get
(
'num_val_samples'
,
16
)
self
.
fixed_sample
=
torch
.
randn
(
num_val_samples
,
self
.
num_style_feat
,
device
=
self
.
device
)
if
self
.
is_train
:
self
.
init_training_settings
()
def
init_training_settings
(
self
):
train_opt
=
self
.
opt
[
'train'
]
# define network net_d
self
.
net_d
=
build_network
(
self
.
opt
[
'network_d'
])
self
.
net_d
=
self
.
model_to_device
(
self
.
net_d
)
self
.
print_network
(
self
.
net_d
)
# load pretrained model
load_path
=
self
.
opt
[
'path'
].
get
(
'pretrain_network_d'
,
None
)
if
load_path
is
not
None
:
param_key
=
self
.
opt
[
'path'
].
get
(
'param_key_d'
,
'params'
)
self
.
load_network
(
self
.
net_d
,
load_path
,
self
.
opt
[
'path'
].
get
(
'strict_load_d'
,
True
),
param_key
)
# define network net_g with Exponential Moving Average (EMA)
# net_g_ema only used for testing on one GPU and saving, do not need to
# wrap with DistributedDataParallel
self
.
net_g_ema
=
build_network
(
self
.
opt
[
'network_g'
]).
to
(
self
.
device
)
# load pretrained model
load_path
=
self
.
opt
[
'path'
].
get
(
'pretrain_network_g'
,
None
)
if
load_path
is
not
None
:
self
.
load_network
(
self
.
net_g_ema
,
load_path
,
self
.
opt
[
'path'
].
get
(
'strict_load_g'
,
True
),
'params_ema'
)
else
:
self
.
model_ema
(
0
)
# copy net_g weight
self
.
net_g
.
train
()
self
.
net_d
.
train
()
self
.
net_g_ema
.
eval
()
# define losses
# gan loss (wgan)
self
.
cri_gan
=
build_loss
(
train_opt
[
'gan_opt'
]).
to
(
self
.
device
)
# regularization weights
self
.
r1_reg_weight
=
train_opt
[
'r1_reg_weight'
]
# for discriminator
self
.
path_reg_weight
=
train_opt
[
'path_reg_weight'
]
# for generator
self
.
net_g_reg_every
=
train_opt
[
'net_g_reg_every'
]
self
.
net_d_reg_every
=
train_opt
[
'net_d_reg_every'
]
self
.
mixing_prob
=
train_opt
[
'mixing_prob'
]
self
.
mean_path_length
=
0
# set up optimizers and schedulers
self
.
setup_optimizers
()
self
.
setup_schedulers
()
def
setup_optimizers
(
self
):
train_opt
=
self
.
opt
[
'train'
]
# optimizer g
net_g_reg_ratio
=
self
.
net_g_reg_every
/
(
self
.
net_g_reg_every
+
1
)
if
self
.
opt
[
'network_g'
][
'type'
]
==
'StyleGAN2GeneratorC'
:
normal_params
=
[]
style_mlp_params
=
[]
modulation_conv_params
=
[]
for
name
,
param
in
self
.
net_g
.
named_parameters
():
if
'modulation'
in
name
:
normal_params
.
append
(
param
)
elif
'style_mlp'
in
name
:
style_mlp_params
.
append
(
param
)
elif
'modulated_conv'
in
name
:
modulation_conv_params
.
append
(
param
)
else
:
normal_params
.
append
(
param
)
optim_params_g
=
[
{
# add normal params first
'params'
:
normal_params
,
'lr'
:
train_opt
[
'optim_g'
][
'lr'
]
},
{
'params'
:
style_mlp_params
,
'lr'
:
train_opt
[
'optim_g'
][
'lr'
]
*
0.01
},
{
'params'
:
modulation_conv_params
,
'lr'
:
train_opt
[
'optim_g'
][
'lr'
]
/
3
}
]
else
:
normal_params
=
[]
for
name
,
param
in
self
.
net_g
.
named_parameters
():
normal_params
.
append
(
param
)
optim_params_g
=
[{
# add normal params first
'params'
:
normal_params
,
'lr'
:
train_opt
[
'optim_g'
][
'lr'
]
}]
optim_type
=
train_opt
[
'optim_g'
].
pop
(
'type'
)
lr
=
train_opt
[
'optim_g'
][
'lr'
]
*
net_g_reg_ratio
betas
=
(
0
**
net_g_reg_ratio
,
0.99
**
net_g_reg_ratio
)
self
.
optimizer_g
=
self
.
get_optimizer
(
optim_type
,
optim_params_g
,
lr
,
betas
=
betas
)
self
.
optimizers
.
append
(
self
.
optimizer_g
)
# optimizer d
net_d_reg_ratio
=
self
.
net_d_reg_every
/
(
self
.
net_d_reg_every
+
1
)
if
self
.
opt
[
'network_d'
][
'type'
]
==
'StyleGAN2DiscriminatorC'
:
normal_params
=
[]
linear_params
=
[]
for
name
,
param
in
self
.
net_d
.
named_parameters
():
if
'final_linear'
in
name
:
linear_params
.
append
(
param
)
else
:
normal_params
.
append
(
param
)
optim_params_d
=
[
{
# add normal params first
'params'
:
normal_params
,
'lr'
:
train_opt
[
'optim_d'
][
'lr'
]
},
{
'params'
:
linear_params
,
'lr'
:
train_opt
[
'optim_d'
][
'lr'
]
*
(
1
/
math
.
sqrt
(
512
))
}
]
else
:
normal_params
=
[]
for
name
,
param
in
self
.
net_d
.
named_parameters
():
normal_params
.
append
(
param
)
optim_params_d
=
[{
# add normal params first
'params'
:
normal_params
,
'lr'
:
train_opt
[
'optim_d'
][
'lr'
]
}]
optim_type
=
train_opt
[
'optim_d'
].
pop
(
'type'
)
lr
=
train_opt
[
'optim_d'
][
'lr'
]
*
net_d_reg_ratio
betas
=
(
0
**
net_d_reg_ratio
,
0.99
**
net_d_reg_ratio
)
self
.
optimizer_d
=
self
.
get_optimizer
(
optim_type
,
optim_params_d
,
lr
,
betas
=
betas
)
self
.
optimizers
.
append
(
self
.
optimizer_d
)
def
feed_data
(
self
,
data
):
self
.
real_img
=
data
[
'gt'
].
to
(
self
.
device
)
def
make_noise
(
self
,
batch
,
num_noise
):
if
num_noise
==
1
:
noises
=
torch
.
randn
(
batch
,
self
.
num_style_feat
,
device
=
self
.
device
)
else
:
noises
=
torch
.
randn
(
num_noise
,
batch
,
self
.
num_style_feat
,
device
=
self
.
device
).
unbind
(
0
)
return
noises
def
mixing_noise
(
self
,
batch
,
prob
):
if
random
.
random
()
<
prob
:
return
self
.
make_noise
(
batch
,
2
)
else
:
return
[
self
.
make_noise
(
batch
,
1
)]
def
optimize_parameters
(
self
,
current_iter
):
loss_dict
=
OrderedDict
()
# optimize net_d
for
p
in
self
.
net_d
.
parameters
():
p
.
requires_grad
=
True
self
.
optimizer_d
.
zero_grad
()
batch
=
self
.
real_img
.
size
(
0
)
noise
=
self
.
mixing_noise
(
batch
,
self
.
mixing_prob
)
fake_img
,
_
=
self
.
net_g
(
noise
)
fake_pred
=
self
.
net_d
(
fake_img
.
detach
())
real_pred
=
self
.
net_d
(
self
.
real_img
)
# wgan loss with softplus (logistic loss) for discriminator
l_d
=
self
.
cri_gan
(
real_pred
,
True
,
is_disc
=
True
)
+
self
.
cri_gan
(
fake_pred
,
False
,
is_disc
=
True
)
loss_dict
[
'l_d'
]
=
l_d
# In wgan, real_score should be positive and fake_score should be
# negative
loss_dict
[
'real_score'
]
=
real_pred
.
detach
().
mean
()
loss_dict
[
'fake_score'
]
=
fake_pred
.
detach
().
mean
()
l_d
.
backward
()
if
current_iter
%
self
.
net_d_reg_every
==
0
:
self
.
real_img
.
requires_grad
=
True
real_pred
=
self
.
net_d
(
self
.
real_img
)
l_d_r1
=
r1_penalty
(
real_pred
,
self
.
real_img
)
l_d_r1
=
(
self
.
r1_reg_weight
/
2
*
l_d_r1
*
self
.
net_d_reg_every
+
0
*
real_pred
[
0
])
# TODO: why do we need to add 0 * real_pred, otherwise, a runtime
# error will arise: RuntimeError: Expected to have finished
# reduction in the prior iteration before starting a new one.
# This error indicates that your module has parameters that were
# not used in producing loss.
loss_dict
[
'l_d_r1'
]
=
l_d_r1
.
detach
().
mean
()
l_d_r1
.
backward
()
self
.
optimizer_d
.
step
()
# optimize net_g
for
p
in
self
.
net_d
.
parameters
():
p
.
requires_grad
=
False
self
.
optimizer_g
.
zero_grad
()
noise
=
self
.
mixing_noise
(
batch
,
self
.
mixing_prob
)
fake_img
,
_
=
self
.
net_g
(
noise
)
fake_pred
=
self
.
net_d
(
fake_img
)
# wgan loss with softplus (non-saturating loss) for generator
l_g
=
self
.
cri_gan
(
fake_pred
,
True
,
is_disc
=
False
)
loss_dict
[
'l_g'
]
=
l_g
l_g
.
backward
()
if
current_iter
%
self
.
net_g_reg_every
==
0
:
path_batch_size
=
max
(
1
,
batch
//
self
.
opt
[
'train'
][
'path_batch_shrink'
])
noise
=
self
.
mixing_noise
(
path_batch_size
,
self
.
mixing_prob
)
fake_img
,
latents
=
self
.
net_g
(
noise
,
return_latents
=
True
)
l_g_path
,
path_lengths
,
self
.
mean_path_length
=
g_path_regularize
(
fake_img
,
latents
,
self
.
mean_path_length
)
l_g_path
=
(
self
.
path_reg_weight
*
self
.
net_g_reg_every
*
l_g_path
+
0
*
fake_img
[
0
,
0
,
0
,
0
])
# TODO: why do we need to add 0 * fake_img[0, 0, 0, 0]
l_g_path
.
backward
()
loss_dict
[
'l_g_path'
]
=
l_g_path
.
detach
().
mean
()
loss_dict
[
'path_length'
]
=
path_lengths
self
.
optimizer_g
.
step
()
self
.
log_dict
=
self
.
reduce_loss_dict
(
loss_dict
)
# EMA
self
.
model_ema
(
decay
=
0.5
**
(
32
/
(
10
*
1000
)))
def
test
(
self
):
with
torch
.
no_grad
():
self
.
net_g_ema
.
eval
()
self
.
output
,
_
=
self
.
net_g_ema
([
self
.
fixed_sample
])
def
dist_validation
(
self
,
dataloader
,
current_iter
,
tb_logger
,
save_img
):
if
self
.
opt
[
'rank'
]
==
0
:
self
.
nondist_validation
(
dataloader
,
current_iter
,
tb_logger
,
save_img
)
def
nondist_validation
(
self
,
dataloader
,
current_iter
,
tb_logger
,
save_img
):
assert
dataloader
is
None
,
'Validation dataloader should be None.'
self
.
test
()
result
=
tensor2img
(
self
.
output
,
min_max
=
(
-
1
,
1
))
if
self
.
opt
[
'is_train'
]:
save_img_path
=
osp
.
join
(
self
.
opt
[
'path'
][
'visualization'
],
'train'
,
f
'train_
{
current_iter
}
.png'
)
else
:
save_img_path
=
osp
.
join
(
self
.
opt
[
'path'
][
'visualization'
],
'test'
,
f
'test_
{
self
.
opt
[
"name"
]
}
.png'
)
imwrite
(
result
,
save_img_path
)
# add sample images to tb_logger
result
=
(
result
/
255.
).
astype
(
np
.
float32
)
result
=
cv2
.
cvtColor
(
result
,
cv2
.
COLOR_BGR2RGB
)
if
tb_logger
is
not
None
:
tb_logger
.
add_image
(
'samples'
,
result
,
global_step
=
current_iter
,
dataformats
=
'HWC'
)
def
save
(
self
,
epoch
,
current_iter
):
self
.
save_network
([
self
.
net_g
,
self
.
net_g_ema
],
'net_g'
,
current_iter
,
param_key
=
[
'params'
,
'params_ema'
])
self
.
save_network
(
self
.
net_d
,
'net_d'
,
current_iter
)
self
.
save_training_state
(
epoch
,
current_iter
)
BasicSR/basicsr/models/swinir_model.py
0 → 100644
View file @
e2696ece
import
torch
from
torch.nn
import
functional
as
F
from
basicsr.utils.registry
import
MODEL_REGISTRY
from
.sr_model
import
SRModel
@
MODEL_REGISTRY
.
register
()
class
SwinIRModel
(
SRModel
):
def
test
(
self
):
# pad to multiplication of window_size
window_size
=
self
.
opt
[
'network_g'
][
'window_size'
]
scale
=
self
.
opt
.
get
(
'scale'
,
1
)
mod_pad_h
,
mod_pad_w
=
0
,
0
_
,
_
,
h
,
w
=
self
.
lq
.
size
()
if
h
%
window_size
!=
0
:
mod_pad_h
=
window_size
-
h
%
window_size
if
w
%
window_size
!=
0
:
mod_pad_w
=
window_size
-
w
%
window_size
img
=
F
.
pad
(
self
.
lq
,
(
0
,
mod_pad_w
,
0
,
mod_pad_h
),
'reflect'
)
if
hasattr
(
self
,
'net_g_ema'
):
self
.
net_g_ema
.
eval
()
with
torch
.
no_grad
():
self
.
output
=
self
.
net_g_ema
(
img
)
else
:
self
.
net_g
.
eval
()
with
torch
.
no_grad
():
self
.
output
=
self
.
net_g
(
img
)
self
.
net_g
.
train
()
_
,
_
,
h
,
w
=
self
.
output
.
size
()
self
.
output
=
self
.
output
[:,
:,
0
:
h
-
mod_pad_h
*
scale
,
0
:
w
-
mod_pad_w
*
scale
]
BasicSR/basicsr/models/video_base_model.py
0 → 100644
View file @
e2696ece
import
torch
from
collections
import
Counter
from
os
import
path
as
osp
from
torch
import
distributed
as
dist
from
tqdm
import
tqdm
from
basicsr.metrics
import
calculate_metric
from
basicsr.utils
import
get_root_logger
,
imwrite
,
tensor2img
from
basicsr.utils.dist_util
import
get_dist_info
from
basicsr.utils.registry
import
MODEL_REGISTRY
from
.sr_model
import
SRModel
@
MODEL_REGISTRY
.
register
()
class
VideoBaseModel
(
SRModel
):
"""Base video SR model."""
def
dist_validation
(
self
,
dataloader
,
current_iter
,
tb_logger
,
save_img
):
dataset
=
dataloader
.
dataset
dataset_name
=
dataset
.
opt
[
'name'
]
with_metrics
=
self
.
opt
[
'val'
][
'metrics'
]
is
not
None
# initialize self.metric_results
# It is a dict: {
# 'folder1': tensor (num_frame x len(metrics)),
# 'folder2': tensor (num_frame x len(metrics))
# }
if
with_metrics
:
if
not
hasattr
(
self
,
'metric_results'
):
# only execute in the first run
self
.
metric_results
=
{}
num_frame_each_folder
=
Counter
(
dataset
.
data_info
[
'folder'
])
for
folder
,
num_frame
in
num_frame_each_folder
.
items
():
self
.
metric_results
[
folder
]
=
torch
.
zeros
(
num_frame
,
len
(
self
.
opt
[
'val'
][
'metrics'
]),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
# initialize the best metric results
self
.
_initialize_best_metric_results
(
dataset_name
)
# zero self.metric_results
rank
,
world_size
=
get_dist_info
()
if
with_metrics
:
for
_
,
tensor
in
self
.
metric_results
.
items
():
tensor
.
zero_
()
metric_data
=
dict
()
# record all frames (border and center frames)
if
rank
==
0
:
pbar
=
tqdm
(
total
=
len
(
dataset
),
unit
=
'frame'
)
for
idx
in
range
(
rank
,
len
(
dataset
),
world_size
):
val_data
=
dataset
[
idx
]
val_data
[
'lq'
].
unsqueeze_
(
0
)
val_data
[
'gt'
].
unsqueeze_
(
0
)
folder
=
val_data
[
'folder'
]
frame_idx
,
max_idx
=
val_data
[
'idx'
].
split
(
'/'
)
lq_path
=
val_data
[
'lq_path'
]
self
.
feed_data
(
val_data
)
self
.
test
()
visuals
=
self
.
get_current_visuals
()
result_img
=
tensor2img
([
visuals
[
'result'
]])
metric_data
[
'img'
]
=
result_img
if
'gt'
in
visuals
:
gt_img
=
tensor2img
([
visuals
[
'gt'
]])
metric_data
[
'img2'
]
=
gt_img
del
self
.
gt
# tentative for out of GPU memory
del
self
.
lq
del
self
.
output
torch
.
cuda
.
empty_cache
()
if
save_img
:
if
self
.
opt
[
'is_train'
]:
raise
NotImplementedError
(
'saving image is not supported during training.'
)
else
:
if
'vimeo'
in
dataset_name
.
lower
():
# vimeo90k dataset
split_result
=
lq_path
.
split
(
'/'
)
img_name
=
f
'
{
split_result
[
-
3
]
}
_
{
split_result
[
-
2
]
}
_
{
split_result
[
-
1
].
split
(
"."
)[
0
]
}
'
else
:
# other datasets, e.g., REDS, Vid4
img_name
=
osp
.
splitext
(
osp
.
basename
(
lq_path
))[
0
]
if
self
.
opt
[
'val'
][
'suffix'
]:
save_img_path
=
osp
.
join
(
self
.
opt
[
'path'
][
'visualization'
],
dataset_name
,
folder
,
f
'
{
img_name
}
_
{
self
.
opt
[
"val"
][
"suffix"
]
}
.png'
)
else
:
save_img_path
=
osp
.
join
(
self
.
opt
[
'path'
][
'visualization'
],
dataset_name
,
folder
,
f
'
{
img_name
}
_
{
self
.
opt
[
"name"
]
}
.png'
)
imwrite
(
result_img
,
save_img_path
)
if
with_metrics
:
# calculate metrics
for
metric_idx
,
opt_
in
enumerate
(
self
.
opt
[
'val'
][
'metrics'
].
values
()):
result
=
calculate_metric
(
metric_data
,
opt_
)
self
.
metric_results
[
folder
][
int
(
frame_idx
),
metric_idx
]
+=
result
# progress bar
if
rank
==
0
:
for
_
in
range
(
world_size
):
pbar
.
update
(
1
)
pbar
.
set_description
(
f
'Test
{
folder
}
:
{
int
(
frame_idx
)
+
world_size
}
/
{
max_idx
}
'
)
if
rank
==
0
:
pbar
.
close
()
if
with_metrics
:
if
self
.
opt
[
'dist'
]:
# collect data among GPUs
for
_
,
tensor
in
self
.
metric_results
.
items
():
dist
.
reduce
(
tensor
,
0
)
dist
.
barrier
()
else
:
pass
# assume use one gpu in non-dist testing
if
rank
==
0
:
self
.
_log_validation_metric_values
(
current_iter
,
dataset_name
,
tb_logger
)
def
nondist_validation
(
self
,
dataloader
,
current_iter
,
tb_logger
,
save_img
):
logger
=
get_root_logger
()
logger
.
warning
(
'nondist_validation is not implemented. Run dist_validation.'
)
self
.
dist_validation
(
dataloader
,
current_iter
,
tb_logger
,
save_img
)
def
_log_validation_metric_values
(
self
,
current_iter
,
dataset_name
,
tb_logger
):
# ----------------- calculate the average values for each folder, and for each metric ----------------- #
# average all frames for each sub-folder
# metric_results_avg is a dict:{
# 'folder1': tensor (len(metrics)),
# 'folder2': tensor (len(metrics))
# }
metric_results_avg
=
{
folder
:
torch
.
mean
(
tensor
,
dim
=
0
).
cpu
()
for
(
folder
,
tensor
)
in
self
.
metric_results
.
items
()
}
# total_avg_results is a dict: {
# 'metric1': float,
# 'metric2': float
# }
total_avg_results
=
{
metric
:
0
for
metric
in
self
.
opt
[
'val'
][
'metrics'
].
keys
()}
for
folder
,
tensor
in
metric_results_avg
.
items
():
for
idx
,
metric
in
enumerate
(
total_avg_results
.
keys
()):
total_avg_results
[
metric
]
+=
metric_results_avg
[
folder
][
idx
].
item
()
# average among folders
for
metric
in
total_avg_results
.
keys
():
total_avg_results
[
metric
]
/=
len
(
metric_results_avg
)
# update the best metric result
self
.
_update_best_metric_result
(
dataset_name
,
metric
,
total_avg_results
[
metric
],
current_iter
)
# ------------------------------------------ log the metric ------------------------------------------ #
log_str
=
f
'Validation
{
dataset_name
}
\n
'
for
metric_idx
,
(
metric
,
value
)
in
enumerate
(
total_avg_results
.
items
()):
log_str
+=
f
'
\t
#
{
metric
}
:
{
value
:.
4
f
}
'
for
folder
,
tensor
in
metric_results_avg
.
items
():
log_str
+=
f
'
\t
#
{
folder
}
:
{
tensor
[
metric_idx
].
item
():.
4
f
}
'
if
hasattr
(
self
,
'best_metric_results'
):
log_str
+=
(
f
'
\n\t
Best:
{
self
.
best_metric_results
[
dataset_name
][
metric
][
"val"
]:.
4
f
}
@ '
f
'
{
self
.
best_metric_results
[
dataset_name
][
metric
][
"iter"
]
}
iter'
)
log_str
+=
'
\n
'
logger
=
get_root_logger
()
logger
.
info
(
log_str
)
if
tb_logger
:
for
metric_idx
,
(
metric
,
value
)
in
enumerate
(
total_avg_results
.
items
()):
tb_logger
.
add_scalar
(
f
'metrics/
{
metric
}
'
,
value
,
current_iter
)
for
folder
,
tensor
in
metric_results_avg
.
items
():
tb_logger
.
add_scalar
(
f
'metrics/
{
metric
}
/
{
folder
}
'
,
tensor
[
metric_idx
].
item
(),
current_iter
)
BasicSR/basicsr/models/video_gan_model.py
0 → 100644
View file @
e2696ece
from
basicsr.utils.registry
import
MODEL_REGISTRY
from
.srgan_model
import
SRGANModel
from
.video_base_model
import
VideoBaseModel
@
MODEL_REGISTRY
.
register
()
class
VideoGANModel
(
SRGANModel
,
VideoBaseModel
):
"""Video GAN model.
Use multiple inheritance.
It will first use the functions of :class:`SRGANModel`:
- :func:`init_training_settings`
- :func:`setup_optimizers`
- :func:`optimize_parameters`
- :func:`save`
Then find functions in :class:`VideoBaseModel`.
"""
BasicSR/basicsr/models/video_recurrent_gan_model.py
0 → 100644
View file @
e2696ece
import
torch
from
collections
import
OrderedDict
from
basicsr.archs
import
build_network
from
basicsr.losses
import
build_loss
from
basicsr.utils
import
get_root_logger
from
basicsr.utils.registry
import
MODEL_REGISTRY
from
.video_recurrent_model
import
VideoRecurrentModel
@
MODEL_REGISTRY
.
register
()
class
VideoRecurrentGANModel
(
VideoRecurrentModel
):
def
init_training_settings
(
self
):
train_opt
=
self
.
opt
[
'train'
]
self
.
ema_decay
=
train_opt
.
get
(
'ema_decay'
,
0
)
if
self
.
ema_decay
>
0
:
logger
=
get_root_logger
()
logger
.
info
(
f
'Use Exponential Moving Average with decay:
{
self
.
ema_decay
}
'
)
# build network net_g with Exponential Moving Average (EMA)
# net_g_ema only used for testing on one GPU and saving.
# There is no need to wrap with DistributedDataParallel
self
.
net_g_ema
=
build_network
(
self
.
opt
[
'network_g'
]).
to
(
self
.
device
)
# load pretrained model
load_path
=
self
.
opt
[
'path'
].
get
(
'pretrain_network_g'
,
None
)
if
load_path
is
not
None
:
self
.
load_network
(
self
.
net_g_ema
,
load_path
,
self
.
opt
[
'path'
].
get
(
'strict_load_g'
,
True
),
'params_ema'
)
else
:
self
.
model_ema
(
0
)
# copy net_g weight
self
.
net_g_ema
.
eval
()
# define network net_d
self
.
net_d
=
build_network
(
self
.
opt
[
'network_d'
])
self
.
net_d
=
self
.
model_to_device
(
self
.
net_d
)
self
.
print_network
(
self
.
net_d
)
# load pretrained models
load_path
=
self
.
opt
[
'path'
].
get
(
'pretrain_network_d'
,
None
)
if
load_path
is
not
None
:
param_key
=
self
.
opt
[
'path'
].
get
(
'param_key_d'
,
'params'
)
self
.
load_network
(
self
.
net_d
,
load_path
,
self
.
opt
[
'path'
].
get
(
'strict_load_d'
,
True
),
param_key
)
self
.
net_g
.
train
()
self
.
net_d
.
train
()
# define losses
if
train_opt
.
get
(
'pixel_opt'
):
self
.
cri_pix
=
build_loss
(
train_opt
[
'pixel_opt'
]).
to
(
self
.
device
)
else
:
self
.
cri_pix
=
None
if
train_opt
.
get
(
'perceptual_opt'
):
self
.
cri_perceptual
=
build_loss
(
train_opt
[
'perceptual_opt'
]).
to
(
self
.
device
)
else
:
self
.
cri_perceptual
=
None
if
train_opt
.
get
(
'gan_opt'
):
self
.
cri_gan
=
build_loss
(
train_opt
[
'gan_opt'
]).
to
(
self
.
device
)
self
.
net_d_iters
=
train_opt
.
get
(
'net_d_iters'
,
1
)
self
.
net_d_init_iters
=
train_opt
.
get
(
'net_d_init_iters'
,
0
)
# set up optimizers and schedulers
self
.
setup_optimizers
()
self
.
setup_schedulers
()
def
setup_optimizers
(
self
):
train_opt
=
self
.
opt
[
'train'
]
if
train_opt
[
'fix_flow'
]:
normal_params
=
[]
flow_params
=
[]
for
name
,
param
in
self
.
net_g
.
named_parameters
():
if
'spynet'
in
name
:
# The fix_flow now only works for spynet.
flow_params
.
append
(
param
)
else
:
normal_params
.
append
(
param
)
optim_params
=
[
{
# add flow params first
'params'
:
flow_params
,
'lr'
:
train_opt
[
'lr_flow'
]
},
{
'params'
:
normal_params
,
'lr'
:
train_opt
[
'optim_g'
][
'lr'
]
},
]
else
:
optim_params
=
self
.
net_g
.
parameters
()
# optimizer g
optim_type
=
train_opt
[
'optim_g'
].
pop
(
'type'
)
self
.
optimizer_g
=
self
.
get_optimizer
(
optim_type
,
optim_params
,
**
train_opt
[
'optim_g'
])
self
.
optimizers
.
append
(
self
.
optimizer_g
)
# optimizer d
optim_type
=
train_opt
[
'optim_d'
].
pop
(
'type'
)
self
.
optimizer_d
=
self
.
get_optimizer
(
optim_type
,
self
.
net_d
.
parameters
(),
**
train_opt
[
'optim_d'
])
self
.
optimizers
.
append
(
self
.
optimizer_d
)
def
optimize_parameters
(
self
,
current_iter
):
logger
=
get_root_logger
()
# optimize net_g
for
p
in
self
.
net_d
.
parameters
():
p
.
requires_grad
=
False
if
self
.
fix_flow_iter
:
if
current_iter
==
1
:
logger
.
info
(
f
'Fix flow network and feature extractor for
{
self
.
fix_flow_iter
}
iters.'
)
for
name
,
param
in
self
.
net_g
.
named_parameters
():
if
'spynet'
in
name
or
'edvr'
in
name
:
param
.
requires_grad_
(
False
)
elif
current_iter
==
self
.
fix_flow_iter
:
logger
.
warning
(
'Train all the parameters.'
)
self
.
net_g
.
requires_grad_
(
True
)
self
.
optimizer_g
.
zero_grad
()
self
.
output
=
self
.
net_g
(
self
.
lq
)
_
,
_
,
c
,
h
,
w
=
self
.
output
.
size
()
l_g_total
=
0
loss_dict
=
OrderedDict
()
if
(
current_iter
%
self
.
net_d_iters
==
0
and
current_iter
>
self
.
net_d_init_iters
):
# pixel loss
if
self
.
cri_pix
:
l_g_pix
=
self
.
cri_pix
(
self
.
output
,
self
.
gt
)
l_g_total
+=
l_g_pix
loss_dict
[
'l_g_pix'
]
=
l_g_pix
# perceptual loss
if
self
.
cri_perceptual
:
l_g_percep
,
l_g_style
=
self
.
cri_perceptual
(
self
.
output
.
view
(
-
1
,
c
,
h
,
w
),
self
.
gt
.
view
(
-
1
,
c
,
h
,
w
))
if
l_g_percep
is
not
None
:
l_g_total
+=
l_g_percep
loss_dict
[
'l_g_percep'
]
=
l_g_percep
if
l_g_style
is
not
None
:
l_g_total
+=
l_g_style
loss_dict
[
'l_g_style'
]
=
l_g_style
# gan loss
fake_g_pred
=
self
.
net_d
(
self
.
output
.
view
(
-
1
,
c
,
h
,
w
))
l_g_gan
=
self
.
cri_gan
(
fake_g_pred
,
True
,
is_disc
=
False
)
l_g_total
+=
l_g_gan
loss_dict
[
'l_g_gan'
]
=
l_g_gan
l_g_total
.
backward
()
self
.
optimizer_g
.
step
()
# optimize net_d
for
p
in
self
.
net_d
.
parameters
():
p
.
requires_grad
=
True
self
.
optimizer_d
.
zero_grad
()
# real
# reshape to (b*n, c, h, w)
real_d_pred
=
self
.
net_d
(
self
.
gt
.
view
(
-
1
,
c
,
h
,
w
))
l_d_real
=
self
.
cri_gan
(
real_d_pred
,
True
,
is_disc
=
True
)
loss_dict
[
'l_d_real'
]
=
l_d_real
loss_dict
[
'out_d_real'
]
=
torch
.
mean
(
real_d_pred
.
detach
())
l_d_real
.
backward
()
# fake
# reshape to (b*n, c, h, w)
fake_d_pred
=
self
.
net_d
(
self
.
output
.
view
(
-
1
,
c
,
h
,
w
).
detach
())
l_d_fake
=
self
.
cri_gan
(
fake_d_pred
,
False
,
is_disc
=
True
)
loss_dict
[
'l_d_fake'
]
=
l_d_fake
loss_dict
[
'out_d_fake'
]
=
torch
.
mean
(
fake_d_pred
.
detach
())
l_d_fake
.
backward
()
self
.
optimizer_d
.
step
()
self
.
log_dict
=
self
.
reduce_loss_dict
(
loss_dict
)
if
self
.
ema_decay
>
0
:
self
.
model_ema
(
decay
=
self
.
ema_decay
)
def
save
(
self
,
epoch
,
current_iter
):
if
self
.
ema_decay
>
0
:
self
.
save_network
([
self
.
net_g
,
self
.
net_g_ema
],
'net_g'
,
current_iter
,
param_key
=
[
'params'
,
'params_ema'
])
else
:
self
.
save_network
(
self
.
net_g
,
'net_g'
,
current_iter
)
self
.
save_network
(
self
.
net_d
,
'net_d'
,
current_iter
)
self
.
save_training_state
(
epoch
,
current_iter
)
BasicSR/basicsr/models/video_recurrent_model.py
0 → 100644
View file @
e2696ece
import
torch
from
collections
import
Counter
from
os
import
path
as
osp
from
torch
import
distributed
as
dist
from
tqdm
import
tqdm
from
basicsr.metrics
import
calculate_metric
from
basicsr.utils
import
get_root_logger
,
imwrite
,
tensor2img
from
basicsr.utils.dist_util
import
get_dist_info
from
basicsr.utils.registry
import
MODEL_REGISTRY
from
.video_base_model
import
VideoBaseModel
@
MODEL_REGISTRY
.
register
()
class
VideoRecurrentModel
(
VideoBaseModel
):
def
__init__
(
self
,
opt
):
super
(
VideoRecurrentModel
,
self
).
__init__
(
opt
)
if
self
.
is_train
:
self
.
fix_flow_iter
=
opt
[
'train'
].
get
(
'fix_flow'
)
def
setup_optimizers
(
self
):
train_opt
=
self
.
opt
[
'train'
]
flow_lr_mul
=
train_opt
.
get
(
'flow_lr_mul'
,
1
)
logger
=
get_root_logger
()
logger
.
info
(
f
'Multiple the learning rate for flow network with
{
flow_lr_mul
}
.'
)
if
flow_lr_mul
==
1
:
optim_params
=
self
.
net_g
.
parameters
()
else
:
# separate flow params and normal params for different lr
normal_params
=
[]
flow_params
=
[]
for
name
,
param
in
self
.
net_g
.
named_parameters
():
if
'spynet'
in
name
:
flow_params
.
append
(
param
)
else
:
normal_params
.
append
(
param
)
optim_params
=
[
{
# add normal params first
'params'
:
normal_params
,
'lr'
:
train_opt
[
'optim_g'
][
'lr'
]
},
{
'params'
:
flow_params
,
'lr'
:
train_opt
[
'optim_g'
][
'lr'
]
*
flow_lr_mul
},
]
optim_type
=
train_opt
[
'optim_g'
].
pop
(
'type'
)
self
.
optimizer_g
=
self
.
get_optimizer
(
optim_type
,
optim_params
,
**
train_opt
[
'optim_g'
])
self
.
optimizers
.
append
(
self
.
optimizer_g
)
def
optimize_parameters
(
self
,
current_iter
):
if
self
.
fix_flow_iter
:
logger
=
get_root_logger
()
if
current_iter
==
1
:
logger
.
info
(
f
'Fix flow network and feature extractor for
{
self
.
fix_flow_iter
}
iters.'
)
for
name
,
param
in
self
.
net_g
.
named_parameters
():
if
'spynet'
in
name
or
'edvr'
in
name
:
param
.
requires_grad_
(
False
)
elif
current_iter
==
self
.
fix_flow_iter
:
logger
.
warning
(
'Train all the parameters.'
)
self
.
net_g
.
requires_grad_
(
True
)
super
(
VideoRecurrentModel
,
self
).
optimize_parameters
(
current_iter
)
def
dist_validation
(
self
,
dataloader
,
current_iter
,
tb_logger
,
save_img
):
dataset
=
dataloader
.
dataset
dataset_name
=
dataset
.
opt
[
'name'
]
with_metrics
=
self
.
opt
[
'val'
][
'metrics'
]
is
not
None
# initialize self.metric_results
# It is a dict: {
# 'folder1': tensor (num_frame x len(metrics)),
# 'folder2': tensor (num_frame x len(metrics))
# }
if
with_metrics
:
if
not
hasattr
(
self
,
'metric_results'
):
# only execute in the first run
self
.
metric_results
=
{}
num_frame_each_folder
=
Counter
(
dataset
.
data_info
[
'folder'
])
for
folder
,
num_frame
in
num_frame_each_folder
.
items
():
self
.
metric_results
[
folder
]
=
torch
.
zeros
(
num_frame
,
len
(
self
.
opt
[
'val'
][
'metrics'
]),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
# initialize the best metric results
self
.
_initialize_best_metric_results
(
dataset_name
)
# zero self.metric_results
rank
,
world_size
=
get_dist_info
()
if
with_metrics
:
for
_
,
tensor
in
self
.
metric_results
.
items
():
tensor
.
zero_
()
metric_data
=
dict
()
num_folders
=
len
(
dataset
)
num_pad
=
(
world_size
-
(
num_folders
%
world_size
))
%
world_size
if
rank
==
0
:
pbar
=
tqdm
(
total
=
len
(
dataset
),
unit
=
'folder'
)
# Will evaluate (num_folders + num_pad) times, but only the first num_folders results will be recorded.
# (To avoid wait-dead)
for
i
in
range
(
rank
,
num_folders
+
num_pad
,
world_size
):
idx
=
min
(
i
,
num_folders
-
1
)
val_data
=
dataset
[
idx
]
folder
=
val_data
[
'folder'
]
# compute outputs
val_data
[
'lq'
].
unsqueeze_
(
0
)
val_data
[
'gt'
].
unsqueeze_
(
0
)
self
.
feed_data
(
val_data
)
val_data
[
'lq'
].
squeeze_
(
0
)
val_data
[
'gt'
].
squeeze_
(
0
)
self
.
test
()
visuals
=
self
.
get_current_visuals
()
# tentative for out of GPU memory
del
self
.
lq
del
self
.
output
if
'gt'
in
visuals
:
del
self
.
gt
torch
.
cuda
.
empty_cache
()
if
self
.
center_frame_only
:
visuals
[
'result'
]
=
visuals
[
'result'
].
unsqueeze
(
1
)
if
'gt'
in
visuals
:
visuals
[
'gt'
]
=
visuals
[
'gt'
].
unsqueeze
(
1
)
# evaluate
if
i
<
num_folders
:
for
idx
in
range
(
visuals
[
'result'
].
size
(
1
)):
result
=
visuals
[
'result'
][
0
,
idx
,
:,
:,
:]
result_img
=
tensor2img
([
result
])
# uint8, bgr
metric_data
[
'img'
]
=
result_img
if
'gt'
in
visuals
:
gt
=
visuals
[
'gt'
][
0
,
idx
,
:,
:,
:]
gt_img
=
tensor2img
([
gt
])
# uint8, bgr
metric_data
[
'img2'
]
=
gt_img
if
save_img
:
if
self
.
opt
[
'is_train'
]:
raise
NotImplementedError
(
'saving image is not supported during training.'
)
else
:
if
self
.
center_frame_only
:
# vimeo-90k
clip_
=
val_data
[
'lq_path'
].
split
(
'/'
)[
-
3
]
seq_
=
val_data
[
'lq_path'
].
split
(
'/'
)[
-
2
]
name_
=
f
'
{
clip_
}
_
{
seq_
}
'
img_path
=
osp
.
join
(
self
.
opt
[
'path'
][
'visualization'
],
dataset_name
,
folder
,
f
"
{
name_
}
_
{
self
.
opt
[
'name'
]
}
.png"
)
else
:
# others
img_path
=
osp
.
join
(
self
.
opt
[
'path'
][
'visualization'
],
dataset_name
,
folder
,
f
"
{
idx
:
08
d
}
_
{
self
.
opt
[
'name'
]
}
.png"
)
# image name only for REDS dataset
imwrite
(
result_img
,
img_path
)
# calculate metrics
if
with_metrics
:
for
metric_idx
,
opt_
in
enumerate
(
self
.
opt
[
'val'
][
'metrics'
].
values
()):
result
=
calculate_metric
(
metric_data
,
opt_
)
self
.
metric_results
[
folder
][
idx
,
metric_idx
]
+=
result
# progress bar
if
rank
==
0
:
for
_
in
range
(
world_size
):
pbar
.
update
(
1
)
pbar
.
set_description
(
f
'Folder:
{
folder
}
'
)
if
rank
==
0
:
pbar
.
close
()
if
with_metrics
:
if
self
.
opt
[
'dist'
]:
# collect data among GPUs
for
_
,
tensor
in
self
.
metric_results
.
items
():
dist
.
reduce
(
tensor
,
0
)
dist
.
barrier
()
if
rank
==
0
:
self
.
_log_validation_metric_values
(
current_iter
,
dataset_name
,
tb_logger
)
def
test
(
self
):
n
=
self
.
lq
.
size
(
1
)
self
.
net_g
.
eval
()
flip_seq
=
self
.
opt
[
'val'
].
get
(
'flip_seq'
,
False
)
self
.
center_frame_only
=
self
.
opt
[
'val'
].
get
(
'center_frame_only'
,
False
)
if
flip_seq
:
self
.
lq
=
torch
.
cat
([
self
.
lq
,
self
.
lq
.
flip
(
1
)],
dim
=
1
)
with
torch
.
no_grad
():
self
.
output
=
self
.
net_g
(
self
.
lq
)
if
flip_seq
:
output_1
=
self
.
output
[:,
:
n
,
:,
:,
:]
output_2
=
self
.
output
[:,
n
:,
:,
:,
:].
flip
(
1
)
self
.
output
=
0.5
*
(
output_1
+
output_2
)
if
self
.
center_frame_only
:
self
.
output
=
self
.
output
[:,
n
//
2
,
:,
:,
:]
self
.
net_g
.
train
()
BasicSR/basicsr/ops/__init__.py
0 → 100644
View file @
e2696ece
BasicSR/basicsr/ops/dcn/__init__.py
0 → 100644
View file @
e2696ece
from
.deform_conv
import
(
DeformConv
,
DeformConvPack
,
ModulatedDeformConv
,
ModulatedDeformConvPack
,
deform_conv
,
modulated_deform_conv
)
__all__
=
[
'DeformConv'
,
'DeformConvPack'
,
'ModulatedDeformConv'
,
'ModulatedDeformConvPack'
,
'deform_conv'
,
'modulated_deform_conv'
]
Prev
1
2
3
4
5
6
7
8
9
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment