"vscode:/vscode.git/clone" did not exist on "3b5badb770ad2b91cac4e046e34adf163e8cbf21"
Commit 97d03ae0 authored by comfyanonymous's avatar comfyanonymous
Browse files

StableCascade CLIP model support.

parent 667c9281
import torch import torch
from enum import Enum
from comfy import model_management from comfy import model_management
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
...@@ -309,8 +310,11 @@ def load_style_model(ckpt_path): ...@@ -309,8 +310,11 @@ def load_style_model(ckpt_path):
model.load_state_dict(model_data) model.load_state_dict(model_data)
return StyleModel(model) return StyleModel(model)
class CLIPType(Enum):
STABLE_DIFFUSION = 1
STABLE_CASCADE = 2
def load_clip(ckpt_paths, embedding_directory=None): def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION):
clip_data = [] clip_data = []
for p in ckpt_paths: for p in ckpt_paths:
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True)) clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
...@@ -326,8 +330,12 @@ def load_clip(ckpt_paths, embedding_directory=None): ...@@ -326,8 +330,12 @@ def load_clip(ckpt_paths, embedding_directory=None):
clip_target.params = {} clip_target.params = {}
if len(clip_data) == 1: if len(clip_data) == 1:
if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]: if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]:
clip_target.clip = sdxl_clip.SDXLRefinerClipModel if clip_type == CLIPType.STABLE_CASCADE:
clip_target.tokenizer = sdxl_clip.SDXLTokenizer clip_target.clip = sdxl_clip.StableCascadeClipModel
clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer
else:
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]: elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
clip_target.clip = sd2_clip.SD2ClipModel clip_target.clip = sd2_clip.SD2ClipModel
clip_target.tokenizer = sd2_clip.SD2Tokenizer clip_target.tokenizer = sd2_clip.SD2Tokenizer
......
...@@ -67,7 +67,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): ...@@ -67,7 +67,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
] ]
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel, freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True): # clip-vit-base-patch32 special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False): # clip-vit-base-patch32
super().__init__() super().__init__()
assert layer in self.LAYERS assert layer in self.LAYERS
...@@ -88,7 +88,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): ...@@ -88,7 +88,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.special_tokens = special_tokens self.special_tokens = special_tokens
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1])) self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
self.enable_attention_masks = False self.enable_attention_masks = enable_attention_masks
self.layer_norm_hidden_state = layer_norm_hidden_state self.layer_norm_hidden_state = layer_norm_hidden_state
if layer == "hidden": if layer == "hidden":
......
...@@ -64,3 +64,25 @@ class SDXLClipModel(torch.nn.Module): ...@@ -64,3 +64,25 @@ class SDXLClipModel(torch.nn.Module):
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel): class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None): def __init__(self, device="cpu", dtype=None):
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG) super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG)
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None):
super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None):
super().__init__(embedding_directory=embedding_directory, clip_name="g", tokenizer=StableCascadeClipGTokenizer)
class StableCascadeClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True)
def load_sd(self, sd):
return super().load_sd(sd)
class StableCascadeClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG)
...@@ -336,7 +336,7 @@ class Stable_Cascade_C(supported_models_base.BASE): ...@@ -336,7 +336,7 @@ class Stable_Cascade_C(supported_models_base.BASE):
return out return out
def clip_target(self): def clip_target(self):
return None return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel)
class Stable_Cascade_B(Stable_Cascade_C): class Stable_Cascade_B(Stable_Cascade_C):
unet_config = { unet_config = {
......
...@@ -854,15 +854,20 @@ class CLIPLoader: ...@@ -854,15 +854,20 @@ class CLIPLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ), return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ),
"type": (["stable_diffusion", "stable_cascade"], ),
}} }}
RETURN_TYPES = ("CLIP",) RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip" FUNCTION = "load_clip"
CATEGORY = "advanced/loaders" CATEGORY = "advanced/loaders"
def load_clip(self, clip_name): def load_clip(self, clip_name, type="stable_diffusion"):
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
if type == "stable_cascade":
clip_type = comfy.sd.CLIPType.STABLE_CASCADE
clip_path = folder_paths.get_full_path("clip", clip_name) clip_path = folder_paths.get_full_path("clip", clip_name)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings")) clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
return (clip,) return (clip,)
class DualCLIPLoader: class DualCLIPLoader:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment