Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ControlNet_pytorch
Commits
e2696ece
Commit
e2696ece
authored
Nov 22, 2023
by
mashun1
Browse files
controlnet
parents
Pipeline
#643
canceled with stages
Changes
263
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2693 additions
and
0 deletions
+2693
-0
BasicSR/basicsr/data/paired_image_dataset.py
BasicSR/basicsr/data/paired_image_dataset.py
+106
-0
BasicSR/basicsr/data/prefetch_dataloader.py
BasicSR/basicsr/data/prefetch_dataloader.py
+122
-0
BasicSR/basicsr/data/realesrgan_dataset.py
BasicSR/basicsr/data/realesrgan_dataset.py
+193
-0
BasicSR/basicsr/data/realesrgan_paired_dataset.py
BasicSR/basicsr/data/realesrgan_paired_dataset.py
+106
-0
BasicSR/basicsr/data/reds_dataset.py
BasicSR/basicsr/data/reds_dataset.py
+352
-0
BasicSR/basicsr/data/single_image_dataset.py
BasicSR/basicsr/data/single_image_dataset.py
+68
-0
BasicSR/basicsr/data/transforms.py
BasicSR/basicsr/data/transforms.py
+179
-0
BasicSR/basicsr/data/video_test_dataset.py
BasicSR/basicsr/data/video_test_dataset.py
+283
-0
BasicSR/basicsr/data/vimeo90k_dataset.py
BasicSR/basicsr/data/vimeo90k_dataset.py
+199
-0
BasicSR/basicsr/losses/__init__.py
BasicSR/basicsr/losses/__init__.py
+31
-0
BasicSR/basicsr/losses/basic_loss.py
BasicSR/basicsr/losses/basic_loss.py
+253
-0
BasicSR/basicsr/losses/gan_loss.py
BasicSR/basicsr/losses/gan_loss.py
+207
-0
BasicSR/basicsr/losses/loss_util.py
BasicSR/basicsr/losses/loss_util.py
+145
-0
BasicSR/basicsr/metrics/README.md
BasicSR/basicsr/metrics/README.md
+48
-0
BasicSR/basicsr/metrics/README_CN.md
BasicSR/basicsr/metrics/README_CN.md
+48
-0
BasicSR/basicsr/metrics/__init__.py
BasicSR/basicsr/metrics/__init__.py
+20
-0
BasicSR/basicsr/metrics/fid.py
BasicSR/basicsr/metrics/fid.py
+89
-0
BasicSR/basicsr/metrics/metric_util.py
BasicSR/basicsr/metrics/metric_util.py
+45
-0
BasicSR/basicsr/metrics/niqe.py
BasicSR/basicsr/metrics/niqe.py
+199
-0
BasicSR/basicsr/metrics/niqe_pris_params.npz
BasicSR/basicsr/metrics/niqe_pris_params.npz
+0
-0
No files found.
Too many changes to show.
To preserve performance only
263 of 263+
files are displayed.
Plain diff
Email patch
BasicSR/basicsr/data/paired_image_dataset.py
0 → 100644
View file @
e2696ece
from
torch.utils
import
data
as
data
from
torchvision.transforms.functional
import
normalize
from
basicsr.data.data_util
import
paired_paths_from_folder
,
paired_paths_from_lmdb
,
paired_paths_from_meta_info_file
from
basicsr.data.transforms
import
augment
,
paired_random_crop
from
basicsr.utils
import
FileClient
,
bgr2ycbcr
,
imfrombytes
,
img2tensor
from
basicsr.utils.registry
import
DATASET_REGISTRY
@
DATASET_REGISTRY
.
register
()
class
PairedImageDataset
(
data
.
Dataset
):
"""Paired image dataset for image restoration.
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
There are three modes:
1. **lmdb**: Use lmdb files. If opt['io_backend'] == lmdb.
2. **meta_info_file**: Use meta information file to generate paths.
\
If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
3. **folder**: Scan folders to generate paths. The rest.
Args:
opt (dict): Config for train datasets. It contains the following keys:
dataroot_gt (str): Data root path for gt.
dataroot_lq (str): Data root path for lq.
meta_info_file (str): Path for meta information file.
io_backend (dict): IO backend type and other kwarg.
filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
Default: '{}'.
gt_size (int): Cropped patched size for gt patches.
use_hflip (bool): Use horizontal flips.
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
scale (bool): Scale, which will be added automatically.
phase (str): 'train' or 'val'.
"""
def
__init__
(
self
,
opt
):
super
(
PairedImageDataset
,
self
).
__init__
()
self
.
opt
=
opt
# file client (io backend)
self
.
file_client
=
None
self
.
io_backend_opt
=
opt
[
'io_backend'
]
self
.
mean
=
opt
[
'mean'
]
if
'mean'
in
opt
else
None
self
.
std
=
opt
[
'std'
]
if
'std'
in
opt
else
None
self
.
gt_folder
,
self
.
lq_folder
=
opt
[
'dataroot_gt'
],
opt
[
'dataroot_lq'
]
if
'filename_tmpl'
in
opt
:
self
.
filename_tmpl
=
opt
[
'filename_tmpl'
]
else
:
self
.
filename_tmpl
=
'{}'
if
self
.
io_backend_opt
[
'type'
]
==
'lmdb'
:
self
.
io_backend_opt
[
'db_paths'
]
=
[
self
.
lq_folder
,
self
.
gt_folder
]
self
.
io_backend_opt
[
'client_keys'
]
=
[
'lq'
,
'gt'
]
self
.
paths
=
paired_paths_from_lmdb
([
self
.
lq_folder
,
self
.
gt_folder
],
[
'lq'
,
'gt'
])
elif
'meta_info_file'
in
self
.
opt
and
self
.
opt
[
'meta_info_file'
]
is
not
None
:
self
.
paths
=
paired_paths_from_meta_info_file
([
self
.
lq_folder
,
self
.
gt_folder
],
[
'lq'
,
'gt'
],
self
.
opt
[
'meta_info_file'
],
self
.
filename_tmpl
)
else
:
self
.
paths
=
paired_paths_from_folder
([
self
.
lq_folder
,
self
.
gt_folder
],
[
'lq'
,
'gt'
],
self
.
filename_tmpl
)
def
__getitem__
(
self
,
index
):
if
self
.
file_client
is
None
:
self
.
file_client
=
FileClient
(
self
.
io_backend_opt
.
pop
(
'type'
),
**
self
.
io_backend_opt
)
scale
=
self
.
opt
[
'scale'
]
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
# image range: [0, 1], float32.
gt_path
=
self
.
paths
[
index
][
'gt_path'
]
img_bytes
=
self
.
file_client
.
get
(
gt_path
,
'gt'
)
img_gt
=
imfrombytes
(
img_bytes
,
float32
=
True
)
lq_path
=
self
.
paths
[
index
][
'lq_path'
]
img_bytes
=
self
.
file_client
.
get
(
lq_path
,
'lq'
)
img_lq
=
imfrombytes
(
img_bytes
,
float32
=
True
)
# augmentation for training
if
self
.
opt
[
'phase'
]
==
'train'
:
gt_size
=
self
.
opt
[
'gt_size'
]
# random crop
img_gt
,
img_lq
=
paired_random_crop
(
img_gt
,
img_lq
,
gt_size
,
scale
,
gt_path
)
# flip, rotation
img_gt
,
img_lq
=
augment
([
img_gt
,
img_lq
],
self
.
opt
[
'use_hflip'
],
self
.
opt
[
'use_rot'
])
# color space transform
if
'color'
in
self
.
opt
and
self
.
opt
[
'color'
]
==
'y'
:
img_gt
=
bgr2ycbcr
(
img_gt
,
y_only
=
True
)[...,
None
]
img_lq
=
bgr2ycbcr
(
img_lq
,
y_only
=
True
)[...,
None
]
# crop the unmatched GT images during validation or testing, especially for SR benchmark datasets
# TODO: It is better to update the datasets, rather than force to crop
if
self
.
opt
[
'phase'
]
!=
'train'
:
img_gt
=
img_gt
[
0
:
img_lq
.
shape
[
0
]
*
scale
,
0
:
img_lq
.
shape
[
1
]
*
scale
,
:]
# BGR to RGB, HWC to CHW, numpy to tensor
img_gt
,
img_lq
=
img2tensor
([
img_gt
,
img_lq
],
bgr2rgb
=
True
,
float32
=
True
)
# normalize
if
self
.
mean
is
not
None
or
self
.
std
is
not
None
:
normalize
(
img_lq
,
self
.
mean
,
self
.
std
,
inplace
=
True
)
normalize
(
img_gt
,
self
.
mean
,
self
.
std
,
inplace
=
True
)
return
{
'lq'
:
img_lq
,
'gt'
:
img_gt
,
'lq_path'
:
lq_path
,
'gt_path'
:
gt_path
}
def
__len__
(
self
):
return
len
(
self
.
paths
)
BasicSR/basicsr/data/prefetch_dataloader.py
0 → 100644
View file @
e2696ece
import
queue
as
Queue
import
threading
import
torch
from
torch.utils.data
import
DataLoader
class
PrefetchGenerator
(
threading
.
Thread
):
"""A general prefetch generator.
Reference: https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
Args:
generator: Python generator.
num_prefetch_queue (int): Number of prefetch queue.
"""
def
__init__
(
self
,
generator
,
num_prefetch_queue
):
threading
.
Thread
.
__init__
(
self
)
self
.
queue
=
Queue
.
Queue
(
num_prefetch_queue
)
self
.
generator
=
generator
self
.
daemon
=
True
self
.
start
()
def
run
(
self
):
for
item
in
self
.
generator
:
self
.
queue
.
put
(
item
)
self
.
queue
.
put
(
None
)
def
__next__
(
self
):
next_item
=
self
.
queue
.
get
()
if
next_item
is
None
:
raise
StopIteration
return
next_item
def
__iter__
(
self
):
return
self
class
PrefetchDataLoader
(
DataLoader
):
"""Prefetch version of dataloader.
Reference: https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
TODO:
Need to test on single gpu and ddp (multi-gpu). There is a known issue in
ddp.
Args:
num_prefetch_queue (int): Number of prefetch queue.
kwargs (dict): Other arguments for dataloader.
"""
def
__init__
(
self
,
num_prefetch_queue
,
**
kwargs
):
self
.
num_prefetch_queue
=
num_prefetch_queue
super
(
PrefetchDataLoader
,
self
).
__init__
(
**
kwargs
)
def
__iter__
(
self
):
return
PrefetchGenerator
(
super
().
__iter__
(),
self
.
num_prefetch_queue
)
class
CPUPrefetcher
():
"""CPU prefetcher.
Args:
loader: Dataloader.
"""
def
__init__
(
self
,
loader
):
self
.
ori_loader
=
loader
self
.
loader
=
iter
(
loader
)
def
next
(
self
):
try
:
return
next
(
self
.
loader
)
except
StopIteration
:
return
None
def
reset
(
self
):
self
.
loader
=
iter
(
self
.
ori_loader
)
class
CUDAPrefetcher
():
"""CUDA prefetcher.
Reference: https://github.com/NVIDIA/apex/issues/304#
It may consume more GPU memory.
Args:
loader: Dataloader.
opt (dict): Options.
"""
def
__init__
(
self
,
loader
,
opt
):
self
.
ori_loader
=
loader
self
.
loader
=
iter
(
loader
)
self
.
opt
=
opt
self
.
stream
=
torch
.
cuda
.
Stream
()
self
.
device
=
torch
.
device
(
'cuda'
if
opt
[
'num_gpu'
]
!=
0
else
'cpu'
)
self
.
preload
()
def
preload
(
self
):
try
:
self
.
batch
=
next
(
self
.
loader
)
# self.batch is a dict
except
StopIteration
:
self
.
batch
=
None
return
None
# put tensors to gpu
with
torch
.
cuda
.
stream
(
self
.
stream
):
for
k
,
v
in
self
.
batch
.
items
():
if
torch
.
is_tensor
(
v
):
self
.
batch
[
k
]
=
self
.
batch
[
k
].
to
(
device
=
self
.
device
,
non_blocking
=
True
)
def
next
(
self
):
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
stream
)
batch
=
self
.
batch
self
.
preload
()
return
batch
def
reset
(
self
):
self
.
loader
=
iter
(
self
.
ori_loader
)
self
.
preload
()
BasicSR/basicsr/data/realesrgan_dataset.py
0 → 100644
View file @
e2696ece
import
cv2
import
math
import
numpy
as
np
import
os
import
os.path
as
osp
import
random
import
time
import
torch
from
torch.utils
import
data
as
data
from
basicsr.data.degradations
import
circular_lowpass_kernel
,
random_mixed_kernels
from
basicsr.data.transforms
import
augment
from
basicsr.utils
import
FileClient
,
get_root_logger
,
imfrombytes
,
img2tensor
from
basicsr.utils.registry
import
DATASET_REGISTRY
@
DATASET_REGISTRY
.
register
(
suffix
=
'basicsr'
)
class
RealESRGANDataset
(
data
.
Dataset
):
"""Dataset used for Real-ESRGAN model:
Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
It loads gt (Ground-Truth) images, and augments them.
It also generates blur kernels and sinc kernels for generating low-quality images.
Note that the low-quality images are processed in tensors on GPUS for faster processing.
Args:
opt (dict): Config for train datasets. It contains the following keys:
dataroot_gt (str): Data root path for gt.
meta_info (str): Path for meta information file.
io_backend (dict): IO backend type and other kwarg.
use_hflip (bool): Use horizontal flips.
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
Please see more options in the codes.
"""
def
__init__
(
self
,
opt
):
super
(
RealESRGANDataset
,
self
).
__init__
()
self
.
opt
=
opt
self
.
file_client
=
None
self
.
io_backend_opt
=
opt
[
'io_backend'
]
self
.
gt_folder
=
opt
[
'dataroot_gt'
]
# file client (lmdb io backend)
if
self
.
io_backend_opt
[
'type'
]
==
'lmdb'
:
self
.
io_backend_opt
[
'db_paths'
]
=
[
self
.
gt_folder
]
self
.
io_backend_opt
[
'client_keys'
]
=
[
'gt'
]
if
not
self
.
gt_folder
.
endswith
(
'.lmdb'
):
raise
ValueError
(
f
"'dataroot_gt' should end with '.lmdb', but received
{
self
.
gt_folder
}
"
)
with
open
(
osp
.
join
(
self
.
gt_folder
,
'meta_info.txt'
))
as
fin
:
self
.
paths
=
[
line
.
split
(
'.'
)[
0
]
for
line
in
fin
]
else
:
# disk backend with meta_info
# Each line in the meta_info describes the relative path to an image
with
open
(
self
.
opt
[
'meta_info'
])
as
fin
:
paths
=
[
line
.
strip
().
split
(
' '
)[
0
]
for
line
in
fin
]
self
.
paths
=
[
os
.
path
.
join
(
self
.
gt_folder
,
v
)
for
v
in
paths
]
# blur settings for the first degradation
self
.
blur_kernel_size
=
opt
[
'blur_kernel_size'
]
self
.
kernel_list
=
opt
[
'kernel_list'
]
self
.
kernel_prob
=
opt
[
'kernel_prob'
]
# a list for each kernel probability
self
.
blur_sigma
=
opt
[
'blur_sigma'
]
self
.
betag_range
=
opt
[
'betag_range'
]
# betag used in generalized Gaussian blur kernels
self
.
betap_range
=
opt
[
'betap_range'
]
# betap used in plateau blur kernels
self
.
sinc_prob
=
opt
[
'sinc_prob'
]
# the probability for sinc filters
# blur settings for the second degradation
self
.
blur_kernel_size2
=
opt
[
'blur_kernel_size2'
]
self
.
kernel_list2
=
opt
[
'kernel_list2'
]
self
.
kernel_prob2
=
opt
[
'kernel_prob2'
]
self
.
blur_sigma2
=
opt
[
'blur_sigma2'
]
self
.
betag_range2
=
opt
[
'betag_range2'
]
self
.
betap_range2
=
opt
[
'betap_range2'
]
self
.
sinc_prob2
=
opt
[
'sinc_prob2'
]
# a final sinc filter
self
.
final_sinc_prob
=
opt
[
'final_sinc_prob'
]
self
.
kernel_range
=
[
2
*
v
+
1
for
v
in
range
(
3
,
11
)]
# kernel size ranges from 7 to 21
# TODO: kernel range is now hard-coded, should be in the configure file
self
.
pulse_tensor
=
torch
.
zeros
(
21
,
21
).
float
()
# convolving with pulse tensor brings no blurry effect
self
.
pulse_tensor
[
10
,
10
]
=
1
def
__getitem__
(
self
,
index
):
if
self
.
file_client
is
None
:
self
.
file_client
=
FileClient
(
self
.
io_backend_opt
.
pop
(
'type'
),
**
self
.
io_backend_opt
)
# -------------------------------- Load gt images -------------------------------- #
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
gt_path
=
self
.
paths
[
index
]
# avoid errors caused by high latency in reading files
retry
=
3
while
retry
>
0
:
try
:
img_bytes
=
self
.
file_client
.
get
(
gt_path
,
'gt'
)
except
(
IOError
,
OSError
)
as
e
:
logger
=
get_root_logger
()
logger
.
warn
(
f
'File client error:
{
e
}
, remaining retry times:
{
retry
-
1
}
'
)
# change another file to read
index
=
random
.
randint
(
0
,
self
.
__len__
())
gt_path
=
self
.
paths
[
index
]
time
.
sleep
(
1
)
# sleep 1s for occasional server congestion
else
:
break
finally
:
retry
-=
1
img_gt
=
imfrombytes
(
img_bytes
,
float32
=
True
)
# -------------------- Do augmentation for training: flip, rotation -------------------- #
img_gt
=
augment
(
img_gt
,
self
.
opt
[
'use_hflip'
],
self
.
opt
[
'use_rot'
])
# crop or pad to 400
# TODO: 400 is hard-coded. You may change it accordingly
h
,
w
=
img_gt
.
shape
[
0
:
2
]
crop_pad_size
=
400
# pad
if
h
<
crop_pad_size
or
w
<
crop_pad_size
:
pad_h
=
max
(
0
,
crop_pad_size
-
h
)
pad_w
=
max
(
0
,
crop_pad_size
-
w
)
img_gt
=
cv2
.
copyMakeBorder
(
img_gt
,
0
,
pad_h
,
0
,
pad_w
,
cv2
.
BORDER_REFLECT_101
)
# crop
if
img_gt
.
shape
[
0
]
>
crop_pad_size
or
img_gt
.
shape
[
1
]
>
crop_pad_size
:
h
,
w
=
img_gt
.
shape
[
0
:
2
]
# randomly choose top and left coordinates
top
=
random
.
randint
(
0
,
h
-
crop_pad_size
)
left
=
random
.
randint
(
0
,
w
-
crop_pad_size
)
img_gt
=
img_gt
[
top
:
top
+
crop_pad_size
,
left
:
left
+
crop_pad_size
,
...]
# ------------------------ Generate kernels (used in the first degradation) ------------------------ #
kernel_size
=
random
.
choice
(
self
.
kernel_range
)
if
np
.
random
.
uniform
()
<
self
.
opt
[
'sinc_prob'
]:
# this sinc filter setting is for kernels ranging from [7, 21]
if
kernel_size
<
13
:
omega_c
=
np
.
random
.
uniform
(
np
.
pi
/
3
,
np
.
pi
)
else
:
omega_c
=
np
.
random
.
uniform
(
np
.
pi
/
5
,
np
.
pi
)
kernel
=
circular_lowpass_kernel
(
omega_c
,
kernel_size
,
pad_to
=
False
)
else
:
kernel
=
random_mixed_kernels
(
self
.
kernel_list
,
self
.
kernel_prob
,
kernel_size
,
self
.
blur_sigma
,
self
.
blur_sigma
,
[
-
math
.
pi
,
math
.
pi
],
self
.
betag_range
,
self
.
betap_range
,
noise_range
=
None
)
# pad kernel
pad_size
=
(
21
-
kernel_size
)
//
2
kernel
=
np
.
pad
(
kernel
,
((
pad_size
,
pad_size
),
(
pad_size
,
pad_size
)))
# ------------------------ Generate kernels (used in the second degradation) ------------------------ #
kernel_size
=
random
.
choice
(
self
.
kernel_range
)
if
np
.
random
.
uniform
()
<
self
.
opt
[
'sinc_prob2'
]:
if
kernel_size
<
13
:
omega_c
=
np
.
random
.
uniform
(
np
.
pi
/
3
,
np
.
pi
)
else
:
omega_c
=
np
.
random
.
uniform
(
np
.
pi
/
5
,
np
.
pi
)
kernel2
=
circular_lowpass_kernel
(
omega_c
,
kernel_size
,
pad_to
=
False
)
else
:
kernel2
=
random_mixed_kernels
(
self
.
kernel_list2
,
self
.
kernel_prob2
,
kernel_size
,
self
.
blur_sigma2
,
self
.
blur_sigma2
,
[
-
math
.
pi
,
math
.
pi
],
self
.
betag_range2
,
self
.
betap_range2
,
noise_range
=
None
)
# pad kernel
pad_size
=
(
21
-
kernel_size
)
//
2
kernel2
=
np
.
pad
(
kernel2
,
((
pad_size
,
pad_size
),
(
pad_size
,
pad_size
)))
# ------------------------------------- the final sinc kernel ------------------------------------- #
if
np
.
random
.
uniform
()
<
self
.
opt
[
'final_sinc_prob'
]:
kernel_size
=
random
.
choice
(
self
.
kernel_range
)
omega_c
=
np
.
random
.
uniform
(
np
.
pi
/
3
,
np
.
pi
)
sinc_kernel
=
circular_lowpass_kernel
(
omega_c
,
kernel_size
,
pad_to
=
21
)
sinc_kernel
=
torch
.
FloatTensor
(
sinc_kernel
)
else
:
sinc_kernel
=
self
.
pulse_tensor
# BGR to RGB, HWC to CHW, numpy to tensor
img_gt
=
img2tensor
([
img_gt
],
bgr2rgb
=
True
,
float32
=
True
)[
0
]
kernel
=
torch
.
FloatTensor
(
kernel
)
kernel2
=
torch
.
FloatTensor
(
kernel2
)
return_d
=
{
'gt'
:
img_gt
,
'kernel1'
:
kernel
,
'kernel2'
:
kernel2
,
'sinc_kernel'
:
sinc_kernel
,
'gt_path'
:
gt_path
}
return
return_d
def
__len__
(
self
):
return
len
(
self
.
paths
)
BasicSR/basicsr/data/realesrgan_paired_dataset.py
0 → 100644
View file @
e2696ece
import
os
from
torch.utils
import
data
as
data
from
torchvision.transforms.functional
import
normalize
from
basicsr.data.data_util
import
paired_paths_from_folder
,
paired_paths_from_lmdb
from
basicsr.data.transforms
import
augment
,
paired_random_crop
from
basicsr.utils
import
FileClient
,
imfrombytes
,
img2tensor
from
basicsr.utils.registry
import
DATASET_REGISTRY
@
DATASET_REGISTRY
.
register
(
suffix
=
'basicsr'
)
class
RealESRGANPairedDataset
(
data
.
Dataset
):
"""Paired image dataset for image restoration.
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
There are three modes:
1. **lmdb**: Use lmdb files. If opt['io_backend'] == lmdb.
2. **meta_info_file**: Use meta information file to generate paths.
\
If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
3. **folder**: Scan folders to generate paths. The rest.
Args:
opt (dict): Config for train datasets. It contains the following keys:
dataroot_gt (str): Data root path for gt.
dataroot_lq (str): Data root path for lq.
meta_info (str): Path for meta information file.
io_backend (dict): IO backend type and other kwarg.
filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
Default: '{}'.
gt_size (int): Cropped patched size for gt patches.
use_hflip (bool): Use horizontal flips.
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
scale (bool): Scale, which will be added automatically.
phase (str): 'train' or 'val'.
"""
def
__init__
(
self
,
opt
):
super
(
RealESRGANPairedDataset
,
self
).
__init__
()
self
.
opt
=
opt
self
.
file_client
=
None
self
.
io_backend_opt
=
opt
[
'io_backend'
]
# mean and std for normalizing the input images
self
.
mean
=
opt
[
'mean'
]
if
'mean'
in
opt
else
None
self
.
std
=
opt
[
'std'
]
if
'std'
in
opt
else
None
self
.
gt_folder
,
self
.
lq_folder
=
opt
[
'dataroot_gt'
],
opt
[
'dataroot_lq'
]
self
.
filename_tmpl
=
opt
[
'filename_tmpl'
]
if
'filename_tmpl'
in
opt
else
'{}'
# file client (lmdb io backend)
if
self
.
io_backend_opt
[
'type'
]
==
'lmdb'
:
self
.
io_backend_opt
[
'db_paths'
]
=
[
self
.
lq_folder
,
self
.
gt_folder
]
self
.
io_backend_opt
[
'client_keys'
]
=
[
'lq'
,
'gt'
]
self
.
paths
=
paired_paths_from_lmdb
([
self
.
lq_folder
,
self
.
gt_folder
],
[
'lq'
,
'gt'
])
elif
'meta_info'
in
self
.
opt
and
self
.
opt
[
'meta_info'
]
is
not
None
:
# disk backend with meta_info
# Each line in the meta_info describes the relative path to an image
with
open
(
self
.
opt
[
'meta_info'
])
as
fin
:
paths
=
[
line
.
strip
()
for
line
in
fin
]
self
.
paths
=
[]
for
path
in
paths
:
gt_path
,
lq_path
=
path
.
split
(
', '
)
gt_path
=
os
.
path
.
join
(
self
.
gt_folder
,
gt_path
)
lq_path
=
os
.
path
.
join
(
self
.
lq_folder
,
lq_path
)
self
.
paths
.
append
(
dict
([(
'gt_path'
,
gt_path
),
(
'lq_path'
,
lq_path
)]))
else
:
# disk backend
# it will scan the whole folder to get meta info
# it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file
self
.
paths
=
paired_paths_from_folder
([
self
.
lq_folder
,
self
.
gt_folder
],
[
'lq'
,
'gt'
],
self
.
filename_tmpl
)
def
__getitem__
(
self
,
index
):
if
self
.
file_client
is
None
:
self
.
file_client
=
FileClient
(
self
.
io_backend_opt
.
pop
(
'type'
),
**
self
.
io_backend_opt
)
scale
=
self
.
opt
[
'scale'
]
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
# image range: [0, 1], float32.
gt_path
=
self
.
paths
[
index
][
'gt_path'
]
img_bytes
=
self
.
file_client
.
get
(
gt_path
,
'gt'
)
img_gt
=
imfrombytes
(
img_bytes
,
float32
=
True
)
lq_path
=
self
.
paths
[
index
][
'lq_path'
]
img_bytes
=
self
.
file_client
.
get
(
lq_path
,
'lq'
)
img_lq
=
imfrombytes
(
img_bytes
,
float32
=
True
)
# augmentation for training
if
self
.
opt
[
'phase'
]
==
'train'
:
gt_size
=
self
.
opt
[
'gt_size'
]
# random crop
img_gt
,
img_lq
=
paired_random_crop
(
img_gt
,
img_lq
,
gt_size
,
scale
,
gt_path
)
# flip, rotation
img_gt
,
img_lq
=
augment
([
img_gt
,
img_lq
],
self
.
opt
[
'use_hflip'
],
self
.
opt
[
'use_rot'
])
# BGR to RGB, HWC to CHW, numpy to tensor
img_gt
,
img_lq
=
img2tensor
([
img_gt
,
img_lq
],
bgr2rgb
=
True
,
float32
=
True
)
# normalize
if
self
.
mean
is
not
None
or
self
.
std
is
not
None
:
normalize
(
img_lq
,
self
.
mean
,
self
.
std
,
inplace
=
True
)
normalize
(
img_gt
,
self
.
mean
,
self
.
std
,
inplace
=
True
)
return
{
'lq'
:
img_lq
,
'gt'
:
img_gt
,
'lq_path'
:
lq_path
,
'gt_path'
:
gt_path
}
def
__len__
(
self
):
return
len
(
self
.
paths
)
BasicSR/basicsr/data/reds_dataset.py
0 → 100644
View file @
e2696ece
import
numpy
as
np
import
random
import
torch
from
pathlib
import
Path
from
torch.utils
import
data
as
data
from
basicsr.data.transforms
import
augment
,
paired_random_crop
from
basicsr.utils
import
FileClient
,
get_root_logger
,
imfrombytes
,
img2tensor
from
basicsr.utils.flow_util
import
dequantize_flow
from
basicsr.utils.registry
import
DATASET_REGISTRY
@
DATASET_REGISTRY
.
register
()
class
REDSDataset
(
data
.
Dataset
):
"""REDS dataset for training.
The keys are generated from a meta info txt file.
basicsr/data/meta_info/meta_info_REDS_GT.txt
Each line contains:
1. subfolder (clip) name; 2. frame number; 3. image shape, separated by
a white space.
Examples:
000 100 (720,1280,3)
001 100 (720,1280,3)
...
Key examples: "000/00000000"
GT (gt): Ground-Truth;
LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
Args:
opt (dict): Config for train dataset. It contains the following keys:
dataroot_gt (str): Data root path for gt.
dataroot_lq (str): Data root path for lq.
dataroot_flow (str, optional): Data root path for flow.
meta_info_file (str): Path for meta information file.
val_partition (str): Validation partition types. 'REDS4' or 'official'.
io_backend (dict): IO backend type and other kwarg.
num_frame (int): Window size for input frames.
gt_size (int): Cropped patched size for gt patches.
interval_list (list): Interval list for temporal augmentation.
random_reverse (bool): Random reverse input frames.
use_hflip (bool): Use horizontal flips.
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
scale (bool): Scale, which will be added automatically.
"""
def
__init__
(
self
,
opt
):
super
(
REDSDataset
,
self
).
__init__
()
self
.
opt
=
opt
self
.
gt_root
,
self
.
lq_root
=
Path
(
opt
[
'dataroot_gt'
]),
Path
(
opt
[
'dataroot_lq'
])
self
.
flow_root
=
Path
(
opt
[
'dataroot_flow'
])
if
opt
[
'dataroot_flow'
]
is
not
None
else
None
assert
opt
[
'num_frame'
]
%
2
==
1
,
(
f
'num_frame should be odd number, but got
{
opt
[
"num_frame"
]
}
'
)
self
.
num_frame
=
opt
[
'num_frame'
]
self
.
num_half_frames
=
opt
[
'num_frame'
]
//
2
self
.
keys
=
[]
with
open
(
opt
[
'meta_info_file'
],
'r'
)
as
fin
:
for
line
in
fin
:
folder
,
frame_num
,
_
=
line
.
split
(
' '
)
self
.
keys
.
extend
([
f
'
{
folder
}
/
{
i
:
08
d
}
'
for
i
in
range
(
int
(
frame_num
))])
# remove the video clips used in validation
if
opt
[
'val_partition'
]
==
'REDS4'
:
val_partition
=
[
'000'
,
'011'
,
'015'
,
'020'
]
elif
opt
[
'val_partition'
]
==
'official'
:
val_partition
=
[
f
'
{
v
:
03
d
}
'
for
v
in
range
(
240
,
270
)]
else
:
raise
ValueError
(
f
'Wrong validation partition
{
opt
[
"val_partition"
]
}
.'
f
"Supported ones are ['official', 'REDS4']."
)
self
.
keys
=
[
v
for
v
in
self
.
keys
if
v
.
split
(
'/'
)[
0
]
not
in
val_partition
]
# file client (io backend)
self
.
file_client
=
None
self
.
io_backend_opt
=
opt
[
'io_backend'
]
self
.
is_lmdb
=
False
if
self
.
io_backend_opt
[
'type'
]
==
'lmdb'
:
self
.
is_lmdb
=
True
if
self
.
flow_root
is
not
None
:
self
.
io_backend_opt
[
'db_paths'
]
=
[
self
.
lq_root
,
self
.
gt_root
,
self
.
flow_root
]
self
.
io_backend_opt
[
'client_keys'
]
=
[
'lq'
,
'gt'
,
'flow'
]
else
:
self
.
io_backend_opt
[
'db_paths'
]
=
[
self
.
lq_root
,
self
.
gt_root
]
self
.
io_backend_opt
[
'client_keys'
]
=
[
'lq'
,
'gt'
]
# temporal augmentation configs
self
.
interval_list
=
opt
[
'interval_list'
]
self
.
random_reverse
=
opt
[
'random_reverse'
]
interval_str
=
','
.
join
(
str
(
x
)
for
x
in
opt
[
'interval_list'
])
logger
=
get_root_logger
()
logger
.
info
(
f
'Temporal augmentation interval list: [
{
interval_str
}
]; '
f
'random reverse is
{
self
.
random_reverse
}
.'
)
def
__getitem__
(
self
,
index
):
if
self
.
file_client
is
None
:
self
.
file_client
=
FileClient
(
self
.
io_backend_opt
.
pop
(
'type'
),
**
self
.
io_backend_opt
)
scale
=
self
.
opt
[
'scale'
]
gt_size
=
self
.
opt
[
'gt_size'
]
key
=
self
.
keys
[
index
]
clip_name
,
frame_name
=
key
.
split
(
'/'
)
# key example: 000/00000000
center_frame_idx
=
int
(
frame_name
)
# determine the neighboring frames
interval
=
random
.
choice
(
self
.
interval_list
)
# ensure not exceeding the borders
start_frame_idx
=
center_frame_idx
-
self
.
num_half_frames
*
interval
end_frame_idx
=
center_frame_idx
+
self
.
num_half_frames
*
interval
# each clip has 100 frames starting from 0 to 99
while
(
start_frame_idx
<
0
)
or
(
end_frame_idx
>
99
):
center_frame_idx
=
random
.
randint
(
0
,
99
)
start_frame_idx
=
(
center_frame_idx
-
self
.
num_half_frames
*
interval
)
end_frame_idx
=
center_frame_idx
+
self
.
num_half_frames
*
interval
frame_name
=
f
'
{
center_frame_idx
:
08
d
}
'
neighbor_list
=
list
(
range
(
start_frame_idx
,
end_frame_idx
+
1
,
interval
))
# random reverse
if
self
.
random_reverse
and
random
.
random
()
<
0.5
:
neighbor_list
.
reverse
()
assert
len
(
neighbor_list
)
==
self
.
num_frame
,
(
f
'Wrong length of neighbor list:
{
len
(
neighbor_list
)
}
'
)
# get the GT frame (as the center frame)
if
self
.
is_lmdb
:
img_gt_path
=
f
'
{
clip_name
}
/
{
frame_name
}
'
else
:
img_gt_path
=
self
.
gt_root
/
clip_name
/
f
'
{
frame_name
}
.png'
img_bytes
=
self
.
file_client
.
get
(
img_gt_path
,
'gt'
)
img_gt
=
imfrombytes
(
img_bytes
,
float32
=
True
)
# get the neighboring LQ frames
img_lqs
=
[]
for
neighbor
in
neighbor_list
:
if
self
.
is_lmdb
:
img_lq_path
=
f
'
{
clip_name
}
/
{
neighbor
:
08
d
}
'
else
:
img_lq_path
=
self
.
lq_root
/
clip_name
/
f
'
{
neighbor
:
08
d
}
.png'
img_bytes
=
self
.
file_client
.
get
(
img_lq_path
,
'lq'
)
img_lq
=
imfrombytes
(
img_bytes
,
float32
=
True
)
img_lqs
.
append
(
img_lq
)
# get flows
if
self
.
flow_root
is
not
None
:
img_flows
=
[]
# read previous flows
for
i
in
range
(
self
.
num_half_frames
,
0
,
-
1
):
if
self
.
is_lmdb
:
flow_path
=
f
'
{
clip_name
}
/
{
frame_name
}
_p
{
i
}
'
else
:
flow_path
=
(
self
.
flow_root
/
clip_name
/
f
'
{
frame_name
}
_p
{
i
}
.png'
)
img_bytes
=
self
.
file_client
.
get
(
flow_path
,
'flow'
)
cat_flow
=
imfrombytes
(
img_bytes
,
flag
=
'grayscale'
,
float32
=
False
)
# uint8, [0, 255]
dx
,
dy
=
np
.
split
(
cat_flow
,
2
,
axis
=
0
)
flow
=
dequantize_flow
(
dx
,
dy
,
max_val
=
20
,
denorm
=
False
)
# we use max_val 20 here.
img_flows
.
append
(
flow
)
# read next flows
for
i
in
range
(
1
,
self
.
num_half_frames
+
1
):
if
self
.
is_lmdb
:
flow_path
=
f
'
{
clip_name
}
/
{
frame_name
}
_n
{
i
}
'
else
:
flow_path
=
(
self
.
flow_root
/
clip_name
/
f
'
{
frame_name
}
_n
{
i
}
.png'
)
img_bytes
=
self
.
file_client
.
get
(
flow_path
,
'flow'
)
cat_flow
=
imfrombytes
(
img_bytes
,
flag
=
'grayscale'
,
float32
=
False
)
# uint8, [0, 255]
dx
,
dy
=
np
.
split
(
cat_flow
,
2
,
axis
=
0
)
flow
=
dequantize_flow
(
dx
,
dy
,
max_val
=
20
,
denorm
=
False
)
# we use max_val 20 here.
img_flows
.
append
(
flow
)
# for random crop, here, img_flows and img_lqs have the same
# spatial size
img_lqs
.
extend
(
img_flows
)
# randomly crop
img_gt
,
img_lqs
=
paired_random_crop
(
img_gt
,
img_lqs
,
gt_size
,
scale
,
img_gt_path
)
if
self
.
flow_root
is
not
None
:
img_lqs
,
img_flows
=
img_lqs
[:
self
.
num_frame
],
img_lqs
[
self
.
num_frame
:]
# augmentation - flip, rotate
img_lqs
.
append
(
img_gt
)
if
self
.
flow_root
is
not
None
:
img_results
,
img_flows
=
augment
(
img_lqs
,
self
.
opt
[
'use_hflip'
],
self
.
opt
[
'use_rot'
],
img_flows
)
else
:
img_results
=
augment
(
img_lqs
,
self
.
opt
[
'use_hflip'
],
self
.
opt
[
'use_rot'
])
img_results
=
img2tensor
(
img_results
)
img_lqs
=
torch
.
stack
(
img_results
[
0
:
-
1
],
dim
=
0
)
img_gt
=
img_results
[
-
1
]
if
self
.
flow_root
is
not
None
:
img_flows
=
img2tensor
(
img_flows
)
# add the zero center flow
img_flows
.
insert
(
self
.
num_half_frames
,
torch
.
zeros_like
(
img_flows
[
0
]))
img_flows
=
torch
.
stack
(
img_flows
,
dim
=
0
)
# img_lqs: (t, c, h, w)
# img_flows: (t, 2, h, w)
# img_gt: (c, h, w)
# key: str
if
self
.
flow_root
is
not
None
:
return
{
'lq'
:
img_lqs
,
'flow'
:
img_flows
,
'gt'
:
img_gt
,
'key'
:
key
}
else
:
return
{
'lq'
:
img_lqs
,
'gt'
:
img_gt
,
'key'
:
key
}
def
__len__
(
self
):
return
len
(
self
.
keys
)
@
DATASET_REGISTRY
.
register
()
class
REDSRecurrentDataset
(
data
.
Dataset
):
"""REDS dataset for training recurrent networks.
The keys are generated from a meta info txt file.
basicsr/data/meta_info/meta_info_REDS_GT.txt
Each line contains:
1. subfolder (clip) name; 2. frame number; 3. image shape, separated by
a white space.
Examples:
000 100 (720,1280,3)
001 100 (720,1280,3)
...
Key examples: "000/00000000"
GT (gt): Ground-Truth;
LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
Args:
opt (dict): Config for train dataset. It contains the following keys:
dataroot_gt (str): Data root path for gt.
dataroot_lq (str): Data root path for lq.
dataroot_flow (str, optional): Data root path for flow.
meta_info_file (str): Path for meta information file.
val_partition (str): Validation partition types. 'REDS4' or 'official'.
io_backend (dict): IO backend type and other kwarg.
num_frame (int): Window size for input frames.
gt_size (int): Cropped patched size for gt patches.
interval_list (list): Interval list for temporal augmentation.
random_reverse (bool): Random reverse input frames.
use_hflip (bool): Use horizontal flips.
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
scale (bool): Scale, which will be added automatically.
"""
def
__init__
(
self
,
opt
):
super
(
REDSRecurrentDataset
,
self
).
__init__
()
self
.
opt
=
opt
self
.
gt_root
,
self
.
lq_root
=
Path
(
opt
[
'dataroot_gt'
]),
Path
(
opt
[
'dataroot_lq'
])
self
.
num_frame
=
opt
[
'num_frame'
]
self
.
keys
=
[]
with
open
(
opt
[
'meta_info_file'
],
'r'
)
as
fin
:
for
line
in
fin
:
folder
,
frame_num
,
_
=
line
.
split
(
' '
)
self
.
keys
.
extend
([
f
'
{
folder
}
/
{
i
:
08
d
}
'
for
i
in
range
(
int
(
frame_num
))])
# remove the video clips used in validation
if
opt
[
'val_partition'
]
==
'REDS4'
:
val_partition
=
[
'000'
,
'011'
,
'015'
,
'020'
]
elif
opt
[
'val_partition'
]
==
'official'
:
val_partition
=
[
f
'
{
v
:
03
d
}
'
for
v
in
range
(
240
,
270
)]
else
:
raise
ValueError
(
f
'Wrong validation partition
{
opt
[
"val_partition"
]
}
.'
f
"Supported ones are ['official', 'REDS4']."
)
if
opt
[
'test_mode'
]:
self
.
keys
=
[
v
for
v
in
self
.
keys
if
v
.
split
(
'/'
)[
0
]
in
val_partition
]
else
:
self
.
keys
=
[
v
for
v
in
self
.
keys
if
v
.
split
(
'/'
)[
0
]
not
in
val_partition
]
# file client (io backend)
self
.
file_client
=
None
self
.
io_backend_opt
=
opt
[
'io_backend'
]
self
.
is_lmdb
=
False
if
self
.
io_backend_opt
[
'type'
]
==
'lmdb'
:
self
.
is_lmdb
=
True
if
hasattr
(
self
,
'flow_root'
)
and
self
.
flow_root
is
not
None
:
self
.
io_backend_opt
[
'db_paths'
]
=
[
self
.
lq_root
,
self
.
gt_root
,
self
.
flow_root
]
self
.
io_backend_opt
[
'client_keys'
]
=
[
'lq'
,
'gt'
,
'flow'
]
else
:
self
.
io_backend_opt
[
'db_paths'
]
=
[
self
.
lq_root
,
self
.
gt_root
]
self
.
io_backend_opt
[
'client_keys'
]
=
[
'lq'
,
'gt'
]
# temporal augmentation configs
self
.
interval_list
=
opt
.
get
(
'interval_list'
,
[
1
])
self
.
random_reverse
=
opt
.
get
(
'random_reverse'
,
False
)
interval_str
=
','
.
join
(
str
(
x
)
for
x
in
self
.
interval_list
)
logger
=
get_root_logger
()
logger
.
info
(
f
'Temporal augmentation interval list: [
{
interval_str
}
]; '
f
'random reverse is
{
self
.
random_reverse
}
.'
)
def
__getitem__
(
self
,
index
):
if
self
.
file_client
is
None
:
self
.
file_client
=
FileClient
(
self
.
io_backend_opt
.
pop
(
'type'
),
**
self
.
io_backend_opt
)
scale
=
self
.
opt
[
'scale'
]
gt_size
=
self
.
opt
[
'gt_size'
]
key
=
self
.
keys
[
index
]
clip_name
,
frame_name
=
key
.
split
(
'/'
)
# key example: 000/00000000
# determine the neighboring frames
interval
=
random
.
choice
(
self
.
interval_list
)
# ensure not exceeding the borders
start_frame_idx
=
int
(
frame_name
)
if
start_frame_idx
>
100
-
self
.
num_frame
*
interval
:
start_frame_idx
=
random
.
randint
(
0
,
100
-
self
.
num_frame
*
interval
)
end_frame_idx
=
start_frame_idx
+
self
.
num_frame
*
interval
neighbor_list
=
list
(
range
(
start_frame_idx
,
end_frame_idx
,
interval
))
# random reverse
if
self
.
random_reverse
and
random
.
random
()
<
0.5
:
neighbor_list
.
reverse
()
# get the neighboring LQ and GT frames
img_lqs
=
[]
img_gts
=
[]
for
neighbor
in
neighbor_list
:
if
self
.
is_lmdb
:
img_lq_path
=
f
'
{
clip_name
}
/
{
neighbor
:
08
d
}
'
img_gt_path
=
f
'
{
clip_name
}
/
{
neighbor
:
08
d
}
'
else
:
img_lq_path
=
self
.
lq_root
/
clip_name
/
f
'
{
neighbor
:
08
d
}
.png'
img_gt_path
=
self
.
gt_root
/
clip_name
/
f
'
{
neighbor
:
08
d
}
.png'
# get LQ
img_bytes
=
self
.
file_client
.
get
(
img_lq_path
,
'lq'
)
img_lq
=
imfrombytes
(
img_bytes
,
float32
=
True
)
img_lqs
.
append
(
img_lq
)
# get GT
img_bytes
=
self
.
file_client
.
get
(
img_gt_path
,
'gt'
)
img_gt
=
imfrombytes
(
img_bytes
,
float32
=
True
)
img_gts
.
append
(
img_gt
)
# randomly crop
img_gts
,
img_lqs
=
paired_random_crop
(
img_gts
,
img_lqs
,
gt_size
,
scale
,
img_gt_path
)
# augmentation - flip, rotate
img_lqs
.
extend
(
img_gts
)
img_results
=
augment
(
img_lqs
,
self
.
opt
[
'use_hflip'
],
self
.
opt
[
'use_rot'
])
img_results
=
img2tensor
(
img_results
)
img_gts
=
torch
.
stack
(
img_results
[
len
(
img_lqs
)
//
2
:],
dim
=
0
)
img_lqs
=
torch
.
stack
(
img_results
[:
len
(
img_lqs
)
//
2
],
dim
=
0
)
# img_lqs: (t, c, h, w)
# img_gts: (t, c, h, w)
# key: str
return
{
'lq'
:
img_lqs
,
'gt'
:
img_gts
,
'key'
:
key
}
def
__len__
(
self
):
return
len
(
self
.
keys
)
BasicSR/basicsr/data/single_image_dataset.py
0 → 100644
View file @
e2696ece
from
os
import
path
as
osp
from
torch.utils
import
data
as
data
from
torchvision.transforms.functional
import
normalize
from
basicsr.data.data_util
import
paths_from_lmdb
from
basicsr.utils
import
FileClient
,
imfrombytes
,
img2tensor
,
rgb2ycbcr
,
scandir
from
basicsr.utils.registry
import
DATASET_REGISTRY
@
DATASET_REGISTRY
.
register
()
class
SingleImageDataset
(
data
.
Dataset
):
"""Read only lq images in the test phase.
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc).
There are two modes:
1. 'meta_info_file': Use meta information file to generate paths.
2. 'folder': Scan folders to generate paths.
Args:
opt (dict): Config for train datasets. It contains the following keys:
dataroot_lq (str): Data root path for lq.
meta_info_file (str): Path for meta information file.
io_backend (dict): IO backend type and other kwarg.
"""
def
__init__
(
self
,
opt
):
super
(
SingleImageDataset
,
self
).
__init__
()
self
.
opt
=
opt
# file client (io backend)
self
.
file_client
=
None
self
.
io_backend_opt
=
opt
[
'io_backend'
]
self
.
mean
=
opt
[
'mean'
]
if
'mean'
in
opt
else
None
self
.
std
=
opt
[
'std'
]
if
'std'
in
opt
else
None
self
.
lq_folder
=
opt
[
'dataroot_lq'
]
if
self
.
io_backend_opt
[
'type'
]
==
'lmdb'
:
self
.
io_backend_opt
[
'db_paths'
]
=
[
self
.
lq_folder
]
self
.
io_backend_opt
[
'client_keys'
]
=
[
'lq'
]
self
.
paths
=
paths_from_lmdb
(
self
.
lq_folder
)
elif
'meta_info_file'
in
self
.
opt
:
with
open
(
self
.
opt
[
'meta_info_file'
],
'r'
)
as
fin
:
self
.
paths
=
[
osp
.
join
(
self
.
lq_folder
,
line
.
rstrip
().
split
(
' '
)[
0
])
for
line
in
fin
]
else
:
self
.
paths
=
sorted
(
list
(
scandir
(
self
.
lq_folder
,
full_path
=
True
)))
def
__getitem__
(
self
,
index
):
if
self
.
file_client
is
None
:
self
.
file_client
=
FileClient
(
self
.
io_backend_opt
.
pop
(
'type'
),
**
self
.
io_backend_opt
)
# load lq image
lq_path
=
self
.
paths
[
index
]
img_bytes
=
self
.
file_client
.
get
(
lq_path
,
'lq'
)
img_lq
=
imfrombytes
(
img_bytes
,
float32
=
True
)
# color space transform
if
'color'
in
self
.
opt
and
self
.
opt
[
'color'
]
==
'y'
:
img_lq
=
rgb2ycbcr
(
img_lq
,
y_only
=
True
)[...,
None
]
# BGR to RGB, HWC to CHW, numpy to tensor
img_lq
=
img2tensor
(
img_lq
,
bgr2rgb
=
True
,
float32
=
True
)
# normalize
if
self
.
mean
is
not
None
or
self
.
std
is
not
None
:
normalize
(
img_lq
,
self
.
mean
,
self
.
std
,
inplace
=
True
)
return
{
'lq'
:
img_lq
,
'lq_path'
:
lq_path
}
def
__len__
(
self
):
return
len
(
self
.
paths
)
BasicSR/basicsr/data/transforms.py
0 → 100644
View file @
e2696ece
import
cv2
import
random
import
torch
def
mod_crop
(
img
,
scale
):
"""Mod crop images, used during testing.
Args:
img (ndarray): Input image.
scale (int): Scale factor.
Returns:
ndarray: Result image.
"""
img
=
img
.
copy
()
if
img
.
ndim
in
(
2
,
3
):
h
,
w
=
img
.
shape
[
0
],
img
.
shape
[
1
]
h_remainder
,
w_remainder
=
h
%
scale
,
w
%
scale
img
=
img
[:
h
-
h_remainder
,
:
w
-
w_remainder
,
...]
else
:
raise
ValueError
(
f
'Wrong img ndim:
{
img
.
ndim
}
.'
)
return
img
def
paired_random_crop
(
img_gts
,
img_lqs
,
gt_patch_size
,
scale
,
gt_path
=
None
):
"""Paired random crop. Support Numpy array and Tensor inputs.
It crops lists of lq and gt images with corresponding locations.
Args:
img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
gt_patch_size (int): GT patch size.
scale (int): Scale factor.
gt_path (str): Path to ground-truth. Default: None.
Returns:
list[ndarray] | ndarray: GT images and LQ images. If returned results
only have one element, just return ndarray.
"""
if
not
isinstance
(
img_gts
,
list
):
img_gts
=
[
img_gts
]
if
not
isinstance
(
img_lqs
,
list
):
img_lqs
=
[
img_lqs
]
# determine input type: Numpy array or Tensor
input_type
=
'Tensor'
if
torch
.
is_tensor
(
img_gts
[
0
])
else
'Numpy'
if
input_type
==
'Tensor'
:
h_lq
,
w_lq
=
img_lqs
[
0
].
size
()[
-
2
:]
h_gt
,
w_gt
=
img_gts
[
0
].
size
()[
-
2
:]
else
:
h_lq
,
w_lq
=
img_lqs
[
0
].
shape
[
0
:
2
]
h_gt
,
w_gt
=
img_gts
[
0
].
shape
[
0
:
2
]
lq_patch_size
=
gt_patch_size
//
scale
if
h_gt
!=
h_lq
*
scale
or
w_gt
!=
w_lq
*
scale
:
raise
ValueError
(
f
'Scale mismatches. GT (
{
h_gt
}
,
{
w_gt
}
) is not
{
scale
}
x '
,
f
'multiplication of LQ (
{
h_lq
}
,
{
w_lq
}
).'
)
if
h_lq
<
lq_patch_size
or
w_lq
<
lq_patch_size
:
raise
ValueError
(
f
'LQ (
{
h_lq
}
,
{
w_lq
}
) is smaller than patch size '
f
'(
{
lq_patch_size
}
,
{
lq_patch_size
}
). '
f
'Please remove
{
gt_path
}
.'
)
# randomly choose top and left coordinates for lq patch
top
=
random
.
randint
(
0
,
h_lq
-
lq_patch_size
)
left
=
random
.
randint
(
0
,
w_lq
-
lq_patch_size
)
# crop lq patch
if
input_type
==
'Tensor'
:
img_lqs
=
[
v
[:,
:,
top
:
top
+
lq_patch_size
,
left
:
left
+
lq_patch_size
]
for
v
in
img_lqs
]
else
:
img_lqs
=
[
v
[
top
:
top
+
lq_patch_size
,
left
:
left
+
lq_patch_size
,
...]
for
v
in
img_lqs
]
# crop corresponding gt patch
top_gt
,
left_gt
=
int
(
top
*
scale
),
int
(
left
*
scale
)
if
input_type
==
'Tensor'
:
img_gts
=
[
v
[:,
:,
top_gt
:
top_gt
+
gt_patch_size
,
left_gt
:
left_gt
+
gt_patch_size
]
for
v
in
img_gts
]
else
:
img_gts
=
[
v
[
top_gt
:
top_gt
+
gt_patch_size
,
left_gt
:
left_gt
+
gt_patch_size
,
...]
for
v
in
img_gts
]
if
len
(
img_gts
)
==
1
:
img_gts
=
img_gts
[
0
]
if
len
(
img_lqs
)
==
1
:
img_lqs
=
img_lqs
[
0
]
return
img_gts
,
img_lqs
def
augment
(
imgs
,
hflip
=
True
,
rotation
=
True
,
flows
=
None
,
return_status
=
False
):
"""Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
We use vertical flip and transpose for rotation implementation.
All the images in the list use the same augmentation.
Args:
imgs (list[ndarray] | ndarray): Images to be augmented. If the input
is an ndarray, it will be transformed to a list.
hflip (bool): Horizontal flip. Default: True.
rotation (bool): Ratotation. Default: True.
flows (list[ndarray]: Flows to be augmented. If the input is an
ndarray, it will be transformed to a list.
Dimension is (h, w, 2). Default: None.
return_status (bool): Return the status of flip and rotation.
Default: False.
Returns:
list[ndarray] | ndarray: Augmented images and flows. If returned
results only have one element, just return ndarray.
"""
hflip
=
hflip
and
random
.
random
()
<
0.5
vflip
=
rotation
and
random
.
random
()
<
0.5
rot90
=
rotation
and
random
.
random
()
<
0.5
def
_augment
(
img
):
if
hflip
:
# horizontal
cv2
.
flip
(
img
,
1
,
img
)
if
vflip
:
# vertical
cv2
.
flip
(
img
,
0
,
img
)
if
rot90
:
img
=
img
.
transpose
(
1
,
0
,
2
)
return
img
def
_augment_flow
(
flow
):
if
hflip
:
# horizontal
cv2
.
flip
(
flow
,
1
,
flow
)
flow
[:,
:,
0
]
*=
-
1
if
vflip
:
# vertical
cv2
.
flip
(
flow
,
0
,
flow
)
flow
[:,
:,
1
]
*=
-
1
if
rot90
:
flow
=
flow
.
transpose
(
1
,
0
,
2
)
flow
=
flow
[:,
:,
[
1
,
0
]]
return
flow
if
not
isinstance
(
imgs
,
list
):
imgs
=
[
imgs
]
imgs
=
[
_augment
(
img
)
for
img
in
imgs
]
if
len
(
imgs
)
==
1
:
imgs
=
imgs
[
0
]
if
flows
is
not
None
:
if
not
isinstance
(
flows
,
list
):
flows
=
[
flows
]
flows
=
[
_augment_flow
(
flow
)
for
flow
in
flows
]
if
len
(
flows
)
==
1
:
flows
=
flows
[
0
]
return
imgs
,
flows
else
:
if
return_status
:
return
imgs
,
(
hflip
,
vflip
,
rot90
)
else
:
return
imgs
def
img_rotate
(
img
,
angle
,
center
=
None
,
scale
=
1.0
):
"""Rotate image.
Args:
img (ndarray): Image to be rotated.
angle (float): Rotation angle in degrees. Positive values mean
counter-clockwise rotation.
center (tuple[int]): Rotation center. If the center is None,
initialize it as the center of the image. Default: None.
scale (float): Isotropic scale factor. Default: 1.0.
"""
(
h
,
w
)
=
img
.
shape
[:
2
]
if
center
is
None
:
center
=
(
w
//
2
,
h
//
2
)
matrix
=
cv2
.
getRotationMatrix2D
(
center
,
angle
,
scale
)
rotated_img
=
cv2
.
warpAffine
(
img
,
matrix
,
(
w
,
h
))
return
rotated_img
BasicSR/basicsr/data/video_test_dataset.py
0 → 100644
View file @
e2696ece
import
glob
import
torch
from
os
import
path
as
osp
from
torch.utils
import
data
as
data
from
basicsr.data.data_util
import
duf_downsample
,
generate_frame_indices
,
read_img_seq
from
basicsr.utils
import
get_root_logger
,
scandir
from
basicsr.utils.registry
import
DATASET_REGISTRY
@
DATASET_REGISTRY
.
register
()
class
VideoTestDataset
(
data
.
Dataset
):
"""Video test dataset.
Supported datasets: Vid4, REDS4, REDSofficial.
More generally, it supports testing dataset with following structures:
::
dataroot
├── subfolder1
├── frame000
├── frame001
├── ...
├── subfolder2
├── frame000
├── frame001
├── ...
├── ...
For testing datasets, there is no need to prepare LMDB files.
Args:
opt (dict): Config for train dataset. It contains the following keys:
dataroot_gt (str): Data root path for gt.
dataroot_lq (str): Data root path for lq.
io_backend (dict): IO backend type and other kwarg.
cache_data (bool): Whether to cache testing datasets.
name (str): Dataset name.
meta_info_file (str): The path to the file storing the list of test folders. If not provided, all the folders
in the dataroot will be used.
num_frame (int): Window size for input frames.
padding (str): Padding mode.
"""
def
__init__
(
self
,
opt
):
super
(
VideoTestDataset
,
self
).
__init__
()
self
.
opt
=
opt
self
.
cache_data
=
opt
[
'cache_data'
]
self
.
gt_root
,
self
.
lq_root
=
opt
[
'dataroot_gt'
],
opt
[
'dataroot_lq'
]
self
.
data_info
=
{
'lq_path'
:
[],
'gt_path'
:
[],
'folder'
:
[],
'idx'
:
[],
'border'
:
[]}
# file client (io backend)
self
.
file_client
=
None
self
.
io_backend_opt
=
opt
[
'io_backend'
]
assert
self
.
io_backend_opt
[
'type'
]
!=
'lmdb'
,
'No need to use lmdb during validation/test.'
logger
=
get_root_logger
()
logger
.
info
(
f
'Generate data info for VideoTestDataset -
{
opt
[
"name"
]
}
'
)
self
.
imgs_lq
,
self
.
imgs_gt
=
{},
{}
if
'meta_info_file'
in
opt
:
with
open
(
opt
[
'meta_info_file'
],
'r'
)
as
fin
:
subfolders
=
[
line
.
split
(
' '
)[
0
]
for
line
in
fin
]
subfolders_lq
=
[
osp
.
join
(
self
.
lq_root
,
key
)
for
key
in
subfolders
]
subfolders_gt
=
[
osp
.
join
(
self
.
gt_root
,
key
)
for
key
in
subfolders
]
else
:
subfolders_lq
=
sorted
(
glob
.
glob
(
osp
.
join
(
self
.
lq_root
,
'*'
)))
subfolders_gt
=
sorted
(
glob
.
glob
(
osp
.
join
(
self
.
gt_root
,
'*'
)))
if
opt
[
'name'
].
lower
()
in
[
'vid4'
,
'reds4'
,
'redsofficial'
]:
for
subfolder_lq
,
subfolder_gt
in
zip
(
subfolders_lq
,
subfolders_gt
):
# get frame list for lq and gt
subfolder_name
=
osp
.
basename
(
subfolder_lq
)
img_paths_lq
=
sorted
(
list
(
scandir
(
subfolder_lq
,
full_path
=
True
)))
img_paths_gt
=
sorted
(
list
(
scandir
(
subfolder_gt
,
full_path
=
True
)))
max_idx
=
len
(
img_paths_lq
)
assert
max_idx
==
len
(
img_paths_gt
),
(
f
'Different number of images in lq (
{
max_idx
}
)'
f
' and gt folders (
{
len
(
img_paths_gt
)
}
)'
)
self
.
data_info
[
'lq_path'
].
extend
(
img_paths_lq
)
self
.
data_info
[
'gt_path'
].
extend
(
img_paths_gt
)
self
.
data_info
[
'folder'
].
extend
([
subfolder_name
]
*
max_idx
)
for
i
in
range
(
max_idx
):
self
.
data_info
[
'idx'
].
append
(
f
'
{
i
}
/
{
max_idx
}
'
)
border_l
=
[
0
]
*
max_idx
for
i
in
range
(
self
.
opt
[
'num_frame'
]
//
2
):
border_l
[
i
]
=
1
border_l
[
max_idx
-
i
-
1
]
=
1
self
.
data_info
[
'border'
].
extend
(
border_l
)
# cache data or save the frame list
if
self
.
cache_data
:
logger
.
info
(
f
'Cache
{
subfolder_name
}
for VideoTestDataset...'
)
self
.
imgs_lq
[
subfolder_name
]
=
read_img_seq
(
img_paths_lq
)
self
.
imgs_gt
[
subfolder_name
]
=
read_img_seq
(
img_paths_gt
)
else
:
self
.
imgs_lq
[
subfolder_name
]
=
img_paths_lq
self
.
imgs_gt
[
subfolder_name
]
=
img_paths_gt
else
:
raise
ValueError
(
f
'Non-supported video test dataset:
{
type
(
opt
[
"name"
])
}
'
)
def
__getitem__
(
self
,
index
):
folder
=
self
.
data_info
[
'folder'
][
index
]
idx
,
max_idx
=
self
.
data_info
[
'idx'
][
index
].
split
(
'/'
)
idx
,
max_idx
=
int
(
idx
),
int
(
max_idx
)
border
=
self
.
data_info
[
'border'
][
index
]
lq_path
=
self
.
data_info
[
'lq_path'
][
index
]
select_idx
=
generate_frame_indices
(
idx
,
max_idx
,
self
.
opt
[
'num_frame'
],
padding
=
self
.
opt
[
'padding'
])
if
self
.
cache_data
:
imgs_lq
=
self
.
imgs_lq
[
folder
].
index_select
(
0
,
torch
.
LongTensor
(
select_idx
))
img_gt
=
self
.
imgs_gt
[
folder
][
idx
]
else
:
img_paths_lq
=
[
self
.
imgs_lq
[
folder
][
i
]
for
i
in
select_idx
]
imgs_lq
=
read_img_seq
(
img_paths_lq
)
img_gt
=
read_img_seq
([
self
.
imgs_gt
[
folder
][
idx
]])
img_gt
.
squeeze_
(
0
)
return
{
'lq'
:
imgs_lq
,
# (t, c, h, w)
'gt'
:
img_gt
,
# (c, h, w)
'folder'
:
folder
,
# folder name
'idx'
:
self
.
data_info
[
'idx'
][
index
],
# e.g., 0/99
'border'
:
border
,
# 1 for border, 0 for non-border
'lq_path'
:
lq_path
# center frame
}
def
__len__
(
self
):
return
len
(
self
.
data_info
[
'gt_path'
])
@
DATASET_REGISTRY
.
register
()
class
VideoTestVimeo90KDataset
(
data
.
Dataset
):
"""Video test dataset for Vimeo90k-Test dataset.
It only keeps the center frame for testing.
For testing datasets, there is no need to prepare LMDB files.
Args:
opt (dict): Config for train dataset. It contains the following keys:
dataroot_gt (str): Data root path for gt.
dataroot_lq (str): Data root path for lq.
io_backend (dict): IO backend type and other kwarg.
cache_data (bool): Whether to cache testing datasets.
name (str): Dataset name.
meta_info_file (str): The path to the file storing the list of test folders. If not provided, all the folders
in the dataroot will be used.
num_frame (int): Window size for input frames.
padding (str): Padding mode.
"""
def
__init__
(
self
,
opt
):
super
(
VideoTestVimeo90KDataset
,
self
).
__init__
()
self
.
opt
=
opt
self
.
cache_data
=
opt
[
'cache_data'
]
if
self
.
cache_data
:
raise
NotImplementedError
(
'cache_data in Vimeo90K-Test dataset is not implemented.'
)
self
.
gt_root
,
self
.
lq_root
=
opt
[
'dataroot_gt'
],
opt
[
'dataroot_lq'
]
self
.
data_info
=
{
'lq_path'
:
[],
'gt_path'
:
[],
'folder'
:
[],
'idx'
:
[],
'border'
:
[]}
neighbor_list
=
[
i
+
(
9
-
opt
[
'num_frame'
])
//
2
for
i
in
range
(
opt
[
'num_frame'
])]
# file client (io backend)
self
.
file_client
=
None
self
.
io_backend_opt
=
opt
[
'io_backend'
]
assert
self
.
io_backend_opt
[
'type'
]
!=
'lmdb'
,
'No need to use lmdb during validation/test.'
logger
=
get_root_logger
()
logger
.
info
(
f
'Generate data info for VideoTestDataset -
{
opt
[
"name"
]
}
'
)
with
open
(
opt
[
'meta_info_file'
],
'r'
)
as
fin
:
subfolders
=
[
line
.
split
(
' '
)[
0
]
for
line
in
fin
]
for
idx
,
subfolder
in
enumerate
(
subfolders
):
gt_path
=
osp
.
join
(
self
.
gt_root
,
subfolder
,
'im4.png'
)
self
.
data_info
[
'gt_path'
].
append
(
gt_path
)
lq_paths
=
[
osp
.
join
(
self
.
lq_root
,
subfolder
,
f
'im
{
i
}
.png'
)
for
i
in
neighbor_list
]
self
.
data_info
[
'lq_path'
].
append
(
lq_paths
)
self
.
data_info
[
'folder'
].
append
(
'vimeo90k'
)
self
.
data_info
[
'idx'
].
append
(
f
'
{
idx
}
/
{
len
(
subfolders
)
}
'
)
self
.
data_info
[
'border'
].
append
(
0
)
def
__getitem__
(
self
,
index
):
lq_path
=
self
.
data_info
[
'lq_path'
][
index
]
gt_path
=
self
.
data_info
[
'gt_path'
][
index
]
imgs_lq
=
read_img_seq
(
lq_path
)
img_gt
=
read_img_seq
([
gt_path
])
img_gt
.
squeeze_
(
0
)
return
{
'lq'
:
imgs_lq
,
# (t, c, h, w)
'gt'
:
img_gt
,
# (c, h, w)
'folder'
:
self
.
data_info
[
'folder'
][
index
],
# folder name
'idx'
:
self
.
data_info
[
'idx'
][
index
],
# e.g., 0/843
'border'
:
self
.
data_info
[
'border'
][
index
],
# 0 for non-border
'lq_path'
:
lq_path
[
self
.
opt
[
'num_frame'
]
//
2
]
# center frame
}
def
__len__
(
self
):
return
len
(
self
.
data_info
[
'gt_path'
])
@
DATASET_REGISTRY
.
register
()
class
VideoTestDUFDataset
(
VideoTestDataset
):
""" Video test dataset for DUF dataset.
Args:
opt (dict): Config for train dataset. Most of keys are the same as VideoTestDataset.
It has the following extra keys:
use_duf_downsampling (bool): Whether to use duf downsampling to generate low-resolution frames.
scale (bool): Scale, which will be added automatically.
"""
def
__getitem__
(
self
,
index
):
folder
=
self
.
data_info
[
'folder'
][
index
]
idx
,
max_idx
=
self
.
data_info
[
'idx'
][
index
].
split
(
'/'
)
idx
,
max_idx
=
int
(
idx
),
int
(
max_idx
)
border
=
self
.
data_info
[
'border'
][
index
]
lq_path
=
self
.
data_info
[
'lq_path'
][
index
]
select_idx
=
generate_frame_indices
(
idx
,
max_idx
,
self
.
opt
[
'num_frame'
],
padding
=
self
.
opt
[
'padding'
])
if
self
.
cache_data
:
if
self
.
opt
[
'use_duf_downsampling'
]:
# read imgs_gt to generate low-resolution frames
imgs_lq
=
self
.
imgs_gt
[
folder
].
index_select
(
0
,
torch
.
LongTensor
(
select_idx
))
imgs_lq
=
duf_downsample
(
imgs_lq
,
kernel_size
=
13
,
scale
=
self
.
opt
[
'scale'
])
else
:
imgs_lq
=
self
.
imgs_lq
[
folder
].
index_select
(
0
,
torch
.
LongTensor
(
select_idx
))
img_gt
=
self
.
imgs_gt
[
folder
][
idx
]
else
:
if
self
.
opt
[
'use_duf_downsampling'
]:
img_paths_lq
=
[
self
.
imgs_gt
[
folder
][
i
]
for
i
in
select_idx
]
# read imgs_gt to generate low-resolution frames
imgs_lq
=
read_img_seq
(
img_paths_lq
,
require_mod_crop
=
True
,
scale
=
self
.
opt
[
'scale'
])
imgs_lq
=
duf_downsample
(
imgs_lq
,
kernel_size
=
13
,
scale
=
self
.
opt
[
'scale'
])
else
:
img_paths_lq
=
[
self
.
imgs_lq
[
folder
][
i
]
for
i
in
select_idx
]
imgs_lq
=
read_img_seq
(
img_paths_lq
)
img_gt
=
read_img_seq
([
self
.
imgs_gt
[
folder
][
idx
]],
require_mod_crop
=
True
,
scale
=
self
.
opt
[
'scale'
])
img_gt
.
squeeze_
(
0
)
return
{
'lq'
:
imgs_lq
,
# (t, c, h, w)
'gt'
:
img_gt
,
# (c, h, w)
'folder'
:
folder
,
# folder name
'idx'
:
self
.
data_info
[
'idx'
][
index
],
# e.g., 0/99
'border'
:
border
,
# 1 for border, 0 for non-border
'lq_path'
:
lq_path
# center frame
}
@
DATASET_REGISTRY
.
register
()
class
VideoRecurrentTestDataset
(
VideoTestDataset
):
"""Video test dataset for recurrent architectures, which takes LR video
frames as input and output corresponding HR video frames.
Args:
opt (dict): Same as VideoTestDataset. Unused opt:
padding (str): Padding mode.
"""
def
__init__
(
self
,
opt
):
super
(
VideoRecurrentTestDataset
,
self
).
__init__
(
opt
)
# Find unique folder strings
self
.
folders
=
sorted
(
list
(
set
(
self
.
data_info
[
'folder'
])))
def
__getitem__
(
self
,
index
):
folder
=
self
.
folders
[
index
]
if
self
.
cache_data
:
imgs_lq
=
self
.
imgs_lq
[
folder
]
imgs_gt
=
self
.
imgs_gt
[
folder
]
else
:
raise
NotImplementedError
(
'Without cache_data is not implemented.'
)
return
{
'lq'
:
imgs_lq
,
'gt'
:
imgs_gt
,
'folder'
:
folder
,
}
def
__len__
(
self
):
return
len
(
self
.
folders
)
BasicSR/basicsr/data/vimeo90k_dataset.py
0 → 100644
View file @
e2696ece
import
random
import
torch
from
pathlib
import
Path
from
torch.utils
import
data
as
data
from
basicsr.data.transforms
import
augment
,
paired_random_crop
from
basicsr.utils
import
FileClient
,
get_root_logger
,
imfrombytes
,
img2tensor
from
basicsr.utils.registry
import
DATASET_REGISTRY
@
DATASET_REGISTRY
.
register
()
class
Vimeo90KDataset
(
data
.
Dataset
):
"""Vimeo90K dataset for training.
The keys are generated from a meta info txt file.
basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt
Each line contains the following items, separated by a white space.
1. clip name;
2. frame number;
3. image shape
Examples:
::
00001/0001 7 (256,448,3)
00001/0002 7 (256,448,3)
- Key examples: "00001/0001"
- GT (gt): Ground-Truth;
- LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
The neighboring frame list for different num_frame:
::
num_frame | frame list
1 | 4
3 | 3,4,5
5 | 2,3,4,5,6
7 | 1,2,3,4,5,6,7
Args:
opt (dict): Config for train dataset. It contains the following keys:
dataroot_gt (str): Data root path for gt.
dataroot_lq (str): Data root path for lq.
meta_info_file (str): Path for meta information file.
io_backend (dict): IO backend type and other kwarg.
num_frame (int): Window size for input frames.
gt_size (int): Cropped patched size for gt patches.
random_reverse (bool): Random reverse input frames.
use_hflip (bool): Use horizontal flips.
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
scale (bool): Scale, which will be added automatically.
"""
def
__init__
(
self
,
opt
):
super
(
Vimeo90KDataset
,
self
).
__init__
()
self
.
opt
=
opt
self
.
gt_root
,
self
.
lq_root
=
Path
(
opt
[
'dataroot_gt'
]),
Path
(
opt
[
'dataroot_lq'
])
with
open
(
opt
[
'meta_info_file'
],
'r'
)
as
fin
:
self
.
keys
=
[
line
.
split
(
' '
)[
0
]
for
line
in
fin
]
# file client (io backend)
self
.
file_client
=
None
self
.
io_backend_opt
=
opt
[
'io_backend'
]
self
.
is_lmdb
=
False
if
self
.
io_backend_opt
[
'type'
]
==
'lmdb'
:
self
.
is_lmdb
=
True
self
.
io_backend_opt
[
'db_paths'
]
=
[
self
.
lq_root
,
self
.
gt_root
]
self
.
io_backend_opt
[
'client_keys'
]
=
[
'lq'
,
'gt'
]
# indices of input images
self
.
neighbor_list
=
[
i
+
(
9
-
opt
[
'num_frame'
])
//
2
for
i
in
range
(
opt
[
'num_frame'
])]
# temporal augmentation configs
self
.
random_reverse
=
opt
[
'random_reverse'
]
logger
=
get_root_logger
()
logger
.
info
(
f
'Random reverse is
{
self
.
random_reverse
}
.'
)
def
__getitem__
(
self
,
index
):
if
self
.
file_client
is
None
:
self
.
file_client
=
FileClient
(
self
.
io_backend_opt
.
pop
(
'type'
),
**
self
.
io_backend_opt
)
# random reverse
if
self
.
random_reverse
and
random
.
random
()
<
0.5
:
self
.
neighbor_list
.
reverse
()
scale
=
self
.
opt
[
'scale'
]
gt_size
=
self
.
opt
[
'gt_size'
]
key
=
self
.
keys
[
index
]
clip
,
seq
=
key
.
split
(
'/'
)
# key example: 00001/0001
# get the GT frame (im4.png)
if
self
.
is_lmdb
:
img_gt_path
=
f
'
{
key
}
/im4'
else
:
img_gt_path
=
self
.
gt_root
/
clip
/
seq
/
'im4.png'
img_bytes
=
self
.
file_client
.
get
(
img_gt_path
,
'gt'
)
img_gt
=
imfrombytes
(
img_bytes
,
float32
=
True
)
# get the neighboring LQ frames
img_lqs
=
[]
for
neighbor
in
self
.
neighbor_list
:
if
self
.
is_lmdb
:
img_lq_path
=
f
'
{
clip
}
/
{
seq
}
/im
{
neighbor
}
'
else
:
img_lq_path
=
self
.
lq_root
/
clip
/
seq
/
f
'im
{
neighbor
}
.png'
img_bytes
=
self
.
file_client
.
get
(
img_lq_path
,
'lq'
)
img_lq
=
imfrombytes
(
img_bytes
,
float32
=
True
)
img_lqs
.
append
(
img_lq
)
# randomly crop
img_gt
,
img_lqs
=
paired_random_crop
(
img_gt
,
img_lqs
,
gt_size
,
scale
,
img_gt_path
)
# augmentation - flip, rotate
img_lqs
.
append
(
img_gt
)
img_results
=
augment
(
img_lqs
,
self
.
opt
[
'use_hflip'
],
self
.
opt
[
'use_rot'
])
img_results
=
img2tensor
(
img_results
)
img_lqs
=
torch
.
stack
(
img_results
[
0
:
-
1
],
dim
=
0
)
img_gt
=
img_results
[
-
1
]
# img_lqs: (t, c, h, w)
# img_gt: (c, h, w)
# key: str
return
{
'lq'
:
img_lqs
,
'gt'
:
img_gt
,
'key'
:
key
}
def
__len__
(
self
):
return
len
(
self
.
keys
)
@
DATASET_REGISTRY
.
register
()
class
Vimeo90KRecurrentDataset
(
Vimeo90KDataset
):
def
__init__
(
self
,
opt
):
super
(
Vimeo90KRecurrentDataset
,
self
).
__init__
(
opt
)
self
.
flip_sequence
=
opt
[
'flip_sequence'
]
self
.
neighbor_list
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
]
def
__getitem__
(
self
,
index
):
if
self
.
file_client
is
None
:
self
.
file_client
=
FileClient
(
self
.
io_backend_opt
.
pop
(
'type'
),
**
self
.
io_backend_opt
)
# random reverse
if
self
.
random_reverse
and
random
.
random
()
<
0.5
:
self
.
neighbor_list
.
reverse
()
scale
=
self
.
opt
[
'scale'
]
gt_size
=
self
.
opt
[
'gt_size'
]
key
=
self
.
keys
[
index
]
clip
,
seq
=
key
.
split
(
'/'
)
# key example: 00001/0001
# get the neighboring LQ and GT frames
img_lqs
=
[]
img_gts
=
[]
for
neighbor
in
self
.
neighbor_list
:
if
self
.
is_lmdb
:
img_lq_path
=
f
'
{
clip
}
/
{
seq
}
/im
{
neighbor
}
'
img_gt_path
=
f
'
{
clip
}
/
{
seq
}
/im
{
neighbor
}
'
else
:
img_lq_path
=
self
.
lq_root
/
clip
/
seq
/
f
'im
{
neighbor
}
.png'
img_gt_path
=
self
.
gt_root
/
clip
/
seq
/
f
'im
{
neighbor
}
.png'
# LQ
img_bytes
=
self
.
file_client
.
get
(
img_lq_path
,
'lq'
)
img_lq
=
imfrombytes
(
img_bytes
,
float32
=
True
)
# GT
img_bytes
=
self
.
file_client
.
get
(
img_gt_path
,
'gt'
)
img_gt
=
imfrombytes
(
img_bytes
,
float32
=
True
)
img_lqs
.
append
(
img_lq
)
img_gts
.
append
(
img_gt
)
# randomly crop
img_gts
,
img_lqs
=
paired_random_crop
(
img_gts
,
img_lqs
,
gt_size
,
scale
,
img_gt_path
)
# augmentation - flip, rotate
img_lqs
.
extend
(
img_gts
)
img_results
=
augment
(
img_lqs
,
self
.
opt
[
'use_hflip'
],
self
.
opt
[
'use_rot'
])
img_results
=
img2tensor
(
img_results
)
img_lqs
=
torch
.
stack
(
img_results
[:
7
],
dim
=
0
)
img_gts
=
torch
.
stack
(
img_results
[
7
:],
dim
=
0
)
if
self
.
flip_sequence
:
# flip the sequence: 7 frames to 14 frames
img_lqs
=
torch
.
cat
([
img_lqs
,
img_lqs
.
flip
(
0
)],
dim
=
0
)
img_gts
=
torch
.
cat
([
img_gts
,
img_gts
.
flip
(
0
)],
dim
=
0
)
# img_lqs: (t, c, h, w)
# img_gt: (c, h, w)
# key: str
return
{
'lq'
:
img_lqs
,
'gt'
:
img_gts
,
'key'
:
key
}
def
__len__
(
self
):
return
len
(
self
.
keys
)
BasicSR/basicsr/losses/__init__.py
0 → 100644
View file @
e2696ece
import
importlib
from
copy
import
deepcopy
from
os
import
path
as
osp
from
basicsr.utils
import
get_root_logger
,
scandir
from
basicsr.utils.registry
import
LOSS_REGISTRY
from
.gan_loss
import
g_path_regularize
,
gradient_penalty_loss
,
r1_penalty
__all__
=
[
'build_loss'
,
'gradient_penalty_loss'
,
'r1_penalty'
,
'g_path_regularize'
]
# automatically scan and import loss modules for registry
# scan all the files under the 'losses' folder and collect files ending with '_loss.py'
loss_folder
=
osp
.
dirname
(
osp
.
abspath
(
__file__
))
loss_filenames
=
[
osp
.
splitext
(
osp
.
basename
(
v
))[
0
]
for
v
in
scandir
(
loss_folder
)
if
v
.
endswith
(
'_loss.py'
)]
# import all the loss modules
_model_modules
=
[
importlib
.
import_module
(
f
'basicsr.losses.
{
file_name
}
'
)
for
file_name
in
loss_filenames
]
def
build_loss
(
opt
):
"""Build loss from options.
Args:
opt (dict): Configuration. It must contain:
type (str): Model type.
"""
opt
=
deepcopy
(
opt
)
loss_type
=
opt
.
pop
(
'type'
)
loss
=
LOSS_REGISTRY
.
get
(
loss_type
)(
**
opt
)
logger
=
get_root_logger
()
logger
.
info
(
f
'Loss [
{
loss
.
__class__
.
__name__
}
] is created.'
)
return
loss
BasicSR/basicsr/losses/basic_loss.py
0 → 100644
View file @
e2696ece
import
torch
from
torch
import
nn
as
nn
from
torch.nn
import
functional
as
F
from
basicsr.archs.vgg_arch
import
VGGFeatureExtractor
from
basicsr.utils.registry
import
LOSS_REGISTRY
from
.loss_util
import
weighted_loss
_reduction_modes
=
[
'none'
,
'mean'
,
'sum'
]
@
weighted_loss
def
l1_loss
(
pred
,
target
):
return
F
.
l1_loss
(
pred
,
target
,
reduction
=
'none'
)
@
weighted_loss
def
mse_loss
(
pred
,
target
):
return
F
.
mse_loss
(
pred
,
target
,
reduction
=
'none'
)
@
weighted_loss
def
charbonnier_loss
(
pred
,
target
,
eps
=
1e-12
):
return
torch
.
sqrt
((
pred
-
target
)
**
2
+
eps
)
@
LOSS_REGISTRY
.
register
()
class
L1Loss
(
nn
.
Module
):
"""L1 (mean absolute error, MAE) loss.
Args:
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
"""
def
__init__
(
self
,
loss_weight
=
1.0
,
reduction
=
'mean'
):
super
(
L1Loss
,
self
).
__init__
()
if
reduction
not
in
[
'none'
,
'mean'
,
'sum'
]:
raise
ValueError
(
f
'Unsupported reduction mode:
{
reduction
}
. Supported ones are:
{
_reduction_modes
}
'
)
self
.
loss_weight
=
loss_weight
self
.
reduction
=
reduction
def
forward
(
self
,
pred
,
target
,
weight
=
None
,
**
kwargs
):
"""
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
"""
return
self
.
loss_weight
*
l1_loss
(
pred
,
target
,
weight
,
reduction
=
self
.
reduction
)
@
LOSS_REGISTRY
.
register
()
class
MSELoss
(
nn
.
Module
):
"""MSE (L2) loss.
Args:
loss_weight (float): Loss weight for MSE loss. Default: 1.0.
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
"""
def
__init__
(
self
,
loss_weight
=
1.0
,
reduction
=
'mean'
):
super
(
MSELoss
,
self
).
__init__
()
if
reduction
not
in
[
'none'
,
'mean'
,
'sum'
]:
raise
ValueError
(
f
'Unsupported reduction mode:
{
reduction
}
. Supported ones are:
{
_reduction_modes
}
'
)
self
.
loss_weight
=
loss_weight
self
.
reduction
=
reduction
def
forward
(
self
,
pred
,
target
,
weight
=
None
,
**
kwargs
):
"""
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
"""
return
self
.
loss_weight
*
mse_loss
(
pred
,
target
,
weight
,
reduction
=
self
.
reduction
)
@
LOSS_REGISTRY
.
register
()
class
CharbonnierLoss
(
nn
.
Module
):
"""Charbonnier loss (one variant of Robust L1Loss, a differentiable
variant of L1Loss).
Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
Super-Resolution".
Args:
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
eps (float): A value used to control the curvature near zero. Default: 1e-12.
"""
def
__init__
(
self
,
loss_weight
=
1.0
,
reduction
=
'mean'
,
eps
=
1e-12
):
super
(
CharbonnierLoss
,
self
).
__init__
()
if
reduction
not
in
[
'none'
,
'mean'
,
'sum'
]:
raise
ValueError
(
f
'Unsupported reduction mode:
{
reduction
}
. Supported ones are:
{
_reduction_modes
}
'
)
self
.
loss_weight
=
loss_weight
self
.
reduction
=
reduction
self
.
eps
=
eps
def
forward
(
self
,
pred
,
target
,
weight
=
None
,
**
kwargs
):
"""
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
"""
return
self
.
loss_weight
*
charbonnier_loss
(
pred
,
target
,
weight
,
eps
=
self
.
eps
,
reduction
=
self
.
reduction
)
@
LOSS_REGISTRY
.
register
()
class
WeightedTVLoss
(
L1Loss
):
"""Weighted TV loss.
Args:
loss_weight (float): Loss weight. Default: 1.0.
"""
def
__init__
(
self
,
loss_weight
=
1.0
,
reduction
=
'mean'
):
if
reduction
not
in
[
'mean'
,
'sum'
]:
raise
ValueError
(
f
'Unsupported reduction mode:
{
reduction
}
. Supported ones are: mean | sum'
)
super
(
WeightedTVLoss
,
self
).
__init__
(
loss_weight
=
loss_weight
,
reduction
=
reduction
)
def
forward
(
self
,
pred
,
weight
=
None
):
if
weight
is
None
:
y_weight
=
None
x_weight
=
None
else
:
y_weight
=
weight
[:,
:,
:
-
1
,
:]
x_weight
=
weight
[:,
:,
:,
:
-
1
]
y_diff
=
super
().
forward
(
pred
[:,
:,
:
-
1
,
:],
pred
[:,
:,
1
:,
:],
weight
=
y_weight
)
x_diff
=
super
().
forward
(
pred
[:,
:,
:,
:
-
1
],
pred
[:,
:,
:,
1
:],
weight
=
x_weight
)
loss
=
x_diff
+
y_diff
return
loss
@
LOSS_REGISTRY
.
register
()
class
PerceptualLoss
(
nn
.
Module
):
"""Perceptual loss with commonly used style loss.
Args:
layer_weights (dict): The weight for each layer of vgg feature.
Here is an example: {'conv5_4': 1.}, which means the conv5_4
feature layer (before relu5_4) will be extracted with weight
1.0 in calculating losses.
vgg_type (str): The type of vgg network used as feature extractor.
Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image in vgg.
Default: True.
range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
Default: False.
perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
loss will be calculated and the loss will multiplied by the
weight. Default: 1.0.
style_weight (float): If `style_weight > 0`, the style loss will be
calculated and the loss will multiplied by the weight.
Default: 0.
criterion (str): Criterion used for perceptual loss. Default: 'l1'.
"""
def
__init__
(
self
,
layer_weights
,
vgg_type
=
'vgg19'
,
use_input_norm
=
True
,
range_norm
=
False
,
perceptual_weight
=
1.0
,
style_weight
=
0.
,
criterion
=
'l1'
):
super
(
PerceptualLoss
,
self
).
__init__
()
self
.
perceptual_weight
=
perceptual_weight
self
.
style_weight
=
style_weight
self
.
layer_weights
=
layer_weights
self
.
vgg
=
VGGFeatureExtractor
(
layer_name_list
=
list
(
layer_weights
.
keys
()),
vgg_type
=
vgg_type
,
use_input_norm
=
use_input_norm
,
range_norm
=
range_norm
)
self
.
criterion_type
=
criterion
if
self
.
criterion_type
==
'l1'
:
self
.
criterion
=
torch
.
nn
.
L1Loss
()
elif
self
.
criterion_type
==
'l2'
:
self
.
criterion
=
torch
.
nn
.
MSELoss
()
elif
self
.
criterion_type
==
'fro'
:
self
.
criterion
=
None
else
:
raise
NotImplementedError
(
f
'
{
criterion
}
criterion has not been supported.'
)
def
forward
(
self
,
x
,
gt
):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
# extract vgg features
x_features
=
self
.
vgg
(
x
)
gt_features
=
self
.
vgg
(
gt
.
detach
())
# calculate perceptual loss
if
self
.
perceptual_weight
>
0
:
percep_loss
=
0
for
k
in
x_features
.
keys
():
if
self
.
criterion_type
==
'fro'
:
percep_loss
+=
torch
.
norm
(
x_features
[
k
]
-
gt_features
[
k
],
p
=
'fro'
)
*
self
.
layer_weights
[
k
]
else
:
percep_loss
+=
self
.
criterion
(
x_features
[
k
],
gt_features
[
k
])
*
self
.
layer_weights
[
k
]
percep_loss
*=
self
.
perceptual_weight
else
:
percep_loss
=
None
# calculate style loss
if
self
.
style_weight
>
0
:
style_loss
=
0
for
k
in
x_features
.
keys
():
if
self
.
criterion_type
==
'fro'
:
style_loss
+=
torch
.
norm
(
self
.
_gram_mat
(
x_features
[
k
])
-
self
.
_gram_mat
(
gt_features
[
k
]),
p
=
'fro'
)
*
self
.
layer_weights
[
k
]
else
:
style_loss
+=
self
.
criterion
(
self
.
_gram_mat
(
x_features
[
k
]),
self
.
_gram_mat
(
gt_features
[
k
]))
*
self
.
layer_weights
[
k
]
style_loss
*=
self
.
style_weight
else
:
style_loss
=
None
return
percep_loss
,
style_loss
def
_gram_mat
(
self
,
x
):
"""Calculate Gram matrix.
Args:
x (torch.Tensor): Tensor with shape of (n, c, h, w).
Returns:
torch.Tensor: Gram matrix.
"""
n
,
c
,
h
,
w
=
x
.
size
()
features
=
x
.
view
(
n
,
c
,
w
*
h
)
features_t
=
features
.
transpose
(
1
,
2
)
gram
=
features
.
bmm
(
features_t
)
/
(
c
*
h
*
w
)
return
gram
BasicSR/basicsr/losses/gan_loss.py
0 → 100644
View file @
e2696ece
import
math
import
torch
from
torch
import
autograd
as
autograd
from
torch
import
nn
as
nn
from
torch.nn
import
functional
as
F
from
basicsr.utils.registry
import
LOSS_REGISTRY
@
LOSS_REGISTRY
.
register
()
class
GANLoss
(
nn
.
Module
):
"""Define GAN loss.
Args:
gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
real_label_val (float): The value for real label. Default: 1.0.
fake_label_val (float): The value for fake label. Default: 0.0.
loss_weight (float): Loss weight. Default: 1.0.
Note that loss_weight is only for generators; and it is always 1.0
for discriminators.
"""
def
__init__
(
self
,
gan_type
,
real_label_val
=
1.0
,
fake_label_val
=
0.0
,
loss_weight
=
1.0
):
super
(
GANLoss
,
self
).
__init__
()
self
.
gan_type
=
gan_type
self
.
loss_weight
=
loss_weight
self
.
real_label_val
=
real_label_val
self
.
fake_label_val
=
fake_label_val
if
self
.
gan_type
==
'vanilla'
:
self
.
loss
=
nn
.
BCEWithLogitsLoss
()
elif
self
.
gan_type
==
'lsgan'
:
self
.
loss
=
nn
.
MSELoss
()
elif
self
.
gan_type
==
'wgan'
:
self
.
loss
=
self
.
_wgan_loss
elif
self
.
gan_type
==
'wgan_softplus'
:
self
.
loss
=
self
.
_wgan_softplus_loss
elif
self
.
gan_type
==
'hinge'
:
self
.
loss
=
nn
.
ReLU
()
else
:
raise
NotImplementedError
(
f
'GAN type
{
self
.
gan_type
}
is not implemented.'
)
def
_wgan_loss
(
self
,
input
,
target
):
"""wgan loss.
Args:
input (Tensor): Input tensor.
target (bool): Target label.
Returns:
Tensor: wgan loss.
"""
return
-
input
.
mean
()
if
target
else
input
.
mean
()
def
_wgan_softplus_loss
(
self
,
input
,
target
):
"""wgan loss with soft plus. softplus is a smooth approximation to the
ReLU function.
In StyleGAN2, it is called:
Logistic loss for discriminator;
Non-saturating loss for generator.
Args:
input (Tensor): Input tensor.
target (bool): Target label.
Returns:
Tensor: wgan loss.
"""
return
F
.
softplus
(
-
input
).
mean
()
if
target
else
F
.
softplus
(
input
).
mean
()
def
get_target_label
(
self
,
input
,
target_is_real
):
"""Get target label.
Args:
input (Tensor): Input tensor.
target_is_real (bool): Whether the target is real or fake.
Returns:
(bool | Tensor): Target tensor. Return bool for wgan, otherwise,
return Tensor.
"""
if
self
.
gan_type
in
[
'wgan'
,
'wgan_softplus'
]:
return
target_is_real
target_val
=
(
self
.
real_label_val
if
target_is_real
else
self
.
fake_label_val
)
return
input
.
new_ones
(
input
.
size
())
*
target_val
def
forward
(
self
,
input
,
target_is_real
,
is_disc
=
False
):
"""
Args:
input (Tensor): The input for the loss module, i.e., the network
prediction.
target_is_real (bool): Whether the targe is real or fake.
is_disc (bool): Whether the loss for discriminators or not.
Default: False.
Returns:
Tensor: GAN loss value.
"""
target_label
=
self
.
get_target_label
(
input
,
target_is_real
)
if
self
.
gan_type
==
'hinge'
:
if
is_disc
:
# for discriminators in hinge-gan
input
=
-
input
if
target_is_real
else
input
loss
=
self
.
loss
(
1
+
input
).
mean
()
else
:
# for generators in hinge-gan
loss
=
-
input
.
mean
()
else
:
# other gan types
loss
=
self
.
loss
(
input
,
target_label
)
# loss_weight is always 1.0 for discriminators
return
loss
if
is_disc
else
loss
*
self
.
loss_weight
@
LOSS_REGISTRY
.
register
()
class
MultiScaleGANLoss
(
GANLoss
):
"""
MultiScaleGANLoss accepts a list of predictions
"""
def
__init__
(
self
,
gan_type
,
real_label_val
=
1.0
,
fake_label_val
=
0.0
,
loss_weight
=
1.0
):
super
(
MultiScaleGANLoss
,
self
).
__init__
(
gan_type
,
real_label_val
,
fake_label_val
,
loss_weight
)
def
forward
(
self
,
input
,
target_is_real
,
is_disc
=
False
):
"""
The input is a list of tensors, or a list of (a list of tensors)
"""
if
isinstance
(
input
,
list
):
loss
=
0
for
pred_i
in
input
:
if
isinstance
(
pred_i
,
list
):
# Only compute GAN loss for the last layer
# in case of multiscale feature matching
pred_i
=
pred_i
[
-
1
]
# Safe operation: 0-dim tensor calling self.mean() does nothing
loss_tensor
=
super
().
forward
(
pred_i
,
target_is_real
,
is_disc
).
mean
()
loss
+=
loss_tensor
return
loss
/
len
(
input
)
else
:
return
super
().
forward
(
input
,
target_is_real
,
is_disc
)
def
r1_penalty
(
real_pred
,
real_img
):
"""R1 regularization for discriminator. The core idea is to
penalize the gradient on real data alone: when the
generator distribution produces the true data distribution
and the discriminator is equal to 0 on the data manifold, the
gradient penalty ensures that the discriminator cannot create
a non-zero gradient orthogonal to the data manifold without
suffering a loss in the GAN game.
Reference: Eq. 9 in Which training methods for GANs do actually converge.
"""
grad_real
=
autograd
.
grad
(
outputs
=
real_pred
.
sum
(),
inputs
=
real_img
,
create_graph
=
True
)[
0
]
grad_penalty
=
grad_real
.
pow
(
2
).
view
(
grad_real
.
shape
[
0
],
-
1
).
sum
(
1
).
mean
()
return
grad_penalty
def
g_path_regularize
(
fake_img
,
latents
,
mean_path_length
,
decay
=
0.01
):
noise
=
torch
.
randn_like
(
fake_img
)
/
math
.
sqrt
(
fake_img
.
shape
[
2
]
*
fake_img
.
shape
[
3
])
grad
=
autograd
.
grad
(
outputs
=
(
fake_img
*
noise
).
sum
(),
inputs
=
latents
,
create_graph
=
True
)[
0
]
path_lengths
=
torch
.
sqrt
(
grad
.
pow
(
2
).
sum
(
2
).
mean
(
1
))
path_mean
=
mean_path_length
+
decay
*
(
path_lengths
.
mean
()
-
mean_path_length
)
path_penalty
=
(
path_lengths
-
path_mean
).
pow
(
2
).
mean
()
return
path_penalty
,
path_lengths
.
detach
().
mean
(),
path_mean
.
detach
()
def
gradient_penalty_loss
(
discriminator
,
real_data
,
fake_data
,
weight
=
None
):
"""Calculate gradient penalty for wgan-gp.
Args:
discriminator (nn.Module): Network for the discriminator.
real_data (Tensor): Real input data.
fake_data (Tensor): Fake input data.
weight (Tensor): Weight tensor. Default: None.
Returns:
Tensor: A tensor for gradient penalty.
"""
batch_size
=
real_data
.
size
(
0
)
alpha
=
real_data
.
new_tensor
(
torch
.
rand
(
batch_size
,
1
,
1
,
1
))
# interpolate between real_data and fake_data
interpolates
=
alpha
*
real_data
+
(
1.
-
alpha
)
*
fake_data
interpolates
=
autograd
.
Variable
(
interpolates
,
requires_grad
=
True
)
disc_interpolates
=
discriminator
(
interpolates
)
gradients
=
autograd
.
grad
(
outputs
=
disc_interpolates
,
inputs
=
interpolates
,
grad_outputs
=
torch
.
ones_like
(
disc_interpolates
),
create_graph
=
True
,
retain_graph
=
True
,
only_inputs
=
True
)[
0
]
if
weight
is
not
None
:
gradients
=
gradients
*
weight
gradients_penalty
=
((
gradients
.
norm
(
2
,
dim
=
1
)
-
1
)
**
2
).
mean
()
if
weight
is
not
None
:
gradients_penalty
/=
torch
.
mean
(
weight
)
return
gradients_penalty
BasicSR/basicsr/losses/loss_util.py
0 → 100644
View file @
e2696ece
import
functools
import
torch
from
torch.nn
import
functional
as
F
def
reduce_loss
(
loss
,
reduction
):
"""Reduce loss as specified.
Args:
loss (Tensor): Elementwise loss tensor.
reduction (str): Options are 'none', 'mean' and 'sum'.
Returns:
Tensor: Reduced loss tensor.
"""
reduction_enum
=
F
.
_Reduction
.
get_enum
(
reduction
)
# none: 0, elementwise_mean:1, sum: 2
if
reduction_enum
==
0
:
return
loss
elif
reduction_enum
==
1
:
return
loss
.
mean
()
else
:
return
loss
.
sum
()
def
weight_reduce_loss
(
loss
,
weight
=
None
,
reduction
=
'mean'
):
"""Apply element-wise weight and reduce loss.
Args:
loss (Tensor): Element-wise loss.
weight (Tensor): Element-wise weights. Default: None.
reduction (str): Same as built-in losses of PyTorch. Options are
'none', 'mean' and 'sum'. Default: 'mean'.
Returns:
Tensor: Loss values.
"""
# if weight is specified, apply element-wise weight
if
weight
is
not
None
:
assert
weight
.
dim
()
==
loss
.
dim
()
assert
weight
.
size
(
1
)
==
1
or
weight
.
size
(
1
)
==
loss
.
size
(
1
)
loss
=
loss
*
weight
# if weight is not specified or reduction is sum, just reduce the loss
if
weight
is
None
or
reduction
==
'sum'
:
loss
=
reduce_loss
(
loss
,
reduction
)
# if reduction is mean, then compute mean over weight region
elif
reduction
==
'mean'
:
if
weight
.
size
(
1
)
>
1
:
weight
=
weight
.
sum
()
else
:
weight
=
weight
.
sum
()
*
loss
.
size
(
1
)
loss
=
loss
.
sum
()
/
weight
return
loss
def
weighted_loss
(
loss_func
):
"""Create a weighted version of a given loss function.
To use this decorator, the loss function must have the signature like
`loss_func(pred, target, **kwargs)`. The function only needs to compute
element-wise loss without any reduction. This decorator will add weight
and reduction arguments to the function. The decorated function will have
the signature like `loss_func(pred, target, weight=None, reduction='mean',
**kwargs)`.
:Example:
>>> import torch
>>> @weighted_loss
>>> def l1_loss(pred, target):
>>> return (pred - target).abs()
>>> pred = torch.Tensor([0, 2, 3])
>>> target = torch.Tensor([1, 1, 1])
>>> weight = torch.Tensor([1, 0, 1])
>>> l1_loss(pred, target)
tensor(1.3333)
>>> l1_loss(pred, target, weight)
tensor(1.5000)
>>> l1_loss(pred, target, reduction='none')
tensor([1., 1., 2.])
>>> l1_loss(pred, target, weight, reduction='sum')
tensor(3.)
"""
@
functools
.
wraps
(
loss_func
)
def
wrapper
(
pred
,
target
,
weight
=
None
,
reduction
=
'mean'
,
**
kwargs
):
# get element-wise loss
loss
=
loss_func
(
pred
,
target
,
**
kwargs
)
loss
=
weight_reduce_loss
(
loss
,
weight
,
reduction
)
return
loss
return
wrapper
def
get_local_weights
(
residual
,
ksize
):
"""Get local weights for generating the artifact map of LDL.
It is only called by the `get_refined_artifact_map` function.
Args:
residual (Tensor): Residual between predicted and ground truth images.
ksize (Int): size of the local window.
Returns:
Tensor: weight for each pixel to be discriminated as an artifact pixel
"""
pad
=
(
ksize
-
1
)
//
2
residual_pad
=
F
.
pad
(
residual
,
pad
=
[
pad
,
pad
,
pad
,
pad
],
mode
=
'reflect'
)
unfolded_residual
=
residual_pad
.
unfold
(
2
,
ksize
,
1
).
unfold
(
3
,
ksize
,
1
)
pixel_level_weight
=
torch
.
var
(
unfolded_residual
,
dim
=
(
-
1
,
-
2
),
unbiased
=
True
,
keepdim
=
True
).
squeeze
(
-
1
).
squeeze
(
-
1
)
return
pixel_level_weight
def
get_refined_artifact_map
(
img_gt
,
img_output
,
img_ema
,
ksize
):
"""Calculate the artifact map of LDL
(Details or Artifacts: A Locally Discriminative Learning Approach to Realistic Image Super-Resolution. In CVPR 2022)
Args:
img_gt (Tensor): ground truth images.
img_output (Tensor): output images given by the optimizing model.
img_ema (Tensor): output images given by the ema model.
ksize (Int): size of the local window.
Returns:
overall_weight: weight for each pixel to be discriminated as an artifact pixel
(calculated based on both local and global observations).
"""
residual_ema
=
torch
.
sum
(
torch
.
abs
(
img_gt
-
img_ema
),
1
,
keepdim
=
True
)
residual_sr
=
torch
.
sum
(
torch
.
abs
(
img_gt
-
img_output
),
1
,
keepdim
=
True
)
patch_level_weight
=
torch
.
var
(
residual_sr
.
clone
(),
dim
=
(
-
1
,
-
2
,
-
3
),
keepdim
=
True
)
**
(
1
/
5
)
pixel_level_weight
=
get_local_weights
(
residual_sr
.
clone
(),
ksize
)
overall_weight
=
patch_level_weight
*
pixel_level_weight
overall_weight
[
residual_sr
<
residual_ema
]
=
0
return
overall_weight
BasicSR/basicsr/metrics/README.md
0 → 100644
View file @
e2696ece
# Metrics
[
English
](
README.md
)
**|**
[
简体中文
](
README_CN.md
)
-
[
约定
](
#约定
)
-
[
PSNR 和 SSIM
](
#psnr-和-ssim
)
## 约定
因为不同的输入类型会导致结果的不同,因此我们对输入做如下约定:
-
Numpy 类型 (一般是 cv2 的结果)
-
UINT8: BGR, [0, 255], (h, w, c)
-
float: BGR, [0, 1], (h, w, c). 一般作为中间结果
-
Tensor 类型
-
float: RGB, [0, 1], (n, c, h, w)
其他约定:
-
以
`_pt`
结尾的是 PyTorch 结果
-
PyTorch version 支持 batch 计算
-
颜色转换在 float32 上做;metric计算在 float64 上做
## PSNR 和 SSIM
PSNR 和 SSIM 的结果趋势是一致的,即一般 PSNR 高,则 SSIM 也高。
在实现上, PSNR 的各种实现都很一致。SSIM 有各种各样的实现,我们这里和 MATLAB 最原始版本保持 (参考
[
NTIRE17比赛
](
https://competitions.codalab.org/competitions/16306#participate
)
的
[
evaluation代码
](
https://competitions.codalab.org/my/datasets/download/ebe960d8-0ec8-4846-a1a2-7c4a586a7378
)
)
下面列了各个实现的结果比对.
总结:PyTorch 实现和 MATLAB 实现基本一致,在 GPU 运行上会有稍许差异
-
PSNR 比对
|Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU |
|:---| :---: | :---: | :---: | :---: | :---: |
|baboon| RGB | 20.419710 | 20.419710 | 20.419710 |20.419710 |
|baboon| Y | - |22.441898 | 22.441899 | 22.444916|
|comic | RGB | 20.239912 | 20.239912 | 20.239912 | 20.239912 |
|comic | Y | - | 21.720398 | 21.720398 | 21.721663|
-
SSIM 比对
|Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU |
|:---| :---: | :---: | :---: | :---: | :---: |
|baboon| RGB | 0.391853 | 0.391853 | 0.391853|0.391853 |
|baboon| Y | - |0.453097| 0.453097 | 0.453171|
|comic | RGB | 0.567738 | 0.567738 | 0.567738 | 0.567738|
|comic | Y | - | 0.585511 | 0.585511 | 0.585522 |
BasicSR/basicsr/metrics/README_CN.md
0 → 100644
View file @
e2696ece
# Metrics
[
English
](
README.md
)
**|**
[
简体中文
](
README_CN.md
)
-
[
约定
](
#约定
)
-
[
PSNR 和 SSIM
](
#psnr-和-ssim
)
## 约定
因为不同的输入类型会导致结果的不同,因此我们对输入做如下约定:
-
Numpy 类型 (一般是 cv2 的结果)
-
UINT8: BGR, [0, 255], (h, w, c)
-
float: BGR, [0, 1], (h, w, c). 一般作为中间结果
-
Tensor 类型
-
float: RGB, [0, 1], (n, c, h, w)
其他约定:
-
以
`_pt`
结尾的是 PyTorch 结果
-
PyTorch version 支持 batch 计算
-
颜色转换在 float32 上做;metric计算在 float64 上做
## PSNR 和 SSIM
PSNR 和 SSIM 的结果趋势是一致的,即一般 PSNR 高,则 SSIM 也高。
在实现上, PSNR 的各种实现都很一致。SSIM 有各种各样的实现,我们这里和 MATLAB 最原始版本保持 (参考
[
NTIRE17比赛
](
https://competitions.codalab.org/competitions/16306#participate
)
的
[
evaluation代码
](
https://competitions.codalab.org/my/datasets/download/ebe960d8-0ec8-4846-a1a2-7c4a586a7378
)
)
下面列了各个实现的结果比对.
总结:PyTorch 实现和 MATLAB 实现基本一致,在 GPU 运行上会有稍许差异
-
PSNR 比对
|Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU |
|:---| :---: | :---: | :---: | :---: | :---: |
|baboon| RGB | 20.419710 | 20.419710 | 20.419710 |20.419710 |
|baboon| Y | - |22.441898 | 22.441899 | 22.444916|
|comic | RGB | 20.239912 | 20.239912 | 20.239912 | 20.239912 |
|comic | Y | - | 21.720398 | 21.720398 | 21.721663|
-
SSIM 比对
|Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU |
|:---| :---: | :---: | :---: | :---: | :---: |
|baboon| RGB | 0.391853 | 0.391853 | 0.391853|0.391853 |
|baboon| Y | - |0.453097| 0.453097 | 0.453171|
|comic | RGB | 0.567738 | 0.567738 | 0.567738 | 0.567738|
|comic | Y | - | 0.585511 | 0.585511 | 0.585522 |
BasicSR/basicsr/metrics/__init__.py
0 → 100644
View file @
e2696ece
from
copy
import
deepcopy
from
basicsr.utils.registry
import
METRIC_REGISTRY
from
.niqe
import
calculate_niqe
from
.psnr_ssim
import
calculate_psnr
,
calculate_ssim
__all__
=
[
'calculate_psnr'
,
'calculate_ssim'
,
'calculate_niqe'
]
def
calculate_metric
(
data
,
opt
):
"""Calculate metric from data and options.
Args:
opt (dict): Configuration. It must contain:
type (str): Model type.
"""
opt
=
deepcopy
(
opt
)
metric_type
=
opt
.
pop
(
'type'
)
metric
=
METRIC_REGISTRY
.
get
(
metric_type
)(
**
data
,
**
opt
)
return
metric
BasicSR/basicsr/metrics/fid.py
0 → 100644
View file @
e2696ece
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
scipy
import
linalg
from
tqdm
import
tqdm
from
basicsr.archs.inception
import
InceptionV3
def
load_patched_inception_v3
(
device
=
'cuda'
,
resize_input
=
True
,
normalize_input
=
False
):
# we may not resize the input, but in [rosinality/stylegan2-pytorch] it
# does resize the input.
inception
=
InceptionV3
([
3
],
resize_input
=
resize_input
,
normalize_input
=
normalize_input
)
inception
=
nn
.
DataParallel
(
inception
).
eval
().
to
(
device
)
return
inception
@
torch
.
no_grad
()
def
extract_inception_features
(
data_generator
,
inception
,
len_generator
=
None
,
device
=
'cuda'
):
"""Extract inception features.
Args:
data_generator (generator): A data generator.
inception (nn.Module): Inception model.
len_generator (int): Length of the data_generator to show the
progressbar. Default: None.
device (str): Device. Default: cuda.
Returns:
Tensor: Extracted features.
"""
if
len_generator
is
not
None
:
pbar
=
tqdm
(
total
=
len_generator
,
unit
=
'batch'
,
desc
=
'Extract'
)
else
:
pbar
=
None
features
=
[]
for
data
in
data_generator
:
if
pbar
:
pbar
.
update
(
1
)
data
=
data
.
to
(
device
)
feature
=
inception
(
data
)[
0
].
view
(
data
.
shape
[
0
],
-
1
)
features
.
append
(
feature
.
to
(
'cpu'
))
if
pbar
:
pbar
.
close
()
features
=
torch
.
cat
(
features
,
0
)
return
features
def
calculate_fid
(
mu1
,
sigma1
,
mu2
,
sigma2
,
eps
=
1e-6
):
"""Numpy implementation of the Frechet Distance.
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) and X_2 ~ N(mu_2, C_2) is:
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
Stable version by Dougal J. Sutherland.
Args:
mu1 (np.array): The sample mean over activations.
sigma1 (np.array): The covariance matrix over activations for generated samples.
mu2 (np.array): The sample mean over activations, precalculated on an representative data set.
sigma2 (np.array): The covariance matrix over activations, precalculated on an representative data set.
Returns:
float: The Frechet Distance.
"""
assert
mu1
.
shape
==
mu2
.
shape
,
'Two mean vectors have different lengths'
assert
sigma1
.
shape
==
sigma2
.
shape
,
(
'Two covariances have different dimensions'
)
cov_sqrt
,
_
=
linalg
.
sqrtm
(
sigma1
@
sigma2
,
disp
=
False
)
# Product might be almost singular
if
not
np
.
isfinite
(
cov_sqrt
).
all
():
print
(
'Product of cov matrices is singular. Adding {eps} to diagonal of cov estimates'
)
offset
=
np
.
eye
(
sigma1
.
shape
[
0
])
*
eps
cov_sqrt
=
linalg
.
sqrtm
((
sigma1
+
offset
)
@
(
sigma2
+
offset
))
# Numerical error might give slight imaginary component
if
np
.
iscomplexobj
(
cov_sqrt
):
if
not
np
.
allclose
(
np
.
diagonal
(
cov_sqrt
).
imag
,
0
,
atol
=
1e-3
):
m
=
np
.
max
(
np
.
abs
(
cov_sqrt
.
imag
))
raise
ValueError
(
f
'Imaginary component
{
m
}
'
)
cov_sqrt
=
cov_sqrt
.
real
mean_diff
=
mu1
-
mu2
mean_norm
=
mean_diff
@
mean_diff
trace
=
np
.
trace
(
sigma1
)
+
np
.
trace
(
sigma2
)
-
2
*
np
.
trace
(
cov_sqrt
)
fid
=
mean_norm
+
trace
return
fid
BasicSR/basicsr/metrics/metric_util.py
0 → 100644
View file @
e2696ece
import
numpy
as
np
from
basicsr.utils
import
bgr2ycbcr
def
reorder_image
(
img
,
input_order
=
'HWC'
):
"""Reorder images to 'HWC' order.
If the input_order is (h, w), return (h, w, 1);
If the input_order is (c, h, w), return (h, w, c);
If the input_order is (h, w, c), return as it is.
Args:
img (ndarray): Input image.
input_order (str): Whether the input order is 'HWC' or 'CHW'.
If the input image shape is (h, w), input_order will not have
effects. Default: 'HWC'.
Returns:
ndarray: reordered image.
"""
if
input_order
not
in
[
'HWC'
,
'CHW'
]:
raise
ValueError
(
f
"Wrong input_order
{
input_order
}
. Supported input_orders are 'HWC' and 'CHW'"
)
if
len
(
img
.
shape
)
==
2
:
img
=
img
[...,
None
]
if
input_order
==
'CHW'
:
img
=
img
.
transpose
(
1
,
2
,
0
)
return
img
def
to_y_channel
(
img
):
"""Change to Y channel of YCbCr.
Args:
img (ndarray): Images with range [0, 255].
Returns:
(ndarray): Images with range [0, 255] (float type) without round.
"""
img
=
img
.
astype
(
np
.
float32
)
/
255.
if
img
.
ndim
==
3
and
img
.
shape
[
2
]
==
3
:
img
=
bgr2ycbcr
(
img
,
y_only
=
True
)
img
=
img
[...,
None
]
return
img
*
255.
BasicSR/basicsr/metrics/niqe.py
0 → 100644
View file @
e2696ece
import
cv2
import
math
import
numpy
as
np
import
os
from
scipy.ndimage
import
convolve
from
scipy.special
import
gamma
from
basicsr.metrics.metric_util
import
reorder_image
,
to_y_channel
from
basicsr.utils.matlab_functions
import
imresize
from
basicsr.utils.registry
import
METRIC_REGISTRY
def
estimate_aggd_param
(
block
):
"""Estimate AGGD (Asymmetric Generalized Gaussian Distribution) parameters.
Args:
block (ndarray): 2D Image block.
Returns:
tuple: alpha (float), beta_l (float) and beta_r (float) for the AGGD
distribution (Estimating the parames in Equation 7 in the paper).
"""
block
=
block
.
flatten
()
gam
=
np
.
arange
(
0.2
,
10.001
,
0.001
)
# len = 9801
gam_reciprocal
=
np
.
reciprocal
(
gam
)
r_gam
=
np
.
square
(
gamma
(
gam_reciprocal
*
2
))
/
(
gamma
(
gam_reciprocal
)
*
gamma
(
gam_reciprocal
*
3
))
left_std
=
np
.
sqrt
(
np
.
mean
(
block
[
block
<
0
]
**
2
))
right_std
=
np
.
sqrt
(
np
.
mean
(
block
[
block
>
0
]
**
2
))
gammahat
=
left_std
/
right_std
rhat
=
(
np
.
mean
(
np
.
abs
(
block
)))
**
2
/
np
.
mean
(
block
**
2
)
rhatnorm
=
(
rhat
*
(
gammahat
**
3
+
1
)
*
(
gammahat
+
1
))
/
((
gammahat
**
2
+
1
)
**
2
)
array_position
=
np
.
argmin
((
r_gam
-
rhatnorm
)
**
2
)
alpha
=
gam
[
array_position
]
beta_l
=
left_std
*
np
.
sqrt
(
gamma
(
1
/
alpha
)
/
gamma
(
3
/
alpha
))
beta_r
=
right_std
*
np
.
sqrt
(
gamma
(
1
/
alpha
)
/
gamma
(
3
/
alpha
))
return
(
alpha
,
beta_l
,
beta_r
)
def
compute_feature
(
block
):
"""Compute features.
Args:
block (ndarray): 2D Image block.
Returns:
list: Features with length of 18.
"""
feat
=
[]
alpha
,
beta_l
,
beta_r
=
estimate_aggd_param
(
block
)
feat
.
extend
([
alpha
,
(
beta_l
+
beta_r
)
/
2
])
# distortions disturb the fairly regular structure of natural images.
# This deviation can be captured by analyzing the sample distribution of
# the products of pairs of adjacent coefficients computed along
# horizontal, vertical and diagonal orientations.
shifts
=
[[
0
,
1
],
[
1
,
0
],
[
1
,
1
],
[
1
,
-
1
]]
for
i
in
range
(
len
(
shifts
)):
shifted_block
=
np
.
roll
(
block
,
shifts
[
i
],
axis
=
(
0
,
1
))
alpha
,
beta_l
,
beta_r
=
estimate_aggd_param
(
block
*
shifted_block
)
# Eq. 8
mean
=
(
beta_r
-
beta_l
)
*
(
gamma
(
2
/
alpha
)
/
gamma
(
1
/
alpha
))
feat
.
extend
([
alpha
,
mean
,
beta_l
,
beta_r
])
return
feat
def
niqe
(
img
,
mu_pris_param
,
cov_pris_param
,
gaussian_window
,
block_size_h
=
96
,
block_size_w
=
96
):
"""Calculate NIQE (Natural Image Quality Evaluator) metric.
``Paper: Making a "Completely Blind" Image Quality Analyzer``
This implementation could produce almost the same results as the official
MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
Note that we do not include block overlap height and width, since they are
always 0 in the official implementation.
For good performance, it is advisable by the official implementation to
divide the distorted image in to the same size patched as used for the
construction of multivariate Gaussian model.
Args:
img (ndarray): Input image whose quality needs to be computed. The
image must be a gray or Y (of YCbCr) image with shape (h, w).
Range [0, 255] with float type.
mu_pris_param (ndarray): Mean of a pre-defined multivariate Gaussian
model calculated on the pristine dataset.
cov_pris_param (ndarray): Covariance of a pre-defined multivariate
Gaussian model calculated on the pristine dataset.
gaussian_window (ndarray): A 7x7 Gaussian window used for smoothing the
image.
block_size_h (int): Height of the blocks in to which image is divided.
Default: 96 (the official recommended value).
block_size_w (int): Width of the blocks in to which image is divided.
Default: 96 (the official recommended value).
"""
assert
img
.
ndim
==
2
,
(
'Input image must be a gray or Y (of YCbCr) image with shape (h, w).'
)
# crop image
h
,
w
=
img
.
shape
num_block_h
=
math
.
floor
(
h
/
block_size_h
)
num_block_w
=
math
.
floor
(
w
/
block_size_w
)
img
=
img
[
0
:
num_block_h
*
block_size_h
,
0
:
num_block_w
*
block_size_w
]
distparam
=
[]
# dist param is actually the multiscale features
for
scale
in
(
1
,
2
):
# perform on two scales (1, 2)
mu
=
convolve
(
img
,
gaussian_window
,
mode
=
'nearest'
)
sigma
=
np
.
sqrt
(
np
.
abs
(
convolve
(
np
.
square
(
img
),
gaussian_window
,
mode
=
'nearest'
)
-
np
.
square
(
mu
)))
# normalize, as in Eq. 1 in the paper
img_nomalized
=
(
img
-
mu
)
/
(
sigma
+
1
)
feat
=
[]
for
idx_w
in
range
(
num_block_w
):
for
idx_h
in
range
(
num_block_h
):
# process ecah block
block
=
img_nomalized
[
idx_h
*
block_size_h
//
scale
:(
idx_h
+
1
)
*
block_size_h
//
scale
,
idx_w
*
block_size_w
//
scale
:(
idx_w
+
1
)
*
block_size_w
//
scale
]
feat
.
append
(
compute_feature
(
block
))
distparam
.
append
(
np
.
array
(
feat
))
if
scale
==
1
:
img
=
imresize
(
img
/
255.
,
scale
=
0.5
,
antialiasing
=
True
)
img
=
img
*
255.
distparam
=
np
.
concatenate
(
distparam
,
axis
=
1
)
# fit a MVG (multivariate Gaussian) model to distorted patch features
mu_distparam
=
np
.
nanmean
(
distparam
,
axis
=
0
)
# use nancov. ref: https://ww2.mathworks.cn/help/stats/nancov.html
distparam_no_nan
=
distparam
[
~
np
.
isnan
(
distparam
).
any
(
axis
=
1
)]
cov_distparam
=
np
.
cov
(
distparam_no_nan
,
rowvar
=
False
)
# compute niqe quality, Eq. 10 in the paper
invcov_param
=
np
.
linalg
.
pinv
((
cov_pris_param
+
cov_distparam
)
/
2
)
quality
=
np
.
matmul
(
np
.
matmul
((
mu_pris_param
-
mu_distparam
),
invcov_param
),
np
.
transpose
((
mu_pris_param
-
mu_distparam
)))
quality
=
np
.
sqrt
(
quality
)
quality
=
float
(
np
.
squeeze
(
quality
))
return
quality
@
METRIC_REGISTRY
.
register
()
def
calculate_niqe
(
img
,
crop_border
,
input_order
=
'HWC'
,
convert_to
=
'y'
,
**
kwargs
):
"""Calculate NIQE (Natural Image Quality Evaluator) metric.
``Paper: Making a "Completely Blind" Image Quality Analyzer``
This implementation could produce almost the same results as the official
MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
> MATLAB R2021a result for tests/data/baboon.png: 5.72957338 (5.7296)
> Our re-implementation result for tests/data/baboon.png: 5.7295763 (5.7296)
We use the official params estimated from the pristine dataset.
We use the recommended block size (96, 96) without overlaps.
Args:
img (ndarray): Input image whose quality needs to be computed.
The input image must be in range [0, 255] with float/int type.
The input_order of image can be 'HW' or 'HWC' or 'CHW'. (BGR order)
If the input order is 'HWC' or 'CHW', it will be converted to gray
or Y (of YCbCr) image according to the ``convert_to`` argument.
crop_border (int): Cropped pixels in each edge of an image. These
pixels are not involved in the metric calculation.
input_order (str): Whether the input order is 'HW', 'HWC' or 'CHW'.
Default: 'HWC'.
convert_to (str): Whether converted to 'y' (of MATLAB YCbCr) or 'gray'.
Default: 'y'.
Returns:
float: NIQE result.
"""
ROOT_DIR
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
# we use the official params estimated from the pristine dataset.
niqe_pris_params
=
np
.
load
(
os
.
path
.
join
(
ROOT_DIR
,
'niqe_pris_params.npz'
))
mu_pris_param
=
niqe_pris_params
[
'mu_pris_param'
]
cov_pris_param
=
niqe_pris_params
[
'cov_pris_param'
]
gaussian_window
=
niqe_pris_params
[
'gaussian_window'
]
img
=
img
.
astype
(
np
.
float32
)
if
input_order
!=
'HW'
:
img
=
reorder_image
(
img
,
input_order
=
input_order
)
if
convert_to
==
'y'
:
img
=
to_y_channel
(
img
)
elif
convert_to
==
'gray'
:
img
=
cv2
.
cvtColor
(
img
/
255.
,
cv2
.
COLOR_BGR2GRAY
)
*
255.
img
=
np
.
squeeze
(
img
)
if
crop_border
!=
0
:
img
=
img
[
crop_border
:
-
crop_border
,
crop_border
:
-
crop_border
]
# round is necessary for being consistent with MATLAB's result
img
=
img
.
round
()
niqe_result
=
niqe
(
img
,
mu_pris_param
,
cov_pris_param
,
gaussian_window
)
return
niqe_result
BasicSR/basicsr/metrics/niqe_pris_params.npz
0 → 100644
View file @
e2696ece
File added
Prev
1
2
3
4
5
6
7
8
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment