Commit 420beeeb authored by comfyanonymous's avatar comfyanonymous
Browse files

Clean up and refactor sampler code.

This should make it much easier to write custom nodes with kdiffusion type
samplers.
parent 94cc718e
...@@ -522,13 +522,17 @@ KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral" ...@@ -522,13 +522,17 @@ KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral"
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"] "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
def ksampler(sampler_name, extra_options={}, inpaint_options={}): class KSAMPLER(Sampler):
class KSAMPLER(Sampler): def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
self.sampler_function = sampler_function
self.extra_options = extra_options
self.inpaint_options = inpaint_options
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
extra_args["denoise_mask"] = denoise_mask extra_args["denoise_mask"] = denoise_mask
model_k = KSamplerX0Inpaint(model_wrap) model_k = KSamplerX0Inpaint(model_wrap)
model_k.latent_image = latent_image model_k.latent_image = latent_image
if inpaint_options.get("random", False): #TODO: Should this be the default? if self.inpaint_options.get("random", False): #TODO: Should this be the default?
generator = torch.manual_seed(extra_args.get("seed", 41) + 1) generator = torch.manual_seed(extra_args.get("seed", 41) + 1)
model_k.noise = torch.randn(noise.shape, generator=generator, device="cpu").to(noise.dtype).to(noise.device) model_k.noise = torch.randn(noise.shape, generator=generator, device="cpu").to(noise.dtype).to(noise.device)
else: else:
...@@ -544,20 +548,33 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}): ...@@ -544,20 +548,33 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
if callback is not None: if callback is not None:
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
sigma_min = sigmas[-1]
if sigma_min == 0:
sigma_min = sigmas[-2]
if latent_image is not None: if latent_image is not None:
noise += latent_image noise += latent_image
samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
return samples
def ksampler(sampler_name, extra_options={}, inpaint_options={}):
if sampler_name == "dpm_fast": if sampler_name == "dpm_fast":
samples = k_diffusion_sampling.sample_dpm_fast(model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar) def dpm_fast_function(model, noise, sigmas, extra_args, callback, disable):
sigma_min = sigmas[-1]
if sigma_min == 0:
sigma_min = sigmas[-2]
total_steps = len(sigmas) - 1
return k_diffusion_sampling.sample_dpm_fast(model, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=callback, disable=disable)
sampler_function = dpm_fast_function
elif sampler_name == "dpm_adaptive": elif sampler_name == "dpm_adaptive":
samples = k_diffusion_sampling.sample_dpm_adaptive(model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar) def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable):
sigma_min = sigmas[-1]
if sigma_min == 0:
sigma_min = sigmas[-2]
return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable)
sampler_function = dpm_adaptive_function
else: else:
samples = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **extra_options) sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
return samples
return KSAMPLER return KSAMPLER(sampler_function, extra_options, inpaint_options)
def wrap_model(model): def wrap_model(model):
model_denoise = CFGNoisePredictor(model) model_denoise = CFGNoisePredictor(model)
...@@ -618,11 +635,11 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps): ...@@ -618,11 +635,11 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps):
print("error invalid scheduler", self.scheduler) print("error invalid scheduler", self.scheduler)
return sigmas return sigmas
def sampler_class(name): def sampler_object(name):
if name == "uni_pc": if name == "uni_pc":
sampler = UNIPC sampler = UNIPC()
elif name == "uni_pc_bh2": elif name == "uni_pc_bh2":
sampler = UNIPCBH2 sampler = UNIPCBH2()
elif name == "ddim": elif name == "ddim":
sampler = ksampler("euler", inpaint_options={"random": True}) sampler = ksampler("euler", inpaint_options={"random": True})
else: else:
...@@ -687,6 +704,6 @@ class KSampler: ...@@ -687,6 +704,6 @@ class KSampler:
else: else:
return torch.zeros_like(noise) return torch.zeros_like(noise)
sampler = sampler_class(self.sampler) sampler = sampler_object(self.sampler)
return sample(self.model, noise, positive, negative, cfg, self.device, sampler(), sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) return sample(self.model, noise, positive, negative, cfg, self.device, sampler, sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
...@@ -149,7 +149,7 @@ class KSamplerSelect: ...@@ -149,7 +149,7 @@ class KSamplerSelect:
FUNCTION = "get_sampler" FUNCTION = "get_sampler"
def get_sampler(self, sampler_name): def get_sampler(self, sampler_name):
sampler = comfy.samplers.sampler_class(sampler_name)() sampler = comfy.samplers.sampler_object(sampler_name)
return (sampler, ) return (sampler, )
class SamplerDPMPP_2M_SDE: class SamplerDPMPP_2M_SDE:
...@@ -172,7 +172,7 @@ class SamplerDPMPP_2M_SDE: ...@@ -172,7 +172,7 @@ class SamplerDPMPP_2M_SDE:
sampler_name = "dpmpp_2m_sde" sampler_name = "dpmpp_2m_sde"
else: else:
sampler_name = "dpmpp_2m_sde_gpu" sampler_name = "dpmpp_2m_sde_gpu"
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})() sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})
return (sampler, ) return (sampler, )
...@@ -196,7 +196,7 @@ class SamplerDPMPP_SDE: ...@@ -196,7 +196,7 @@ class SamplerDPMPP_SDE:
sampler_name = "dpmpp_sde" sampler_name = "dpmpp_sde"
else: else:
sampler_name = "dpmpp_sde_gpu" sampler_name = "dpmpp_sde_gpu"
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})() sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
return (sampler, ) return (sampler, )
class SamplerCustom: class SamplerCustom:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment