"sims/git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "2ddf5f12fee5620c827e95b089fb24df6946e349"
Commit 51581dbf authored by comfyanonymous's avatar comfyanonymous
Browse files

Fix last commits causing an issue with the text encoder lora.

parent bf3f2717
...@@ -357,12 +357,13 @@ class ModelPatcher: ...@@ -357,12 +357,13 @@ class ModelPatcher:
self.patches += [(strength_patch, p, strength_model)] self.patches += [(strength_patch, p, strength_model)]
return p.keys() return p.keys()
def model_state_dict(self): def model_state_dict(self, filter_prefix=None):
sd = self.model.state_dict() sd = self.model.state_dict()
keys = list(sd.keys()) keys = list(sd.keys())
for k in keys: if filter_prefix is not None:
if not k.startswith("diffusion_model."): for k in keys:
sd.pop(k) if not k.startswith(filter_prefix):
sd.pop(k)
return sd return sd
def patch_model(self): def patch_model(self):
...@@ -443,7 +444,7 @@ class ModelPatcher: ...@@ -443,7 +444,7 @@ class ModelPatcher:
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype).to(weight.device) weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype).to(weight.device)
return self.model return self.model
def unpatch_model(self): def unpatch_model(self):
model_sd = self.model.state_dict() model_sd = self.model_state_dict()
keys = list(self.backup.keys()) keys = list(self.backup.keys())
for k in keys: for k in keys:
model_sd[k][:] = self.backup[k] model_sd[k][:] = self.backup[k]
......
...@@ -14,7 +14,7 @@ class ModelMergeSimple: ...@@ -14,7 +14,7 @@ class ModelMergeSimple:
def merge(self, model1, model2, ratio): def merge(self, model1, model2, ratio):
m = model1.clone() m = model1.clone()
sd = model2.model_state_dict() sd = model2.model_state_dict("diffusion_model.")
for k in sd: for k in sd:
m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio) m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio)
return (m, ) return (m, )
...@@ -35,7 +35,7 @@ class ModelMergeBlocks: ...@@ -35,7 +35,7 @@ class ModelMergeBlocks:
def merge(self, model1, model2, **kwargs): def merge(self, model1, model2, **kwargs):
m = model1.clone() m = model1.clone()
sd = model2.model_state_dict() sd = model2.model_state_dict("diffusion_model.")
default_ratio = next(iter(kwargs.values())) default_ratio = next(iter(kwargs.values()))
for k in sd: for k in sd:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment