Commit fa6dd7e5 authored by comfyanonymous's avatar comfyanonymous
Browse files

Fix lowvram issue with saving checkpoints.

The previous fix didn't cover the case where the model was loaded in
lowvram mode right before.
parent 49c20cdc
......@@ -309,6 +309,11 @@ class LoadedModel:
self.weights_loaded = True
return self.real_model
def should_reload_model(self, force_patch_weights=False):
if force_patch_weights and self.model.lowvram_patch_counter > 0:
return True
return False
def model_unload(self, unpatch_weights=True):
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
self.model.model_patches_to(self.model.offload_device)
......@@ -391,10 +396,22 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False):
models_already_loaded = []
for x in models:
loaded_model = LoadedModel(x)
loaded = None
if loaded_model in current_loaded_models:
models_already_loaded.append(loaded_model)
try:
loaded_model_index = current_loaded_models.index(loaded_model)
except:
loaded_model_index = None
if loaded_model_index is not None:
loaded = current_loaded_models[loaded_model_index]
if loaded.should_reload_model(force_patch_weights=force_patch_weights): #TODO: cleanup this model reload logic
current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True)
loaded = None
else:
models_already_loaded.append(loaded)
if loaded is None:
if hasattr(x, "model"):
logging.info(f"Requested to load {x.model.__class__.__name__}")
models_to_load.append(loaded_model)
......
......@@ -58,6 +58,7 @@ class ModelPatcher:
self.weight_inplace_update = weight_inplace_update
self.model_lowvram = False
self.lowvram_patch_counter = 0
self.patches_uuid = uuid.uuid4()
def model_size(self):
......@@ -284,6 +285,7 @@ class ModelPatcher:
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
mem_counter = 0
patch_counter = 0
for n, m in self.model.named_modules():
lowvram_weight = False
if hasattr(m, "comfy_cast_weights"):
......@@ -300,11 +302,13 @@ class ModelPatcher:
self.patch_weight_to_device(weight_key)
else:
m.weight_function = LowVramPatch(weight_key, self)
patch_counter += 1
if bias_key in self.patches:
if force_patch_weights:
self.patch_weight_to_device(bias_key)
else:
m.bias_function = LowVramPatch(bias_key, self)
patch_counter += 1
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
......@@ -317,6 +321,7 @@ class ModelPatcher:
logging.debug("lowvram: loaded module regularly {}".format(m))
self.model_lowvram = True
self.lowvram_patch_counter = patch_counter
return self.model
def calculate_weight(self, patches, weight, key):
......@@ -468,6 +473,7 @@ class ModelPatcher:
m.bias_function = None
self.model_lowvram = False
self.lowvram_patch_counter = 0
keys = list(self.backup.keys())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment