Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
905857ed
Commit
905857ed
authored
Mar 11, 2023
by
comfyanonymous
Browse files
Take some code from chainner to implement ESRGAN and other upscale models.
parent
8c4ccb55
Changes
45
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
196 additions
and
0 deletions
+196
-0
comfy_extras/chainner_models/model_loading.py
comfy_extras/chainner_models/model_loading.py
+89
-0
comfy_extras/chainner_models/types.py
comfy_extras/chainner_models/types.py
+53
-0
comfy_extras/nodes_upscale_model.py
comfy_extras/nodes_upscale_model.py
+52
-0
models/upscale_models/put_esrgan_and_other_upscale_models_here
...s/upscale_models/put_esrgan_and_other_upscale_models_here
+0
-0
nodes.py
nodes.py
+2
-0
No files found.
comfy_extras/chainner_models/model_loading.py
0 → 100644
View file @
905857ed
import
logging
as
logger
from
.architecture.face.codeformer
import
CodeFormer
from
.architecture.face.gfpganv1_clean_arch
import
GFPGANv1Clean
from
.architecture.face.restoreformer_arch
import
RestoreFormer
from
.architecture.HAT
import
HAT
from
.architecture.LaMa
import
LaMa
from
.architecture.MAT
import
MAT
from
.architecture.RRDB
import
RRDBNet
as
ESRGAN
from
.architecture.SPSR
import
SPSRNet
as
SPSR
from
.architecture.SRVGG
import
SRVGGNetCompact
as
RealESRGANv2
from
.architecture.SwiftSRGAN
import
Generator
as
SwiftSRGAN
from
.architecture.Swin2SR
import
Swin2SR
from
.architecture.SwinIR
import
SwinIR
from
.types
import
PyTorchModel
class
UnsupportedModel
(
Exception
):
pass
def
load_state_dict
(
state_dict
)
->
PyTorchModel
:
logger
.
debug
(
f
"Loading state dict into pytorch model arch"
)
state_dict_keys
=
list
(
state_dict
.
keys
())
if
"params_ema"
in
state_dict_keys
:
state_dict
=
state_dict
[
"params_ema"
]
elif
"params-ema"
in
state_dict_keys
:
state_dict
=
state_dict
[
"params-ema"
]
elif
"params"
in
state_dict_keys
:
state_dict
=
state_dict
[
"params"
]
state_dict_keys
=
list
(
state_dict
.
keys
())
# SRVGGNet Real-ESRGAN (v2)
if
"body.0.weight"
in
state_dict_keys
and
"body.1.weight"
in
state_dict_keys
:
model
=
RealESRGANv2
(
state_dict
)
# SPSR (ESRGAN with lots of extra layers)
elif
"f_HR_conv1.0.weight"
in
state_dict
:
model
=
SPSR
(
state_dict
)
# Swift-SRGAN
elif
(
"model"
in
state_dict_keys
and
"initial.cnn.depthwise.weight"
in
state_dict
[
"model"
].
keys
()
):
model
=
SwiftSRGAN
(
state_dict
)
# HAT -- be sure it is above swinir
elif
"layers.0.residual_group.blocks.0.conv_block.cab.0.weight"
in
state_dict_keys
:
model
=
HAT
(
state_dict
)
# SwinIR
elif
"layers.0.residual_group.blocks.0.norm1.weight"
in
state_dict_keys
:
if
"patch_embed.proj.weight"
in
state_dict_keys
:
model
=
Swin2SR
(
state_dict
)
else
:
model
=
SwinIR
(
state_dict
)
# GFPGAN
elif
(
"toRGB.0.weight"
in
state_dict_keys
and
"stylegan_decoder.style_mlp.1.weight"
in
state_dict_keys
):
model
=
GFPGANv1Clean
(
state_dict
)
# RestoreFormer
elif
(
"encoder.conv_in.weight"
in
state_dict_keys
and
"encoder.down.0.block.0.norm1.weight"
in
state_dict_keys
):
model
=
RestoreFormer
(
state_dict
)
elif
(
"encoder.blocks.0.weight"
in
state_dict_keys
and
"quantize.embedding.weight"
in
state_dict_keys
):
model
=
CodeFormer
(
state_dict
)
# LaMa
elif
(
"model.model.1.bn_l.running_mean"
in
state_dict_keys
or
"generator.model.1.bn_l.running_mean"
in
state_dict_keys
):
model
=
LaMa
(
state_dict
)
# MAT
elif
"synthesis.first_stage.conv_first.conv.resample_filter"
in
state_dict_keys
:
model
=
MAT
(
state_dict
)
# Regular ESRGAN, "new-arch" ESRGAN, Real-ESRGAN v1
else
:
try
:
model
=
ESRGAN
(
state_dict
)
except
:
# pylint: disable=raise-missing-from
raise
UnsupportedModel
return
model
comfy_extras/chainner_models/types.py
0 → 100644
View file @
905857ed
from
typing
import
Union
from
.architecture.face.codeformer
import
CodeFormer
from
.architecture.face.gfpganv1_clean_arch
import
GFPGANv1Clean
from
.architecture.face.restoreformer_arch
import
RestoreFormer
from
.architecture.HAT
import
HAT
from
.architecture.LaMa
import
LaMa
from
.architecture.MAT
import
MAT
from
.architecture.RRDB
import
RRDBNet
as
ESRGAN
from
.architecture.SPSR
import
SPSRNet
as
SPSR
from
.architecture.SRVGG
import
SRVGGNetCompact
as
RealESRGANv2
from
.architecture.SwiftSRGAN
import
Generator
as
SwiftSRGAN
from
.architecture.Swin2SR
import
Swin2SR
from
.architecture.SwinIR
import
SwinIR
PyTorchSRModels
=
(
RealESRGANv2
,
SPSR
,
SwiftSRGAN
,
ESRGAN
,
SwinIR
,
Swin2SR
,
HAT
)
PyTorchSRModel
=
Union
[
RealESRGANv2
,
SPSR
,
SwiftSRGAN
,
ESRGAN
,
SwinIR
,
Swin2SR
,
HAT
,
]
def
is_pytorch_sr_model
(
model
:
object
):
return
isinstance
(
model
,
PyTorchSRModels
)
PyTorchFaceModels
=
(
GFPGANv1Clean
,
RestoreFormer
,
CodeFormer
)
PyTorchFaceModel
=
Union
[
GFPGANv1Clean
,
RestoreFormer
,
CodeFormer
]
def
is_pytorch_face_model
(
model
:
object
):
return
isinstance
(
model
,
PyTorchFaceModels
)
PyTorchInpaintModels
=
(
LaMa
,
MAT
)
PyTorchInpaintModel
=
Union
[
LaMa
,
MAT
]
def
is_pytorch_inpaint_model
(
model
:
object
):
return
isinstance
(
model
,
PyTorchInpaintModels
)
PyTorchModels
=
(
*
PyTorchSRModels
,
*
PyTorchFaceModels
,
*
PyTorchInpaintModels
)
PyTorchModel
=
Union
[
PyTorchSRModel
,
PyTorchFaceModel
,
PyTorchInpaintModel
]
def
is_pytorch_model
(
model
:
object
):
return
isinstance
(
model
,
PyTorchModels
)
comfy_extras/nodes_upscale_model.py
0 → 100644
View file @
905857ed
import
os
from
comfy_extras.chainner_models
import
model_loading
from
comfy.sd
import
load_torch_file
import
comfy.model_management
from
nodes
import
filter_files_extensions
,
recursive_search
,
supported_ckpt_extensions
import
torch
class
UpscaleModelLoader
:
models_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))),
"models"
)
upscale_model_dir
=
os
.
path
.
join
(
models_dir
,
"upscale_models"
)
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"model_name"
:
(
filter_files_extensions
(
recursive_search
(
s
.
upscale_model_dir
),
supported_ckpt_extensions
),
),
}}
RETURN_TYPES
=
(
"UPSCALE_MODEL"
,)
FUNCTION
=
"load_model"
CATEGORY
=
"loaders"
def
load_model
(
self
,
model_name
):
model_path
=
os
.
path
.
join
(
self
.
upscale_model_dir
,
model_name
)
sd
=
load_torch_file
(
model_path
)
out
=
model_loading
.
load_state_dict
(
sd
).
eval
()
return
(
out
,
)
class
ImageUpscaleWithModel
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"upscale_model"
:
(
"UPSCALE_MODEL"
,),
"image"
:
(
"IMAGE"
,),
}}
RETURN_TYPES
=
(
"IMAGE"
,)
FUNCTION
=
"upscale"
CATEGORY
=
"image"
def
upscale
(
self
,
upscale_model
,
image
):
device
=
comfy
.
model_management
.
get_torch_device
()
upscale_model
.
to
(
device
)
in_img
=
image
.
movedim
(
-
1
,
-
3
).
to
(
device
)
with
torch
.
inference_mode
():
s
=
upscale_model
(
in_img
).
cpu
()
upscale_model
.
cpu
()
s
=
torch
.
clamp
(
s
.
movedim
(
-
3
,
-
1
),
min
=
0
,
max
=
1.0
)
return
(
s
,)
NODE_CLASS_MAPPINGS
=
{
"UpscaleModelLoader"
:
UpscaleModelLoader
,
"ImageUpscaleWithModel"
:
ImageUpscaleWithModel
}
models/upscale_models/put_esrgan_and_other_upscale_models_here
0 → 100644
View file @
905857ed
nodes.py
View file @
905857ed
...
@@ -981,3 +981,5 @@ def load_custom_nodes():
...
@@ -981,3 +981,5 @@ def load_custom_nodes():
load_custom_node
(
module_path
)
load_custom_node
(
module_path
)
load_custom_nodes
()
load_custom_nodes
()
load_custom_node
(
os
.
path
.
join
(
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"comfy_extras"
),
"nodes_upscale_model.py"
))
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment