Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
DiffBIR_pytorch
Commits
e5a6b0a4
Commit
e5a6b0a4
authored
Sep 15, 2023
by
zycXD
Browse files
fix bugs in face enhancement
parent
a27df00b
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
572 additions
and
42 deletions
+572
-42
README.md
README.md
+7
-14
assets/visual_results/whole_image1.png
assets/visual_results/whole_image1.png
+0
-0
assets/visual_results/whole_image2.png
assets/visual_results/whole_image2.png
+0
-0
assets/visual_results/whole_image3.png
assets/visual_results/whole_image3.png
+0
-0
inference_face.py
inference_face.py
+40
-27
utils/face_restoration_helper.py
utils/face_restoration_helper.py
+517
-0
utils/realesrgan/realesrganer.py
utils/realesrgan/realesrganer.py
+8
-1
No files found.
README.md
View file @
e5a6b0a4
...
@@ -49,12 +49,15 @@
...
@@ -49,12 +49,15 @@
[
<img src="assets/visual_results/face1.png" height="223px"/>
](
https://imgsli.com/MTk5ODI5
)
[
<img src="assets/visual_results/face2.png" height="223px"/>
]
(https://imgsli.com/MTk5ODMw)
[
<img src="assets/visual_results/face4.png" height="223px"/>
](
https://imgsli.com/MTk5ODM0
)
[
<img src="assets/visual_results/face1.png" height="223px"/>
](
https://imgsli.com/MTk5ODI5
)
[
<img src="assets/visual_results/face2.png" height="223px"/>
]
(https://imgsli.com/MTk5ODMw)
[
<img src="assets/visual_results/face4.png" height="223px"/>
](
https://imgsli.com/MTk5ODM0
)
[
<img src="assets/visual_results/whole_image2.png" height="268"/>
](
https://imgsli.com/MjA1OTU3
)
[
<img src="assets/visual_results/whole_image3.png" height="268"/>
]
(https://imgsli.com/MjA1OTY2)
[
<img src="assets/visual_results/whole_image1.png" height="370"/>
](
https://imgsli.com/MjA2MTU0
)
[
<img src="assets/visual_results/whole_image2.png" height="370"/>
](
https://imgsli.com/MjA2MTQ4
)
<!-- [<img src="assets/visual_results/whole_image3.png" height="268"/>
](https://imgsli.com/MjA1OTY2) -->
<!-- [<img src="assets/visual_results/face3.png" height="223px"/>
](https://imgsli.com/MTk5ODMy) -->
<!-- [<img src="assets/visual_results/face3.png" height="223px"/>
](https://imgsli.com/MTk5ODMy) -->
<!-- [<img src="assets/visual_results/face5.png" height="223px"/>
](https://imgsli.com/MTk5ODM1) -->
<!-- [<img src="assets/visual_results/face5.png" height="223px"/>
](https://imgsli.com/MTk5ODM1) -->
[
<img src="assets/visual_results/whole_image1.png" height="410"/>
](
https://imgsli.com/MjA1OTU5
)
<!--
[<img src="assets/visual_results/whole_image1.png" height="410"/>
](https://imgsli.com/MjA1OTU5)
-->
:star: Face and the background enhanced by DiffBIR.
<!-- </details>
-->
<!-- </details>
-->
...
@@ -171,14 +174,9 @@ Download [face_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/
...
@@ -171,14 +174,9 @@ Download [face_full_v1.ckpt](https://huggingface.co/lxq007/DiffBIR/resolve/main/
```
shell
```
shell
# for aligned face inputs
# for aligned face inputs
python inference_face.py
\
python inference_face.py
\
--config
configs/model/cldm.yaml
\
--ckpt
weights/face_full_v1.ckpt
\
--input
inputs/demo/face/aligned
\
--input
inputs/demo/face/aligned
\
--steps
50
\
--sr_scale
1
\
--sr_scale
1
\
--image_size
512
\
--output
results/demo/face/aligned
\
--color_fix_type
wavelet
\
--output
results/demo/face/aligned
--resize_back
\
--has_aligned
\
--has_aligned
\
--device
cuda
--device
cuda
```
```
...
@@ -188,14 +186,9 @@ python inference_face.py \
...
@@ -188,14 +186,9 @@ python inference_face.py \
```
shell
```
shell
# for unaligned face inputs
# for unaligned face inputs
python inference_face.py
\
python inference_face.py
\
--config
configs/model/cldm.yaml
\
--ckpt
weights/face_full_v1.ckpt
\
--input
inputs/demo/face/whole_img
\
--input
inputs/demo/face/whole_img
\
--steps
50
\
--sr_scale
2
\
--sr_scale
2
\
--image_size
512
\
--output
results/demo/face/whole_img
\
--color_fix_type
wavelet
\
--output
results/demo/face/whole_img
--resize_back
\
--bg_upsampler
DiffBIR
\
--bg_upsampler
DiffBIR
\
--device
cuda
--device
cuda
```
```
...
...
assets/visual_results/whole_image1.png
View replaced file @
a27df00b
View file @
e5a6b0a4
3.95 MB
|
W:
|
H:
2.54 MB
|
W:
|
H:
2-up
Swipe
Onion skin
assets/visual_results/whole_image2.png
View replaced file @
a27df00b
View file @
e5a6b0a4
2.57 MB
|
W:
|
H:
2.34 MB
|
W:
|
H:
2-up
Swipe
Onion skin
assets/visual_results/whole_image3.png
deleted
100644 → 0
View file @
a27df00b
1.99 MB
inference_face.py
View file @
e5a6b0a4
...
@@ -7,32 +7,41 @@ from omegaconf import OmegaConf
...
@@ -7,32 +7,41 @@ from omegaconf import OmegaConf
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
from
argparse
import
ArgumentParser
,
Namespace
from
argparse
import
ArgumentParser
,
Namespace
from
facexlib.utils.face_restoration_helper
import
FaceRestoreHelper
from
ldm.xformers_state
import
auto_xformers_status
from
ldm.xformers_state
import
auto_xformers_status
from
model.cldm
import
ControlLDM
from
model.cldm
import
ControlLDM
from
utils.common
import
instantiate_from_config
,
load_state_dict
from
utils.common
import
instantiate_from_config
,
load_state_dict
from
utils.file
import
list_image_files
,
get_file_name_parts
from
utils.file
import
list_image_files
,
get_file_name_parts
from
utils.image
import
auto_resize
,
pad
from
utils.image
import
auto_resize
,
pad
from
utils.file
import
load_file_from_url
from
utils.file
import
load_file_from_url
from
utils.face_restoration_helper
import
FaceRestoreHelper
from
inference
import
process
from
inference
import
process
pretrained_models
=
{
'general_v1'
:
{
'ckpt_url'
:
'https://huggingface.co/lxq007/DiffBIR/resolve/main/general_full_v1.ckpt'
,
'swinir_url'
:
'https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt'
},
'face_v1'
:
{
'ckpt_url'
:
'https://huggingface.co/lxq007/DiffBIR/resolve/main/face_full_v1.ckpt'
}
}
def
parse_args
()
->
Namespace
:
def
parse_args
()
->
Namespace
:
parser
=
ArgumentParser
()
parser
=
ArgumentParser
()
# model
# model
# Specify the model ckpt path, and the official model can be downloaded direclty.
# Specify the model ckpt path, and the official model can be downloaded direclty.
parser
.
add_argument
(
"--ckpt"
,
type
=
str
,
help
=
'Model checkpoint.'
,
default
=
'weights/face_full_v1.ckpt'
)
parser
.
add_argument
(
"--ckpt"
,
type
=
str
,
help
=
'Model checkpoint.'
,
default
=
'weights/face_full_v1.ckpt'
)
parser
.
add_argument
(
"--config"
,
required
=
True
,
type
=
str
,
help
=
'Model config file.'
)
parser
.
add_argument
(
"--config"
,
type
=
str
,
default
=
'configs/model/cldm.yaml'
,
help
=
'Model config file.'
)
parser
.
add_argument
(
"--reload_swinir"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--reload_swinir"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--swinir_ckpt"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--swinir_ckpt"
,
type
=
str
,
default
=
None
)
# input and preprocessing
# input and preprocessing
parser
.
add_argument
(
"--input"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--input"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--steps"
,
required
=
True
,
type
=
int
)
parser
.
add_argument
(
"--steps"
,
type
=
int
,
default
=
50
)
parser
.
add_argument
(
"--sr_scale"
,
type
=
float
,
default
=
2
)
parser
.
add_argument
(
"--sr_scale"
,
type
=
float
,
default
=
2
,
help
=
'An upscale factor.'
)
parser
.
add_argument
(
"--image_size"
,
type
=
int
,
default
=
512
)
parser
.
add_argument
(
"--image_size"
,
type
=
int
,
default
=
512
,
help
=
'Image size as the model input.'
)
parser
.
add_argument
(
"--repeat_times"
,
type
=
int
,
default
=
1
,
help
=
'To generate multiple results for each input image.'
)
parser
.
add_argument
(
"--repeat_times"
,
type
=
int
,
default
=
1
,
help
=
'To generate multiple results for each input image.'
)
parser
.
add_argument
(
"--disable_preprocess_model"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--disable_preprocess_model"
,
action
=
"store_true"
)
...
@@ -42,19 +51,20 @@ def parse_args() -> Namespace:
...
@@ -42,19 +51,20 @@ def parse_args() -> Namespace:
parser
.
add_argument
(
'--detection_model'
,
type
=
str
,
default
=
'retinaface_resnet50'
,
parser
.
add_argument
(
'--detection_model'
,
type
=
str
,
default
=
'retinaface_resnet50'
,
help
=
'Face detector. Optional: retinaface_resnet50, retinaface_mobile0.25, YOLOv5l, YOLOv5n, dlib.
\
help
=
'Face detector. Optional: retinaface_resnet50, retinaface_mobile0.25, YOLOv5l, YOLOv5n, dlib.
\
Default: retinaface_resnet50'
)
Default: retinaface_resnet50'
)
# TODO: support diffbir background upsampler
# Loading two DiffBIR models requires huge GPU memory capacity. Choose RealESRGAN as an alternative.
# Loading two DiffBIR models requires huge GPU memory capacity. Choose RealESRGAN as an alternative.
parser
.
add_argument
(
'--bg_upsampler'
,
type
=
str
,
default
=
'RealESRGAN'
,
choices
=
[
'DiffBIR'
,
'RealESRGAN'
],
help
=
'Background upsampler.'
)
parser
.
add_argument
(
'--bg_upsampler'
,
type
=
str
,
default
=
'RealESRGAN'
,
choices
=
[
'DiffBIR'
,
'RealESRGAN'
],
help
=
'Background upsampler.'
)
parser
.
add_argument
(
'--bg_tile'
,
type
=
int
,
default
=
400
,
help
=
'Tile size for background sampler.'
)
parser
.
add_argument
(
'--bg_tile'
,
type
=
int
,
default
=
400
,
help
=
'Tile size for background sampler.'
)
# postprocessing and saving
# postprocessing and saving
parser
.
add_argument
(
"--color_fix_type"
,
type
=
str
,
default
=
"wavelet"
,
choices
=
[
"wavelet"
,
"adain"
,
"none"
])
parser
.
add_argument
(
"--color_fix_type"
,
type
=
str
,
default
=
"wavelet"
,
choices
=
[
"wavelet"
,
"adain"
,
"none"
])
parser
.
add_argument
(
"--resize_back"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--output"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--output"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--show_lq"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--show_lq"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--skip_if_exist"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--skip_if_exist"
,
action
=
"store_true"
)
# change seed to finte-tune your restored images! just specify another random number.
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
231
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
231
)
# TODO: support mps device for MacOS devices
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
"cuda"
,
choices
=
[
"cpu"
,
"cuda"
])
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
"cuda"
,
choices
=
[
"cpu"
,
"cuda"
])
return
parser
.
parse_args
()
return
parser
.
parse_args
()
...
@@ -62,19 +72,21 @@ def parse_args() -> Namespace:
...
@@ -62,19 +72,21 @@ def parse_args() -> Namespace:
def
build_diffbir_model
(
model_config
,
ckpt
,
swinir_ckpt
=
None
):
def
build_diffbir_model
(
model_config
,
ckpt
,
swinir_ckpt
=
None
):
''''
''''
model_config: model architecture config file.
model_config: model architecture config file.
ckpt: path of the model checkpoint file.
ckpt: checkpoint file path of the main model.
swinir_ckpt: checkpoint file path of the swinir model.
load swinir from the main model if set None.
'''
'''
weight_root
=
os
.
path
.
dirname
(
ckpt
)
weight_root
=
os
.
path
.
dirname
(
ckpt
)
# download ckpt automatically if ckpt not exist in the local path
# download ckpt automatically if ckpt not exist in the local path
if
'general_full_v1'
in
ckpt
:
if
'general_full_v1'
in
ckpt
:
ckpt_url
=
'https://huggingface.co/lxq007/DiffBIR/resolve/main/general_full_v1.ckpt'
ckpt_url
=
pretrained_models
[
'general_v1'
][
'ckpt_url'
]
if
swinir_ckpt
is
None
:
if
swinir_ckpt
is
None
:
swinir_ckpt
=
f
'
{
weight_root
}
/general_swinir_v1.ckpt'
swinir_ckpt
=
f
'
{
weight_root
}
/general_swinir_v1.ckpt'
swinir_url
=
'https://huggingface.co/lxq007/DiffBIR/resolve/main/
general_swinir_
v1.ckpt'
swinir_url
=
pretrained_models
[
'
general_
v1'
][
'
swinir_
url'
]
elif
'face_full_v1'
in
ckpt
:
elif
'face_full_v1'
in
ckpt
:
# swinir ckpt is already included in
face_full_v1.ckpt
# swinir ckpt is already included in
the main model
ckpt_url
=
'https://huggingface.co/lxq007/DiffBIR/resolve/main/face_full_v1.ckpt'
ckpt_url
=
pretrained_models
[
'face_v1'
][
'ckpt_url'
]
else
:
else
:
# define a custom diffbir model
# define a custom diffbir model
raise
NotImplementedError
(
'undefined diffbir model type!'
)
raise
NotImplementedError
(
'undefined diffbir model type!'
)
...
@@ -116,8 +128,7 @@ def main() -> None:
...
@@ -116,8 +128,7 @@ def main() -> None:
)
)
# set up the backgrouns upsampler
# set up the backgrouns upsampler
if
args
.
bg_upsampler
.
lower
()
==
'diffbir'
:
if
args
.
bg_upsampler
==
'DiffBIR'
:
# TODO: to support DiffBIR as background upsampler
# Loading two DiffBIR models consumes huge GPU memory capacity.
# Loading two DiffBIR models consumes huge GPU memory capacity.
bg_upsampler
=
build_diffbir_model
(
args
.
config
,
'weights/general_full_v1.pth'
)
bg_upsampler
=
build_diffbir_model
(
args
.
config
,
'weights/general_full_v1.pth'
)
# try:
# try:
...
@@ -125,7 +136,7 @@ def main() -> None:
...
@@ -125,7 +136,7 @@ def main() -> None:
# except:
# except:
# # put the bg_upsampler on cpu to avoid OOM
# # put the bg_upsampler on cpu to avoid OOM
# gpu_alternate = True
# gpu_alternate = True
elif
args
.
bg_upsampler
.
lower
()
==
'
r
eal
esrgan
'
:
elif
args
.
bg_upsampler
==
'
R
eal
ESRGAN
'
:
from
utils.realesrgan.realesrganer
import
set_realesrgan
from
utils.realesrgan.realesrganer
import
set_realesrgan
# support official RealESRGAN x2 & x4 upsample model
# support official RealESRGAN x2 & x4 upsample model
bg_upscale
=
int
(
args
.
sr_scale
)
if
int
(
args
.
sr_scale
)
in
[
2
,
4
]
else
4
bg_upscale
=
int
(
args
.
sr_scale
)
if
int
(
args
.
sr_scale
)
in
[
2
,
4
]
else
4
...
@@ -137,6 +148,7 @@ def main() -> None:
...
@@ -137,6 +148,7 @@ def main() -> None:
for
file_path
in
list_image_files
(
args
.
input
,
follow_links
=
True
):
for
file_path
in
list_image_files
(
args
.
input
,
follow_links
=
True
):
# read image
# read image
lq
=
Image
.
open
(
file_path
).
convert
(
"RGB"
)
lq
=
Image
.
open
(
file_path
).
convert
(
"RGB"
)
if
args
.
sr_scale
!=
1
:
if
args
.
sr_scale
!=
1
:
lq
=
lq
.
resize
(
lq
=
lq
.
resize
(
tuple
(
math
.
ceil
(
x
*
args
.
sr_scale
)
for
x
in
lq
.
size
),
tuple
(
math
.
ceil
(
x
*
args
.
sr_scale
)
for
x
in
lq
.
size
),
...
@@ -155,12 +167,13 @@ def main() -> None:
...
@@ -155,12 +167,13 @@ def main() -> None:
face_helper
.
get_face_landmarks_5
(
only_center_face
=
args
.
only_center_face
,
resize
=
640
,
eye_dist_threshold
=
5
)
face_helper
.
get_face_landmarks_5
(
only_center_face
=
args
.
only_center_face
,
resize
=
640
,
eye_dist_threshold
=
5
)
face_helper
.
align_warp_face
()
face_helper
.
align_warp_face
()
os
.
makedirs
(
os
.
path
.
join
(
parent_path
,
'cropped_faces'
),
exist_ok
=
True
)
os
.
makedirs
(
os
.
path
.
join
(
parent_path
,
'restored_imgs'
),
exist_ok
=
True
)
save_path
=
os
.
path
.
join
(
args
.
output
,
os
.
path
.
relpath
(
file_path
,
args
.
input
))
save_path
=
os
.
path
.
join
(
args
.
output
,
os
.
path
.
relpath
(
file_path
,
args
.
input
))
parent_path
,
img_basename
,
_
=
get_file_name_parts
(
save_path
)
parent_path
,
img_basename
,
_
=
get_file_name_parts
(
save_path
)
os
.
makedirs
(
parent_path
,
exist_ok
=
True
)
os
.
makedirs
(
parent_path
,
exist_ok
=
True
)
os
.
makedirs
(
os
.
path
.
join
(
parent_path
,
'cropped_faces'
),
exist_ok
=
True
)
os
.
makedirs
(
os
.
path
.
join
(
parent_path
,
'restored_faces'
),
exist_ok
=
True
)
os
.
makedirs
(
os
.
path
.
join
(
parent_path
,
'restored_faces'
),
exist_ok
=
True
)
os
.
makedirs
(
os
.
path
.
join
(
parent_path
,
'restored_imgs'
),
exist_ok
=
True
)
for
i
in
range
(
args
.
repeat_times
):
for
i
in
range
(
args
.
repeat_times
):
basename
=
f
'
{
img_basename
}
_
{
i
}
'
if
i
else
img_basename
basename
=
f
'
{
img_basename
}
_
{
i
}
'
if
i
else
img_basename
restored_img_path
=
os
.
path
.
join
(
parent_path
,
'restored_imgs'
,
f
'
{
basename
}
.
{
img_save_ext
}
'
)
restored_img_path
=
os
.
path
.
join
(
parent_path
,
'restored_imgs'
,
f
'
{
basename
}
.
{
img_save_ext
}
'
)
...
@@ -193,17 +206,20 @@ def main() -> None:
...
@@ -193,17 +206,20 @@ def main() -> None:
if
not
args
.
has_aligned
:
if
not
args
.
has_aligned
:
# upsample the background
# upsample the background
if
bg_upsampler
is
not
None
:
if
bg_upsampler
is
not
None
:
print
(
f
'Upsampling the background image...'
)
print
(
f
'upsampling the background image using
{
args
.
bg_upsampler
}
...'
)
print
(
'bg upsampler'
,
bg_upsampler
.
device
)
if
args
.
bg_upsampler
==
'DiffBIR'
:
if
args
.
bg_upsampler
.
lower
()
==
'diffbir'
:
bg_img
,
_
=
process
(
bg_img
,
_
=
process
(
bg_upsampler
,
[
x
],
steps
=
args
.
steps
,
bg_upsampler
,
[
x
],
steps
=
args
.
steps
,
color_fix_type
=
args
.
color_fix_type
,
color_fix_type
=
args
.
color_fix_type
,
strength
=
1
,
disable_preprocess_model
=
args
.
disable_preprocess_model
,
strength
=
1
,
disable_preprocess_model
=
args
.
disable_preprocess_model
,
cond_fn
=
None
,
tiled
=
False
,
tile_size
=
None
,
tile_stride
=
None
)
cond_fn
=
None
,
tiled
=
False
,
tile_size
=
None
,
tile_stride
=
None
)
bg_img
=
bg_img
[
0
]
bg_img
=
bg_img
[
0
]
else
:
elif
args
.
bg_upsampler
==
'RealESRGAN'
:
bg_img
=
bg_upsampler
.
enhance
(
x
,
outscale
=
args
.
sr_scale
)[
0
]
# resize back to the original size
w
,
h
=
x
.
shape
[:
2
]
input_size
=
(
int
(
w
/
args
.
sr_scale
),
int
(
h
/
args
.
sr_scale
))
x
=
Image
.
fromarray
(
x
).
resize
(
input_size
,
Image
.
LANCZOS
)
bg_img
=
bg_upsampler
.
enhance
(
np
.
array
(
x
),
outscale
=
args
.
sr_scale
)[
0
]
else
:
else
:
bg_img
=
None
bg_img
=
None
face_helper
.
get_inverse_affine
(
None
)
face_helper
.
get_inverse_affine
(
None
)
...
@@ -232,10 +248,7 @@ def main() -> None:
...
@@ -232,10 +248,7 @@ def main() -> None:
# remove padding
# remove padding
restored_img
=
restored_img
[:
lq_resized
.
height
,
:
lq_resized
.
width
,
:]
restored_img
=
restored_img
[:
lq_resized
.
height
,
:
lq_resized
.
width
,
:]
# save restored image
# save restored image
if
args
.
resize_back
and
lq_resized
.
size
!=
lq
.
size
:
Image
.
fromarray
(
restored_img
).
resize
(
lq
.
size
,
Image
.
LANCZOS
).
convert
(
"RGB"
).
save
(
restored_img_path
)
Image
.
fromarray
(
restored_img
).
resize
(
lq
.
size
,
Image
.
LANCZOS
).
convert
(
"RGB"
).
save
(
restored_img_path
)
else
:
Image
.
fromarray
(
restored_img
).
convert
(
"RGB"
).
save
(
restored_img_path
)
print
(
f
"Face image
{
basename
}
saved to
{
parent_path
}
"
)
print
(
f
"Face image
{
basename
}
saved to
{
parent_path
}
"
)
...
...
utils/face_restoration_helper.py
0 → 100644
View file @
e5a6b0a4
This diff is collapsed.
Click to expand it.
utils/realesrgan/realesrganer.py
View file @
e5a6b0a4
...
@@ -307,6 +307,12 @@ def set_realesrgan(bg_tile, device, scale=2):
...
@@ -307,6 +307,12 @@ def set_realesrgan(bg_tile, device, scale=2):
'''
'''
assert
isinstance
(
scale
,
int
),
'Expected param scale to be an integer!'
assert
isinstance
(
scale
,
int
),
'Expected param scale to be an integer!'
use_half
=
False
if
'cuda'
in
str
(
device
):
# set False in CPU/MPS mode
no_half_gpu_list
=
[
'1650'
,
'1660'
]
# set False for GPUs that don't support f16
if
not
True
in
[
gpu
in
torch
.
cuda
.
get_device_name
(
0
)
for
gpu
in
no_half_gpu_list
]:
use_half
=
True
model
=
RRDBNet
(
model
=
RRDBNet
(
num_in_ch
=
3
,
num_in_ch
=
3
,
num_out_ch
=
3
,
num_out_ch
=
3
,
...
@@ -322,6 +328,7 @@ def set_realesrgan(bg_tile, device, scale=2):
...
@@ -322,6 +328,7 @@ def set_realesrgan(bg_tile, device, scale=2):
tile
=
bg_tile
,
tile
=
bg_tile
,
tile_pad
=
40
,
tile_pad
=
40
,
pre_pad
=
0
,
pre_pad
=
0
,
device
=
device
device
=
device
,
half
=
use_half
)
)
return
upsampler
return
upsampler
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment