Commit 575acb69 authored by comfyanonymous's avatar comfyanonymous
Browse files

IP2P model loading support.

This is the code to load the model and inference it with only a text
prompt. This commit does not contain the nodes to properly use it with an
image input.

This supports both the original SD1 instructpix2pix model and the
diffusers SDXL one.
parent 96b4c757
...@@ -473,6 +473,40 @@ class SD_X4Upscaler(BaseModel): ...@@ -473,6 +473,40 @@ class SD_X4Upscaler(BaseModel):
out['y'] = comfy.conds.CONDRegular(noise_level) out['y'] = comfy.conds.CONDRegular(noise_level)
return out return out
class IP2P:
def extra_conds(self, **kwargs):
out = {}
image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
device = kwargs["device"]
if image is None:
image = torch.zeros_like(noise)
if image.shape[1:] != noise.shape[1:]:
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = utils.resize_to_batch_size(image, noise.shape[0])
out['c_concat'] = comfy.conds.CONDNoiseShape(self.process_ip2p_image_in(image))
adm = self.encode_adm(**kwargs)
if adm is not None:
out['y'] = comfy.conds.CONDRegular(adm)
return out
class SD15_instructpix2pix(IP2P, BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device)
self.process_ip2p_image_in = lambda image: image
class SDXL_instructpix2pix(IP2P, SDXL):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device)
# self.process_ip2p_image_in = lambda image: comfy.latent_formats.SDXL().process_in(image)
self.process_ip2p_image_in = lambda image: image
class StableCascade_C(BaseModel): class StableCascade_C(BaseModel):
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
super().__init__(model_config, model_type, device=device, unet_model=StageC) super().__init__(model_config, model_type, device=device, unet_model=StageC)
......
...@@ -182,9 +182,9 @@ def detect_unet_config(state_dict, key_prefix): ...@@ -182,9 +182,9 @@ def detect_unet_config(state_dict, key_prefix):
return unet_config return unet_config
def model_config_from_unet_config(unet_config): def model_config_from_unet_config(unet_config, state_dict=None):
for model_config in comfy.supported_models.models: for model_config in comfy.supported_models.models:
if model_config.matches(unet_config): if model_config.matches(unet_config, state_dict):
return model_config(unet_config) return model_config(unet_config)
logging.error("no match {}".format(unet_config)) logging.error("no match {}".format(unet_config))
...@@ -192,7 +192,7 @@ def model_config_from_unet_config(unet_config): ...@@ -192,7 +192,7 @@ def model_config_from_unet_config(unet_config):
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False): def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
unet_config = detect_unet_config(state_dict, unet_key_prefix) unet_config = detect_unet_config(state_dict, unet_key_prefix)
model_config = model_config_from_unet_config(unet_config) model_config = model_config_from_unet_config(unet_config, state_dict)
if model_config is None and use_base_if_no_match: if model_config is None and use_base_if_no_match:
return comfy.supported_models_base.BASE(unet_config) return comfy.supported_models_base.BASE(unet_config)
else: else:
...@@ -321,6 +321,12 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None): ...@@ -321,6 +321,12 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10], 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
'use_temporal_attention': False, 'use_temporal_resblock': False} 'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL_diffusers_ip2p = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 8, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4], 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4],
...@@ -351,7 +357,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None): ...@@ -351,7 +357,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1], 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1],
'use_temporal_attention': False, 'use_temporal_resblock': False, 'disable_self_attentions': [True, False, False]} 'use_temporal_attention': False, 'use_temporal_resblock': False, 'disable_self_attentions': [True, False, False]}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS] supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SDXL_diffusers_ip2p]
for unet_config in supported_models: for unet_config in supported_models:
matches = True matches = True
......
...@@ -334,6 +334,11 @@ class Stable_Zero123(supported_models_base.BASE): ...@@ -334,6 +334,11 @@ class Stable_Zero123(supported_models_base.BASE):
"num_head_channels": -1, "num_head_channels": -1,
} }
required_keys = {
"cc_projection.weight": None,
"cc_projection.bias": None,
}
clip_vision_prefix = "cond_stage_model.model.visual." clip_vision_prefix = "cond_stage_model.model.visual."
latent_format = latent_formats.SD15 latent_format = latent_formats.SD15
...@@ -439,6 +444,33 @@ class Stable_Cascade_B(Stable_Cascade_C): ...@@ -439,6 +444,33 @@ class Stable_Cascade_B(Stable_Cascade_C):
out = model_base.StableCascade_B(self, device=device) out = model_base.StableCascade_B(self, device=device)
return out return out
class SD15_instructpix2pix(SD15):
unet_config = {
"context_dim": 768,
"model_channels": 320,
"use_linear_in_transformer": False,
"adm_in_channels": None,
"use_temporal_attention": False,
"in_channels": 8,
}
def get_model(self, state_dict, prefix="", device=None):
return model_base.SD15_instructpix2pix(self, device=device)
class SDXL_instructpix2pix(SDXL):
unet_config = {
"model_channels": 320,
"use_linear_in_transformer": True,
"transformer_depth": [0, 0, 2, 2, 10, 10],
"context_dim": 2048,
"adm_in_channels": 2816,
"use_temporal_attention": False,
"in_channels": 8,
}
def get_model(self, state_dict, prefix="", device=None):
return model_base.SDXL_instructpix2pix(self, device=device)
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p]
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p]
models += [SVD_img2vid] models += [SVD_img2vid]
...@@ -16,6 +16,8 @@ class BASE: ...@@ -16,6 +16,8 @@ class BASE:
"num_head_channels": 64, "num_head_channels": 64,
} }
required_keys = {}
clip_prefix = [] clip_prefix = []
clip_vision_prefix = None clip_vision_prefix = None
noise_aug_config = None noise_aug_config = None
...@@ -28,10 +30,14 @@ class BASE: ...@@ -28,10 +30,14 @@ class BASE:
manual_cast_dtype = None manual_cast_dtype = None
@classmethod @classmethod
def matches(s, unet_config): def matches(s, unet_config, state_dict=None):
for k in s.unet_config: for k in s.unet_config:
if k not in unet_config or s.unet_config[k] != unet_config[k]: if k not in unet_config or s.unet_config[k] != unet_config[k]:
return False return False
if state_dict is not None:
for k in s.required_keys:
if k not in state_dict:
return False
return True return True
def model_type(self, state_dict, prefix=""): def model_type(self, state_dict, prefix=""):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment