Commit 420beeeb authored by comfyanonymous's avatar comfyanonymous
Browse files

Clean up and refactor sampler code.

This should make it much easier to write custom nodes with kdiffusion type
samplers.
parent 94cc718e
...@@ -522,42 +522,59 @@ KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral" ...@@ -522,42 +522,59 @@ KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral"
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"] "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
def ksampler(sampler_name, extra_options={}, inpaint_options={}): class KSAMPLER(Sampler):
class KSAMPLER(Sampler): def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): self.sampler_function = sampler_function
extra_args["denoise_mask"] = denoise_mask self.extra_options = extra_options
model_k = KSamplerX0Inpaint(model_wrap) self.inpaint_options = inpaint_options
model_k.latent_image = latent_image
if inpaint_options.get("random", False): #TODO: Should this be the default?
generator = torch.manual_seed(extra_args.get("seed", 41) + 1)
model_k.noise = torch.randn(noise.shape, generator=generator, device="cpu").to(noise.dtype).to(noise.device)
else:
model_k.noise = noise
if self.max_denoise(model_wrap, sigmas): def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0) extra_args["denoise_mask"] = denoise_mask
else: model_k = KSamplerX0Inpaint(model_wrap)
noise = noise * sigmas[0] model_k.latent_image = latent_image
if self.inpaint_options.get("random", False): #TODO: Should this be the default?
generator = torch.manual_seed(extra_args.get("seed", 41) + 1)
model_k.noise = torch.randn(noise.shape, generator=generator, device="cpu").to(noise.dtype).to(noise.device)
else:
model_k.noise = noise
k_callback = None if self.max_denoise(model_wrap, sigmas):
total_steps = len(sigmas) - 1 noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
if callback is not None: else:
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) noise = noise * sigmas[0]
k_callback = None
total_steps = len(sigmas) - 1
if callback is not None:
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
if latent_image is not None:
noise += latent_image
samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
return samples
def ksampler(sampler_name, extra_options={}, inpaint_options={}):
if sampler_name == "dpm_fast":
def dpm_fast_function(model, noise, sigmas, extra_args, callback, disable):
sigma_min = sigmas[-1] sigma_min = sigmas[-1]
if sigma_min == 0: if sigma_min == 0:
sigma_min = sigmas[-2] sigma_min = sigmas[-2]
total_steps = len(sigmas) - 1
return k_diffusion_sampling.sample_dpm_fast(model, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=callback, disable=disable)
sampler_function = dpm_fast_function
elif sampler_name == "dpm_adaptive":
def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable):
sigma_min = sigmas[-1]
if sigma_min == 0:
sigma_min = sigmas[-2]
return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable)
sampler_function = dpm_adaptive_function
else:
sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
if latent_image is not None: return KSAMPLER(sampler_function, extra_options, inpaint_options)
noise += latent_image
if sampler_name == "dpm_fast":
samples = k_diffusion_sampling.sample_dpm_fast(model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
elif sampler_name == "dpm_adaptive":
samples = k_diffusion_sampling.sample_dpm_adaptive(model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
else:
samples = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **extra_options)
return samples
return KSAMPLER
def wrap_model(model): def wrap_model(model):
model_denoise = CFGNoisePredictor(model) model_denoise = CFGNoisePredictor(model)
...@@ -618,11 +635,11 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps): ...@@ -618,11 +635,11 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps):
print("error invalid scheduler", self.scheduler) print("error invalid scheduler", self.scheduler)
return sigmas return sigmas
def sampler_class(name): def sampler_object(name):
if name == "uni_pc": if name == "uni_pc":
sampler = UNIPC sampler = UNIPC()
elif name == "uni_pc_bh2": elif name == "uni_pc_bh2":
sampler = UNIPCBH2 sampler = UNIPCBH2()
elif name == "ddim": elif name == "ddim":
sampler = ksampler("euler", inpaint_options={"random": True}) sampler = ksampler("euler", inpaint_options={"random": True})
else: else:
...@@ -687,6 +704,6 @@ class KSampler: ...@@ -687,6 +704,6 @@ class KSampler:
else: else:
return torch.zeros_like(noise) return torch.zeros_like(noise)
sampler = sampler_class(self.sampler) sampler = sampler_object(self.sampler)
return sample(self.model, noise, positive, negative, cfg, self.device, sampler(), sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) return sample(self.model, noise, positive, negative, cfg, self.device, sampler, sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
...@@ -149,7 +149,7 @@ class KSamplerSelect: ...@@ -149,7 +149,7 @@ class KSamplerSelect:
FUNCTION = "get_sampler" FUNCTION = "get_sampler"
def get_sampler(self, sampler_name): def get_sampler(self, sampler_name):
sampler = comfy.samplers.sampler_class(sampler_name)() sampler = comfy.samplers.sampler_object(sampler_name)
return (sampler, ) return (sampler, )
class SamplerDPMPP_2M_SDE: class SamplerDPMPP_2M_SDE:
...@@ -172,7 +172,7 @@ class SamplerDPMPP_2M_SDE: ...@@ -172,7 +172,7 @@ class SamplerDPMPP_2M_SDE:
sampler_name = "dpmpp_2m_sde" sampler_name = "dpmpp_2m_sde"
else: else:
sampler_name = "dpmpp_2m_sde_gpu" sampler_name = "dpmpp_2m_sde_gpu"
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})() sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})
return (sampler, ) return (sampler, )
...@@ -196,7 +196,7 @@ class SamplerDPMPP_SDE: ...@@ -196,7 +196,7 @@ class SamplerDPMPP_SDE:
sampler_name = "dpmpp_sde" sampler_name = "dpmpp_sde"
else: else:
sampler_name = "dpmpp_sde_gpu" sampler_name = "dpmpp_sde_gpu"
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})() sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
return (sampler, ) return (sampler, )
class SamplerCustom: class SamplerCustom:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment