Commit 905857ed authored by comfyanonymous's avatar comfyanonymous
Browse files

Take some code from chainner to implement ESRGAN and other upscale models.

parent 8c4ccb55
import logging as logger
from .architecture.face.codeformer import CodeFormer
from .architecture.face.gfpganv1_clean_arch import GFPGANv1Clean
from .architecture.face.restoreformer_arch import RestoreFormer
from .architecture.HAT import HAT
from .architecture.LaMa import LaMa
from .architecture.MAT import MAT
from .architecture.RRDB import RRDBNet as ESRGAN
from .architecture.SPSR import SPSRNet as SPSR
from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2
from .architecture.SwiftSRGAN import Generator as SwiftSRGAN
from .architecture.Swin2SR import Swin2SR
from .architecture.SwinIR import SwinIR
from .types import PyTorchModel
class UnsupportedModel(Exception):
pass
def load_state_dict(state_dict) -> PyTorchModel:
logger.debug(f"Loading state dict into pytorch model arch")
state_dict_keys = list(state_dict.keys())
if "params_ema" in state_dict_keys:
state_dict = state_dict["params_ema"]
elif "params-ema" in state_dict_keys:
state_dict = state_dict["params-ema"]
elif "params" in state_dict_keys:
state_dict = state_dict["params"]
state_dict_keys = list(state_dict.keys())
# SRVGGNet Real-ESRGAN (v2)
if "body.0.weight" in state_dict_keys and "body.1.weight" in state_dict_keys:
model = RealESRGANv2(state_dict)
# SPSR (ESRGAN with lots of extra layers)
elif "f_HR_conv1.0.weight" in state_dict:
model = SPSR(state_dict)
# Swift-SRGAN
elif (
"model" in state_dict_keys
and "initial.cnn.depthwise.weight" in state_dict["model"].keys()
):
model = SwiftSRGAN(state_dict)
# HAT -- be sure it is above swinir
elif "layers.0.residual_group.blocks.0.conv_block.cab.0.weight" in state_dict_keys:
model = HAT(state_dict)
# SwinIR
elif "layers.0.residual_group.blocks.0.norm1.weight" in state_dict_keys:
if "patch_embed.proj.weight" in state_dict_keys:
model = Swin2SR(state_dict)
else:
model = SwinIR(state_dict)
# GFPGAN
elif (
"toRGB.0.weight" in state_dict_keys
and "stylegan_decoder.style_mlp.1.weight" in state_dict_keys
):
model = GFPGANv1Clean(state_dict)
# RestoreFormer
elif (
"encoder.conv_in.weight" in state_dict_keys
and "encoder.down.0.block.0.norm1.weight" in state_dict_keys
):
model = RestoreFormer(state_dict)
elif (
"encoder.blocks.0.weight" in state_dict_keys
and "quantize.embedding.weight" in state_dict_keys
):
model = CodeFormer(state_dict)
# LaMa
elif (
"model.model.1.bn_l.running_mean" in state_dict_keys
or "generator.model.1.bn_l.running_mean" in state_dict_keys
):
model = LaMa(state_dict)
# MAT
elif "synthesis.first_stage.conv_first.conv.resample_filter" in state_dict_keys:
model = MAT(state_dict)
# Regular ESRGAN, "new-arch" ESRGAN, Real-ESRGAN v1
else:
try:
model = ESRGAN(state_dict)
except:
# pylint: disable=raise-missing-from
raise UnsupportedModel
return model
from typing import Union
from .architecture.face.codeformer import CodeFormer
from .architecture.face.gfpganv1_clean_arch import GFPGANv1Clean
from .architecture.face.restoreformer_arch import RestoreFormer
from .architecture.HAT import HAT
from .architecture.LaMa import LaMa
from .architecture.MAT import MAT
from .architecture.RRDB import RRDBNet as ESRGAN
from .architecture.SPSR import SPSRNet as SPSR
from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2
from .architecture.SwiftSRGAN import Generator as SwiftSRGAN
from .architecture.Swin2SR import Swin2SR
from .architecture.SwinIR import SwinIR
PyTorchSRModels = (RealESRGANv2, SPSR, SwiftSRGAN, ESRGAN, SwinIR, Swin2SR, HAT)
PyTorchSRModel = Union[
RealESRGANv2,
SPSR,
SwiftSRGAN,
ESRGAN,
SwinIR,
Swin2SR,
HAT,
]
def is_pytorch_sr_model(model: object):
return isinstance(model, PyTorchSRModels)
PyTorchFaceModels = (GFPGANv1Clean, RestoreFormer, CodeFormer)
PyTorchFaceModel = Union[GFPGANv1Clean, RestoreFormer, CodeFormer]
def is_pytorch_face_model(model: object):
return isinstance(model, PyTorchFaceModels)
PyTorchInpaintModels = (LaMa, MAT)
PyTorchInpaintModel = Union[LaMa, MAT]
def is_pytorch_inpaint_model(model: object):
return isinstance(model, PyTorchInpaintModels)
PyTorchModels = (*PyTorchSRModels, *PyTorchFaceModels, *PyTorchInpaintModels)
PyTorchModel = Union[PyTorchSRModel, PyTorchFaceModel, PyTorchInpaintModel]
def is_pytorch_model(model: object):
return isinstance(model, PyTorchModels)
import os
from comfy_extras.chainner_models import model_loading
from comfy.sd import load_torch_file
import comfy.model_management
from nodes import filter_files_extensions, recursive_search, supported_ckpt_extensions
import torch
class UpscaleModelLoader:
models_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.realpath(__file__))), "models")
upscale_model_dir = os.path.join(models_dir, "upscale_models")
@classmethod
def INPUT_TYPES(s):
return {"required": { "model_name": (filter_files_extensions(recursive_search(s.upscale_model_dir), supported_ckpt_extensions), ),
}}
RETURN_TYPES = ("UPSCALE_MODEL",)
FUNCTION = "load_model"
CATEGORY = "loaders"
def load_model(self, model_name):
model_path = os.path.join(self.upscale_model_dir, model_name)
sd = load_torch_file(model_path)
out = model_loading.load_state_dict(sd).eval()
return (out, )
class ImageUpscaleWithModel:
@classmethod
def INPUT_TYPES(s):
return {"required": { "upscale_model": ("UPSCALE_MODEL",),
"image": ("IMAGE",),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "upscale"
CATEGORY = "image"
def upscale(self, upscale_model, image):
device = comfy.model_management.get_torch_device()
upscale_model.to(device)
in_img = image.movedim(-1,-3).to(device)
with torch.inference_mode():
s = upscale_model(in_img).cpu()
upscale_model.cpu()
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
return (s,)
NODE_CLASS_MAPPINGS = {
"UpscaleModelLoader": UpscaleModelLoader,
"ImageUpscaleWithModel": ImageUpscaleWithModel
}
...@@ -981,3 +981,5 @@ def load_custom_nodes(): ...@@ -981,3 +981,5 @@ def load_custom_nodes():
load_custom_node(module_path) load_custom_node(module_path)
load_custom_nodes() load_custom_nodes()
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment