Commit ec120001 authored by comfyanonymous's avatar comfyanonymous
Browse files

Add support for full diff lora keys.

parent 064d7583
...@@ -131,6 +131,18 @@ def load_lora(lora, to_load): ...@@ -131,6 +131,18 @@ def load_lora(lora, to_load):
loaded_keys.add(b_norm_name) loaded_keys.add(b_norm_name)
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,) patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,)
diff_name = "{}.diff".format(x)
diff_weight = lora.get(diff_name, None)
if diff_weight is not None:
patch_dict[to_load[x]] = (diff_weight,)
loaded_keys.add(diff_name)
diff_bias_name = "{}.diff_b".format(x)
diff_bias = lora.get(diff_bias_name, None)
if diff_bias is not None:
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,)
loaded_keys.add(diff_bias_name)
for x in lora.keys(): for x in lora.keys():
if x not in loaded_keys: if x not in loaded_keys:
print("lora key not loaded", x) print("lora key not loaded", x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment