Commit ef4f6037 authored by comfyanonymous's avatar comfyanonymous
Browse files

Fix model patches not working in custom sampling scheduler nodes.

parent a7874d1a
......@@ -174,13 +174,14 @@ class ModelPatcher:
sd.pop(k)
return sd
def patch_model(self, device_to=None):
def patch_model(self, device_to=None, patch_weights=True):
for k in self.object_patches:
old = getattr(self.model, k)
if k not in self.object_patches_backup:
self.object_patches_backup[k] = old
setattr(self.model, k, self.object_patches[k])
if patch_weights:
model_sd = self.model_state_dict()
for key in self.patches:
if key not in model_sd:
......
......@@ -26,7 +26,9 @@ class BasicScheduler:
if denoise < 1.0:
total_steps = int(steps/denoise)
sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, scheduler, total_steps).cpu()
inner_model = model.patch_model(patch_weights=False)
sigmas = comfy.samplers.calculate_sigmas_scheduler(inner_model, scheduler, total_steps).cpu()
model.unpatch_model()
sigmas = sigmas[-(steps + 1):]
return (sigmas, )
......@@ -104,7 +106,9 @@ class SDTurboScheduler:
def get_sigmas(self, model, steps, denoise):
start_step = 10 - int(10 * denoise)
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps]
sigmas = model.model.model_sampling.sigma(timesteps)
inner_model = model.patch_model(patch_weights=False)
sigmas = inner_model.model_sampling.sigma(timesteps)
model.unpatch_model()
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
return (sigmas, )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment