Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ControlNet_pytorch
Commits
e2696ece
Commit
e2696ece
authored
Nov 22, 2023
by
mashun1
Browse files
controlnet
parents
Pipeline
#643
canceled with stages
Changes
263
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
223 additions
and
0 deletions
+223
-0
BasicSR/scripts/metrics/calculate_psnr_ssim.py
BasicSR/scripts/metrics/calculate_psnr_ssim.py
+77
-0
BasicSR/scripts/metrics/calculate_stylegan2_fid.py
BasicSR/scripts/metrics/calculate_stylegan2_fid.py
+72
-0
BasicSR/scripts/model_conversion/convert_dfdnet.py
BasicSR/scripts/model_conversion/convert_dfdnet.py
+74
-0
No files found.
Too many changes to show.
To preserve performance only
263 of 263+
files are displayed.
Plain diff
Email patch
BasicSR/scripts/metrics/calculate_psnr_ssim.py
0 → 100644
View file @
e2696ece
import
argparse
import
cv2
import
numpy
as
np
from
os
import
path
as
osp
from
basicsr.metrics
import
calculate_psnr
,
calculate_ssim
from
basicsr.utils
import
bgr2ycbcr
,
scandir
def
main
(
args
):
"""Calculate PSNR and SSIM for images.
"""
psnr_all
=
[]
ssim_all
=
[]
img_list_gt
=
sorted
(
list
(
scandir
(
args
.
gt
,
recursive
=
True
,
full_path
=
True
)))
img_list_restored
=
sorted
(
list
(
scandir
(
args
.
restored
,
recursive
=
True
,
full_path
=
True
)))
if
args
.
test_y_channel
:
print
(
'Testing Y channel.'
)
else
:
print
(
'Testing RGB channels.'
)
for
i
,
img_path
in
enumerate
(
img_list_gt
):
basename
,
ext
=
osp
.
splitext
(
osp
.
basename
(
img_path
))
img_gt
=
cv2
.
imread
(
img_path
,
cv2
.
IMREAD_UNCHANGED
).
astype
(
np
.
float32
)
/
255.
if
args
.
suffix
==
''
:
img_path_restored
=
img_list_restored
[
i
]
else
:
img_path_restored
=
osp
.
join
(
args
.
restored
,
basename
+
args
.
suffix
+
ext
)
img_restored
=
cv2
.
imread
(
img_path_restored
,
cv2
.
IMREAD_UNCHANGED
).
astype
(
np
.
float32
)
/
255.
if
args
.
correct_mean_var
:
mean_l
=
[]
std_l
=
[]
for
j
in
range
(
3
):
mean_l
.
append
(
np
.
mean
(
img_gt
[:,
:,
j
]))
std_l
.
append
(
np
.
std
(
img_gt
[:,
:,
j
]))
for
j
in
range
(
3
):
# correct twice
mean
=
np
.
mean
(
img_restored
[:,
:,
j
])
img_restored
[:,
:,
j
]
=
img_restored
[:,
:,
j
]
-
mean
+
mean_l
[
j
]
std
=
np
.
std
(
img_restored
[:,
:,
j
])
img_restored
[:,
:,
j
]
=
img_restored
[:,
:,
j
]
/
std
*
std_l
[
j
]
mean
=
np
.
mean
(
img_restored
[:,
:,
j
])
img_restored
[:,
:,
j
]
=
img_restored
[:,
:,
j
]
-
mean
+
mean_l
[
j
]
std
=
np
.
std
(
img_restored
[:,
:,
j
])
img_restored
[:,
:,
j
]
=
img_restored
[:,
:,
j
]
/
std
*
std_l
[
j
]
if
args
.
test_y_channel
and
img_gt
.
ndim
==
3
and
img_gt
.
shape
[
2
]
==
3
:
img_gt
=
bgr2ycbcr
(
img_gt
,
y_only
=
True
)
img_restored
=
bgr2ycbcr
(
img_restored
,
y_only
=
True
)
# calculate PSNR and SSIM
psnr
=
calculate_psnr
(
img_gt
*
255
,
img_restored
*
255
,
crop_border
=
args
.
crop_border
,
input_order
=
'HWC'
)
ssim
=
calculate_ssim
(
img_gt
*
255
,
img_restored
*
255
,
crop_border
=
args
.
crop_border
,
input_order
=
'HWC'
)
print
(
f
'
{
i
+
1
:
3
d
}
:
{
basename
:
25
}
.
\t
PSNR:
{
psnr
:.
6
f
}
dB,
\t
SSIM:
{
ssim
:.
6
f
}
'
)
psnr_all
.
append
(
psnr
)
ssim_all
.
append
(
ssim
)
print
(
args
.
gt
)
print
(
args
.
restored
)
print
(
f
'Average: PSNR:
{
sum
(
psnr_all
)
/
len
(
psnr_all
):.
6
f
}
dB, SSIM:
{
sum
(
ssim_all
)
/
len
(
ssim_all
):.
6
f
}
'
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--gt'
,
type
=
str
,
default
=
'datasets/val_set14/Set14'
,
help
=
'Path to gt (Ground-Truth)'
)
parser
.
add_argument
(
'--restored'
,
type
=
str
,
default
=
'results/Set14'
,
help
=
'Path to restored images'
)
parser
.
add_argument
(
'--crop_border'
,
type
=
int
,
default
=
0
,
help
=
'Crop border for each side'
)
parser
.
add_argument
(
'--suffix'
,
type
=
str
,
default
=
''
,
help
=
'Suffix for restored images'
)
parser
.
add_argument
(
'--test_y_channel'
,
action
=
'store_true'
,
help
=
'If True, test Y channel (In MatLab YCbCr format). If False, test RGB channels.'
)
parser
.
add_argument
(
'--correct_mean_var'
,
action
=
'store_true'
,
help
=
'Correct the mean and var of restored images.'
)
args
=
parser
.
parse_args
()
main
(
args
)
BasicSR/scripts/metrics/calculate_stylegan2_fid.py
0 → 100644
View file @
e2696ece
import
argparse
import
math
import
numpy
as
np
import
torch
from
torch
import
nn
from
basicsr.archs.stylegan2_arch
import
StyleGAN2Generator
from
basicsr.metrics.fid
import
calculate_fid
,
extract_inception_features
,
load_patched_inception_v3
def
calculate_stylegan2_fid
():
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'ckpt'
,
type
=
str
,
help
=
'Path to the stylegan2 checkpoint.'
)
parser
.
add_argument
(
'fid_stats'
,
type
=
str
,
help
=
'Path to the dataset fid statistics.'
)
parser
.
add_argument
(
'--size'
,
type
=
int
,
default
=
256
)
parser
.
add_argument
(
'--channel_multiplier'
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
'--num_sample'
,
type
=
int
,
default
=
50000
)
parser
.
add_argument
(
'--truncation'
,
type
=
float
,
default
=
1
)
parser
.
add_argument
(
'--truncation_mean'
,
type
=
int
,
default
=
4096
)
args
=
parser
.
parse_args
()
# create stylegan2 model
generator
=
StyleGAN2Generator
(
out_size
=
args
.
size
,
num_style_feat
=
512
,
num_mlp
=
8
,
channel_multiplier
=
args
.
channel_multiplier
,
resample_kernel
=
(
1
,
3
,
3
,
1
))
generator
.
load_state_dict
(
torch
.
load
(
args
.
ckpt
)[
'params_ema'
])
generator
=
nn
.
DataParallel
(
generator
).
eval
().
to
(
device
)
if
args
.
truncation
<
1
:
with
torch
.
no_grad
():
truncation_latent
=
generator
.
mean_latent
(
args
.
truncation_mean
)
else
:
truncation_latent
=
None
# inception model
inception
=
load_patched_inception_v3
(
device
)
total_batch
=
math
.
ceil
(
args
.
num_sample
/
args
.
batch_size
)
def
sample_generator
(
total_batch
):
for
_
in
range
(
total_batch
):
with
torch
.
no_grad
():
latent
=
torch
.
randn
(
args
.
batch_size
,
512
,
device
=
device
)
samples
,
_
=
generator
([
latent
],
truncation
=
args
.
truncation
,
truncation_latent
=
truncation_latent
)
yield
samples
features
=
extract_inception_features
(
sample_generator
(
total_batch
),
inception
,
total_batch
,
device
)
features
=
features
.
numpy
()
total_len
=
features
.
shape
[
0
]
features
=
features
[:
args
.
num_sample
]
print
(
f
'Extracted
{
total_len
}
features, use the first
{
features
.
shape
[
0
]
}
features to calculate stats.'
)
sample_mean
=
np
.
mean
(
features
,
0
)
sample_cov
=
np
.
cov
(
features
,
rowvar
=
False
)
# load the dataset stats
stats
=
torch
.
load
(
args
.
fid_stats
)
real_mean
=
stats
[
'mean'
]
real_cov
=
stats
[
'cov'
]
# calculate FID metric
fid
=
calculate_fid
(
sample_mean
,
sample_cov
,
real_mean
,
real_cov
)
print
(
'fid:'
,
fid
)
if
__name__
==
'__main__'
:
calculate_stylegan2_fid
()
BasicSR/scripts/model_conversion/convert_dfdnet.py
0 → 100644
View file @
e2696ece
import
torch
from
basicsr.archs.dfdnet_arch
import
DFDNet
from
basicsr.archs.vgg_arch
import
NAMES
def
convert_net
(
ori_net
,
crt_net
):
for
crt_k
,
_
in
crt_net
.
items
():
# vgg feature extractor
if
'vgg_extractor'
in
crt_k
:
ori_k
=
crt_k
.
replace
(
'vgg_extractor'
,
'VggExtract'
).
replace
(
'vgg_net'
,
'model'
)
if
'mean'
in
crt_k
:
ori_k
=
ori_k
.
replace
(
'mean'
,
'RGB_mean'
)
elif
'std'
in
crt_k
:
ori_k
=
ori_k
.
replace
(
'std'
,
'RGB_std'
)
else
:
idx
=
NAMES
[
'vgg19'
].
index
(
crt_k
.
split
(
'.'
)[
2
])
if
'weight'
in
crt_k
:
ori_k
=
f
'VggExtract.model.features.
{
idx
}
.weight'
else
:
ori_k
=
f
'VggExtract.model.features.
{
idx
}
.bias'
elif
'attn_blocks'
in
crt_k
:
if
'left_eye'
in
crt_k
:
ori_k
=
crt_k
.
replace
(
'attn_blocks.left_eye'
,
'le'
)
elif
'right_eye'
in
crt_k
:
ori_k
=
crt_k
.
replace
(
'attn_blocks.right_eye'
,
're'
)
elif
'mouth'
in
crt_k
:
ori_k
=
crt_k
.
replace
(
'attn_blocks.mouth'
,
'mo'
)
elif
'nose'
in
crt_k
:
ori_k
=
crt_k
.
replace
(
'attn_blocks.nose'
,
'no'
)
else
:
raise
ValueError
(
'Wrong!'
)
elif
'multi_scale_dilation'
in
crt_k
:
if
'conv_blocks'
in
crt_k
:
_
,
_
,
c
,
d
,
e
=
crt_k
.
split
(
'.'
)
ori_k
=
f
'MSDilate.conv
{
int
(
c
)
+
1
}
.
{
d
}
.
{
e
}
'
else
:
ori_k
=
crt_k
.
replace
(
'multi_scale_dilation.conv_fusion'
,
'MSDilate.convi'
)
elif
crt_k
.
startswith
(
'upsample'
):
ori_k
=
crt_k
.
replace
(
'upsample'
,
'up'
)
if
'scale_block'
in
crt_k
:
ori_k
=
ori_k
.
replace
(
'scale_block'
,
'ScaleModel1'
)
elif
'shift_block'
in
crt_k
:
ori_k
=
ori_k
.
replace
(
'shift_block'
,
'ShiftModel1'
)
elif
'upsample4'
in
crt_k
and
'body'
in
crt_k
:
ori_k
=
ori_k
.
replace
(
'body'
,
'Model'
)
else
:
print
(
'unprocess key: '
,
crt_k
)
# replace
if
crt_net
[
crt_k
].
size
()
!=
ori_net
[
ori_k
].
size
():
raise
ValueError
(
'Wrong tensor size:
\n
'
f
'crt_net:
{
crt_net
[
crt_k
].
size
()
}
\n
'
f
'ori_net:
{
ori_net
[
ori_k
].
size
()
}
'
)
else
:
crt_net
[
crt_k
]
=
ori_net
[
ori_k
]
return
crt_net
if
__name__
==
'__main__'
:
ori_net
=
torch
.
load
(
'experiments/pretrained_models/DFDNet/DFDNet_official_original.pth'
)
dfd_net
=
DFDNet
(
64
,
dict_path
=
'experiments/pretrained_models/DFDNet/DFDNet_dict_512.pth'
)
crt_net
=
dfd_net
.
state_dict
()
crt_net_params
=
convert_net
(
ori_net
,
crt_net
)
torch
.
save
(
dict
(
params
=
crt_net_params
),
'experiments/pretrained_models/DFDNet/DFDNet_official.pth'
,
_use_new_zipfile_serialization
=
False
)
Prev
1
…
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment