Commit 884ea653 authored by comfyanonymous's avatar comfyanonymous
Browse files

Add a way for nodes to set a custom CFG function.

parent 0ab5c619
...@@ -211,6 +211,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con ...@@ -211,6 +211,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
max_total_area = model_management.maximum_batch_area() max_total_area = model_management.maximum_batch_area()
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options) cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options)
if "sampler_cfg_function" in model_options:
return model_options["sampler_cfg_function"](cond, uncond, cond_scale)
else:
return uncond + (cond - uncond) * cond_scale return uncond + (cond - uncond) * cond_scale
......
...@@ -250,6 +250,9 @@ class ModelPatcher: ...@@ -250,6 +250,9 @@ class ModelPatcher:
def set_model_tomesd(self, ratio): def set_model_tomesd(self, ratio):
self.model_options["transformer_options"]["tomesd"] = {"ratio": ratio} self.model_options["transformer_options"]["tomesd"] = {"ratio": ratio}
def set_model_sampler_cfg_function(self, sampler_cfg_function):
self.model_options["sampler_cfg_function"] = sampler_cfg_function
def model_dtype(self): def model_dtype(self):
return self.model.diffusion_model.dtype return self.model.diffusion_model.dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment