Commit 2c7c14de authored by comfyanonymous's avatar comfyanonymous
Browse files

Support for SDXL text encoder lora.

parent fcef47f0
...@@ -223,13 +223,28 @@ def model_lora_keys(model, key_map={}): ...@@ -223,13 +223,28 @@ def model_lora_keys(model, key_map={}):
counter += 1 counter += 1
counter = 0 counter = 0
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
for b in range(24): clip_l_present = False
for b in range(32):
for c in LORA_CLIP_MAP: for c in LORA_CLIP_MAP:
k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
if k in sdk: if k in sdk:
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
key_map[lora_key] = k key_map[lora_key] = k
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
if k in sdk:
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
key_map[lora_key] = k
clip_l_present = True
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
if k in sdk:
if clip_l_present:
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
else:
lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner
key_map[lora_key] = k
#Locon stuff #Locon stuff
ds_counter = 0 ds_counter = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment