"git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "bb5b4db98d045fb100a2b32996e688e63d86c68b"
Commit 0542088e authored by comfyanonymous's avatar comfyanonymous
Browse files

Refactor sampler code for more advanced sampler nodes part 2.

parent 57753c96
import torch import torch
import comfy.model_management import comfy.model_management
import comfy.samplers import comfy.samplers
import comfy.conds
import comfy.utils import comfy.utils
import math
import numpy as np import numpy as np
import logging
def prepare_noise(latent_image, seed, noise_inds=None): def prepare_noise(latent_image, seed, noise_inds=None):
""" """
...@@ -25,106 +24,21 @@ def prepare_noise(latent_image, seed, noise_inds=None): ...@@ -25,106 +24,21 @@ def prepare_noise(latent_image, seed, noise_inds=None):
noises = torch.cat(noises, axis=0) noises = torch.cat(noises, axis=0)
return noises return noises
def prepare_mask(noise_mask, shape, device): def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
"""ensures noise mask is of proper dimensions""" logging.warning("Warning: comfy.sample.prepare_sampling isn't used anymore and can be removed")
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear") return model, positive, negative, noise_mask, []
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0])
noise_mask = noise_mask.to(device)
return noise_mask
def get_models_from_cond(cond, model_type):
models = []
for c in cond:
if model_type in c:
models += [c[model_type]]
return models
def convert_cond(cond):
out = []
for c in cond:
temp = c[1].copy()
model_conds = temp.get("model_conds", {})
if c[0] is not None:
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove
temp["cross_attn"] = c[0]
temp["model_conds"] = model_conds
out.append(temp)
return out
def get_additional_models(conds, dtype):
"""loads additional models in conditioning"""
cnets = []
gligen = []
for i in range(len(conds)):
cnets += get_models_from_cond(conds[i], "control")
gligen += get_models_from_cond(conds[i], "gligen")
control_nets = set(cnets)
inference_memory = 0
control_models = []
for m in control_nets:
control_models += m.get_models()
inference_memory += m.inference_memory_requirements(dtype)
gligen = [x[1] for x in gligen]
models = control_models + gligen
return models, inference_memory
def cleanup_additional_models(models): def cleanup_additional_models(models):
"""cleanup additional models that were loaded""" logging.warning("Warning: comfy.sample.cleanup_additional_models isn't used anymore and can be removed")
for m in models:
if hasattr(m, 'cleanup'):
m.cleanup()
def prepare_sampling(model, noise_shape, conds, noise_mask):
device = model.load_device
for i in range(len(conds)):
conds[i] = convert_cond(conds[i])
if noise_mask is not None:
noise_mask = prepare_mask(noise_mask, noise_shape, device)
real_model = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
real_model = model.model
return real_model, conds, noise_mask, models
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
real_model, conds_copy, noise_mask, models = prepare_sampling(model, noise.shape, [positive, negative], noise_mask) sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
positive_copy, negative_copy = conds_copy
noise = noise.to(model.load_device)
latent_image = latent_image.to(model.load_device)
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.to(comfy.model_management.intermediate_device()) samples = samples.to(comfy.model_management.intermediate_device())
cleanup_additional_models(models)
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
return samples return samples
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None): def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
real_model, conds, noise_mask, models = prepare_sampling(model, noise.shape, [positive, negative], noise_mask) samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
noise = noise.to(model.load_device)
latent_image = latent_image.to(model.load_device)
sigmas = sigmas.to(model.load_device)
samples = comfy.samplers.sample(real_model, noise, conds[0], conds[1], cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.to(comfy.model_management.intermediate_device()) samples = samples.to(comfy.model_management.intermediate_device())
cleanup_additional_models(models)
control_cleanup = []
for i in range(len(conds)):
control_cleanup += get_models_from_cond(conds[i], "control")
cleanup_additional_models(set(control_cleanup))
return samples return samples
import torch
import comfy.model_management
import comfy.conds
def prepare_mask(noise_mask, shape, device):
"""ensures noise mask is of proper dimensions"""
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0])
noise_mask = noise_mask.to(device)
return noise_mask
def get_models_from_cond(cond, model_type):
models = []
for c in cond:
if model_type in c:
models += [c[model_type]]
return models
def convert_cond(cond):
out = []
for c in cond:
temp = c[1].copy()
model_conds = temp.get("model_conds", {})
if c[0] is not None:
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove
temp["cross_attn"] = c[0]
temp["model_conds"] = model_conds
out.append(temp)
return out
def get_additional_models(conds, dtype):
"""loads additional models in conditioning"""
cnets = []
gligen = []
for k in conds:
cnets += get_models_from_cond(conds[k], "control")
gligen += get_models_from_cond(conds[k], "gligen")
control_nets = set(cnets)
inference_memory = 0
control_models = []
for m in control_nets:
control_models += m.get_models()
inference_memory += m.inference_memory_requirements(dtype)
gligen = [x[1] for x in gligen]
models = control_models + gligen
return models, inference_memory
def cleanup_additional_models(models):
"""cleanup additional models that were loaded"""
for m in models:
if hasattr(m, 'cleanup'):
m.cleanup()
def prepare_sampling(model, noise_shape, conds):
device = model.load_device
real_model = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
real_model = model.model
return real_model, conds, models
def cleanup_models(conds, models):
cleanup_additional_models(models)
control_cleanup = []
for k in conds:
control_cleanup += get_models_from_cond(conds[k], "control")
cleanup_additional_models(set(control_cleanup))
...@@ -5,6 +5,7 @@ import collections ...@@ -5,6 +5,7 @@ import collections
from comfy import model_management from comfy import model_management
import math import math
import logging import logging
import comfy.sampler_helpers
def get_area_and_mult(conds, x_in, timestep_in): def get_area_and_mult(conds, x_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0) area = (x_in.shape[2], x_in.shape[3], 0, 0)
...@@ -230,58 +231,45 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): ...@@ -230,58 +231,45 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.") logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options)) return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
#The main sampling function shared by all the samplers def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options={}):
#Returns denoised if "sampler_cfg_function" in model_options:
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
uncond_ = None cfg_result = x - model_options["sampler_cfg_function"](args)
else: else:
uncond_ = uncond cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
conds = [cond, uncond_] for fn in model_options.get("sampler_post_cfg_function", []):
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
"sigma": timestep, "model_options": model_options, "input": x}
cfg_result = fn(args)
out = calc_cond_batch(model, conds, x, timestep, model_options) return cfg_result
cond_pred = out[0]
uncond_pred = out[1]
if "sampler_cfg_function" in model_options: #The main sampling function shared by all the samplers
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, #Returns denoised
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options} def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
cfg_result = x - model_options["sampler_cfg_function"](args) if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
else: uncond_ = None
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale else:
uncond_ = uncond
for fn in model_options.get("sampler_post_cfg_function", []):
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
"sigma": timestep, "model_options": model_options, "input": x}
cfg_result = fn(args)
return cfg_result conds = [cond, uncond_]
out = calc_cond_batch(model, conds, x, timestep, model_options)
return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options)
class CFGNoisePredictor(torch.nn.Module):
def __init__(self, model, cond_scale=1.0):
super().__init__()
self.inner_model = model
self.cond_scale = cond_scale
def apply_model(self, x, timestep, conds, model_options={}, seed=None):
out = sampling_function(self.inner_model, x, timestep, conds.get("negative", None), conds.get("positive", None), self.cond_scale, model_options=model_options, seed=seed)
return out
def forward(self, *args, **kwargs):
return self.apply_model(*args, **kwargs)
class KSamplerX0Inpaint(torch.nn.Module): class KSamplerX0Inpaint:
def __init__(self, model, sigmas): def __init__(self, model, sigmas):
super().__init__()
self.inner_model = model self.inner_model = model
self.sigmas = sigmas self.sigmas = sigmas
def forward(self, x, sigma, conds, denoise_mask, model_options={}, seed=None): def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None):
if denoise_mask is not None: if denoise_mask is not None:
if "denoise_mask_function" in model_options: if "denoise_mask_function" in model_options:
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas}) denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
latent_mask = 1. - denoise_mask latent_mask = 1. - denoise_mask
x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask
out = self.inner_model(x, sigma, conds=conds, model_options=model_options, seed=seed) out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
if denoise_mask is not None: if denoise_mask is not None:
out = out * denoise_mask + self.latent_image * latent_mask out = out * denoise_mask + self.latent_image * latent_mask
return out return out
...@@ -601,22 +589,66 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N ...@@ -601,22 +589,66 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
return conds return conds
class CFGGuider:
def __init__(self, model_patcher):
self.model_patcher = model_patcher
self.model_options = model_patcher.model_options
self.original_conds = {}
self.cfg = 1.0
def set_conds(self, conds):
for k in conds:
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
def set_cfg(self, cfg):
self.cfg = cfg
def __call__(self, *args, **kwargs):
return self.predict_noise(*args, **kwargs)
def predict_noise(self, x, timestep, model_options={}, seed=None):
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed):
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
latent_image = self.inner_model.process_latent_in(latent_image)
def sample_advanced(model, noise, conds, guider_class, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
latent_image = model.process_latent_in(latent_image) extra_args = {"model_options": self.model_options, "seed":seed}
samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return self.inner_model.process_latent_out(samples.to(torch.float32))
def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
self.conds = {}
for k in self.original_conds:
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds)
device = self.model_patcher.load_device
if denoise_mask is not None:
denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
conds = process_conds(model, noise, conds, device, latent_image, denoise_mask, seed) noise = noise.to(device)
model_wrap = guider_class(model) latent_image = latent_image.to(device)
sigmas = sigmas.to(device)
extra_args = {"conds": conds, "model_options": model_options, "seed":seed} output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
return model.process_latent_out(samples.to(torch.float32)) del self.inner_model
del self.conds
del self.loaded_models
return output
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
return sample_advanced(model, noise, {"positive": positive, "negative": negative}, lambda a: CFGNoisePredictor(a, cfg), device, sampler, sigmas, model_options, latent_image, denoise_mask, callback, disable_pbar, seed) cfg_guider = CFGGuider(model)
cfg_guider.set_conds({"positive": positive, "negative": negative})
cfg_guider.set_cfg(cfg)
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"] SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
...@@ -676,7 +708,7 @@ class KSampler: ...@@ -676,7 +708,7 @@ class KSampler:
steps += 1 steps += 1
discard_penultimate_sigma = True discard_penultimate_sigma = True
sigmas = calculate_sigmas_scheduler(self.model, self.scheduler, steps) sigmas = calculate_sigmas_scheduler(self.model.model, self.scheduler, steps)
if discard_penultimate_sigma: if discard_penultimate_sigma:
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment