Commit fa6dd7e5 authored by comfyanonymous's avatar comfyanonymous
Browse files

Fix lowvram issue with saving checkpoints.

The previous fix didn't cover the case where the model was loaded in
lowvram mode right before.
parent 49c20cdc
...@@ -309,6 +309,11 @@ class LoadedModel: ...@@ -309,6 +309,11 @@ class LoadedModel:
self.weights_loaded = True self.weights_loaded = True
return self.real_model return self.real_model
def should_reload_model(self, force_patch_weights=False):
if force_patch_weights and self.model.lowvram_patch_counter > 0:
return True
return False
def model_unload(self, unpatch_weights=True): def model_unload(self, unpatch_weights=True):
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights) self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
self.model.model_patches_to(self.model.offload_device) self.model.model_patches_to(self.model.offload_device)
...@@ -391,10 +396,22 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False): ...@@ -391,10 +396,22 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False):
models_already_loaded = [] models_already_loaded = []
for x in models: for x in models:
loaded_model = LoadedModel(x) loaded_model = LoadedModel(x)
loaded = None
if loaded_model in current_loaded_models: try:
models_already_loaded.append(loaded_model) loaded_model_index = current_loaded_models.index(loaded_model)
else: except:
loaded_model_index = None
if loaded_model_index is not None:
loaded = current_loaded_models[loaded_model_index]
if loaded.should_reload_model(force_patch_weights=force_patch_weights): #TODO: cleanup this model reload logic
current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True)
loaded = None
else:
models_already_loaded.append(loaded)
if loaded is None:
if hasattr(x, "model"): if hasattr(x, "model"):
logging.info(f"Requested to load {x.model.__class__.__name__}") logging.info(f"Requested to load {x.model.__class__.__name__}")
models_to_load.append(loaded_model) models_to_load.append(loaded_model)
......
...@@ -58,6 +58,7 @@ class ModelPatcher: ...@@ -58,6 +58,7 @@ class ModelPatcher:
self.weight_inplace_update = weight_inplace_update self.weight_inplace_update = weight_inplace_update
self.model_lowvram = False self.model_lowvram = False
self.lowvram_patch_counter = 0
self.patches_uuid = uuid.uuid4() self.patches_uuid = uuid.uuid4()
def model_size(self): def model_size(self):
...@@ -284,6 +285,7 @@ class ModelPatcher: ...@@ -284,6 +285,7 @@ class ModelPatcher:
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key) return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
mem_counter = 0 mem_counter = 0
patch_counter = 0
for n, m in self.model.named_modules(): for n, m in self.model.named_modules():
lowvram_weight = False lowvram_weight = False
if hasattr(m, "comfy_cast_weights"): if hasattr(m, "comfy_cast_weights"):
...@@ -300,11 +302,13 @@ class ModelPatcher: ...@@ -300,11 +302,13 @@ class ModelPatcher:
self.patch_weight_to_device(weight_key) self.patch_weight_to_device(weight_key)
else: else:
m.weight_function = LowVramPatch(weight_key, self) m.weight_function = LowVramPatch(weight_key, self)
patch_counter += 1
if bias_key in self.patches: if bias_key in self.patches:
if force_patch_weights: if force_patch_weights:
self.patch_weight_to_device(bias_key) self.patch_weight_to_device(bias_key)
else: else:
m.bias_function = LowVramPatch(bias_key, self) m.bias_function = LowVramPatch(bias_key, self)
patch_counter += 1
m.prev_comfy_cast_weights = m.comfy_cast_weights m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True m.comfy_cast_weights = True
...@@ -317,6 +321,7 @@ class ModelPatcher: ...@@ -317,6 +321,7 @@ class ModelPatcher:
logging.debug("lowvram: loaded module regularly {}".format(m)) logging.debug("lowvram: loaded module regularly {}".format(m))
self.model_lowvram = True self.model_lowvram = True
self.lowvram_patch_counter = patch_counter
return self.model return self.model
def calculate_weight(self, patches, weight, key): def calculate_weight(self, patches, weight, key):
...@@ -468,6 +473,7 @@ class ModelPatcher: ...@@ -468,6 +473,7 @@ class ModelPatcher:
m.bias_function = None m.bias_function = None
self.model_lowvram = False self.model_lowvram = False
self.lowvram_patch_counter = 0
keys = list(self.backup.keys()) keys = list(self.backup.keys())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment