"doc/git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "5a9c8e49ac166262b5e2ea1196c90ab54b7720da"
Commit 1cb3f6a8 authored by comfyanonymous's avatar comfyanonymous
Browse files

Move text projection into the CLIP model code.

Fix issue with not loading the SSD1B clip correctly.
parent 6533b172
...@@ -119,6 +119,9 @@ class CLIPTextModel(torch.nn.Module): ...@@ -119,6 +119,9 @@ class CLIPTextModel(torch.nn.Module):
super().__init__() super().__init__()
self.num_layers = config_dict["num_hidden_layers"] self.num_layers = config_dict["num_hidden_layers"]
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations) self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
embed_dim = config_dict["hidden_size"]
self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
self.text_projection.weight.copy_(torch.eye(embed_dim))
self.dtype = dtype self.dtype = dtype
def get_input_embeddings(self): def get_input_embeddings(self):
...@@ -128,7 +131,10 @@ class CLIPTextModel(torch.nn.Module): ...@@ -128,7 +131,10 @@ class CLIPTextModel(torch.nn.Module):
self.text_model.embeddings.token_embedding = embeddings self.text_model.embeddings.token_embedding = embeddings
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return self.text_model(*args, **kwargs) x = self.text_model(*args, **kwargs)
out = self.text_projection(x[2])
return (x[0], x[1], out)
class CLIPVisionEmbeddings(torch.nn.Module): class CLIPVisionEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None): def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None):
......
...@@ -52,7 +52,7 @@ def load_clip_weights(model, sd): ...@@ -52,7 +52,7 @@ def load_clip_weights(model, sd):
if ids.dtype == torch.float32: if ids.dtype == torch.float32:
sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round() sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
sd = comfy.utils.transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24) sd = comfy.utils.clip_text_transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.")
return load_model_weights(model, sd) return load_model_weights(model, sd)
...@@ -361,7 +361,7 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI ...@@ -361,7 +361,7 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
for i in range(len(clip_data)): for i in range(len(clip_data)):
if "transformer.resblocks.0.ln_1.weight" in clip_data[i]: if "transformer.resblocks.0.ln_1.weight" in clip_data[i]:
clip_data[i] = comfy.utils.transformers_convert(clip_data[i], "", "text_model.", 32) clip_data[i] = comfy.utils.clip_text_transformers_convert(clip_data[i], "", "")
clip_target = EmptyClass() clip_target = EmptyClass()
clip_target.params = {} clip_target.params = {}
......
...@@ -86,7 +86,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): ...@@ -86,7 +86,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.layer = layer self.layer = layer
self.layer_idx = None self.layer_idx = None
self.special_tokens = special_tokens self.special_tokens = special_tokens
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
self.enable_attention_masks = enable_attention_masks self.enable_attention_masks = enable_attention_masks
...@@ -182,18 +182,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): ...@@ -182,18 +182,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
else: else:
pooled_output = None pooled_output = None
if self.text_projection is not None and pooled_output is not None:
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
return z.float(), pooled_output return z.float(), pooled_output
def encode(self, tokens): def encode(self, tokens):
return self(tokens) return self(tokens)
def load_sd(self, sd): def load_sd(self, sd):
if "text_projection" in sd:
self.text_projection[:] = sd.pop("text_projection")
if "text_projection.weight" in sd:
self.text_projection[:] = sd.pop("text_projection.weight").transpose(0, 1)
return self.transformer.load_state_dict(sd, strict=False) return self.transformer.load_state_dict(sd, strict=False)
def parse_parentheses(string): def parse_parentheses(string):
......
...@@ -75,7 +75,7 @@ class SD20(supported_models_base.BASE): ...@@ -75,7 +75,7 @@ class SD20(supported_models_base.BASE):
replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format
replace_prefix["cond_stage_model.model."] = "clip_h." replace_prefix["cond_stage_model.model."] = "clip_h."
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
state_dict = utils.transformers_convert(state_dict, "clip_h.", "clip_h.transformer.text_model.", 24) state_dict = utils.clip_text_transformers_convert(state_dict, "clip_h.", "clip_h.transformer.")
return state_dict return state_dict
def process_clip_state_dict_for_saving(self, state_dict): def process_clip_state_dict_for_saving(self, state_dict):
...@@ -134,7 +134,7 @@ class SDXLRefiner(supported_models_base.BASE): ...@@ -134,7 +134,7 @@ class SDXLRefiner(supported_models_base.BASE):
replace_prefix["conditioner.embedders.0.model."] = "clip_g." replace_prefix["conditioner.embedders.0.model."] = "clip_g."
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
state_dict = utils.transformers_convert(state_dict, "clip_g.", "clip_g.transformer.text_model.", 32) state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.")
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace) state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
return state_dict return state_dict
...@@ -182,10 +182,8 @@ class SDXL(supported_models_base.BASE): ...@@ -182,10 +182,8 @@ class SDXL(supported_models_base.BASE):
replace_prefix["conditioner.embedders.1.model."] = "clip_g." replace_prefix["conditioner.embedders.1.model."] = "clip_g."
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
state_dict = utils.transformers_convert(state_dict, "clip_g.", "clip_g.transformer.text_model.", 32)
keys_to_replace["clip_g.text_projection.weight"] = "clip_g.text_projection"
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace) state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.")
return state_dict return state_dict
def process_clip_state_dict_for_saving(self, state_dict): def process_clip_state_dict_for_saving(self, state_dict):
...@@ -338,6 +336,12 @@ class Stable_Cascade_C(supported_models_base.BASE): ...@@ -338,6 +336,12 @@ class Stable_Cascade_C(supported_models_base.BASE):
state_dict[k_to] = weights[shape_from*x:shape_from*(x + 1)] state_dict[k_to] = weights[shape_from*x:shape_from*(x + 1)]
return state_dict return state_dict
def process_clip_state_dict(self, state_dict):
state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True)
if "clip_g.text_projection" in state_dict:
state_dict["clip_g.transformer.text_projection.weight"] = state_dict.pop("clip_g.text_projection").transpose(0, 1)
return state_dict
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
out = model_base.StableCascade_C(self, device=device) out = model_base.StableCascade_C(self, device=device)
return out return out
......
...@@ -98,8 +98,22 @@ def transformers_convert(sd, prefix_from, prefix_to, number): ...@@ -98,8 +98,22 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"] p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y) k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
return sd
def clip_text_transformers_convert(sd, prefix_from, prefix_to):
sd = transformers_convert(sd, prefix_from, "{}text_model.".format(prefix_to), 32)
tp = "{}text_projection.weight".format(prefix_from)
if tp in sd:
sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp)
tp = "{}text_projection".format(prefix_from)
if tp in sd:
sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp).transpose(0, 1)
return sd return sd
UNET_MAP_ATTENTIONS = { UNET_MAP_ATTENTIONS = {
"proj_in.weight", "proj_in.weight",
"proj_in.bias", "proj_in.bias",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment