"ppocr/vscode:/vscode.git/clone" did not exist on "2e9abcb91d058af87b188aee5e7efd1cf34eead0"
Commit f87ec10a authored by comfyanonymous's avatar comfyanonymous
Browse files

Support base SDXL and SDXL refiner models.

Large refactor of the model detection and loading code.
parent 9fccf4aa
...@@ -34,8 +34,10 @@ class ControlNet(nn.Module): ...@@ -34,8 +34,10 @@ class ControlNet(nn.Module):
channel_mult=(1, 2, 4, 8), channel_mult=(1, 2, 4, 8),
conv_resample=True, conv_resample=True,
dims=2, dims=2,
num_classes=None,
use_checkpoint=False, use_checkpoint=False,
use_fp16=False, use_fp16=False,
use_bf16=False,
num_heads=-1, num_heads=-1,
num_head_channels=-1, num_head_channels=-1,
num_heads_upsample=-1, num_heads_upsample=-1,
...@@ -51,6 +53,8 @@ class ControlNet(nn.Module): ...@@ -51,6 +53,8 @@ class ControlNet(nn.Module):
num_attention_blocks=None, num_attention_blocks=None,
disable_middle_self_attn=False, disable_middle_self_attn=False,
use_linear_in_transformer=False, use_linear_in_transformer=False,
adm_in_channels=None,
transformer_depth_middle=None,
): ):
super().__init__() super().__init__()
if use_spatial_transformer: if use_spatial_transformer:
...@@ -75,6 +79,10 @@ class ControlNet(nn.Module): ...@@ -75,6 +79,10 @@ class ControlNet(nn.Module):
self.image_size = image_size self.image_size = image_size
self.in_channels = in_channels self.in_channels = in_channels
self.model_channels = model_channels self.model_channels = model_channels
if isinstance(transformer_depth, int):
transformer_depth = len(channel_mult) * [transformer_depth]
if transformer_depth_middle is None:
transformer_depth_middle = transformer_depth[-1]
if isinstance(num_res_blocks, int): if isinstance(num_res_blocks, int):
self.num_res_blocks = len(channel_mult) * [num_res_blocks] self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else: else:
...@@ -97,8 +105,10 @@ class ControlNet(nn.Module): ...@@ -97,8 +105,10 @@ class ControlNet(nn.Module):
self.dropout = dropout self.dropout = dropout
self.channel_mult = channel_mult self.channel_mult = channel_mult
self.conv_resample = conv_resample self.conv_resample = conv_resample
self.num_classes = num_classes
self.use_checkpoint = use_checkpoint self.use_checkpoint = use_checkpoint
self.dtype = th.float16 if use_fp16 else th.float32 self.dtype = th.float16 if use_fp16 else th.float32
self.dtype = th.bfloat16 if use_bf16 else self.dtype
self.num_heads = num_heads self.num_heads = num_heads
self.num_head_channels = num_head_channels self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample self.num_heads_upsample = num_heads_upsample
...@@ -111,6 +121,24 @@ class ControlNet(nn.Module): ...@@ -111,6 +121,24 @@ class ControlNet(nn.Module):
linear(time_embed_dim, time_embed_dim), linear(time_embed_dim, time_embed_dim),
) )
if self.num_classes is not None:
if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
elif self.num_classes == "continuous":
print("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim)
elif self.num_classes == "sequential":
assert adm_in_channels is not None
self.label_emb = nn.Sequential(
nn.Sequential(
linear(adm_in_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
)
else:
raise ValueError()
self.input_blocks = nn.ModuleList( self.input_blocks = nn.ModuleList(
[ [
TimestepEmbedSequential( TimestepEmbedSequential(
...@@ -179,7 +207,7 @@ class ControlNet(nn.Module): ...@@ -179,7 +207,7 @@ class ControlNet(nn.Module):
num_head_channels=dim_head, num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order, use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( ) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint use_checkpoint=use_checkpoint
) )
...@@ -238,7 +266,7 @@ class ControlNet(nn.Module): ...@@ -238,7 +266,7 @@ class ControlNet(nn.Module):
num_head_channels=dim_head, num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order, use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint use_checkpoint=use_checkpoint
), ),
...@@ -257,7 +285,7 @@ class ControlNet(nn.Module): ...@@ -257,7 +285,7 @@ class ControlNet(nn.Module):
def make_zero_conv(self, channels): def make_zero_conv(self, channels):
return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0))) return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
def forward(self, x, hint, timesteps, context, **kwargs): def forward(self, x, hint, timesteps, context, y=None, **kwargs):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb) emb = self.time_embed(t_emb)
...@@ -265,6 +293,14 @@ class ControlNet(nn.Module): ...@@ -265,6 +293,14 @@ class ControlNet(nn.Module):
outs = [] outs = []
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
h = x.type(self.dtype) h = x.type(self.dtype)
for module, zero_conv in zip(self.input_blocks, self.zero_convs): for module, zero_conv in zip(self.input_blocks, self.zero_convs):
if guided_hint is not None: if guided_hint is not None:
......
{
"architectures": [
"CLIPTextModel"
],
"attention_dropout": 0.0,
"bos_token_id": 0,
"dropout": 0.0,
"eos_token_id": 2,
"hidden_act": "gelu",
"hidden_size": 1280,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 5120,
"layer_norm_eps": 1e-05,
"max_position_embeddings": 77,
"model_type": "clip_text_model",
"num_attention_heads": 20,
"num_hidden_layers": 32,
"pad_token_id": 1,
"projection_dim": 512,
"torch_dtype": "float32",
"vocab_size": 49408
}
...@@ -29,31 +29,31 @@ class ClipVisionModel(): ...@@ -29,31 +29,31 @@ class ClipVisionModel():
outputs = self.model(**inputs) outputs = self.model(**inputs)
return outputs return outputs
def convert_to_transformers(sd): def convert_to_transformers(sd, prefix):
sd_k = sd.keys() sd_k = sd.keys()
if "embedder.model.visual.transformer.resblocks.0.attn.in_proj_weight" in sd_k: if "{}transformer.resblocks.0.attn.in_proj_weight".format(prefix) in sd_k:
keys_to_replace = { keys_to_replace = {
"embedder.model.visual.class_embedding": "vision_model.embeddings.class_embedding", "{}class_embedding".format(prefix): "vision_model.embeddings.class_embedding",
"embedder.model.visual.conv1.weight": "vision_model.embeddings.patch_embedding.weight", "{}conv1.weight".format(prefix): "vision_model.embeddings.patch_embedding.weight",
"embedder.model.visual.positional_embedding": "vision_model.embeddings.position_embedding.weight", "{}positional_embedding".format(prefix): "vision_model.embeddings.position_embedding.weight",
"embedder.model.visual.ln_post.bias": "vision_model.post_layernorm.bias", "{}ln_post.bias".format(prefix): "vision_model.post_layernorm.bias",
"embedder.model.visual.ln_post.weight": "vision_model.post_layernorm.weight", "{}ln_post.weight".format(prefix): "vision_model.post_layernorm.weight",
"embedder.model.visual.ln_pre.bias": "vision_model.pre_layrnorm.bias", "{}ln_pre.bias".format(prefix): "vision_model.pre_layrnorm.bias",
"embedder.model.visual.ln_pre.weight": "vision_model.pre_layrnorm.weight", "{}ln_pre.weight".format(prefix): "vision_model.pre_layrnorm.weight",
} }
for x in keys_to_replace: for x in keys_to_replace:
if x in sd_k: if x in sd_k:
sd[keys_to_replace[x]] = sd.pop(x) sd[keys_to_replace[x]] = sd.pop(x)
if "embedder.model.visual.proj" in sd_k: if "{}proj".format(prefix) in sd_k:
sd['visual_projection.weight'] = sd.pop("embedder.model.visual.proj").transpose(0, 1) sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1)
sd = transformers_convert(sd, "embedder.model.visual", "vision_model", 32) sd = transformers_convert(sd, prefix, "vision_model.", 32)
return sd return sd
def load_clipvision_from_sd(sd): def load_clipvision_from_sd(sd, prefix):
sd = convert_to_transformers(sd) sd = convert_to_transformers(sd, prefix)
if "vision_model.encoder.layers.30.layer_norm1.weight" in sd: if "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json") json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
else: else:
......
...@@ -600,7 +600,7 @@ class SpatialTransformer(nn.Module): ...@@ -600,7 +600,7 @@ class SpatialTransformer(nn.Module):
use_checkpoint=True, dtype=None): use_checkpoint=True, dtype=None):
super().__init__() super().__init__()
if exists(context_dim) and not isinstance(context_dim, list): if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim] context_dim = [context_dim] * depth
self.in_channels = in_channels self.in_channels = in_channels
inner_dim = n_heads * d_head inner_dim = n_heads * d_head
self.norm = Normalize(in_channels, dtype=dtype) self.norm = Normalize(in_channels, dtype=dtype)
...@@ -630,7 +630,7 @@ class SpatialTransformer(nn.Module): ...@@ -630,7 +630,7 @@ class SpatialTransformer(nn.Module):
def forward(self, x, context=None, transformer_options={}): def forward(self, x, context=None, transformer_options={}):
# note: if no context is given, cross-attention defaults to self-attention # note: if no context is given, cross-attention defaults to self-attention
if not isinstance(context, list): if not isinstance(context, list):
context = [context] context = [context] * len(self.transformer_blocks)
b, c, h, w = x.shape b, c, h, w = x.shape
x_in = x x_in = x
x = self.norm(x) x = self.norm(x)
......
...@@ -502,6 +502,7 @@ class UNetModel(nn.Module): ...@@ -502,6 +502,7 @@ class UNetModel(nn.Module):
disable_middle_self_attn=False, disable_middle_self_attn=False,
use_linear_in_transformer=False, use_linear_in_transformer=False,
adm_in_channels=None, adm_in_channels=None,
transformer_depth_middle=None,
): ):
super().__init__() super().__init__()
if use_spatial_transformer: if use_spatial_transformer:
...@@ -526,6 +527,10 @@ class UNetModel(nn.Module): ...@@ -526,6 +527,10 @@ class UNetModel(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
self.model_channels = model_channels self.model_channels = model_channels
self.out_channels = out_channels self.out_channels = out_channels
if isinstance(transformer_depth, int):
transformer_depth = len(channel_mult) * [transformer_depth]
if transformer_depth_middle is None:
transformer_depth_middle = transformer_depth[-1]
if isinstance(num_res_blocks, int): if isinstance(num_res_blocks, int):
self.num_res_blocks = len(channel_mult) * [num_res_blocks] self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else: else:
...@@ -631,7 +636,7 @@ class UNetModel(nn.Module): ...@@ -631,7 +636,7 @@ class UNetModel(nn.Module):
num_head_channels=dim_head, num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order, use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( ) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype use_checkpoint=use_checkpoint, dtype=self.dtype
) )
...@@ -690,7 +695,7 @@ class UNetModel(nn.Module): ...@@ -690,7 +695,7 @@ class UNetModel(nn.Module):
num_head_channels=dim_head, num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order, use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype use_checkpoint=use_checkpoint, dtype=self.dtype
), ),
...@@ -746,7 +751,7 @@ class UNetModel(nn.Module): ...@@ -746,7 +751,7 @@ class UNetModel(nn.Module):
num_head_channels=dim_head, num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order, use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( ) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype use_checkpoint=use_checkpoint, dtype=self.dtype
) )
......
...@@ -2,6 +2,7 @@ import torch ...@@ -2,6 +2,7 @@ import torch
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
import numpy as np import numpy as np
class BaseModel(torch.nn.Module): class BaseModel(torch.nn.Module):
...@@ -15,9 +16,9 @@ class BaseModel(torch.nn.Module): ...@@ -15,9 +16,9 @@ class BaseModel(torch.nn.Module):
self.parameterization = "v" self.parameterization = "v"
else: else:
self.parameterization = "eps" self.parameterization = "eps"
if "adm_in_channels" in unet_config:
self.adm_channels = unet_config["adm_in_channels"] self.adm_channels = unet_config.get("adm_in_channels", None)
else: if self.adm_channels is None:
self.adm_channels = 0 self.adm_channels = 0
print("v_prediction", v_prediction) print("v_prediction", v_prediction)
print("adm", self.adm_channels) print("adm", self.adm_channels)
...@@ -55,6 +56,25 @@ class BaseModel(torch.nn.Module): ...@@ -55,6 +56,25 @@ class BaseModel(torch.nn.Module):
def is_adm(self): def is_adm(self):
return self.adm_channels > 0 return self.adm_channels > 0
def encode_adm(self, **kwargs):
return None
def load_model_weights(self, sd, unet_prefix=""):
to_load = {}
keys = list(sd.keys())
for k in keys:
if k.startswith(unet_prefix):
to_load[k[len(unet_prefix):]] = sd.pop(k)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
if len(m) > 0:
print("unet missing:", m)
if len(u) > 0:
print("unet unexpected:", u)
del to_load
return self
class SD21UNCLIP(BaseModel): class SD21UNCLIP(BaseModel):
def __init__(self, unet_config, noise_aug_config, v_prediction=True): def __init__(self, unet_config, noise_aug_config, v_prediction=True):
super().__init__(unet_config, v_prediction) super().__init__(unet_config, v_prediction)
...@@ -95,3 +115,55 @@ class SDInpaint(BaseModel): ...@@ -95,3 +115,55 @@ class SDInpaint(BaseModel):
def __init__(self, unet_config, v_prediction=False): def __init__(self, unet_config, v_prediction=False):
super().__init__(unet_config, v_prediction) super().__init__(unet_config, v_prediction)
self.concat_keys = ("mask", "masked_image") self.concat_keys = ("mask", "masked_image")
class SDXLRefiner(BaseModel):
def __init__(self, unet_config, v_prediction=False):
super().__init__(unet_config, v_prediction)
self.embedder = Timestep(256)
def encode_adm(self, **kwargs):
clip_pooled = kwargs["pooled_output"]
width = kwargs.get("width", 768)
height = kwargs.get("height", 768)
crop_w = kwargs.get("crop_w", 0)
crop_h = kwargs.get("crop_h", 0)
if kwargs.get("prompt_type", "") == "negative":
aesthetic_score = kwargs.get("aesthetic_score", 2.5)
else:
aesthetic_score = kwargs.get("aesthetic_score", 6)
print(clip_pooled.shape, width, height, crop_w, crop_h, aesthetic_score)
out = []
out.append(self.embedder(torch.Tensor([width])))
out.append(self.embedder(torch.Tensor([height])))
out.append(self.embedder(torch.Tensor([crop_w])))
out.append(self.embedder(torch.Tensor([crop_h])))
out.append(self.embedder(torch.Tensor([aesthetic_score])))
flat = torch.flatten(torch.cat(out))[None, ]
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
class SDXL(BaseModel):
def __init__(self, unet_config, v_prediction=False):
super().__init__(unet_config, v_prediction)
self.embedder = Timestep(256)
def encode_adm(self, **kwargs):
clip_pooled = kwargs["pooled_output"]
width = kwargs.get("width", 768)
height = kwargs.get("height", 768)
crop_w = kwargs.get("crop_w", 0)
crop_h = kwargs.get("crop_h", 0)
target_width = kwargs.get("target_width", width)
target_height = kwargs.get("target_height", height)
print(clip_pooled.shape, width, height, crop_w, crop_h, target_width, target_height)
out = []
out.append(self.embedder(torch.Tensor([width])))
out.append(self.embedder(torch.Tensor([height])))
out.append(self.embedder(torch.Tensor([crop_w])))
out.append(self.embedder(torch.Tensor([crop_h])))
out.append(self.embedder(torch.Tensor([target_width])))
out.append(self.embedder(torch.Tensor([target_height])))
flat = torch.flatten(torch.cat(out))[None, ]
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
from . import supported_models
def count_blocks(state_dict_keys, prefix_string):
count = 0
while True:
c = False
for k in state_dict_keys:
if k.startswith(prefix_string.format(count)):
c = True
break
if c == False:
break
count += 1
return count
def detect_unet_config(state_dict, key_prefix, use_fp16):
state_dict_keys = list(state_dict.keys())
num_res_blocks = 2
unet_config = {
"use_checkpoint": False,
"image_size": 32,
"out_channels": 4,
"num_res_blocks": num_res_blocks,
"use_spatial_transformer": True,
"legacy": False
}
y_input = '{}label_emb.0.0.weight'.format(key_prefix)
if y_input in state_dict_keys:
unet_config["num_classes"] = "sequential"
unet_config["adm_in_channels"] = state_dict[y_input].shape[1]
else:
unet_config["adm_in_channels"] = None
unet_config["use_fp16"] = use_fp16
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
num_res_blocks = []
channel_mult = []
attention_resolutions = []
transformer_depth = []
context_dim = None
use_linear_in_transformer = False
current_res = 1
count = 0
last_res_blocks = 0
last_transformer_depth = 0
last_channel_mult = 0
while True:
prefix = '{}input_blocks.{}.'.format(key_prefix, count)
block_keys = sorted(list(filter(lambda a: a.startswith(prefix), state_dict_keys)))
if len(block_keys) == 0:
break
if "{}0.op.weight".format(prefix) in block_keys: #new layer
if last_transformer_depth > 0:
attention_resolutions.append(current_res)
transformer_depth.append(last_transformer_depth)
num_res_blocks.append(last_res_blocks)
channel_mult.append(last_channel_mult)
current_res *= 2
last_res_blocks = 0
last_transformer_depth = 0
last_channel_mult = 0
else:
res_block_prefix = "{}0.in_layers.0.weight".format(prefix)
if res_block_prefix in block_keys:
last_res_blocks += 1
last_channel_mult = state_dict["{}0.out_layers.3.weight".format(prefix)].shape[0] // model_channels
transformer_prefix = prefix + "1.transformer_blocks."
transformer_keys = sorted(list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys)))
if len(transformer_keys) > 0:
last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}')
if context_dim is None:
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
count += 1
if last_transformer_depth > 0:
attention_resolutions.append(current_res)
transformer_depth.append(last_transformer_depth)
num_res_blocks.append(last_res_blocks)
channel_mult.append(last_channel_mult)
transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}')
if len(set(num_res_blocks)) == 1:
num_res_blocks = num_res_blocks[0]
if len(set(transformer_depth)) == 1:
transformer_depth = transformer_depth[0]
unet_config["in_channels"] = in_channels
unet_config["model_channels"] = model_channels
unet_config["num_res_blocks"] = num_res_blocks
unet_config["attention_resolutions"] = attention_resolutions
unet_config["transformer_depth"] = transformer_depth
unet_config["channel_mult"] = channel_mult
unet_config["transformer_depth_middle"] = transformer_depth_middle
unet_config['use_linear_in_transformer'] = use_linear_in_transformer
unet_config["context_dim"] = context_dim
return unet_config
def model_config_from_unet(state_dict, unet_key_prefix, use_fp16):
unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16)
for model_config in supported_models.models:
if model_config.matches(unet_config):
return model_config(unet_config)
return None
...@@ -229,7 +229,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con ...@@ -229,7 +229,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
timestep_ = torch.cat([timestep] * batch_chunks) timestep_ = torch.cat([timestep] * batch_chunks)
if control is not None: if control is not None:
c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond)) c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
transformer_options = {} transformer_options = {}
if 'transformer_options' in model_options: if 'transformer_options' in model_options:
...@@ -460,8 +460,7 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): ...@@ -460,8 +460,7 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
n[name] = uncond_fill_func(cond_cnets, x) n[name] = uncond_fill_func(cond_cnets, x)
uncond[temp[1]] = [o[0], n] uncond[temp[1]] = [o[0], n]
def encode_adm(model, conds, batch_size, width, height, device, prompt_type):
def encode_adm(model, conds, batch_size, device):
for t in range(len(conds)): for t in range(len(conds)):
x = conds[t] x = conds[t]
adm_out = None adm_out = None
...@@ -469,7 +468,11 @@ def encode_adm(model, conds, batch_size, device): ...@@ -469,7 +468,11 @@ def encode_adm(model, conds, batch_size, device):
adm_out = x[1]["adm"] adm_out = x[1]["adm"]
else: else:
params = x[1].copy() params = x[1].copy()
params["width"] = params.get("width", width * 8)
params["height"] = params.get("height", height * 8)
params["prompt_type"] = params.get("prompt_type", prompt_type)
adm_out = model.encode_adm(device=device, **params) adm_out = model.encode_adm(device=device, **params)
if adm_out is not None: if adm_out is not None:
x[1] = x[1].copy() x[1] = x[1].copy()
x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size).to(device) x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size).to(device)
...@@ -580,8 +583,8 @@ class KSampler: ...@@ -580,8 +583,8 @@ class KSampler:
precision_scope = contextlib.nullcontext precision_scope = contextlib.nullcontext
if self.model.is_adm(): if self.model.is_adm():
positive = encode_adm(self.model, positive, noise.shape[0], self.device) positive = encode_adm(self.model, positive, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "positive")
negative = encode_adm(self.model, negative, noise.shape[0], self.device) negative = encode_adm(self.model, negative, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "negative")
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options} extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}
......
...@@ -3,8 +3,6 @@ import contextlib ...@@ -3,8 +3,6 @@ import contextlib
import copy import copy
import inspect import inspect
from . import sd1_clip
from . import sd2_clip
from comfy import model_management from comfy import model_management
from .ldm.util import instantiate_from_config from .ldm.util import instantiate_from_config
from .ldm.models.autoencoder import AutoencoderKL from .ldm.models.autoencoder import AutoencoderKL
...@@ -17,19 +15,28 @@ from . import clip_vision ...@@ -17,19 +15,28 @@ from . import clip_vision
from . import gligen from . import gligen
from . import diffusers_convert from . import diffusers_convert
from . import model_base from . import model_base
from . import model_detection
def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]): from . import sd1_clip
replace_prefix = {"model.diffusion_model.": "diffusion_model."} from . import sd2_clip
for rp in replace_prefix:
replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), sd.keys())))
for x in replace:
sd[x[1]] = sd.pop(x[0])
def load_model_weights(model, sd):
m, u = model.load_state_dict(sd, strict=False) m, u = model.load_state_dict(sd, strict=False)
m = set(m)
unexpected_keys = set(u)
k = list(sd.keys()) k = list(sd.keys())
for x in k: for x in k:
# print(x) if x not in unexpected_keys:
w = sd.pop(x)
del w
if len(m) > 0:
print("missing", m)
return model
def load_clip_weights(model, sd):
k = list(sd.keys())
for x in k:
if x.startswith("cond_stage_model.transformer.") and not x.startswith("cond_stage_model.transformer.text_model."): if x.startswith("cond_stage_model.transformer.") and not x.startswith("cond_stage_model.transformer.text_model."):
y = x.replace("cond_stage_model.transformer.", "cond_stage_model.transformer.text_model.") y = x.replace("cond_stage_model.transformer.", "cond_stage_model.transformer.text_model.")
sd[y] = sd.pop(x) sd[y] = sd.pop(x)
...@@ -39,20 +46,8 @@ def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]): ...@@ -39,20 +46,8 @@ def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
if ids.dtype == torch.float32: if ids.dtype == torch.float32:
sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round() sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
sd = utils.transformers_convert(sd, "cond_stage_model.model", "cond_stage_model.transformer.text_model", 24) sd = utils.transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
return load_model_weights(model, sd)
for x in load_state_dict_to:
x.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
model.eval()
return model
LORA_CLIP_MAP = { LORA_CLIP_MAP = {
"mlp.fc1": "mlp_fc1", "mlp.fc1": "mlp_fc1",
...@@ -66,18 +61,26 @@ LORA_CLIP_MAP = { ...@@ -66,18 +61,26 @@ LORA_CLIP_MAP = {
LORA_UNET_MAP_ATTENTIONS = { LORA_UNET_MAP_ATTENTIONS = {
"proj_in": "proj_in", "proj_in": "proj_in",
"proj_out": "proj_out", "proj_out": "proj_out",
"transformer_blocks.0.attn1.to_q": "transformer_blocks_0_attn1_to_q",
"transformer_blocks.0.attn1.to_k": "transformer_blocks_0_attn1_to_k",
"transformer_blocks.0.attn1.to_v": "transformer_blocks_0_attn1_to_v",
"transformer_blocks.0.attn1.to_out.0": "transformer_blocks_0_attn1_to_out_0",
"transformer_blocks.0.attn2.to_q": "transformer_blocks_0_attn2_to_q",
"transformer_blocks.0.attn2.to_k": "transformer_blocks_0_attn2_to_k",
"transformer_blocks.0.attn2.to_v": "transformer_blocks_0_attn2_to_v",
"transformer_blocks.0.attn2.to_out.0": "transformer_blocks_0_attn2_to_out_0",
"transformer_blocks.0.ff.net.0.proj": "transformer_blocks_0_ff_net_0_proj",
"transformer_blocks.0.ff.net.2": "transformer_blocks_0_ff_net_2",
} }
transformer_lora_blocks = {
"transformer_blocks.{}.attn1.to_q": "transformer_blocks_{}_attn1_to_q",
"transformer_blocks.{}.attn1.to_k": "transformer_blocks_{}_attn1_to_k",
"transformer_blocks.{}.attn1.to_v": "transformer_blocks_{}_attn1_to_v",
"transformer_blocks.{}.attn1.to_out.0": "transformer_blocks_{}_attn1_to_out_0",
"transformer_blocks.{}.attn2.to_q": "transformer_blocks_{}_attn2_to_q",
"transformer_blocks.{}.attn2.to_k": "transformer_blocks_{}_attn2_to_k",
"transformer_blocks.{}.attn2.to_v": "transformer_blocks_{}_attn2_to_v",
"transformer_blocks.{}.attn2.to_out.0": "transformer_blocks_{}_attn2_to_out_0",
"transformer_blocks.{}.ff.net.0.proj": "transformer_blocks_{}_ff_net_0_proj",
"transformer_blocks.{}.ff.net.2": "transformer_blocks_{}_ff_net_2",
}
for i in range(10):
for k in transformer_lora_blocks:
LORA_UNET_MAP_ATTENTIONS[k.format(i)] = transformer_lora_blocks[k].format(i)
LORA_UNET_MAP_RESNET = { LORA_UNET_MAP_RESNET = {
"in_layers.2": "resnets_{}_conv1", "in_layers.2": "resnets_{}_conv1",
"emb_layers.1": "resnets_{}_time_emb_proj", "emb_layers.1": "resnets_{}_time_emb_proj",
...@@ -470,21 +473,12 @@ def load_lora_for_models(model, clip, lora_path, strength_model, strength_clip): ...@@ -470,21 +473,12 @@ def load_lora_for_models(model, clip, lora_path, strength_model, strength_clip):
class CLIP: class CLIP:
def __init__(self, config={}, embedding_directory=None, no_init=False): def __init__(self, target=None, embedding_directory=None, no_init=False):
if no_init: if no_init:
return return
self.target_clip = config["target"] params = target.params
if "params" in config: clip = target.clip
params = config["params"] tokenizer = target.tokenizer
else:
params = {}
if self.target_clip.endswith("FrozenOpenCLIPEmbedder"):
clip = sd2_clip.SD2ClipModel
tokenizer = sd2_clip.SD2Tokenizer
elif self.target_clip.endswith("FrozenCLIPEmbedder"):
clip = sd1_clip.SD1ClipModel
tokenizer = sd1_clip.SD1Tokenizer
self.device = model_management.text_encoder_device() self.device = model_management.text_encoder_device()
params["device"] = self.device params["device"] = self.device
...@@ -497,11 +491,11 @@ class CLIP: ...@@ -497,11 +491,11 @@ class CLIP:
def clone(self): def clone(self):
n = CLIP(no_init=True) n = CLIP(no_init=True)
n.target_clip = self.target_clip
n.patcher = self.patcher.clone() n.patcher = self.patcher.clone()
n.cond_stage_model = self.cond_stage_model n.cond_stage_model = self.cond_stage_model
n.tokenizer = self.tokenizer n.tokenizer = self.tokenizer
n.layer_idx = self.layer_idx n.layer_idx = self.layer_idx
n.device = self.device
return n return n
def load_from_state_dict(self, sd): def load_from_state_dict(self, sd):
...@@ -521,21 +515,22 @@ class CLIP: ...@@ -521,21 +515,22 @@ class CLIP:
self.cond_stage_model.clip_layer(self.layer_idx) self.cond_stage_model.clip_layer(self.layer_idx)
try: try:
self.patcher.patch_model() self.patcher.patch_model()
cond = self.cond_stage_model.encode_token_weights(tokens) cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
self.patcher.unpatch_model() self.patcher.unpatch_model()
except Exception as e: except Exception as e:
self.patcher.unpatch_model() self.patcher.unpatch_model()
raise e raise e
cond_out = cond
if return_pooled: if return_pooled:
eos_token_index = max(range(len(tokens[0])), key=tokens[0].__getitem__) return cond_out, pooled
pooled = cond[:, eos_token_index] return cond_out
return cond, pooled
return cond
def encode(self, text): def encode(self, text):
tokens = self.tokenize(text) tokens = self.tokenize(text)
return self.encode_from_tokens(tokens) return self.encode_from_tokens(tokens)
class VAE: class VAE:
def __init__(self, ckpt_path=None, scale_factor=0.18215, device=None, config=None): def __init__(self, ckpt_path=None, scale_factor=0.18215, device=None, config=None):
if config is None: if config is None:
...@@ -668,10 +663,10 @@ class ControlNet: ...@@ -668,10 +663,10 @@ class ControlNet:
self.previous_controlnet = None self.previous_controlnet = None
self.global_average_pooling = global_average_pooling self.global_average_pooling = global_average_pooling
def get_control(self, x_noisy, t, cond_txt, batched_number): def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None control_prev = None
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond_txt, batched_number) control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
output_dtype = x_noisy.dtype output_dtype = x_noisy.dtype
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
...@@ -689,7 +684,9 @@ class ControlNet: ...@@ -689,7 +684,9 @@ class ControlNet:
with precision_scope(model_management.get_autocast_device(self.device)): with precision_scope(model_management.get_autocast_device(self.device)):
self.control_model = model_management.load_if_low_vram(self.control_model) self.control_model = model_management.load_if_low_vram(self.control_model)
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt) context = torch.cat(cond['c_crossattn'], 1)
y = cond.get('c_adm', None)
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=context, y=y)
self.control_model = model_management.unload_if_low_vram(self.control_model) self.control_model = model_management.unload_if_low_vram(self.control_model)
out = {'middle':[], 'output': []} out = {'middle':[], 'output': []}
autocast_enabled = torch.is_autocast_enabled() autocast_enabled = torch.is_autocast_enabled()
...@@ -749,60 +746,28 @@ class ControlNet: ...@@ -749,60 +746,28 @@ class ControlNet:
def load_controlnet(ckpt_path, model=None): def load_controlnet(ckpt_path, model=None):
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True) controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight' pth_key = 'control_model.zero_convs.0.0.weight'
pth = False pth = False
sd2 = False key = 'zero_convs.0.0.weight'
key = 'input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
if pth_key in controlnet_data: if pth_key in controlnet_data:
pth = True pth = True
key = pth_key key = pth_key
prefix = "control_model."
elif key in controlnet_data: elif key in controlnet_data:
pass prefix = ""
else: else:
net = load_t2i_adapter(controlnet_data) net = load_t2i_adapter(controlnet_data)
if net is None: if net is None:
print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path) print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path)
return net return net
context_dim = controlnet_data[key].shape[1] use_fp16 = model_management.should_use_fp16()
use_fp16 = False controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config
if model_management.should_use_fp16() and controlnet_data[key].dtype == torch.float16: controlnet_config.pop("out_channels")
use_fp16 = True controlnet_config["hint_channels"] = 3
control_model = cldm.ControlNet(**controlnet_config)
if context_dim == 768:
#SD1.x
control_model = cldm.ControlNet(image_size=32,
in_channels=4,
hint_channels=3,
model_channels=320,
attention_resolutions=[ 4, 2, 1 ],
num_res_blocks=2,
channel_mult=[ 1, 2, 4, 4 ],
num_heads=8,
use_spatial_transformer=True,
transformer_depth=1,
context_dim=context_dim,
use_checkpoint=False,
legacy=False,
use_fp16=use_fp16)
else:
#SD2.x
control_model = cldm.ControlNet(image_size=32,
in_channels=4,
hint_channels=3,
model_channels=320,
attention_resolutions=[ 4, 2, 1 ],
num_res_blocks=2,
channel_mult=[ 1, 2, 4, 4 ],
num_head_channels=64,
use_spatial_transformer=True,
use_linear_in_transformer=True,
transformer_depth=1,
context_dim=context_dim,
use_checkpoint=False,
legacy=False,
use_fp16=use_fp16)
if pth: if pth:
if 'difference' in controlnet_data: if 'difference' in controlnet_data:
if model is not None: if model is not None:
...@@ -823,9 +788,10 @@ def load_controlnet(ckpt_path, model=None): ...@@ -823,9 +788,10 @@ def load_controlnet(ckpt_path, model=None):
pass pass
w = WeightsLoader() w = WeightsLoader()
w.control_model = control_model w.control_model = control_model
w.load_state_dict(controlnet_data, strict=False) missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
else: else:
control_model.load_state_dict(controlnet_data, strict=False) missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
print(missing, unexpected)
if use_fp16: if use_fp16:
control_model = control_model.half() control_model = control_model.half()
...@@ -850,10 +816,10 @@ class T2IAdapter: ...@@ -850,10 +816,10 @@ class T2IAdapter:
self.cond_hint_original = None self.cond_hint_original = None
self.cond_hint = None self.cond_hint = None
def get_control(self, x_noisy, t, cond_txt, batched_number): def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None control_prev = None
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond_txt, batched_number) control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
if self.cond_hint is not None: if self.cond_hint is not None:
...@@ -929,12 +895,21 @@ class T2IAdapter: ...@@ -929,12 +895,21 @@ class T2IAdapter:
def load_t2i_adapter(t2i_data): def load_t2i_adapter(t2i_data):
keys = t2i_data.keys() keys = t2i_data.keys()
if 'adapter' in keys:
t2i_data = t2i_data['adapter']
keys = t2i_data.keys()
if "body.0.in_conv.weight" in keys: if "body.0.in_conv.weight" in keys:
cin = t2i_data['body.0.in_conv.weight'].shape[1] cin = t2i_data['body.0.in_conv.weight'].shape[1]
model_ad = adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4) model_ad = adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4)
elif 'conv_in.weight' in keys: elif 'conv_in.weight' in keys:
cin = t2i_data['conv_in.weight'].shape[1] cin = t2i_data['conv_in.weight'].shape[1]
model_ad = adapter.Adapter(cin=cin, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False) channel = t2i_data['conv_in.weight'].shape[0]
ksize = t2i_data['body.0.block2.weight'].shape[2]
use_conv = False
down_opts = list(filter(lambda a: a.endswith("down_opt.op.weight"), keys))
if len(down_opts) > 0:
use_conv = True
model_ad = adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv)
else: else:
return None return None
model_ad.load_state_dict(t2i_data) model_ad.load_state_dict(t2i_data)
...@@ -1010,17 +985,8 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl ...@@ -1010,17 +985,8 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
class WeightsLoader(torch.nn.Module): class WeightsLoader(torch.nn.Module):
pass pass
w = WeightsLoader() if state_dict is None:
load_state_dict_to = [] state_dict = utils.load_torch_file(ckpt_path)
if output_vae:
vae = VAE(scale_factor=scale_factor, config=vae_config)
w.first_stage_model = vae.first_stage_model
load_state_dict_to = [w]
if output_clip:
clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
w.cond_stage_model = clip.cond_stage_model
load_state_dict_to = [w]
if config['model']["target"].endswith("LatentInpaintDiffusion"): if config['model']["target"].endswith("LatentInpaintDiffusion"):
model = model_base.SDInpaint(unet_config, v_prediction=v_prediction) model = model_base.SDInpaint(unet_config, v_prediction=v_prediction)
...@@ -1029,13 +995,33 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl ...@@ -1029,13 +995,33 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
else: else:
model = model_base.BaseModel(unet_config, v_prediction=v_prediction) model = model_base.BaseModel(unet_config, v_prediction=v_prediction)
if state_dict is None:
state_dict = utils.load_torch_file(ckpt_path)
model = load_model_weights(model, state_dict, verbose=False, load_state_dict_to=load_state_dict_to)
if fp16: if fp16:
model = model.half() model = model.half()
model.load_model_weights(state_dict, "model.diffusion_model.")
if output_vae:
w = WeightsLoader()
vae = VAE(scale_factor=scale_factor, config=vae_config)
w.first_stage_model = vae.first_stage_model
load_model_weights(w, state_dict)
if output_clip:
w = WeightsLoader()
class EmptyClass:
pass
clip_target = EmptyClass()
clip_target.params = clip_config["params"]
if clip_config["target"].endswith("FrozenOpenCLIPEmbedder"):
clip_target.clip = sd2_clip.SD2ClipModel
clip_target.tokenizer = sd2_clip.SD2Tokenizer
elif clip_config["target"].endswith("FrozenCLIPEmbedder"):
clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer
clip = CLIP(clip_target, embedding_directory=embedding_directory)
w.cond_stage_model = clip.cond_stage_model
load_clip_weights(w, state_dict)
return (ModelPatcher(model), clip, vae) return (ModelPatcher(model), clip, vae)
...@@ -1045,139 +1031,41 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o ...@@ -1045,139 +1031,41 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
clip = None clip = None
clipvision = None clipvision = None
vae = None vae = None
model = None
clip_target = None
fp16 = model_management.should_use_fp16() fp16 = model_management.should_use_fp16()
class WeightsLoader(torch.nn.Module): class WeightsLoader(torch.nn.Module):
pass pass
w = WeightsLoader() model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", fp16)
load_state_dict_to = [] if model_config is None:
if output_vae: raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
vae = VAE()
w.first_stage_model = vae.first_stage_model
load_state_dict_to = [w]
if output_clip:
clip_config = {}
if "cond_stage_model.model.transformer.resblocks.22.attn.out_proj.weight" in sd_keys:
clip_config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
else:
clip_config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenCLIPEmbedder'
clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
w.cond_stage_model = clip.cond_stage_model
load_state_dict_to = [w]
clipvision_key = "embedder.model.visual.transformer.resblocks.0.attn.in_proj_weight"
noise_aug_config = None
if clipvision_key in sd_keys:
size = sd[clipvision_key].shape[1]
if model_config.clip_vision_prefix is not None:
if output_clipvision: if output_clipvision:
clipvision = clip_vision.load_clipvision_from_sd(sd) clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix)
noise_aug_key = "noise_augmentor.betas"
if noise_aug_key in sd_keys:
noise_aug_config = {}
params = {}
noise_schedule_config = {}
noise_schedule_config["timesteps"] = sd[noise_aug_key].shape[0]
noise_schedule_config["beta_schedule"] = "squaredcos_cap_v2"
params["noise_schedule_config"] = noise_schedule_config
noise_aug_config['target'] = "comfy.ldm.modules.encoders.noise_aug_modules.CLIPEmbeddingNoiseAugmentation"
if size == 1280: #h
params["timestep_dim"] = 1024
elif size == 1024: #l
params["timestep_dim"] = 768
noise_aug_config['params'] = params
sd_config = {
"linear_start": 0.00085,
"linear_end": 0.012,
"num_timesteps_cond": 1,
"log_every_t": 200,
"timesteps": 1000,
"first_stage_key": "jpg",
"cond_stage_key": "txt",
"image_size": 64,
"channels": 4,
"cond_stage_trainable": False,
"monitor": "val/loss_simple_ema",
"scale_factor": 0.18215,
"use_ema": False,
}
unet_config = {
"use_checkpoint": False,
"image_size": 32,
"out_channels": 4,
"attention_resolutions": [
4,
2,
1
],
"num_res_blocks": 2,
"channel_mult": [
1,
2,
4,
4
],
"use_spatial_transformer": True,
"transformer_depth": 1,
"legacy": False
}
if len(sd['model.diffusion_model.input_blocks.4.1.proj_in.weight'].shape) == 2:
unet_config['use_linear_in_transformer'] = True
unet_config["use_fp16"] = fp16
unet_config["model_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[0]
unet_config["in_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[1]
unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight'].shape[1]
sd_config["unet_config"] = {"target": "comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
unclip_model = False
inpaint_model = False
if noise_aug_config is not None: #SD2.x unclip model
sd_config["noise_aug_config"] = noise_aug_config
sd_config["image_size"] = 96
sd_config["embedding_dropout"] = 0.25
sd_config["conditioning_key"] = 'crossattn-adm'
unclip_model = True
elif unet_config["in_channels"] > 4: #inpainting model
sd_config["conditioning_key"] = "hybrid"
sd_config["finetune_keys"] = None
inpaint_model = True
else:
sd_config["conditioning_key"] = "crossattn"
if unet_config["context_dim"] == 768:
unet_config["num_heads"] = 8 #SD1.x
else:
unet_config["num_head_channels"] = 64 #SD2.x
unclip = 'model.diffusion_model.label_emb.0.0.weight' model = model_config.get_model(sd)
if unclip in sd_keys: model.load_model_weights(sd, "model.diffusion_model.")
unet_config["num_classes"] = "sequential"
unet_config["adm_in_channels"] = sd[unclip].shape[1]
v_prediction = False if output_vae:
if unet_config["context_dim"] == 1024 and unet_config["in_channels"] == 4: #only SD2.x non inpainting models are v prediction vae = VAE(scale_factor=model_config.vae_scale_factor)
k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias" w = WeightsLoader()
out = sd[k] w.first_stage_model = vae.first_stage_model
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out. load_model_weights(w, sd)
v_prediction = True
sd_config["parameterization"] = 'v'
if inpaint_model: if output_clip:
model = model_base.SDInpaint(unet_config, v_prediction=v_prediction) w = WeightsLoader()
elif unclip_model: clip_target = model_config.clip_target()
model = model_base.SD21UNCLIP(unet_config, noise_aug_config["params"], v_prediction=v_prediction) clip = CLIP(clip_target, embedding_directory=embedding_directory)
else: w.cond_stage_model = clip.cond_stage_model
model = model_base.BaseModel(unet_config, v_prediction=v_prediction) sd = model_config.process_clip_state_dict(sd)
load_model_weights(w, sd)
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) left_over = sd.keys()
if len(left_over) > 0:
print("left over keys:", left_over)
return (ModelPatcher(model), clip, vae, clipvision) return (ModelPatcher(model), clip, vae, clipvision)
...@@ -8,11 +8,14 @@ import zipfile ...@@ -8,11 +8,14 @@ import zipfile
class ClipTokenWeightEncoder: class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs): def encode_token_weights(self, token_weight_pairs):
z_empty = self.encode(self.empty_tokens) z_empty, _ = self.encode(self.empty_tokens)
output = [] output = []
first_pooled = None
for x in token_weight_pairs: for x in token_weight_pairs:
tokens = [list(map(lambda a: a[0], x))] tokens = [list(map(lambda a: a[0], x))]
z = self.encode(tokens) z, pooled = self.encode(tokens)
if first_pooled is None:
first_pooled = pooled
for i in range(len(z)): for i in range(len(z)):
for j in range(len(z[i])): for j in range(len(z[i])):
weight = x[j][1] weight = x[j][1]
...@@ -20,7 +23,7 @@ class ClipTokenWeightEncoder: ...@@ -20,7 +23,7 @@ class ClipTokenWeightEncoder:
output += [z] output += [z]
if (len(output) == 0): if (len(output) == 0):
return self.encode(self.empty_tokens) return self.encode(self.empty_tokens)
return torch.cat(output, dim=-2).cpu() return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)""" """Uses the CLIP transformer encoder for text (from huggingface)"""
...@@ -50,6 +53,8 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): ...@@ -50,6 +53,8 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.layer = layer self.layer = layer
self.layer_idx = None self.layer_idx = None
self.empty_tokens = [[49406] + [49407] * 76] self.empty_tokens = [[49406] + [49407] * 76]
self.text_projection = None
self.layer_norm_hidden_state = True
if layer == "hidden": if layer == "hidden":
assert layer_idx is not None assert layer_idx is not None
assert abs(layer_idx) <= 12 assert abs(layer_idx) <= 12
...@@ -112,9 +117,13 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): ...@@ -112,9 +117,13 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
z = outputs.pooler_output[:, None, :] z = outputs.pooler_output[:, None, :]
else: else:
z = outputs.hidden_states[self.layer_idx] z = outputs.hidden_states[self.layer_idx]
z = self.transformer.text_model.final_layer_norm(z) if self.layer_norm_hidden_state:
z = self.transformer.text_model.final_layer_norm(z)
return z pooled_output = outputs.pooler_output
if self.text_projection is not None:
pooled_output = pooled_output @ self.text_projection
return z, pooled_output
def encode(self, tokens): def encode(self, tokens):
return self(tokens) return self(tokens)
...@@ -204,7 +213,7 @@ def expand_directory_list(directories): ...@@ -204,7 +213,7 @@ def expand_directory_list(directories):
dirs.add(root) dirs.add(root)
return list(dirs) return list(dirs)
def load_embed(embedding_name, embedding_directory): def load_embed(embedding_name, embedding_directory, embedding_size):
if isinstance(embedding_directory, str): if isinstance(embedding_directory, str):
embedding_directory = [embedding_directory] embedding_directory = [embedding_directory]
...@@ -253,13 +262,23 @@ def load_embed(embedding_name, embedding_directory): ...@@ -253,13 +262,23 @@ def load_embed(embedding_name, embedding_directory):
if embed_out is None: if embed_out is None:
if 'string_to_param' in embed: if 'string_to_param' in embed:
values = embed['string_to_param'].values() values = embed['string_to_param'].values()
embed_out = next(iter(values))
elif isinstance(embed, list):
out_list = []
for x in range(len(embed)):
for k in embed[x]:
t = embed[x][k]
if t.shape[-1] != embedding_size:
continue
out_list.append(t.reshape(-1, t.shape[-1]))
embed_out = torch.cat(out_list, dim=0)
else: else:
values = embed.values() values = embed.values()
embed_out = next(iter(values)) embed_out = next(iter(values))
return embed_out return embed_out
class SD1Tokenizer: class SD1Tokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None): def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768):
if tokenizer_path is None: if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
...@@ -275,17 +294,18 @@ class SD1Tokenizer: ...@@ -275,17 +294,18 @@ class SD1Tokenizer:
self.embedding_directory = embedding_directory self.embedding_directory = embedding_directory
self.max_word_length = 8 self.max_word_length = 8
self.embedding_identifier = "embedding:" self.embedding_identifier = "embedding:"
self.embedding_size = embedding_size
def _try_get_embedding(self, embedding_name:str): def _try_get_embedding(self, embedding_name:str):
''' '''
Takes a potential embedding name and tries to retrieve it. Takes a potential embedding name and tries to retrieve it.
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None. Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
''' '''
embed = load_embed(embedding_name, self.embedding_directory) embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size)
if embed is None: if embed is None:
stripped = embedding_name.strip(',') stripped = embedding_name.strip(',')
if len(stripped) < len(embedding_name): if len(stripped) < len(embedding_name):
embed = load_embed(stripped, self.embedding_directory) embed = load_embed(stripped, self.embedding_directory, self.embedding_size)
return (embed, embedding_name[len(stripped):]) return (embed, embedding_name[len(stripped):])
return (embed, "") return (embed, "")
......
...@@ -31,4 +31,4 @@ class SD2ClipModel(sd1_clip.SD1ClipModel): ...@@ -31,4 +31,4 @@ class SD2ClipModel(sd1_clip.SD1ClipModel):
class SD2Tokenizer(sd1_clip.SD1Tokenizer): class SD2Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None): def __init__(self, tokenizer_path=None, embedding_directory=None):
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory) super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024)
from comfy import sd1_clip
import torch
import os
class SDXLClipG(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, textmodel_json_config=textmodel_json_config)
self.empty_tokens = [[49406] + [49407] + [0] * 75]
self.text_projection = torch.nn.Parameter(torch.empty(1280, 1280))
self.layer_norm_hidden_state = False
if layer == "last":
pass
elif layer == "penultimate":
layer_idx = -1
self.clip_layer(layer_idx)
elif self.layer == "hidden":
assert layer_idx is not None
assert abs(layer_idx) < 32
self.clip_layer(layer_idx)
else:
raise NotImplementedError()
def clip_layer(self, layer_idx):
if layer_idx < 0:
layer_idx -= 1 #The real last layer of SD2.x clip is the penultimate one. The last one might contain garbage.
if abs(layer_idx) >= 32:
self.layer = "hidden"
self.layer_idx = -2
else:
self.layer = "hidden"
self.layer_idx = layer_idx
class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None):
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280)
class SDXLTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None):
self.clip_l = sd1_clip.SD1Tokenizer(embedding_directory=embedding_directory)
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False):
out = {}
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
return out
def untokenize(self, token_weight_pair):
return self.clip_g.untokenize(token_weight_pair)
class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu"):
super().__init__()
self.clip_l = sd1_clip.SD1ClipModel(layer="hidden", layer_idx=11, device=device)
self.clip_l.layer_norm_hidden_state = False
self.clip_g = SDXLClipG(device=device)
def clip_layer(self, layer_idx):
self.clip_l.clip_layer(layer_idx)
self.clip_g.clip_layer(layer_idx)
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_g = token_weight_pairs["g"]
token_weight_pairs_l = token_weight_pairs["l"]
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
return torch.cat([l_out, g_out], dim=-1), g_pooled
class SDXLRefinerClipModel(torch.nn.Module):
def __init__(self, device="cpu"):
super().__init__()
self.clip_g = SDXLClipG(device=device)
def clip_layer(self, layer_idx):
self.clip_g.clip_layer(layer_idx)
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_g = token_weight_pairs["g"]
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
return g_out, g_pooled
import torch
from . import model_base
from . import utils
from . import sd1_clip
from . import sd2_clip
from . import sdxl_clip
from . import supported_models_base
class SD15(supported_models_base.BASE):
unet_config = {
"context_dim": 768,
"model_channels": 320,
"use_linear_in_transformer": False,
"adm_in_channels": None,
}
unet_extra_config = {
"num_heads": 8,
"num_head_channels": -1,
}
vae_scale_factor = 0.18215
def process_clip_state_dict(self, state_dict):
k = list(state_dict.keys())
for x in k:
if x.startswith("cond_stage_model.transformer.") and not x.startswith("cond_stage_model.transformer.text_model."):
y = x.replace("cond_stage_model.transformer.", "cond_stage_model.transformer.text_model.")
state_dict[y] = state_dict.pop(x)
if 'cond_stage_model.transformer.text_model.embeddings.position_ids' in state_dict:
ids = state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids']
if ids.dtype == torch.float32:
state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
return state_dict
def clip_target(self):
return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel)
class SD20(supported_models_base.BASE):
unet_config = {
"context_dim": 1024,
"model_channels": 320,
"use_linear_in_transformer": True,
"adm_in_channels": None,
}
vae_scale_factor = 0.18215
def v_prediction(self, state_dict):
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias"
out = state_dict[k]
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
return True
return False
def process_clip_state_dict(self, state_dict):
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
return state_dict
def clip_target(self):
return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)
class SD21UnclipL(SD20):
unet_config = {
"context_dim": 1024,
"model_channels": 320,
"use_linear_in_transformer": True,
"adm_in_channels": 1536,
}
clip_vision_prefix = "embedder.model.visual."
noise_aug_config = {"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 768}
class SD21UnclipH(SD20):
unet_config = {
"context_dim": 1024,
"model_channels": 320,
"use_linear_in_transformer": True,
"adm_in_channels": 2048,
}
clip_vision_prefix = "embedder.model.visual."
noise_aug_config = {"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1024}
class SDXLRefiner(supported_models_base.BASE):
unet_config = {
"model_channels": 384,
"use_linear_in_transformer": True,
"context_dim": 1280,
"adm_in_channels": 2560,
"transformer_depth": [0, 4, 4, 0],
}
vae_scale_factor = 0.13025
def get_model(self, state_dict):
return model_base.SDXLRefiner(self.unet_config)
def process_clip_state_dict(self, state_dict):
keys_to_replace = {}
replace_prefix = {}
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.0.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
keys_to_replace["conditioner.embedders.0.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace)
return state_dict
def clip_target(self):
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)
class SDXL(supported_models_base.BASE):
unet_config = {
"model_channels": 320,
"use_linear_in_transformer": True,
"transformer_depth": [0, 2, 10],
"context_dim": 2048,
"adm_in_channels": 2816
}
vae_scale_factor = 0.13025
def get_model(self, state_dict):
return model_base.SDXL(self.unet_config)
def process_clip_state_dict(self, state_dict):
keys_to_replace = {}
replace_prefix = {}
replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model"
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
state_dict = supported_models_base.state_dict_prefix_replace(state_dict, replace_prefix)
state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace)
return state_dict
def clip_target(self):
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL]
import torch
from . import model_base
from . import utils
def state_dict_key_replace(state_dict, keys_to_replace):
for x in keys_to_replace:
if x in state_dict:
state_dict[keys_to_replace[x]] = state_dict.pop(x)
return state_dict
def state_dict_prefix_replace(state_dict, replace_prefix):
for rp in replace_prefix:
replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys())))
for x in replace:
state_dict[x[1]] = state_dict.pop(x[0])
return state_dict
class ClipTarget:
def __init__(self, tokenizer, clip):
self.clip = clip
self.tokenizer = tokenizer
self.params = {}
class BASE:
unet_config = {}
unet_extra_config = {
"num_heads": -1,
"num_head_channels": 64,
}
clip_prefix = []
clip_vision_prefix = None
noise_aug_config = None
@classmethod
def matches(s, unet_config):
for k in s.unet_config:
if s.unet_config[k] != unet_config[k]:
return False
return True
def v_prediction(self, state_dict):
return False
def inpaint_model(self):
return self.unet_config["in_channels"] > 4
def __init__(self, unet_config):
self.unet_config = unet_config
for x in self.unet_extra_config:
self.unet_config[x] = self.unet_extra_config[x]
def get_model(self, state_dict):
if self.inpaint_model():
return model_base.SDInpaint(self.unet_config, v_prediction=self.v_prediction(state_dict))
elif self.noise_aug_config is not None:
return model_base.SD21UNCLIP(self.unet_config, self.noise_aug_config, v_prediction=self.v_prediction(state_dict))
else:
return model_base.BaseModel(self.unet_config, v_prediction=self.v_prediction(state_dict))
def process_clip_state_dict(self, state_dict):
return state_dict
...@@ -26,10 +26,10 @@ def load_torch_file(ckpt, safe_load=False): ...@@ -26,10 +26,10 @@ def load_torch_file(ckpt, safe_load=False):
def transformers_convert(sd, prefix_from, prefix_to, number): def transformers_convert(sd, prefix_from, prefix_to, number):
keys_to_replace = { keys_to_replace = {
"{}.positional_embedding": "{}.embeddings.position_embedding.weight", "{}positional_embedding": "{}embeddings.position_embedding.weight",
"{}.token_embedding.weight": "{}.embeddings.token_embedding.weight", "{}token_embedding.weight": "{}embeddings.token_embedding.weight",
"{}.ln_final.weight": "{}.final_layer_norm.weight", "{}ln_final.weight": "{}final_layer_norm.weight",
"{}.ln_final.bias": "{}.final_layer_norm.bias", "{}ln_final.bias": "{}final_layer_norm.bias",
} }
for k in keys_to_replace: for k in keys_to_replace:
...@@ -48,19 +48,19 @@ def transformers_convert(sd, prefix_from, prefix_to, number): ...@@ -48,19 +48,19 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
for resblock in range(number): for resblock in range(number):
for x in resblock_to_replace: for x in resblock_to_replace:
for y in ["weight", "bias"]: for y in ["weight", "bias"]:
k = "{}.transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y) k = "{}transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
k_to = "{}.encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y) k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
if k in sd: if k in sd:
sd[k_to] = sd.pop(k) sd[k_to] = sd.pop(k)
for y in ["weight", "bias"]: for y in ["weight", "bias"]:
k_from = "{}.transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y) k_from = "{}transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
if k_from in sd: if k_from in sd:
weights = sd.pop(k_from) weights = sd.pop(k_from)
shape_from = weights.shape[0] // 3 shape_from = weights.shape[0] // 3
for x in range(3): for x in range(3):
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"] p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
k_to = "{}.encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y) k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
return sd return sd
......
...@@ -48,7 +48,9 @@ class CLIPTextEncode: ...@@ -48,7 +48,9 @@ class CLIPTextEncode:
CATEGORY = "conditioning" CATEGORY = "conditioning"
def encode(self, clip, text): def encode(self, clip, text):
return ([[clip.encode(text), {}]], ) tokens = clip.tokenize(text)
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
return ([[cond, {"pooled_output": pooled}]], )
class ConditioningCombine: class ConditioningCombine:
@classmethod @classmethod
...@@ -1344,7 +1346,7 @@ NODE_CLASS_MAPPINGS = { ...@@ -1344,7 +1346,7 @@ NODE_CLASS_MAPPINGS = {
"DiffusersLoader": DiffusersLoader, "DiffusersLoader": DiffusersLoader,
"LoadLatent": LoadLatent, "LoadLatent": LoadLatent,
"SaveLatent": SaveLatent "SaveLatent": SaveLatent,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment