Commit efa5a711 authored by comfyanonymous's avatar comfyanonymous
Browse files

Reduce memory usage when applying DORA: #3557

parent 58c98382
...@@ -9,7 +9,7 @@ import comfy.model_management ...@@ -9,7 +9,7 @@ import comfy.model_management
from comfy.types import UnetWrapperFunction from comfy.types import UnetWrapperFunction
def apply_weight_decompose(dora_scale, weight): def weight_decompose_scale(dora_scale, weight):
weight_norm = ( weight_norm = (
weight.transpose(0, 1) weight.transpose(0, 1)
.reshape(weight.shape[1], -1) .reshape(weight.shape[1], -1)
...@@ -18,7 +18,7 @@ def apply_weight_decompose(dora_scale, weight): ...@@ -18,7 +18,7 @@ def apply_weight_decompose(dora_scale, weight):
.transpose(0, 1) .transpose(0, 1)
) )
return weight * (dora_scale / weight_norm).type(weight.dtype) return (dora_scale / weight_norm).type(weight.dtype)
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None): def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
to = model_options["transformer_options"].copy() to = model_options["transformer_options"].copy()
...@@ -365,7 +365,7 @@ class ModelPatcher: ...@@ -365,7 +365,7 @@ class ModelPatcher:
try: try:
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None: if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) weight *= weight_decompose_scale(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e: except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e)) logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "lokr": elif patch_type == "lokr":
...@@ -407,7 +407,7 @@ class ModelPatcher: ...@@ -407,7 +407,7 @@ class ModelPatcher:
try: try:
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None: if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) weight *= weight_decompose_scale(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e: except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e)) logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "loha": elif patch_type == "loha":
...@@ -439,7 +439,7 @@ class ModelPatcher: ...@@ -439,7 +439,7 @@ class ModelPatcher:
try: try:
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None: if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) weight *= weight_decompose_scale(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e: except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e)) logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora": elif patch_type == "glora":
...@@ -456,7 +456,7 @@ class ModelPatcher: ...@@ -456,7 +456,7 @@ class ModelPatcher:
try: try:
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype) weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None: if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) weight *= weight_decompose_scale(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e: except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e)) logging.error("ERROR {} {} {}".format(patch_type, key, e))
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment