You need to sign in or sign up before continuing.
Commit 20f579d9 authored by comfyanonymous's avatar comfyanonymous
Browse files

Add DualClipLoader to load clip models for SDXL.

Update LoadClip to load clip models for SDXL refiner.
parent b7933960
...@@ -19,6 +19,7 @@ from . import model_detection ...@@ -19,6 +19,7 @@ from . import model_detection
from . import sd1_clip from . import sd1_clip
from . import sd2_clip from . import sd2_clip
from . import sdxl_clip
def load_model_weights(model, sd): def load_model_weights(model, sd):
m, u = model.load_state_dict(sd, strict=False) m, u = model.load_state_dict(sd, strict=False)
...@@ -524,7 +525,7 @@ class CLIP: ...@@ -524,7 +525,7 @@ class CLIP:
return n return n
def load_from_state_dict(self, sd): def load_from_state_dict(self, sd):
self.cond_stage_model.transformer.load_state_dict(sd, strict=False) self.cond_stage_model.load_sd(sd)
def add_patches(self, patches, strength=1.0): def add_patches(self, patches, strength=1.0):
return self.patcher.add_patches(patches, strength) return self.patcher.add_patches(patches, strength)
...@@ -555,6 +556,8 @@ class CLIP: ...@@ -555,6 +556,8 @@ class CLIP:
tokens = self.tokenize(text) tokens = self.tokenize(text)
return self.encode_from_tokens(tokens) return self.encode_from_tokens(tokens)
def load_sd(self, sd):
return self.cond_stage_model.load_sd(sd)
class VAE: class VAE:
def __init__(self, ckpt_path=None, device=None, config=None): def __init__(self, ckpt_path=None, device=None, config=None):
...@@ -959,22 +962,42 @@ def load_style_model(ckpt_path): ...@@ -959,22 +962,42 @@ def load_style_model(ckpt_path):
return StyleModel(model) return StyleModel(model)
def load_clip(ckpt_path, embedding_directory=None): def load_clip(ckpt_paths, embedding_directory=None):
clip_data = utils.load_torch_file(ckpt_path, safe_load=True) clip_data = []
for p in ckpt_paths:
clip_data.append(utils.load_torch_file(p, safe_load=True))
class EmptyClass: class EmptyClass:
pass pass
for i in range(len(clip_data)):
if "transformer.resblocks.0.ln_1.weight" in clip_data[i]:
clip_data[i] = utils.transformers_convert(clip_data[i], "", "text_model.", 32)
clip_target = EmptyClass() clip_target = EmptyClass()
clip_target.params = {} clip_target.params = {}
if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data: if len(clip_data) == 1:
clip_target.clip = sd2_clip.SD2ClipModel if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]:
clip_target.tokenizer = sd2_clip.SD2Tokenizer clip_target.clip = sdxl_clip.SDXLRefinerClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
clip_target.clip = sd2_clip.SD2ClipModel
clip_target.tokenizer = sd2_clip.SD2Tokenizer
else:
clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer
else: else:
clip_target.clip = sd1_clip.SD1ClipModel clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer clip_target.tokenizer = sdxl_clip.SDXLTokenizer
clip = CLIP(clip_target, embedding_directory=embedding_directory) clip = CLIP(clip_target, embedding_directory=embedding_directory)
clip.load_from_state_dict(clip_data) for c in clip_data:
m, u = clip.load_sd(c)
if len(m) > 0:
print("clip missing:", m)
if len(u) > 0:
print("clip unexpected:", u)
return clip return clip
def load_gligen(ckpt_path): def load_gligen(ckpt_path):
......
...@@ -128,6 +128,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): ...@@ -128,6 +128,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def encode(self, tokens): def encode(self, tokens):
return self(tokens) return self(tokens)
def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False)
def parse_parentheses(string): def parse_parentheses(string):
result = [] result = []
current_item = "" current_item = ""
......
...@@ -31,6 +31,11 @@ class SDXLClipG(sd1_clip.SD1ClipModel): ...@@ -31,6 +31,11 @@ class SDXLClipG(sd1_clip.SD1ClipModel):
self.layer = "hidden" self.layer = "hidden"
self.layer_idx = layer_idx self.layer_idx = layer_idx
def load_sd(self, sd):
if "text_projection" in sd:
self.text_projection[:] = sd.pop("text_projection")
return super().load_sd(sd)
class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer): class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None): def __init__(self, tokenizer_path=None, embedding_directory=None):
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280) super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280)
...@@ -68,6 +73,12 @@ class SDXLClipModel(torch.nn.Module): ...@@ -68,6 +73,12 @@ class SDXLClipModel(torch.nn.Module):
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
return torch.cat([l_out, g_out], dim=-1), g_pooled return torch.cat([l_out, g_out], dim=-1), g_pooled
def load_sd(self, sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
return self.clip_g.load_sd(sd)
else:
return self.clip_l.load_sd(sd)
class SDXLRefinerClipModel(torch.nn.Module): class SDXLRefinerClipModel(torch.nn.Module):
def __init__(self, device="cpu"): def __init__(self, device="cpu"):
super().__init__() super().__init__()
...@@ -81,3 +92,5 @@ class SDXLRefinerClipModel(torch.nn.Module): ...@@ -81,3 +92,5 @@ class SDXLRefinerClipModel(torch.nn.Module):
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g) g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
return g_out, g_pooled return g_out, g_pooled
def load_sd(self, sd):
return self.clip_g.load_sd(sd)
...@@ -520,11 +520,27 @@ class CLIPLoader: ...@@ -520,11 +520,27 @@ class CLIPLoader:
RETURN_TYPES = ("CLIP",) RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip" FUNCTION = "load_clip"
CATEGORY = "loaders" CATEGORY = "advanced/loaders"
def load_clip(self, clip_name): def load_clip(self, clip_name):
clip_path = folder_paths.get_full_path("clip", clip_name) clip_path = folder_paths.get_full_path("clip", clip_name)
clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=folder_paths.get_folder_paths("embeddings")) clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,)
class DualCLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ), "clip_name2": (folder_paths.get_filename_list("clip"), ),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
CATEGORY = "advanced/loaders"
def load_clip(self, clip_name1, clip_name2):
clip_path1 = folder_paths.get_full_path("clip", clip_name1)
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,) return (clip,)
class CLIPVisionLoader: class CLIPVisionLoader:
...@@ -1315,6 +1331,7 @@ NODE_CLASS_MAPPINGS = { ...@@ -1315,6 +1331,7 @@ NODE_CLASS_MAPPINGS = {
"LatentCrop": LatentCrop, "LatentCrop": LatentCrop,
"LoraLoader": LoraLoader, "LoraLoader": LoraLoader,
"CLIPLoader": CLIPLoader, "CLIPLoader": CLIPLoader,
"DualCLIPLoader": DualCLIPLoader,
"CLIPVisionEncode": CLIPVisionEncode, "CLIPVisionEncode": CLIPVisionEncode,
"StyleModelApply": StyleModelApply, "StyleModelApply": StyleModelApply,
"unCLIPConditioning": unCLIPConditioning, "unCLIPConditioning": unCLIPConditioning,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment