Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ControlNet_pytorch
Commits
e2696ece
Commit
e2696ece
authored
Nov 22, 2023
by
mashun1
Browse files
controlnet
parents
Pipeline
#643
canceled with stages
Changes
263
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1358 additions
and
0 deletions
+1358
-0
BasicSR/scripts/data_preparation/create_lmdb.py
BasicSR/scripts/data_preparation/create_lmdb.py
+174
-0
BasicSR/scripts/data_preparation/download_datasets.py
BasicSR/scripts/data_preparation/download_datasets.py
+70
-0
BasicSR/scripts/data_preparation/extract_images_from_tfrecords.py
...scripts/data_preparation/extract_images_from_tfrecords.py
+199
-0
BasicSR/scripts/data_preparation/extract_subimages.py
BasicSR/scripts/data_preparation/extract_subimages.py
+156
-0
BasicSR/scripts/data_preparation/generate_meta_info.py
BasicSR/scripts/data_preparation/generate_meta_info.py
+34
-0
BasicSR/scripts/data_preparation/prepare_hifacegan_dataset.py
...cSR/scripts/data_preparation/prepare_hifacegan_dataset.py
+113
-0
BasicSR/scripts/data_preparation/regroup_reds_dataset.py
BasicSR/scripts/data_preparation/regroup_reds_dataset.py
+34
-0
BasicSR/scripts/dist_test.sh
BasicSR/scripts/dist_test.sh
+16
-0
BasicSR/scripts/dist_train.sh
BasicSR/scripts/dist_train.sh
+16
-0
BasicSR/scripts/download_gdrive.py
BasicSR/scripts/download_gdrive.py
+12
-0
BasicSR/scripts/download_pretrained_models.py
BasicSR/scripts/download_pretrained_models.py
+112
-0
BasicSR/scripts/matlab_scripts/back_projection/backprojection.m
...R/scripts/matlab_scripts/back_projection/backprojection.m
+20
-0
BasicSR/scripts/matlab_scripts/back_projection/main_bp.m
BasicSR/scripts/matlab_scripts/back_projection/main_bp.m
+22
-0
BasicSR/scripts/matlab_scripts/back_projection/main_reverse_filter.m
...ipts/matlab_scripts/back_projection/main_reverse_filter.m
+25
-0
BasicSR/scripts/matlab_scripts/generate_LR_Vimeo90K.m
BasicSR/scripts/matlab_scripts/generate_LR_Vimeo90K.m
+49
-0
BasicSR/scripts/matlab_scripts/generate_bicubic_img.m
BasicSR/scripts/matlab_scripts/generate_bicubic_img.m
+87
-0
BasicSR/scripts/metrics/calculate_fid_folder.py
BasicSR/scripts/metrics/calculate_fid_folder.py
+74
-0
BasicSR/scripts/metrics/calculate_fid_stats_from_datasets.py
BasicSR/scripts/metrics/calculate_fid_stats_from_datasets.py
+61
-0
BasicSR/scripts/metrics/calculate_lpips.py
BasicSR/scripts/metrics/calculate_lpips.py
+50
-0
BasicSR/scripts/metrics/calculate_niqe.py
BasicSR/scripts/metrics/calculate_niqe.py
+34
-0
No files found.
Too many changes to show.
To preserve performance only
263 of 263+
files are displayed.
Plain diff
Email patch
BasicSR/scripts/data_preparation/create_lmdb.py
0 → 100644
View file @
e2696ece
import
argparse
from
os
import
path
as
osp
from
basicsr.utils
import
scandir
from
basicsr.utils.lmdb_util
import
make_lmdb_from_imgs
def
create_lmdb_for_div2k
():
"""Create lmdb files for DIV2K dataset.
Usage:
Before run this script, please run `extract_subimages.py`.
Typically, there are four folders to be processed for DIV2K dataset.
* DIV2K_train_HR_sub
* DIV2K_train_LR_bicubic/X2_sub
* DIV2K_train_LR_bicubic/X3_sub
* DIV2K_train_LR_bicubic/X4_sub
Remember to modify opt configurations according to your settings.
"""
# HR images
folder_path
=
'datasets/DIV2K/DIV2K_train_HR_sub'
lmdb_path
=
'datasets/DIV2K/DIV2K_train_HR_sub.lmdb'
img_path_list
,
keys
=
prepare_keys_div2k
(
folder_path
)
make_lmdb_from_imgs
(
folder_path
,
lmdb_path
,
img_path_list
,
keys
)
# LRx2 images
folder_path
=
'datasets/DIV2K/DIV2K_train_LR_bicubic/X2_sub'
lmdb_path
=
'datasets/DIV2K/DIV2K_train_LR_bicubic_X2_sub.lmdb'
img_path_list
,
keys
=
prepare_keys_div2k
(
folder_path
)
make_lmdb_from_imgs
(
folder_path
,
lmdb_path
,
img_path_list
,
keys
)
# LRx3 images
folder_path
=
'datasets/DIV2K/DIV2K_train_LR_bicubic/X3_sub'
lmdb_path
=
'datasets/DIV2K/DIV2K_train_LR_bicubic_X3_sub.lmdb'
img_path_list
,
keys
=
prepare_keys_div2k
(
folder_path
)
make_lmdb_from_imgs
(
folder_path
,
lmdb_path
,
img_path_list
,
keys
)
# LRx4 images
folder_path
=
'datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub'
lmdb_path
=
'datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb'
img_path_list
,
keys
=
prepare_keys_div2k
(
folder_path
)
make_lmdb_from_imgs
(
folder_path
,
lmdb_path
,
img_path_list
,
keys
)
def
prepare_keys_div2k
(
folder_path
):
"""Prepare image path list and keys for DIV2K dataset.
Args:
folder_path (str): Folder path.
Returns:
list[str]: Image path list.
list[str]: Key list.
"""
print
(
'Reading image path list ...'
)
img_path_list
=
sorted
(
list
(
scandir
(
folder_path
,
suffix
=
'png'
,
recursive
=
False
)))
keys
=
[
img_path
.
split
(
'.png'
)[
0
]
for
img_path
in
sorted
(
img_path_list
)]
return
img_path_list
,
keys
def
create_lmdb_for_reds
():
"""Create lmdb files for REDS dataset.
Usage:
Before run this script, please run :file:`merge_reds_train_val.py`.
We take two folders for example:
* train_sharp
* train_sharp_bicubic
Remember to modify opt configurations according to your settings.
"""
# train_sharp
folder_path
=
'datasets/REDS/train_sharp'
lmdb_path
=
'datasets/REDS/train_sharp_with_val.lmdb'
img_path_list
,
keys
=
prepare_keys_reds
(
folder_path
)
make_lmdb_from_imgs
(
folder_path
,
lmdb_path
,
img_path_list
,
keys
,
multiprocessing_read
=
True
)
# train_sharp_bicubic
folder_path
=
'datasets/REDS/train_sharp_bicubic'
lmdb_path
=
'datasets/REDS/train_sharp_bicubic_with_val.lmdb'
img_path_list
,
keys
=
prepare_keys_reds
(
folder_path
)
make_lmdb_from_imgs
(
folder_path
,
lmdb_path
,
img_path_list
,
keys
,
multiprocessing_read
=
True
)
def
prepare_keys_reds
(
folder_path
):
"""Prepare image path list and keys for REDS dataset.
Args:
folder_path (str): Folder path.
Returns:
list[str]: Image path list.
list[str]: Key list.
"""
print
(
'Reading image path list ...'
)
img_path_list
=
sorted
(
list
(
scandir
(
folder_path
,
suffix
=
'png'
,
recursive
=
True
)))
keys
=
[
v
.
split
(
'.png'
)[
0
]
for
v
in
img_path_list
]
# example: 000/00000000
return
img_path_list
,
keys
def
create_lmdb_for_vimeo90k
():
"""Create lmdb files for Vimeo90K dataset.
Usage:
Remember to modify opt configurations according to your settings.
"""
# GT
folder_path
=
'datasets/vimeo90k/vimeo_septuplet/sequences'
lmdb_path
=
'datasets/vimeo90k/vimeo90k_train_GT_only4th.lmdb'
train_list_path
=
'datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
img_path_list
,
keys
=
prepare_keys_vimeo90k
(
folder_path
,
train_list_path
,
'gt'
)
make_lmdb_from_imgs
(
folder_path
,
lmdb_path
,
img_path_list
,
keys
,
multiprocessing_read
=
True
)
# LQ
folder_path
=
'datasets/vimeo90k/vimeo_septuplet_matlabLRx4/sequences'
lmdb_path
=
'datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb'
train_list_path
=
'datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
img_path_list
,
keys
=
prepare_keys_vimeo90k
(
folder_path
,
train_list_path
,
'lq'
)
make_lmdb_from_imgs
(
folder_path
,
lmdb_path
,
img_path_list
,
keys
,
multiprocessing_read
=
True
)
def
prepare_keys_vimeo90k
(
folder_path
,
train_list_path
,
mode
):
"""Prepare image path list and keys for Vimeo90K dataset.
Args:
folder_path (str): Folder path.
train_list_path (str): Path to the official train list.
mode (str): One of 'gt' or 'lq'.
Returns:
list[str]: Image path list.
list[str]: Key list.
"""
print
(
'Reading image path list ...'
)
with
open
(
train_list_path
,
'r'
)
as
fin
:
train_list
=
[
line
.
strip
()
for
line
in
fin
]
img_path_list
=
[]
keys
=
[]
for
line
in
train_list
:
folder
,
sub_folder
=
line
.
split
(
'/'
)
img_path_list
.
extend
([
osp
.
join
(
folder
,
sub_folder
,
f
'im
{
j
+
1
}
.png'
)
for
j
in
range
(
7
)])
keys
.
extend
([
f
'
{
folder
}
/
{
sub_folder
}
/im
{
j
+
1
}
'
for
j
in
range
(
7
)])
if
mode
==
'gt'
:
print
(
'Only keep the 4th frame for the gt mode.'
)
img_path_list
=
[
v
for
v
in
img_path_list
if
v
.
endswith
(
'im4.png'
)]
keys
=
[
v
for
v
in
keys
if
v
.
endswith
(
'/im4'
)]
return
img_path_list
,
keys
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--dataset'
,
type
=
str
,
help
=
(
"Options: 'DIV2K', 'REDS', 'Vimeo90K' You may need to modify the corresponding configurations in codes."
))
args
=
parser
.
parse_args
()
dataset
=
args
.
dataset
.
lower
()
if
dataset
==
'div2k'
:
create_lmdb_for_div2k
()
elif
dataset
==
'reds'
:
create_lmdb_for_reds
()
elif
dataset
==
'vimeo90k'
:
create_lmdb_for_vimeo90k
()
else
:
raise
ValueError
(
'Wrong dataset.'
)
BasicSR/scripts/data_preparation/download_datasets.py
0 → 100644
View file @
e2696ece
import
argparse
import
glob
import
os
from
os
import
path
as
osp
from
basicsr.utils.download_util
import
download_file_from_google_drive
def
download_dataset
(
dataset
,
file_ids
):
save_path_root
=
'./datasets/'
os
.
makedirs
(
save_path_root
,
exist_ok
=
True
)
for
file_name
,
file_id
in
file_ids
.
items
():
save_path
=
osp
.
abspath
(
osp
.
join
(
save_path_root
,
file_name
))
if
osp
.
exists
(
save_path
):
user_response
=
input
(
f
'
{
file_name
}
already exist. Do you want to cover it? Y/N
\n
'
)
if
user_response
.
lower
()
==
'y'
:
print
(
f
'Covering
{
file_name
}
to
{
save_path
}
'
)
download_file_from_google_drive
(
file_id
,
save_path
)
elif
user_response
.
lower
()
==
'n'
:
print
(
f
'Skipping
{
file_name
}
'
)
else
:
raise
ValueError
(
'Wrong input. Only accepts Y/N.'
)
else
:
print
(
f
'Downloading
{
file_name
}
to
{
save_path
}
'
)
download_file_from_google_drive
(
file_id
,
save_path
)
# unzip
if
save_path
.
endswith
(
'.zip'
):
extracted_path
=
save_path
.
replace
(
'.zip'
,
''
)
print
(
f
'Extract
{
save_path
}
to
{
extracted_path
}
'
)
import
zipfile
with
zipfile
.
ZipFile
(
save_path
,
'r'
)
as
zip_ref
:
zip_ref
.
extractall
(
extracted_path
)
file_name
=
file_name
.
replace
(
'.zip'
,
''
)
subfolder
=
osp
.
join
(
extracted_path
,
file_name
)
if
osp
.
isdir
(
subfolder
):
print
(
f
'Move
{
subfolder
}
to
{
extracted_path
}
'
)
import
shutil
for
path
in
glob
.
glob
(
osp
.
join
(
subfolder
,
'*'
)):
shutil
.
move
(
path
,
extracted_path
)
shutil
.
rmtree
(
subfolder
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'dataset'
,
type
=
str
,
help
=
(
"Options: 'Set5', 'Set14'. "
"Set to 'all' if you want to download all the dataset."
))
args
=
parser
.
parse_args
()
file_ids
=
{
'Set5'
:
{
'Set5.zip'
:
# file name
'1RtyIeUFTyW8u7oa4z7a0lSzT3T1FwZE9'
,
# file id
},
'Set14'
:
{
'Set14.zip'
:
'1vsw07sV8wGrRQ8UARe2fO5jjgy9QJy_E'
,
}
}
if
args
.
dataset
==
'all'
:
for
dataset
in
file_ids
.
keys
():
download_dataset
(
dataset
,
file_ids
[
dataset
])
else
:
download_dataset
(
args
.
dataset
,
file_ids
[
args
.
dataset
])
BasicSR/scripts/data_preparation/extract_images_from_tfrecords.py
0 → 100644
View file @
e2696ece
import
argparse
import
cv2
import
glob
import
numpy
as
np
import
os
from
basicsr.utils.lmdb_util
import
LmdbMaker
def
convert_celeba_tfrecords
(
tf_file
,
log_resolution
,
save_root
,
save_type
=
'img'
,
compress_level
=
1
):
"""Convert CelebA tfrecords to images or lmdb files.
Args:
tf_file (str): Input tfrecords file in glob pattern.
Example: 'datasets/celeba/celeba_tfrecords/validation/validation-r08-s-*-of-*.tfrecords' # noqa:E501
log_resolution (int): Log scale of resolution.
save_root (str): Path root to save.
save_type (str): Save type. Options: img | lmdb. Default: img.
compress_level (int): Compress level when encoding images. Default: 1.
"""
if
'validation'
in
tf_file
:
phase
=
'validation'
else
:
phase
=
'train'
if
save_type
==
'lmdb'
:
save_path
=
os
.
path
.
join
(
save_root
,
f
'celeba_
{
2
**
log_resolution
}
_
{
phase
}
.lmdb'
)
lmdb_maker
=
LmdbMaker
(
save_path
)
elif
save_type
==
'img'
:
save_path
=
os
.
path
.
join
(
save_root
,
f
'celeba_
{
2
**
log_resolution
}
_
{
phase
}
'
)
else
:
raise
ValueError
(
'Wrong save type.'
)
os
.
makedirs
(
save_path
,
exist_ok
=
True
)
idx
=
0
for
record
in
sorted
(
glob
.
glob
(
tf_file
)):
print
(
'Processing record: '
,
record
)
record_iterator
=
tf
.
python_io
.
tf_record_iterator
(
record
)
for
string_record
in
record_iterator
:
example
=
tf
.
train
.
Example
()
example
.
ParseFromString
(
string_record
)
# label = example.features.feature['label'].int64_list.value[0]
# attr = example.features.feature['attr'].int64_list.value
# male = attr[20]
# young = attr[39]
shape
=
example
.
features
.
feature
[
'shape'
].
int64_list
.
value
h
,
w
,
c
=
shape
img_str
=
example
.
features
.
feature
[
'data'
].
bytes_list
.
value
[
0
]
img
=
np
.
fromstring
(
img_str
,
dtype
=
np
.
uint8
).
reshape
((
h
,
w
,
c
))
img
=
img
[:,
:,
[
2
,
1
,
0
]]
if
save_type
==
'img'
:
cv2
.
imwrite
(
os
.
path
.
join
(
save_path
,
f
'
{
idx
:
08
d
}
.png'
),
img
)
elif
save_type
==
'lmdb'
:
_
,
img_byte
=
cv2
.
imencode
(
'.png'
,
img
,
[
cv2
.
IMWRITE_PNG_COMPRESSION
,
compress_level
])
key
=
f
'
{
idx
:
08
d
}
/r
{
log_resolution
:
02
d
}
'
lmdb_maker
.
put
(
img_byte
,
key
,
(
h
,
w
,
c
))
idx
+=
1
print
(
idx
)
if
save_type
==
'lmdb'
:
lmdb_maker
.
close
()
def
convert_ffhq_tfrecords
(
tf_file
,
log_resolution
,
save_root
,
save_type
=
'img'
,
compress_level
=
1
):
"""Convert FFHQ tfrecords to images or lmdb files.
Args:
tf_file (str): Input tfrecords file.
log_resolution (int): Log scale of resolution.
save_root (str): Path root to save.
save_type (str): Save type. Options: img | lmdb. Default: img.
compress_level (int): Compress level when encoding images. Default: 1.
"""
if
save_type
==
'lmdb'
:
save_path
=
os
.
path
.
join
(
save_root
,
f
'ffhq_
{
2
**
log_resolution
}
.lmdb'
)
lmdb_maker
=
LmdbMaker
(
save_path
)
elif
save_type
==
'img'
:
save_path
=
os
.
path
.
join
(
save_root
,
f
'ffhq_
{
2
**
log_resolution
}
'
)
else
:
raise
ValueError
(
'Wrong save type.'
)
os
.
makedirs
(
save_path
,
exist_ok
=
True
)
idx
=
0
for
record
in
sorted
(
glob
.
glob
(
tf_file
)):
print
(
'Processing record: '
,
record
)
record_iterator
=
tf
.
python_io
.
tf_record_iterator
(
record
)
for
string_record
in
record_iterator
:
example
=
tf
.
train
.
Example
()
example
.
ParseFromString
(
string_record
)
shape
=
example
.
features
.
feature
[
'shape'
].
int64_list
.
value
c
,
h
,
w
=
shape
img_str
=
example
.
features
.
feature
[
'data'
].
bytes_list
.
value
[
0
]
img
=
np
.
fromstring
(
img_str
,
dtype
=
np
.
uint8
).
reshape
((
c
,
h
,
w
))
img
=
img
.
transpose
(
1
,
2
,
0
)
img
=
img
[:,
:,
[
2
,
1
,
0
]]
if
save_type
==
'img'
:
cv2
.
imwrite
(
os
.
path
.
join
(
save_path
,
f
'
{
idx
:
08
d
}
.png'
),
img
)
elif
save_type
==
'lmdb'
:
_
,
img_byte
=
cv2
.
imencode
(
'.png'
,
img
,
[
cv2
.
IMWRITE_PNG_COMPRESSION
,
compress_level
])
key
=
f
'
{
idx
:
08
d
}
/r
{
log_resolution
:
02
d
}
'
lmdb_maker
.
put
(
img_byte
,
key
,
(
h
,
w
,
c
))
idx
+=
1
print
(
idx
)
if
save_type
==
'lmdb'
:
lmdb_maker
.
close
()
def
make_ffhq_lmdb_from_imgs
(
folder_path
,
log_resolution
,
save_root
,
save_type
=
'lmdb'
,
compress_level
=
1
):
"""Make FFHQ lmdb from images.
Args:
folder_path (str): Folder path.
log_resolution (int): Log scale of resolution.
save_root (str): Path root to save.
save_type (str): Save type. Options: img | lmdb. Default: img.
compress_level (int): Compress level when encoding images. Default: 1.
"""
if
save_type
==
'lmdb'
:
save_path
=
os
.
path
.
join
(
save_root
,
f
'ffhq_
{
2
**
log_resolution
}
_crop1.2.lmdb'
)
lmdb_maker
=
LmdbMaker
(
save_path
)
else
:
raise
ValueError
(
'Wrong save type.'
)
os
.
makedirs
(
save_path
,
exist_ok
=
True
)
img_list
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
folder_path
,
'*'
)))
for
idx
,
img_path
in
enumerate
(
img_list
):
print
(
f
'Processing
{
idx
}
: '
,
img_path
)
img
=
cv2
.
imread
(
img_path
)
h
,
w
,
c
=
img
.
shape
if
save_type
==
'lmdb'
:
_
,
img_byte
=
cv2
.
imencode
(
'.png'
,
img
,
[
cv2
.
IMWRITE_PNG_COMPRESSION
,
compress_level
])
key
=
f
'
{
idx
:
08
d
}
/r
{
log_resolution
:
02
d
}
'
lmdb_maker
.
put
(
img_byte
,
key
,
(
h
,
w
,
c
))
if
save_type
==
'lmdb'
:
lmdb_maker
.
close
()
if
__name__
==
'__main__'
:
"""Read tfrecords w/o define a graph.
We have tested it on TensorFlow 1.15
References: http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/
"""
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--dataset'
,
type
=
str
,
default
=
'ffhq'
,
help
=
"Dataset name. Options: 'ffhq' | 'celeba'. Default: 'ffhq'."
)
parser
.
add_argument
(
'--tf_file'
,
type
=
str
,
default
=
'datasets/ffhq/ffhq-r10.tfrecords'
,
help
=
(
'Input tfrecords file. For celeba, it should be glob pattern. '
'Put quotes around the wildcard argument to prevent the shell '
'from expanding it.'
"Example: 'datasets/celeba/celeba_tfrecords/validation/validation-r08-s-*-of-*.tfrecords'"
# noqa:E501
))
parser
.
add_argument
(
'--log_resolution'
,
type
=
int
,
default
=
10
,
help
=
'Log scale of resolution.'
)
parser
.
add_argument
(
'--save_root'
,
type
=
str
,
default
=
'datasets/ffhq/'
,
help
=
'Save root path.'
)
parser
.
add_argument
(
'--save_type'
,
type
=
str
,
default
=
'img'
,
help
=
"Save type. Options: 'img' | 'lmdb'. Default: 'img'."
)
parser
.
add_argument
(
'--compress_level'
,
type
=
int
,
default
=
1
,
help
=
'Compress level when encoding images. Default: 1.'
)
args
=
parser
.
parse_args
()
try
:
import
tensorflow
as
tf
except
Exception
:
raise
ImportError
(
'You need to install tensorflow to read tfrecords.'
)
if
args
.
dataset
==
'ffhq'
:
convert_ffhq_tfrecords
(
args
.
tf_file
,
args
.
log_resolution
,
args
.
save_root
,
save_type
=
args
.
save_type
,
compress_level
=
args
.
compress_level
)
else
:
convert_celeba_tfrecords
(
args
.
tf_file
,
args
.
log_resolution
,
args
.
save_root
,
save_type
=
args
.
save_type
,
compress_level
=
args
.
compress_level
)
BasicSR/scripts/data_preparation/extract_subimages.py
0 → 100644
View file @
e2696ece
import
cv2
import
numpy
as
np
import
os
import
sys
from
multiprocessing
import
Pool
from
os
import
path
as
osp
from
tqdm
import
tqdm
from
basicsr.utils
import
scandir
def
main
():
"""A multi-thread tool to crop large images to sub-images for faster IO.
It is used for DIV2K dataset.
Args:
opt (dict): Configuration dict. It contains:
n_thread (int): Thread number.
compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and
longer compression time. Use 0 for faster CPU decompression. Default: 3, same in cv2.
input_folder (str): Path to the input folder.
save_folder (str): Path to save folder.
crop_size (int): Crop size.
step (int): Step for overlapped sliding window.
thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped.
Usage:
For each folder, run this script.
Typically, there are four folders to be processed for DIV2K dataset.
* DIV2K_train_HR
* DIV2K_train_LR_bicubic/X2
* DIV2K_train_LR_bicubic/X3
* DIV2K_train_LR_bicubic/X4
After process, each sub_folder should have the same number of subimages.
Remember to modify opt configurations according to your settings.
"""
opt
=
{}
opt
[
'n_thread'
]
=
20
opt
[
'compression_level'
]
=
3
# HR images
opt
[
'input_folder'
]
=
'datasets/DIV2K/DIV2K_train_HR'
opt
[
'save_folder'
]
=
'datasets/DIV2K/DIV2K_train_HR_sub'
opt
[
'crop_size'
]
=
480
opt
[
'step'
]
=
240
opt
[
'thresh_size'
]
=
0
extract_subimages
(
opt
)
# LRx2 images
opt
[
'input_folder'
]
=
'datasets/DIV2K/DIV2K_train_LR_bicubic/X2'
opt
[
'save_folder'
]
=
'datasets/DIV2K/DIV2K_train_LR_bicubic/X2_sub'
opt
[
'crop_size'
]
=
240
opt
[
'step'
]
=
120
opt
[
'thresh_size'
]
=
0
extract_subimages
(
opt
)
# LRx3 images
opt
[
'input_folder'
]
=
'datasets/DIV2K/DIV2K_train_LR_bicubic/X3'
opt
[
'save_folder'
]
=
'datasets/DIV2K/DIV2K_train_LR_bicubic/X3_sub'
opt
[
'crop_size'
]
=
160
opt
[
'step'
]
=
80
opt
[
'thresh_size'
]
=
0
extract_subimages
(
opt
)
# LRx4 images
opt
[
'input_folder'
]
=
'datasets/DIV2K/DIV2K_train_LR_bicubic/X4'
opt
[
'save_folder'
]
=
'datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub'
opt
[
'crop_size'
]
=
120
opt
[
'step'
]
=
60
opt
[
'thresh_size'
]
=
0
extract_subimages
(
opt
)
def
extract_subimages
(
opt
):
"""Crop images to subimages.
Args:
opt (dict): Configuration dict. It contains:
input_folder (str): Path to the input folder.
save_folder (str): Path to save folder.
n_thread (int): Thread number.
"""
input_folder
=
opt
[
'input_folder'
]
save_folder
=
opt
[
'save_folder'
]
if
not
osp
.
exists
(
save_folder
):
os
.
makedirs
(
save_folder
)
print
(
f
'mkdir
{
save_folder
}
...'
)
else
:
print
(
f
'Folder
{
save_folder
}
already exists. Exit.'
)
sys
.
exit
(
1
)
img_list
=
list
(
scandir
(
input_folder
,
full_path
=
True
))
pbar
=
tqdm
(
total
=
len
(
img_list
),
unit
=
'image'
,
desc
=
'Extract'
)
pool
=
Pool
(
opt
[
'n_thread'
])
for
path
in
img_list
:
pool
.
apply_async
(
worker
,
args
=
(
path
,
opt
),
callback
=
lambda
arg
:
pbar
.
update
(
1
))
pool
.
close
()
pool
.
join
()
pbar
.
close
()
print
(
'All processes done.'
)
def
worker
(
path
,
opt
):
"""Worker for each process.
Args:
path (str): Image path.
opt (dict): Configuration dict. It contains:
crop_size (int): Crop size.
step (int): Step for overlapped sliding window.
thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped.
save_folder (str): Path to save folder.
compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION.
Returns:
process_info (str): Process information displayed in progress bar.
"""
crop_size
=
opt
[
'crop_size'
]
step
=
opt
[
'step'
]
thresh_size
=
opt
[
'thresh_size'
]
img_name
,
extension
=
osp
.
splitext
(
osp
.
basename
(
path
))
# remove the x2, x3, x4 and x8 in the filename for DIV2K
img_name
=
img_name
.
replace
(
'x2'
,
''
).
replace
(
'x3'
,
''
).
replace
(
'x4'
,
''
).
replace
(
'x8'
,
''
)
img
=
cv2
.
imread
(
path
,
cv2
.
IMREAD_UNCHANGED
)
h
,
w
=
img
.
shape
[
0
:
2
]
h_space
=
np
.
arange
(
0
,
h
-
crop_size
+
1
,
step
)
if
h
-
(
h_space
[
-
1
]
+
crop_size
)
>
thresh_size
:
h_space
=
np
.
append
(
h_space
,
h
-
crop_size
)
w_space
=
np
.
arange
(
0
,
w
-
crop_size
+
1
,
step
)
if
w
-
(
w_space
[
-
1
]
+
crop_size
)
>
thresh_size
:
w_space
=
np
.
append
(
w_space
,
w
-
crop_size
)
index
=
0
for
x
in
h_space
:
for
y
in
w_space
:
index
+=
1
cropped_img
=
img
[
x
:
x
+
crop_size
,
y
:
y
+
crop_size
,
...]
cropped_img
=
np
.
ascontiguousarray
(
cropped_img
)
cv2
.
imwrite
(
osp
.
join
(
opt
[
'save_folder'
],
f
'
{
img_name
}
_s
{
index
:
03
d
}{
extension
}
'
),
cropped_img
,
[
cv2
.
IMWRITE_PNG_COMPRESSION
,
opt
[
'compression_level'
]])
process_info
=
f
'Processing
{
img_name
}
...'
return
process_info
if
__name__
==
'__main__'
:
main
()
BasicSR/scripts/data_preparation/generate_meta_info.py
0 → 100644
View file @
e2696ece
from
os
import
path
as
osp
from
PIL
import
Image
from
basicsr.utils
import
scandir
def
generate_meta_info_div2k
():
"""Generate meta info for DIV2K dataset.
"""
gt_folder
=
'datasets/DIV2K/DIV2K_train_HR_sub/'
meta_info_txt
=
'basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt'
img_list
=
sorted
(
list
(
scandir
(
gt_folder
)))
with
open
(
meta_info_txt
,
'w'
)
as
f
:
for
idx
,
img_path
in
enumerate
(
img_list
):
img
=
Image
.
open
(
osp
.
join
(
gt_folder
,
img_path
))
# lazy load
width
,
height
=
img
.
size
mode
=
img
.
mode
if
mode
==
'RGB'
:
n_channel
=
3
elif
mode
==
'L'
:
n_channel
=
1
else
:
raise
ValueError
(
f
'Unsupported mode
{
mode
}
.'
)
info
=
f
'
{
img_path
}
(
{
height
}
,
{
width
}
,
{
n_channel
}
)'
print
(
idx
+
1
,
info
)
f
.
write
(
f
'
{
info
}
\n
'
)
if
__name__
==
'__main__'
:
generate_meta_info_div2k
()
BasicSR/scripts/data_preparation/prepare_hifacegan_dataset.py
0 → 100644
View file @
e2696ece
import
cv2
import
os
from
tqdm
import
tqdm
class
Mosaic16x
:
"""
Mosaic16x: A customized image augmentor for 16-pixel mosaic
By default it replaces each pixel value with the mean value
of its 16x16 neighborhood
"""
def
augment_image
(
self
,
x
):
h
,
w
=
x
.
shape
[:
2
]
x
=
x
.
astype
(
'float'
)
# avoid overflow for uint8
irange
,
jrange
=
(
h
+
15
)
//
16
,
(
w
+
15
)
//
16
for
i
in
range
(
irange
):
for
j
in
range
(
jrange
):
mean
=
x
[
i
*
16
:(
i
+
1
)
*
16
,
j
*
16
:(
j
+
1
)
*
16
].
mean
(
axis
=
(
0
,
1
))
x
[
i
*
16
:(
i
+
1
)
*
16
,
j
*
16
:(
j
+
1
)
*
16
]
=
mean
return
x
.
astype
(
'uint8'
)
class
DegradationSimulator
:
"""
Generating training/testing data pairs on the fly.
The degradation script is aligned with HiFaceGAN paper settings.
Args:
opt(str | op): Config for degradation script, with degradation type and parameters
Custom degradation is possible by passing an inherited class from ia.augmentors
"""
def
__init__
(
self
,
):
import
imgaug.augmenters
as
ia
self
.
default_deg_templates
=
{
'sr4x'
:
ia
.
Sequential
([
# It's almost like a 4x bicubic downsampling
ia
.
Resize
((
0.25000
,
0.25001
),
cv2
.
INTER_AREA
),
ia
.
Resize
({
'height'
:
512
,
'width'
:
512
},
cv2
.
INTER_CUBIC
),
]),
'sr4x8x'
:
ia
.
Sequential
([
ia
.
Resize
((
0.125
,
0.25
),
cv2
.
INTER_AREA
),
ia
.
Resize
({
'height'
:
512
,
'width'
:
512
},
cv2
.
INTER_CUBIC
),
]),
'denoise'
:
ia
.
OneOf
([
ia
.
AdditiveGaussianNoise
(
scale
=
(
20
,
40
),
per_channel
=
True
),
ia
.
AdditiveLaplaceNoise
(
scale
=
(
20
,
40
),
per_channel
=
True
),
ia
.
AdditivePoissonNoise
(
lam
=
(
15
,
30
),
per_channel
=
True
),
]),
'deblur'
:
ia
.
OneOf
([
ia
.
MotionBlur
(
k
=
(
10
,
20
)),
ia
.
GaussianBlur
((
3.0
,
8.0
)),
]),
'jpeg'
:
ia
.
JpegCompression
(
compression
=
(
50
,
85
)),
'16x'
:
Mosaic16x
(),
}
rand_deg_list
=
[
self
.
default_deg_templates
[
'deblur'
],
self
.
default_deg_templates
[
'denoise'
],
self
.
default_deg_templates
[
'jpeg'
],
self
.
default_deg_templates
[
'sr4x8x'
],
]
self
.
default_deg_templates
[
'face_renov'
]
=
ia
.
Sequential
(
rand_deg_list
,
random_order
=
True
)
def
create_training_dataset
(
self
,
deg
,
gt_folder
,
lq_folder
=
None
):
from
imgaug.augmenters.meta
import
Augmenter
# baseclass
"""
Create a degradation simulator and apply it to GT images on the fly
Save the degraded result in the lq_folder (if None, name it as GT_deg)
"""
if
not
lq_folder
:
suffix
=
deg
if
isinstance
(
deg
,
str
)
else
'custom'
lq_folder
=
'_'
.
join
([
gt_folder
.
replace
(
'gt'
,
'lq'
),
suffix
])
print
(
lq_folder
)
os
.
makedirs
(
lq_folder
,
exist_ok
=
True
)
if
isinstance
(
deg
,
str
):
assert
deg
in
self
.
default_deg_templates
,
(
f
'Degration type
{
deg
}
not recognized:
{
"|"
.
join
(
list
(
self
.
default_deg_templates
.
keys
()))
}
'
)
deg
=
self
.
default_deg_templates
[
deg
]
else
:
assert
isinstance
(
deg
,
Augmenter
),
f
'Deg must be either str|Augmenter, got
{
deg
}
'
names
=
os
.
listdir
(
gt_folder
)
for
name
in
tqdm
(
names
):
gt
=
cv2
.
imread
(
os
.
path
.
join
(
gt_folder
,
name
))
lq
=
deg
.
augment_image
(
gt
)
# pack = np.concatenate([lq, gt], axis=0)
cv2
.
imwrite
(
os
.
path
.
join
(
lq_folder
,
name
),
lq
)
print
(
'Dataset prepared.'
)
if
__name__
==
'__main__'
:
simuator
=
DegradationSimulator
()
gt_folder
=
'datasets/FFHQ_512_gt'
deg
=
'sr4x'
simuator
.
create_training_dataset
(
deg
,
gt_folder
)
BasicSR/scripts/data_preparation/regroup_reds_dataset.py
0 → 100644
View file @
e2696ece
import
glob
import
os
def
regroup_reds_dataset
(
train_path
,
val_path
):
"""Regroup original REDS datasets.
We merge train and validation data into one folder, and separate the
validation clips in reds_dataset.py.
There are 240 training clips (starting from 0 to 239),
so we name the validation clip index starting from 240 to 269 (total 30
validation clips).
Args:
train_path (str): Path to the train folder.
val_path (str): Path to the validation folder.
"""
# move the validation data to the train folder
val_folders
=
glob
.
glob
(
os
.
path
.
join
(
val_path
,
'*'
))
for
folder
in
val_folders
:
new_folder_idx
=
int
(
folder
.
split
(
'/'
)[
-
1
])
+
240
os
.
system
(
f
'cp -r
{
folder
}
{
os
.
path
.
join
(
train_path
,
str
(
new_folder_idx
))
}
'
)
if
__name__
==
'__main__'
:
# train_sharp
train_path
=
'datasets/REDS/train_sharp'
val_path
=
'datasets/REDS/val_sharp'
regroup_reds_dataset
(
train_path
,
val_path
)
# train_sharp_bicubic
train_path
=
'datasets/REDS/train_sharp_bicubic/X4'
val_path
=
'datasets/REDS/val_sharp_bicubic/X4'
regroup_reds_dataset
(
train_path
,
val_path
)
BasicSR/scripts/dist_test.sh
0 → 100755
View file @
e2696ece
#!/usr/bin/env bash
GPUS
=
$1
CONFIG
=
$2
PORT
=
${
PORT
:-
4321
}
# usage
if
[
$#
-ne
2
]
;
then
echo
"usage:"
echo
"./scripts/dist_test.sh [number of gpu] [path to option file]"
exit
fi
PYTHONPATH
=
"
$(
dirname
$0
)
/..:
${
PYTHONPATH
}
"
\
python
-m
torch.distributed.launch
--nproc_per_node
=
$GPUS
--master_port
=
$PORT
\
basicsr/test.py
-opt
$CONFIG
--launcher
pytorch
BasicSR/scripts/dist_train.sh
0 → 100755
View file @
e2696ece
#!/usr/bin/env bash
GPUS
=
$1
CONFIG
=
$2
PORT
=
${
PORT
:-
4321
}
# usage
if
[
$#
-lt
2
]
;
then
echo
"usage:"
echo
"./scripts/dist_train.sh [number of gpu] [path to option file]"
exit
fi
PYTHONPATH
=
"
$(
dirname
$0
)
/..:
${
PYTHONPATH
}
"
\
python
-m
torch.distributed.launch
--nproc_per_node
=
$GPUS
--master_port
=
$PORT
\
basicsr/train.py
-opt
$CONFIG
--launcher
pytorch
${
@
:3
}
BasicSR/scripts/download_gdrive.py
0 → 100644
View file @
e2696ece
import
argparse
from
basicsr.utils.download_util
import
download_file_from_google_drive
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--id'
,
type
=
str
,
help
=
'File id'
)
parser
.
add_argument
(
'--output'
,
type
=
str
,
help
=
'Save path'
)
args
=
parser
.
parse_args
()
download_file_from_google_drive
(
args
.
id
,
args
.
save_path
)
BasicSR/scripts/download_pretrained_models.py
0 → 100644
View file @
e2696ece
import
argparse
import
os
from
os
import
path
as
osp
from
basicsr.utils.download_util
import
download_file_from_google_drive
def
download_pretrained_models
(
method
,
file_ids
):
save_path_root
=
f
'./experiments/pretrained_models/
{
method
}
'
os
.
makedirs
(
save_path_root
,
exist_ok
=
True
)
for
file_name
,
file_id
in
file_ids
.
items
():
save_path
=
osp
.
abspath
(
osp
.
join
(
save_path_root
,
file_name
))
if
osp
.
exists
(
save_path
):
user_response
=
input
(
f
'
{
file_name
}
already exist. Do you want to cover it? Y/N
\n
'
)
if
user_response
.
lower
()
==
'y'
:
print
(
f
'Covering
{
file_name
}
to
{
save_path
}
'
)
download_file_from_google_drive
(
file_id
,
save_path
)
elif
user_response
.
lower
()
==
'n'
:
print
(
f
'Skipping
{
file_name
}
'
)
else
:
raise
ValueError
(
'Wrong input. Only accepts Y/N.'
)
else
:
print
(
f
'Downloading
{
file_name
}
to
{
save_path
}
'
)
download_file_from_google_drive
(
file_id
,
save_path
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'method'
,
type
=
str
,
help
=
(
"Options: 'ESRGAN', 'EDVR', 'StyleGAN', 'EDSR', 'DUF', 'DFDNet', 'dlib', 'TOF', 'flownet', 'BasicVSR'. "
"Set to 'all' to download all the models."
))
args
=
parser
.
parse_args
()
file_ids
=
{
'ESRGAN'
:
{
'ESRGAN_SRx4_DF2KOST_official-ff704c30.pth'
:
# file name
'1b3_bWZTjNO3iL2js1yWkJfjZykcQgvzT'
,
# file id
'ESRGAN_PSNR_SRx4_DF2K_official-150ff491.pth'
:
'1swaV5iBMFfg-DL6ZyiARztbhutDCWXMM'
},
'EDVR'
:
{
'EDVR_L_x4_SR_REDS_official-9f5f5039.pth'
:
'127KXEjlCwfoPC1aXyDkluNwr9elwyHNb'
,
'EDVR_L_x4_SR_Vimeo90K_official-162b54e4.pth'
:
'1aVR3lkX6ItCphNLcT7F5bbbC484h4Qqy'
,
'EDVR_M_woTSA_x4_SR_REDS_official-1edf645c.pth'
:
'1C_WdN-NyNj-P7SOB5xIVuHl4EBOwd-Ny'
,
'EDVR_M_x4_SR_REDS_official-32075921.pth'
:
'1dd6aFj-5w2v08VJTq5mS9OFsD-wALYD6'
,
'EDVR_L_x4_SRblur_REDS_official-983d7b8e.pth'
:
'1GZz_87ybR8eAAY3X2HWwI3L6ny7-5Yvl'
,
'EDVR_L_deblur_REDS_official-ca46bd8c.pth'
:
'1_ma2tgHscZtkIY2tEJkVdU-UP8bnqBRE'
,
'EDVR_L_deblurcomp_REDS_official-0e988e5c.pth'
:
'1fEoSeLFnHSBbIs95Au2W197p8e4ws4DW'
},
'StyleGAN'
:
{
'stylegan2_ffhq_config_f_1024_official-3ab41b38.pth'
:
'1qtdsT1FrvKQsFiW3OqOcIb-VS55TVy1g'
,
'stylegan2_ffhq_config_f_1024_discriminator_official-a386354a.pth'
:
'1nPqCxm8TkDU3IvXdHCzPUxlBwR5Pd78G'
,
'stylegan2_cat_config_f_256_official-0a9173ad.pth'
:
'1gfJkX6XO5pJ2J8LyMdvUgGldz7xwWpBJ'
,
'stylegan2_cat_config_f_256_discriminator_official-2c97fd08.pth'
:
'1hy5FEQQl28XvfqpiWvSBd8YnIzsyDRb7'
,
'stylegan2_church_config_f_256_official-44ba63bf.pth'
:
'1FCQMZXeOKZyl-xYKbl1Y_x2--rFl-1N_'
,
'stylegan2_church_config_f_256_discriminator_official-20cd675b.pth'
:
# noqa: E501
'1BS9ODHkUkhfTGFVfR6alCMGtr9nGm9ox'
,
'stylegan2_car_config_f_512_official-e8fcab4f.pth'
:
'14jS-nWNTguDSd1kTIX-tBHp2WdvK7hva'
,
'stylegan2_car_config_f_512_discriminator_official-5008e3d1.pth'
:
'1UxkAzZ0zvw4KzBVOUpShCivsdXBS8Zi2'
,
'stylegan2_horse_config_f_256_official-26d57fee.pth'
:
'12QsZ-mrO8_4gC0UrO15Jb3ykcQ88HxFx'
,
'stylegan2_horse_config_f_256_discriminator_official-be6c4c33.pth'
:
'1me4ybSib72xA9ZxmzKsHDtP-eNCKw_X4'
},
'EDSR'
:
{
'EDSR_Mx2_f64b16_DIV2K_official-3ba7b086.pth'
:
'1mREMGVDymId3NzIc2u90sl_X4-pb4ZcV'
,
'EDSR_Mx3_f64b16_DIV2K_official-6908f88a.pth'
:
'1EriqQqlIiRyPbrYGBbwr_FZzvb3iwqz5'
,
'EDSR_Mx4_f64b16_DIV2K_official-0c287733.pth'
:
'1bCK6cFYU01uJudLgUUe-jgx-tZ3ikOWn'
,
'EDSR_Lx2_f256b32_DIV2K_official-be38e77d.pth'
:
'15257lZCRZ0V6F9LzTyZFYbbPrqNjKyMU'
,
'EDSR_Lx3_f256b32_DIV2K_official-3660f70d.pth'
:
'18q_D434sLG_rAZeHGonAX8dkqjoyZ2su'
,
'EDSR_Lx4_f256b32_DIV2K_official-76ee1c8f.pth'
:
'1GCi30YYCzgMCcgheGWGusP9aWKOAy5vl'
},
'DUF'
:
{
'DUF_x2_16L_official-39537cb9.pth'
:
'1e91cEZOlUUk35keK9EnuK0F54QegnUKo'
,
'DUF_x3_16L_official-34ce53ec.pth'
:
'1XN6aQj20esM7i0hxTbfiZr_SL8i4PZ76'
,
'DUF_x4_16L_official-bf8f0cfa.pth'
:
'1V_h9U1CZgLSHTv1ky2M3lvuH-hK5hw_J'
,
'DUF_x4_28L_official-cbada450.pth'
:
'1M8w0AMBJW65MYYD-_8_be0cSH_SHhDQ4'
,
'DUF_x4_52L_official-483d2c78.pth'
:
'1GcmEWNr7mjTygi-QCOVgQWOo5OCNbh_T'
},
'TOF'
:
{
'tof_x4_vimeo90k_official-32c9e01f.pth'
:
'1TgQiXXsvkTBFrQ1D0eKPgL10tQGu0gKb'
},
'DFDNet'
:
{
'DFDNet_dict_512-f79685f0.pth'
:
'1iH00oMsoN_1OJaEQw3zP7_wqiAYMnY79'
,
'DFDNet_official-d1fa5650.pth'
:
'1u6Sgcp8gVoy4uVTrOJKD3y9RuqH2JBAe'
},
'dlib'
:
{
'mmod_human_face_detector-4cb19393.dat'
:
'1FUM-hcoxNzFCOpCWbAUStBBMiU4uIGIL'
,
'shape_predictor_5_face_landmarks-c4b1e980.dat'
:
'1PNPSmFjmbuuUDd5Mg5LDxyk7tu7TQv2F'
,
'shape_predictor_68_face_landmarks-fbdc2cb8.dat'
:
'1IneH-O-gNkG0SQpNCplwxtOAtRCkG2ni'
},
'flownet'
:
{
'spynet_sintel_final-3d2a1287.pth'
:
'1VZz1cikwTRVX7zXoD247DB7n5Tj_LQpF'
},
'BasicVSR'
:
{
'BasicVSR_REDS4-543c8261.pth'
:
'1wLWdz18lWf9Z7lomHPkdySZ-_GV2920p'
,
'BasicVSR_Vimeo90K_BDx4-e9bf46eb.pth'
:
'1baaf4RSpzs_zcDAF_s2CyArrGvLgmXxW'
,
'BasicVSR_Vimeo90K_BIx4-2a29695a.pth'
:
'1ykIu1jv5wo95Kca2TjlieJFxeV4VVfHP'
,
'EDVR_REDS_pretrained_for_IconVSR-f62a2f1e.pth'
:
'1ShfwddugTmT3_kB8VL6KpCMrIpEO5sBi'
,
'EDVR_Vimeo90K_pretrained_for_IconVSR-ee48ee92.pth'
:
'16vR262NDVyVv5Q49xp2Sb-Llu05f63tt'
,
'IconVSR_REDS-aaa5367f.pth'
:
'1b8ir754uIAFUSJ8YW_cmPzqer19AR7Hz'
,
'IconVSR_Vimeo90K_BDx4-cfcb7e00.pth'
:
'13lp55s-YTd-fApx8tTy24bbHsNIGXdAH'
,
'IconVSR_Vimeo90K_BIx4-35fec07c.pth'
:
'1lWUB36ERjFbAspr-8UsopJ6xwOuWjh2g'
}
}
if
args
.
method
==
'all'
:
for
method
in
file_ids
.
keys
():
download_pretrained_models
(
method
,
file_ids
[
method
])
else
:
download_pretrained_models
(
args
.
method
,
file_ids
[
args
.
method
])
BasicSR/scripts/matlab_scripts/back_projection/backprojection.m
0 → 100644
View file @
e2696ece
function [im_h] = backprojection(im_h, im_l, maxIter)
[row_l, col_l,~] = size(im_l);
[row_h, col_h,~] = size(im_h);
p = fspecial('gaussian', 5, 1);
p = p.^2;
p = p./sum(p(:));
im_l = double(im_l);
im_h = double(im_h);
for ii = 1:maxIter
im_l_s = imresize(im_h, [row_l, col_l], 'bicubic');
im_diff = im_l - im_l_s;
im_diff = imresize(im_diff, [row_h, col_h], 'bicubic');
im_h(:,:,1) = im_h(:,:,1) + conv2(im_diff(:,:,1), p, 'same');
im_h(:,:,2) = im_h(:,:,2) + conv2(im_diff(:,:,2), p, 'same');
im_h(:,:,3) = im_h(:,:,3) + conv2(im_diff(:,:,3), p, 'same');
end
BasicSR/scripts/matlab_scripts/back_projection/main_bp.m
0 → 100644
View file @
e2696ece
clear
;
close
all
;
clc
;
LR_folder
=
'./LR'
;
% LR
preout_folder
=
'./results'
;
% pre output
save_folder
=
'./results_20bp'
;
filepaths
=
dir
(
fullfile
(
preout_folder
,
'*.png'
));
max_iter
=
20
;
if
~
exist
(
save_folder
,
'dir'
)
mkdir
(
save_folder
);
end
for
idx_im
=
1
:
length
(
filepaths
)
fprintf
([
num2str
(
idx_im
)
'\n'
]);
im_name
=
filepaths
(
idx_im
)
.
name
;
im_LR
=
im2double
(
imread
(
fullfile
(
LR_folder
,
im_name
)));
im_out
=
im2double
(
imread
(
fullfile
(
preout_folder
,
im_name
)));
%tic
im_out
=
backprojection
(
im_out
,
im_LR
,
max_iter
);
%toc
imwrite
(
im_out
,
fullfile
(
save_folder
,
im_name
));
end
BasicSR/scripts/matlab_scripts/back_projection/main_reverse_filter.m
0 → 100644
View file @
e2696ece
clear
;
close
all
;
clc
;
LR_folder
=
'./LR'
;
% LR
preout_folder
=
'./results'
;
% pre output
save_folder
=
'./results_20if'
;
filepaths
=
dir
(
fullfile
(
preout_folder
,
'*.png'
));
max_iter
=
20
;
if
~
exist
(
save_folder
,
'dir'
)
mkdir
(
save_folder
);
end
for
idx_im
=
1
:
length
(
filepaths
)
fprintf
([
num2str
(
idx_im
)
'\n'
]);
im_name
=
filepaths
(
idx_im
)
.
name
;
im_LR
=
im2double
(
imread
(
fullfile
(
LR_folder
,
im_name
)));
im_out
=
im2double
(
imread
(
fullfile
(
preout_folder
,
im_name
)));
J
=
imresize
(
im_LR
,
4
,
'bicubic'
);
%tic
for
m
=
1
:
max_iter
im_out
=
im_out
+
(
J
-
imresize
(
imresize
(
im_out
,
1
/
4
,
'bicubic'
),
4
,
'bicubic'
));
end
%toc
imwrite
(
im_out
,
fullfile
(
save_folder
,
im_name
));
end
BasicSR/scripts/matlab_scripts/generate_LR_Vimeo90K.m
0 → 100644
View file @
e2696ece
function
generate_LR_Vimeo90K
()
%% matlab code to genetate bicubic-downsampled for Vimeo90K dataset
up_scale
=
4
;
mod_scale
=
4
;
idx
=
0
;
filepaths
=
dir
(
'/home/xtwang/datasets/vimeo90k/vimeo_septuplet/sequences/*/*/*.png'
);
for
i
=
1
:
length
(
filepaths
)
[
~
,
imname
,
ext
]
=
fileparts
(
filepaths
(
i
)
.
name
);
folder_path
=
filepaths
(
i
)
.
folder
;
save_LR_folder
=
strrep
(
folder_path
,
'vimeo_septuplet'
,
'vimeo_septuplet_matlabLRx4'
);
if
~
exist
(
save_LR_folder
,
'dir'
)
mkdir
(
save_LR_folder
);
end
if
isempty
(
imname
)
disp
(
'Ignore . folder.'
);
elseif
strcmp
(
imname
,
'.'
)
disp
(
'Ignore .. folder.'
);
else
idx
=
idx
+
1
;
str_result
=
sprintf
(
'%d\t%s.\n'
,
idx
,
imname
);
fprintf
(
str_result
);
% read image
img
=
imread
(
fullfile
(
folder_path
,
[
imname
,
ext
]));
img
=
im2double
(
img
);
% modcrop
img
=
modcrop
(
img
,
mod_scale
);
% LR
im_LR
=
imresize
(
img
,
1
/
up_scale
,
'bicubic'
);
if
exist
(
'save_LR_folder'
,
'var'
)
imwrite
(
im_LR
,
fullfile
(
save_LR_folder
,
[
imname
,
'.png'
]));
end
end
end
end
%% modcrop
function
img
=
modcrop
(
img
,
modulo
)
if
size
(
img
,
3
)
==
1
sz
=
size
(
img
);
sz
=
sz
-
mod
(
sz
,
modulo
);
img
=
img
(
1
:
sz
(
1
),
1
:
sz
(
2
));
else
tmpsz
=
size
(
img
);
sz
=
tmpsz
(
1
:
2
);
sz
=
sz
-
mod
(
sz
,
modulo
);
img
=
img
(
1
:
sz
(
1
),
1
:
sz
(
2
),:);
end
end
BasicSR/scripts/matlab_scripts/generate_bicubic_img.m
0 → 100644
View file @
e2696ece
function
generate_bicubic_img
()
%% matlab code to genetate mod images, bicubic-downsampled images and
%% bicubic_upsampled images
%% set configurations
% comment the unnecessary lines
input_folder
=
'../../datasets/Set5/original'
;
save_mod_folder
=
'../../datasets/Set5/GTmod12'
;
save_lr_folder
=
'../../datasets/Set5/LRbicx2'
;
% save_bic_folder = '';
mod_scale
=
12
;
up_scale
=
2
;
if
exist
(
'save_mod_folder'
,
'var'
)
if
exist
(
save_mod_folder
,
'dir'
)
disp
([
'It will cover '
,
save_mod_folder
]);
else
mkdir
(
save_mod_folder
);
end
end
if
exist
(
'save_lr_folder'
,
'var'
)
if
exist
(
save_lr_folder
,
'dir'
)
disp
([
'It will cover '
,
save_lr_folder
]);
else
mkdir
(
save_lr_folder
);
end
end
if
exist
(
'save_bic_folder'
,
'var'
)
if
exist
(
save_bic_folder
,
'dir'
)
disp
([
'It will cover '
,
save_bic_folder
]);
else
mkdir
(
save_bic_folder
);
end
end
idx
=
0
;
filepaths
=
dir
(
fullfile
(
input_folder
,
'*.*'
));
for
i
=
1
:
length
(
filepaths
)
[
paths
,
img_name
,
ext
]
=
fileparts
(
filepaths
(
i
)
.
name
);
if
isempty
(
img_name
)
disp
(
'Ignore . folder.'
);
elseif
strcmp
(
img_name
,
'.'
)
disp
(
'Ignore .. folder.'
);
else
idx
=
idx
+
1
;
str_result
=
sprintf
(
'%d\t%s.\n'
,
idx
,
img_name
);
fprintf
(
str_result
);
% read image
img
=
imread
(
fullfile
(
input_folder
,
[
img_name
,
ext
]));
img
=
im2double
(
img
);
% modcrop
img
=
modcrop
(
img
,
mod_scale
);
if
exist
(
'save_mod_folder'
,
'var'
)
imwrite
(
img
,
fullfile
(
save_mod_folder
,
[
img_name
,
'.png'
]));
end
% LR
im_lr
=
imresize
(
img
,
1
/
up_scale
,
'bicubic'
);
if
exist
(
'save_lr_folder'
,
'var'
)
imwrite
(
im_lr
,
fullfile
(
save_lr_folder
,
[
img_name
,
'.png'
]));
end
% Bicubic
if
exist
(
'save_bic_folder'
,
'var'
)
im_bicubic
=
imresize
(
im_lr
,
up_scale
,
'bicubic'
);
imwrite
(
im_bicubic
,
fullfile
(
save_bic_folder
,
[
img_name
,
'.png'
]));
end
end
end
end
%% modcrop
function
img
=
modcrop
(
img
,
modulo
)
if
size
(
img
,
3
)
==
1
sz
=
size
(
img
);
sz
=
sz
-
mod
(
sz
,
modulo
);
img
=
img
(
1
:
sz
(
1
),
1
:
sz
(
2
));
else
tmpsz
=
size
(
img
);
sz
=
tmpsz
(
1
:
2
);
sz
=
sz
-
mod
(
sz
,
modulo
);
img
=
img
(
1
:
sz
(
1
),
1
:
sz
(
2
),:);
end
end
BasicSR/scripts/metrics/calculate_fid_folder.py
0 → 100644
View file @
e2696ece
import
argparse
import
math
import
numpy
as
np
import
torch
from
torch.utils.data
import
DataLoader
from
basicsr.data
import
build_dataset
from
basicsr.metrics.fid
import
calculate_fid
,
extract_inception_features
,
load_patched_inception_v3
def
calculate_fid_folder
():
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'folder'
,
type
=
str
,
help
=
'Path to the folder.'
)
parser
.
add_argument
(
'--fid_stats'
,
type
=
str
,
help
=
'Path to the dataset fid statistics.'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
'--num_sample'
,
type
=
int
,
default
=
50000
)
parser
.
add_argument
(
'--num_workers'
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
'--backend'
,
type
=
str
,
default
=
'disk'
,
help
=
'io backend for dataset. Option: disk, lmdb'
)
args
=
parser
.
parse_args
()
# inception model
inception
=
load_patched_inception_v3
(
device
)
# create dataset
opt
=
{}
opt
[
'name'
]
=
'SingleImageDataset'
opt
[
'type'
]
=
'SingleImageDataset'
opt
[
'dataroot_lq'
]
=
args
.
folder
opt
[
'io_backend'
]
=
dict
(
type
=
args
.
backend
)
opt
[
'mean'
]
=
[
0.5
,
0.5
,
0.5
]
opt
[
'std'
]
=
[
0.5
,
0.5
,
0.5
]
dataset
=
build_dataset
(
opt
)
# create dataloader
data_loader
=
DataLoader
(
dataset
=
dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
False
,
num_workers
=
args
.
num_workers
,
sampler
=
None
,
drop_last
=
False
)
args
.
num_sample
=
min
(
args
.
num_sample
,
len
(
dataset
))
total_batch
=
math
.
ceil
(
args
.
num_sample
/
args
.
batch_size
)
def
data_generator
(
data_loader
,
total_batch
):
for
idx
,
data
in
enumerate
(
data_loader
):
if
idx
>=
total_batch
:
break
else
:
yield
data
[
'lq'
]
features
=
extract_inception_features
(
data_generator
(
data_loader
,
total_batch
),
inception
,
total_batch
,
device
)
features
=
features
.
numpy
()
total_len
=
features
.
shape
[
0
]
features
=
features
[:
args
.
num_sample
]
print
(
f
'Extracted
{
total_len
}
features, use the first
{
features
.
shape
[
0
]
}
features to calculate stats.'
)
sample_mean
=
np
.
mean
(
features
,
0
)
sample_cov
=
np
.
cov
(
features
,
rowvar
=
False
)
# load the dataset stats
stats
=
torch
.
load
(
args
.
fid_stats
)
real_mean
=
stats
[
'mean'
]
real_cov
=
stats
[
'cov'
]
# calculate FID metric
fid
=
calculate_fid
(
sample_mean
,
sample_cov
,
real_mean
,
real_cov
)
print
(
'fid:'
,
fid
)
if
__name__
==
'__main__'
:
calculate_fid_folder
()
BasicSR/scripts/metrics/calculate_fid_stats_from_datasets.py
0 → 100644
View file @
e2696ece
import
argparse
import
math
import
numpy
as
np
import
torch
from
torch.utils.data
import
DataLoader
from
basicsr.data
import
build_dataset
from
basicsr.metrics.fid
import
extract_inception_features
,
load_patched_inception_v3
def
calculate_stats_from_dataset
():
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--num_sample'
,
type
=
int
,
default
=
50000
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
'--size'
,
type
=
int
,
default
=
512
)
parser
.
add_argument
(
'--dataroot'
,
type
=
str
,
default
=
'datasets/ffhq'
)
args
=
parser
.
parse_args
()
# inception model
inception
=
load_patched_inception_v3
(
device
)
# create dataset
opt
=
{}
opt
[
'name'
]
=
'FFHQ'
opt
[
'type'
]
=
'FFHQDataset'
opt
[
'dataroot_gt'
]
=
f
'datasets/ffhq/ffhq_
{
args
.
size
}
.lmdb'
opt
[
'io_backend'
]
=
dict
(
type
=
'lmdb'
)
opt
[
'use_hflip'
]
=
False
opt
[
'mean'
]
=
[
0.5
,
0.5
,
0.5
]
opt
[
'std'
]
=
[
0.5
,
0.5
,
0.5
]
dataset
=
build_dataset
(
opt
)
# create dataloader
data_loader
=
DataLoader
(
dataset
=
dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
False
,
num_workers
=
4
,
sampler
=
None
,
drop_last
=
False
)
total_batch
=
math
.
ceil
(
args
.
num_sample
/
args
.
batch_size
)
def
data_generator
(
data_loader
,
total_batch
):
for
idx
,
data
in
enumerate
(
data_loader
):
if
idx
>=
total_batch
:
break
else
:
yield
data
[
'gt'
]
features
=
extract_inception_features
(
data_generator
(
data_loader
,
total_batch
),
inception
,
total_batch
,
device
)
features
=
features
.
numpy
()
total_len
=
features
.
shape
[
0
]
features
=
features
[:
args
.
num_sample
]
print
(
f
'Extracted
{
total_len
}
features, use the first
{
features
.
shape
[
0
]
}
features to calculate stats.'
)
mean
=
np
.
mean
(
features
,
0
)
cov
=
np
.
cov
(
features
,
rowvar
=
False
)
save_path
=
f
'inception_
{
opt
[
"name"
]
}
_
{
args
.
size
}
.pth'
torch
.
save
(
dict
(
name
=
opt
[
'name'
],
size
=
args
.
size
,
mean
=
mean
,
cov
=
cov
),
save_path
,
_use_new_zipfile_serialization
=
False
)
if
__name__
==
'__main__'
:
calculate_stats_from_dataset
()
BasicSR/scripts/metrics/calculate_lpips.py
0 → 100644
View file @
e2696ece
import
cv2
import
glob
import
numpy
as
np
import
os.path
as
osp
from
torchvision.transforms.functional
import
normalize
from
basicsr.utils
import
img2tensor
try
:
import
lpips
except
ImportError
:
print
(
'Please install lpips: pip install lpips'
)
def
main
():
# Configurations
# -------------------------------------------------------------------------
folder_gt
=
'datasets/celeba/celeba_512_validation'
folder_restored
=
'datasets/celeba/celeba_512_validation_lq'
# crop_border = 4
suffix
=
''
# -------------------------------------------------------------------------
loss_fn_vgg
=
lpips
.
LPIPS
(
net
=
'vgg'
).
cuda
()
# RGB, normalized to [-1,1]
lpips_all
=
[]
img_list
=
sorted
(
glob
.
glob
(
osp
.
join
(
folder_gt
,
'*'
)))
mean
=
[
0.5
,
0.5
,
0.5
]
std
=
[
0.5
,
0.5
,
0.5
]
for
i
,
img_path
in
enumerate
(
img_list
):
basename
,
ext
=
osp
.
splitext
(
osp
.
basename
(
img_path
))
img_gt
=
cv2
.
imread
(
img_path
,
cv2
.
IMREAD_UNCHANGED
).
astype
(
np
.
float32
)
/
255.
img_restored
=
cv2
.
imread
(
osp
.
join
(
folder_restored
,
basename
+
suffix
+
ext
),
cv2
.
IMREAD_UNCHANGED
).
astype
(
np
.
float32
)
/
255.
img_gt
,
img_restored
=
img2tensor
([
img_gt
,
img_restored
],
bgr2rgb
=
True
,
float32
=
True
)
# norm to [-1, 1]
normalize
(
img_gt
,
mean
,
std
,
inplace
=
True
)
normalize
(
img_restored
,
mean
,
std
,
inplace
=
True
)
# calculate lpips
lpips_val
=
loss_fn_vgg
(
img_restored
.
unsqueeze
(
0
).
cuda
(),
img_gt
.
unsqueeze
(
0
).
cuda
())
print
(
f
'
{
i
+
1
:
3
d
}
:
{
basename
:
25
}
.
\t
LPIPS:
{
lpips_val
:.
6
f
}
.'
)
lpips_all
.
append
(
lpips_val
)
print
(
f
'Average: LPIPS:
{
sum
(
lpips_all
)
/
len
(
lpips_all
):.
6
f
}
'
)
if
__name__
==
'__main__'
:
main
()
BasicSR/scripts/metrics/calculate_niqe.py
0 → 100644
View file @
e2696ece
import
argparse
import
cv2
import
os
import
warnings
from
basicsr.metrics
import
calculate_niqe
from
basicsr.utils
import
scandir
def
main
(
args
):
niqe_all
=
[]
img_list
=
sorted
(
scandir
(
args
.
input
,
recursive
=
True
,
full_path
=
True
))
for
i
,
img_path
in
enumerate
(
img_list
):
basename
,
_
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
img_path
))
img
=
cv2
.
imread
(
img_path
,
cv2
.
IMREAD_UNCHANGED
)
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
'ignore'
,
category
=
RuntimeWarning
)
niqe_score
=
calculate_niqe
(
img
,
args
.
crop_border
,
input_order
=
'HWC'
,
convert_to
=
'y'
)
print
(
f
'
{
i
+
1
:
3
d
}
:
{
basename
:
25
}
.
\t
NIQE:
{
niqe_score
:.
6
f
}
'
)
niqe_all
.
append
(
niqe_score
)
print
(
args
.
input
)
print
(
f
'Average: NIQE:
{
sum
(
niqe_all
)
/
len
(
niqe_all
):.
6
f
}
'
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--input'
,
type
=
str
,
default
=
'datasets/val_set14/Set14'
,
help
=
'Input path'
)
parser
.
add_argument
(
'--crop_border'
,
type
=
int
,
default
=
0
,
help
=
'Crop border for each side'
)
args
=
parser
.
parse_args
()
main
(
args
)
Prev
1
…
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment