Commit 69c8d6d8 authored by comfyanonymous's avatar comfyanonymous
Browse files

Single and dual clip loader nodes support SD3.

You can use the CLIPLoader to use the t5xxl only or the DualCLIPLoader to
use CLIP-L and CLIP-G only for sd3.
parent 0e49211a
......@@ -370,6 +370,7 @@ def load_style_model(ckpt_path):
class CLIPType(Enum):
STABLE_DIFFUSION = 1
STABLE_CASCADE = 2
SD3 = 3
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION):
clip_data = []
......@@ -399,12 +400,20 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
clip_target.clip = sd2_clip.SD2ClipModel
clip_target.tokenizer = sd2_clip.SD2Tokenizer
elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in clip_data[0]:
dtype_t5 = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"].dtype
clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
clip_target.tokenizer = sd3_clip.SD3Tokenizer
else:
clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer
elif len(clip_data) == 2:
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
if clip_type == CLIPType.SD3:
clip_target.clip = sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
clip_target.tokenizer = sd3_clip.SD3Tokenizer
else:
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
elif len(clip_data) == 3:
clip_target.clip = sd3_clip.SD3ClipModel
clip_target.tokenizer = sd3_clip.SD3Tokenizer
......
......@@ -142,3 +142,9 @@ class SD3ClipModel(torch.nn.Module):
return self.clip_l.load_sd(sd)
else:
return self.t5xxl.load_sd(sd)
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None):
class SD3ClipModel_(SD3ClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype)
return SD3ClipModel_
......@@ -522,11 +522,7 @@ class SD3(supported_models_base.BASE):
t5 = True
dtype_t5 = state_dict[t5_key].dtype
class SD3ClipModel(sd3_clip.SD3ClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype)
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, SD3ClipModel)
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5))
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3]
......
......@@ -818,7 +818,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ),
"type": (["stable_diffusion", "stable_cascade"], ),
"type": (["stable_diffusion", "stable_cascade", "sd3"], ),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
......@@ -829,6 +829,8 @@ class CLIPLoader:
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
if type == "stable_cascade":
clip_type = comfy.sd.CLIPType.STABLE_CASCADE
elif type == "sd3":
clip_type = comfy.sd.CLIPType.SD3
clip_path = folder_paths.get_full_path("clip", clip_name)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
......@@ -837,17 +839,24 @@ class CLIPLoader:
class DualCLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ), "clip_name2": (folder_paths.get_filename_list("clip"), ),
return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ),
"clip_name2": (folder_paths.get_filename_list("clip"), ),
"type": (["sdxl", "sd3"], ),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
CATEGORY = "advanced/loaders"
def load_clip(self, clip_name1, clip_name2):
def load_clip(self, clip_name1, clip_name2, type):
clip_path1 = folder_paths.get_full_path("clip", clip_name1)
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"))
if type == "sdxl":
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
elif type == "sd3":
clip_type = comfy.sd.CLIPType.SD3
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
return (clip,)
class CLIPVisionLoader:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment