Commit 45beebd3 authored by comfyanonymous's avatar comfyanonymous
Browse files

Add a type of model patch useful for model merging.

parent 186f9204
...@@ -347,15 +347,23 @@ class ModelPatcher: ...@@ -347,15 +347,23 @@ class ModelPatcher:
def model_dtype(self): def model_dtype(self):
return self.model.get_dtype() return self.model.get_dtype()
def add_patches(self, patches, strength=1.0): def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
p = {} p = {}
model_sd = self.model.state_dict() model_sd = self.model.state_dict()
for k in patches: for k in patches:
if k in model_sd: if k in model_sd:
p[k] = patches[k] p[k] = patches[k]
self.patches += [(strength, p)] self.patches += [(strength_patch, p, strength_model)]
return p.keys() return p.keys()
def model_state_dict(self):
sd = self.model.state_dict()
keys = list(sd.keys())
for k in keys:
if not k.startswith("diffusion_model."):
sd.pop(k)
return sd
def patch_model(self): def patch_model(self):
model_sd = self.model.state_dict() model_sd = self.model.state_dict()
for p in self.patches: for p in self.patches:
...@@ -371,8 +379,14 @@ class ModelPatcher: ...@@ -371,8 +379,14 @@ class ModelPatcher:
self.backup[key] = weight.clone() self.backup[key] = weight.clone()
alpha = p[0] alpha = p[0]
strength_model = p[2]
if strength_model != 1.0:
weight *= strength_model
if len(v) == 4: #lora/locon if len(v) == 1:
weight += alpha * (v[0]).type(weight.dtype).to(weight.device)
elif len(v) == 4: #lora/locon
mat1 = v[0] mat1 = v[0]
mat2 = v[1] mat2 = v[1]
if v[2] is not None: if v[2] is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment