Commit 873911fa authored by fengyf1's avatar fengyf1
Browse files

Initial commit

parents
#Taken from: https://github.com/zju-pi/diff-sampler/blob/main/gits-main/solver_utils.py
#under Apache 2 license
import torch
import numpy as np
# A pytorch reimplementation of DEIS (https://github.com/qsh-zh/deis).
#############################
### Utils for DEIS solver ###
#############################
#----------------------------------------------------------------------------
# Transfer from the input time (sigma) used in EDM to that (t) used in DEIS.
def edm2t(edm_steps, epsilon_s=1e-3, sigma_min=0.002, sigma_max=80):
vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
vp_beta_d = 2 * (np.log(torch.tensor(sigma_min).cpu() ** 2 + 1) / epsilon_s - np.log(torch.tensor(sigma_max).cpu() ** 2 + 1)) / (epsilon_s - 1)
vp_beta_min = np.log(torch.tensor(sigma_max).cpu() ** 2 + 1) - 0.5 * vp_beta_d
t_steps = vp_sigma_inv(vp_beta_d.clone().detach().cpu(), vp_beta_min.clone().detach().cpu())(edm_steps.clone().detach().cpu())
return t_steps, vp_beta_min, vp_beta_d + vp_beta_min
#----------------------------------------------------------------------------
def cal_poly(prev_t, j, taus):
poly = 1
for k in range(prev_t.shape[0]):
if k == j:
continue
poly *= (taus - prev_t[k]) / (prev_t[j] - prev_t[k])
return poly
#----------------------------------------------------------------------------
# Transfer from t to alpha_t.
def t2alpha_fn(beta_0, beta_1, t):
return torch.exp(-0.5 * t ** 2 * (beta_1 - beta_0) - t * beta_0)
#----------------------------------------------------------------------------
def cal_intergrand(beta_0, beta_1, taus):
with torch.inference_mode(mode=False):
taus = taus.clone()
beta_0 = beta_0.clone()
beta_1 = beta_1.clone()
with torch.enable_grad():
taus.requires_grad_(True)
alpha = t2alpha_fn(beta_0, beta_1, taus)
log_alpha = alpha.log()
log_alpha.sum().backward()
d_log_alpha_dtau = taus.grad
integrand = -0.5 * d_log_alpha_dtau / torch.sqrt(alpha * (1 - alpha))
return integrand
#----------------------------------------------------------------------------
def get_deis_coeff_list(t_steps, max_order, N=10000, deis_mode='tab'):
"""
Get the coefficient list for DEIS sampling.
Args:
t_steps: A pytorch tensor. The time steps for sampling.
max_order: A `int`. Maximum order of the solver. 1 <= max_order <= 4
N: A `int`. Use how many points to perform the numerical integration when deis_mode=='tab'.
deis_mode: A `str`. Select between 'tab' and 'rhoab'. Type of DEIS.
Returns:
A pytorch tensor. A batch of generated samples or sampling trajectories if return_inters=True.
"""
if deis_mode == 'tab':
t_steps, beta_0, beta_1 = edm2t(t_steps)
C = []
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
order = min(i+1, max_order)
if order == 1:
C.append([])
else:
taus = torch.linspace(t_cur, t_next, N) # split the interval for integral appximation
dtau = (t_next - t_cur) / N
prev_t = t_steps[[i - k for k in range(order)]]
coeff_temp = []
integrand = cal_intergrand(beta_0, beta_1, taus)
for j in range(order):
poly = cal_poly(prev_t, j, taus)
coeff_temp.append(torch.sum(integrand * poly) * dtau)
C.append(coeff_temp)
elif deis_mode == 'rhoab':
# Analytical solution, second order
def get_def_intergral_2(a, b, start, end, c):
coeff = (end**3 - start**3) / 3 - (end**2 - start**2) * (a + b) / 2 + (end - start) * a * b
return coeff / ((c - a) * (c - b))
# Analytical solution, third order
def get_def_intergral_3(a, b, c, start, end, d):
coeff = (end**4 - start**4) / 4 - (end**3 - start**3) * (a + b + c) / 3 \
+ (end**2 - start**2) * (a*b + a*c + b*c) / 2 - (end - start) * a * b * c
return coeff / ((d - a) * (d - b) * (d - c))
C = []
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
order = min(i, max_order)
if order == 0:
C.append([])
else:
prev_t = t_steps[[i - k for k in range(order+1)]]
if order == 1:
coeff_cur = ((t_next - prev_t[1])**2 - (t_cur - prev_t[1])**2) / (2 * (t_cur - prev_t[1]))
coeff_prev1 = (t_next - t_cur)**2 / (2 * (prev_t[1] - t_cur))
coeff_temp = [coeff_cur, coeff_prev1]
elif order == 2:
coeff_cur = get_def_intergral_2(prev_t[1], prev_t[2], t_cur, t_next, t_cur)
coeff_prev1 = get_def_intergral_2(t_cur, prev_t[2], t_cur, t_next, prev_t[1])
coeff_prev2 = get_def_intergral_2(t_cur, prev_t[1], t_cur, t_next, prev_t[2])
coeff_temp = [coeff_cur, coeff_prev1, coeff_prev2]
elif order == 3:
coeff_cur = get_def_intergral_3(prev_t[1], prev_t[2], prev_t[3], t_cur, t_next, t_cur)
coeff_prev1 = get_def_intergral_3(t_cur, prev_t[2], prev_t[3], t_cur, t_next, prev_t[1])
coeff_prev2 = get_def_intergral_3(t_cur, prev_t[1], prev_t[3], t_cur, t_next, prev_t[2])
coeff_prev3 = get_def_intergral_3(t_cur, prev_t[1], prev_t[2], t_cur, t_next, prev_t[3])
coeff_temp = [coeff_cur, coeff_prev1, coeff_prev2, coeff_prev3]
C.append(coeff_temp)
return C
# SA-Solver: Stochastic Adams Solver (NeurIPS 2023, arXiv:2309.05019)
# Conference: https://proceedings.neurips.cc/paper_files/paper/2023/file/f4a6806490d31216a3ba667eb240c897-Paper-Conference.pdf
# Codebase ref: https://github.com/scxue/SA-Solver
import math
from typing import Union, Callable
import torch
def compute_exponential_coeffs(s: torch.Tensor, t: torch.Tensor, solver_order: int, tau_t: float) -> torch.Tensor:
"""Compute (1 + tau^2) * integral of exp((1 + tau^2) * x) * x^p dx from s to t with exp((1 + tau^2) * t) factored out, using integration by parts.
Integral of exp((1 + tau^2) * x) * x^p dx
= product_terms[p] - (p / (1 + tau^2)) * integral of exp((1 + tau^2) * x) * x^(p-1) dx,
with base case p=0 where integral equals product_terms[0].
where
product_terms[p] = x^p * exp((1 + tau^2) * x) / (1 + tau^2).
Construct a recursive coefficient matrix following the above recursive relation to compute all integral terms up to p = (solver_order - 1).
Return coefficients used by the SA-Solver in data prediction mode.
Args:
s: Start time s.
t: End time t.
solver_order: Current order of the solver.
tau_t: Stochastic strength parameter in the SDE.
Returns:
Exponential coefficients used in data prediction, with exp((1 + tau^2) * t) factored out, ordered from p=0 to p=solver_order−1, shape (solver_order,).
"""
tau_mul = 1 + tau_t ** 2
h = t - s
p = torch.arange(solver_order, dtype=s.dtype, device=s.device)
# product_terms after factoring out exp((1 + tau^2) * t)
# Includes (1 + tau^2) factor from outside the integral
product_terms_factored = (t ** p - s ** p * (-tau_mul * h).exp())
# Lower triangular recursive coefficient matrix
# Accumulates recursive coefficients based on p / (1 + tau^2)
recursive_depth_mat = p.unsqueeze(1) - p.unsqueeze(0)
log_factorial = (p + 1).lgamma()
recursive_coeff_mat = log_factorial.unsqueeze(1) - log_factorial.unsqueeze(0)
if tau_t > 0:
recursive_coeff_mat = recursive_coeff_mat - (recursive_depth_mat * math.log(tau_mul))
signs = torch.where(recursive_depth_mat % 2 == 0, 1.0, -1.0)
recursive_coeff_mat = (recursive_coeff_mat.exp() * signs).tril()
return recursive_coeff_mat @ product_terms_factored
def compute_simple_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, is_corrector_step: bool = False) -> torch.Tensor:
"""Compute simple order-2 b coefficients from SA-Solver paper (Appendix D. Implementation Details)."""
tau_mul = 1 + tau_t ** 2
h = lambda_t - lambda_s
alpha_t = sigma_next * lambda_t.exp()
if is_corrector_step:
# Simplified 1-step (order-2) corrector
b_1 = alpha_t * (0.5 * tau_mul * h)
b_2 = alpha_t * (-h * tau_mul).expm1().neg() - b_1
else:
# Simplified 2-step predictor
b_2 = alpha_t * (0.5 * tau_mul * h ** 2) / (curr_lambdas[-2] - lambda_s)
b_1 = alpha_t * (-h * tau_mul).expm1().neg() - b_2
return torch.stack([b_2, b_1])
def compute_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, simple_order_2: bool = False, is_corrector_step: bool = False) -> torch.Tensor:
"""Compute b_i coefficients for the SA-Solver (see eqs. 15 and 18).
The solver order corresponds to the number of input lambdas (half-logSNR points).
Args:
sigma_next: Sigma at end time t.
curr_lambdas: Lambda time points used to construct the Lagrange basis, shape (N,).
lambda_s: Lambda at start time s.
lambda_t: Lambda at end time t.
tau_t: Stochastic strength parameter in the SDE.
simple_order_2: Whether to enable the simple order-2 scheme.
is_corrector_step: Flag for corrector step in simple order-2 mode.
Returns:
b_i coefficients for the SA-Solver, shape (N,), where N is the solver order.
"""
num_timesteps = curr_lambdas.shape[0]
if simple_order_2 and num_timesteps == 2:
return compute_simple_stochastic_adams_b_coeffs(sigma_next, curr_lambdas, lambda_s, lambda_t, tau_t, is_corrector_step)
# Compute coefficients by solving a linear system from Lagrange basis interpolation
exp_integral_coeffs = compute_exponential_coeffs(lambda_s, lambda_t, num_timesteps, tau_t)
vandermonde_matrix_T = torch.vander(curr_lambdas, num_timesteps, increasing=True).T
lagrange_integrals = torch.linalg.solve(vandermonde_matrix_T, exp_integral_coeffs)
# (sigma_t * exp(-tau^2 * lambda_t)) * exp((1 + tau^2) * lambda_t)
# = sigma_t * exp(lambda_t) = alpha_t
# exp((1 + tau^2) * lambda_t) is extracted from the integral
alpha_t = sigma_next * lambda_t.exp()
return alpha_t * lagrange_integrals
def get_tau_interval_func(start_sigma: float, end_sigma: float, eta: float = 1.0) -> Callable[[Union[torch.Tensor, float]], float]:
"""Return a function that controls the stochasticity of SA-Solver.
When eta = 0, SA-Solver runs as ODE. The official approach uses
time t to determine the SDE interval, while here we use sigma instead.
See:
https://github.com/scxue/SA-Solver/blob/main/README.md
"""
def tau_func(sigma: Union[torch.Tensor, float]) -> float:
if eta <= 0:
return 0.0 # ODE
if isinstance(sigma, torch.Tensor):
sigma = sigma.item()
return eta if start_sigma >= sigma >= end_sigma else 0.0
return tau_func
import math
from functools import partial
from scipy import integrate
import torch
from torch import nn
import torchsde
from tqdm.auto import trange, tqdm
from . import utils
from . import deis
from . import sa_solver
import comfy.model_patcher
import comfy.model_sampling
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
"""Constructs the noise schedule of Karras et al. (2022)."""
ramp = torch.linspace(0, 1, n, device=device)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return append_zero(sigmas).to(device)
def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
"""Constructs an exponential noise schedule."""
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
return append_zero(sigmas)
def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
"""Constructs an polynomial in log sigma noise schedule."""
ramp = torch.linspace(1, 0, n, device=device) ** rho
sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min))
return append_zero(sigmas)
def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
"""Constructs a continuous VP noise schedule."""
t = torch.linspace(1, eps_s, n, device=device)
sigmas = torch.sqrt(torch.special.expm1(beta_d * t ** 2 / 2 + beta_min * t))
return append_zero(sigmas)
def get_sigmas_laplace(n, sigma_min, sigma_max, mu=0., beta=0.5, device='cpu'):
"""Constructs the noise schedule proposed by Tiankai et al. (2024). """
epsilon = 1e-5 # avoid log(0)
x = torch.linspace(0, 1, n, device=device)
clamp = lambda x: torch.clamp(x, min=sigma_min, max=sigma_max)
lmb = mu - beta * torch.sign(0.5-x) * torch.log(1 - 2 * torch.abs(0.5-x) + epsilon)
sigmas = clamp(torch.exp(lmb))
return sigmas
def to_d(x, sigma, denoised):
"""Converts a denoiser output to a Karras ODE derivative."""
return (x - denoised) / utils.append_dims(sigma, x.ndim)
def get_ancestral_step(sigma_from, sigma_to, eta=1.):
"""Calculates the noise level (sigma_down) to step down to and the amount
of noise to add (sigma_up) when doing an ancestral sampling step."""
if not eta:
return sigma_to, 0.
sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
return sigma_down, sigma_up
def default_noise_sampler(x, seed=None):
if seed is not None:
generator = torch.Generator(device=x.device)
generator.manual_seed(seed)
else:
generator = None
return lambda sigma, sigma_next: torch.randn(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator)
class BatchedBrownianTree:
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
def __init__(self, x, t0, t1, seed=None, **kwargs):
self.cpu_tree = True
if "cpu" in kwargs:
self.cpu_tree = kwargs.pop("cpu")
t0, t1, self.sign = self.sort(t0, t1)
w0 = kwargs.get('w0', torch.zeros_like(x))
if seed is None:
seed = torch.randint(0, 2 ** 63 - 1, []).item()
self.batched = True
try:
assert len(seed) == x.shape[0]
w0 = w0[0]
except TypeError:
seed = [seed]
self.batched = False
if self.cpu_tree:
self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed]
else:
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
@staticmethod
def sort(a, b):
return (a, b, 1) if a < b else (b, a, -1)
def __call__(self, t0, t1):
t0, t1, sign = self.sort(t0, t1)
if self.cpu_tree:
w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign)
else:
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
return w if self.batched else w[0]
class BrownianTreeNoiseSampler:
"""A noise sampler backed by a torchsde.BrownianTree.
Args:
x (Tensor): The tensor whose shape, device and dtype to use to generate
random samples.
sigma_min (float): The low end of the valid interval.
sigma_max (float): The high end of the valid interval.
seed (int or List[int]): The random seed. If a list of seeds is
supplied instead of a single integer, then the noise sampler will
use one BrownianTree per batch item, each with its own seed.
transform (callable): A function that maps sigma to the sampler's
internal timestep.
"""
def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False):
self.transform = transform
t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
self.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu)
def __call__(self, sigma, sigma_next):
t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
def sigma_to_half_log_snr(sigma, model_sampling):
"""Convert sigma to half-logSNR log(alpha_t / sigma_t)."""
if isinstance(model_sampling, comfy.model_sampling.CONST):
# log((1 - t) / t) = log((1 - sigma) / sigma)
return sigma.logit().neg()
return sigma.log().neg()
def half_log_snr_to_sigma(half_log_snr, model_sampling):
"""Convert half-logSNR log(alpha_t / sigma_t) to sigma."""
if isinstance(model_sampling, comfy.model_sampling.CONST):
# 1 / (1 + exp(half_log_snr))
return half_log_snr.neg().sigmoid()
return half_log_snr.neg().exp()
def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4):
"""Adjust the first sigma to avoid invalid logSNR."""
if len(sigmas) <= 1:
return sigmas
if isinstance(model_sampling, comfy.model_sampling.CONST):
if sigmas[0] >= 1:
sigmas = sigmas.clone()
sigmas[0] = model_sampling.percent_to_sigma(percent_offset)
return sigmas
@torch.no_grad()
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
if s_churn > 0:
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
sigma_hat = sigmas[i] * (gamma + 1)
else:
gamma = 0
sigma_hat = sigmas[i]
if gamma > 0:
eps = torch.randn_like(x) * s_noise
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
dt = sigmas[i + 1] - sigma_hat
# Euler method
x = x + d * dt
return x
@torch.no_grad()
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
return sample_euler_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
"""Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigma_down == 0:
x = denoised
else:
d = to_d(x, sigmas[i], denoised)
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1., noise_sampler=None):
"""Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
x = denoised
else:
downstep_ratio = 1 + (sigmas[i + 1] / sigmas[i] - 1) * eta
sigma_down = sigmas[i + 1] * downstep_ratio
alpha_ip1 = 1 - sigmas[i + 1]
alpha_down = 1 - sigma_down
renoise_coeff = (sigmas[i + 1]**2 - sigma_down**2 * alpha_ip1**2 / alpha_down**2)**0.5
# Euler method
sigma_down_i_ratio = sigma_down / sigmas[i]
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised
if eta > 0:
x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
return x
@torch.no_grad()
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
if s_churn > 0:
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
sigma_hat = sigmas[i] * (gamma + 1)
else:
gamma = 0
sigma_hat = sigmas[i]
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
eps = torch.randn_like(x) * s_noise
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
dt = sigmas[i + 1] - sigma_hat
if sigmas[i + 1] == 0:
# Euler method
x = x + d * dt
else:
# Heun's method
x_2 = x + d * dt
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
d_prime = (d + d_2) / 2
x = x + d_prime * dt
return x
@torch.no_grad()
def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
if s_churn > 0:
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
sigma_hat = sigmas[i] * (gamma + 1)
else:
gamma = 0
sigma_hat = sigmas[i]
if gamma > 0:
eps = torch.randn_like(x) * s_noise
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
if sigmas[i + 1] == 0:
# Euler method
dt = sigmas[i + 1] - sigma_hat
x = x + d * dt
else:
# DPM-Solver-2
sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
dt_1 = sigma_mid - sigma_hat
dt_2 = sigmas[i + 1] - sigma_hat
x_2 = x + d * dt_1
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
return x
@torch.no_grad()
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
return sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
"""Ancestral sampling with DPM-Solver second-order steps."""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d = to_d(x, sigmas[i], denoised)
if sigma_down == 0:
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
else:
# DPM-Solver-2
sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
dt_1 = sigma_mid - sigmas[i]
dt_2 = sigma_down - sigmas[i]
x_2 = x + d * dt_1
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver second-order steps."""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
sigma_down = sigmas[i+1] * downstep_ratio
alpha_ip1 = 1 - sigmas[i+1]
alpha_down = 1 - sigma_down
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d = to_d(x, sigmas[i], denoised)
if sigma_down == 0:
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
else:
# DPM-Solver-2
sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
dt_1 = sigma_mid - sigmas[i]
dt_2 = sigma_down - sigmas[i]
x_2 = x + d * dt_1
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
return x
def linear_multistep_coeff(order, t, i, j):
if order - 1 > i:
raise ValueError(f'Order {order} too high for step {i}')
def fn(tau):
prod = 1.
for k in range(order):
if j == k:
continue
prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
return prod
return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
@torch.no_grad()
def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigmas_cpu = sigmas.detach().cpu().numpy()
ds = []
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
d = to_d(x, sigmas[i], denoised)
ds.append(d)
if len(ds) > order:
ds.pop(0)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
cur_order = min(i + 1, order)
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
return x
class PIDStepSizeController:
"""A PID controller for ODE adaptive step size control."""
def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
self.h = h
self.b1 = (pcoeff + icoeff + dcoeff) / order
self.b2 = -(pcoeff + 2 * dcoeff) / order
self.b3 = dcoeff / order
self.accept_safety = accept_safety
self.eps = eps
self.errs = []
def limiter(self, x):
return 1 + math.atan(x - 1)
def propose_step(self, error):
inv_error = 1 / (float(error) + self.eps)
if not self.errs:
self.errs = [inv_error, inv_error, inv_error]
self.errs[0] = inv_error
factor = self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3
factor = self.limiter(factor)
accept = factor >= self.accept_safety
if accept:
self.errs[2] = self.errs[1]
self.errs[1] = self.errs[0]
self.h *= factor
return accept
class DPMSolver(nn.Module):
"""DPM-Solver. See https://arxiv.org/abs/2206.00927."""
def __init__(self, model, extra_args=None, eps_callback=None, info_callback=None):
super().__init__()
self.model = model
self.extra_args = {} if extra_args is None else extra_args
self.eps_callback = eps_callback
self.info_callback = info_callback
def t(self, sigma):
return -sigma.log()
def sigma(self, t):
return t.neg().exp()
def eps(self, eps_cache, key, x, t, *args, **kwargs):
if key in eps_cache:
return eps_cache[key], eps_cache
sigma = self.sigma(t) * x.new_ones([x.shape[0]])
eps = (x - self.model(x, sigma, *args, **self.extra_args, **kwargs)) / self.sigma(t)
if self.eps_callback is not None:
self.eps_callback()
return eps, {key: eps, **eps_cache}
def dpm_solver_1_step(self, x, t, t_next, eps_cache=None):
eps_cache = {} if eps_cache is None else eps_cache
h = t_next - t
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
x_1 = x - self.sigma(t_next) * h.expm1() * eps
return x_1, eps_cache
def dpm_solver_2_step(self, x, t, t_next, r1=1 / 2, eps_cache=None):
eps_cache = {} if eps_cache is None else eps_cache
h = t_next - t
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
s1 = t + r1 * h
u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
x_2 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps)
return x_2, eps_cache
def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None):
eps_cache = {} if eps_cache is None else eps_cache
h = t_next - t
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
s1 = t + r1 * h
s2 = t + r2 * h
u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
u2 = x - self.sigma(s2) * (r2 * h).expm1() * eps - self.sigma(s2) * (r2 / r1) * ((r2 * h).expm1() / (r2 * h) - 1) * (eps_r1 - eps)
eps_r2, eps_cache = self.eps(eps_cache, 'eps_r2', u2, s2)
x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps)
return x_3, eps_cache
def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None):
noise_sampler = default_noise_sampler(x, seed=self.extra_args.get("seed", None)) if noise_sampler is None else noise_sampler
if not t_end > t_start and eta:
raise ValueError('eta must be 0 for reverse sampling')
m = math.floor(nfe / 3) + 1
ts = torch.linspace(t_start, t_end, m + 1, device=x.device)
if nfe % 3 == 0:
orders = [3] * (m - 2) + [2, 1]
else:
orders = [3] * (m - 1) + [nfe % 3]
for i in range(len(orders)):
eps_cache = {}
t, t_next = ts[i], ts[i + 1]
if eta:
sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta)
t_next_ = torch.minimum(t_end, self.t(sd))
su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5
else:
t_next_, su = t_next, 0.
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
denoised = x - self.sigma(t) * eps
if self.info_callback is not None:
self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised})
if orders[i] == 1:
x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache)
elif orders[i] == 2:
x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache)
else:
x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache)
x = x + su * s_noise * noise_sampler(self.sigma(t), self.sigma(t_next))
return x
def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
noise_sampler = default_noise_sampler(x, seed=self.extra_args.get("seed", None)) if noise_sampler is None else noise_sampler
if order not in {2, 3}:
raise ValueError('order should be 2 or 3')
forward = t_end > t_start
if not forward and eta:
raise ValueError('eta must be 0 for reverse sampling')
h_init = abs(h_init) * (1 if forward else -1)
atol = torch.tensor(atol)
rtol = torch.tensor(rtol)
s = t_start
x_prev = x
accept = True
pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety)
info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0}
while s < t_end - 1e-5 if forward else s > t_end + 1e-5:
eps_cache = {}
t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h)
if eta:
sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta)
t_ = torch.minimum(t_end, self.t(sd))
su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5
else:
t_, su = t, 0.
eps, eps_cache = self.eps(eps_cache, 'eps', x, s)
denoised = x - self.sigma(s) * eps
if order == 2:
x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache)
x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache)
else:
x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache)
x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache)
delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs()))
error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5
accept = pid.propose_step(error)
if accept:
x_prev = x_low
x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t))
s = t
info['n_accept'] += 1
else:
info['n_reject'] += 1
info['nfe'] += order
info['steps'] += 1
if self.info_callback is not None:
self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info})
return x, info
@torch.no_grad()
def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None):
"""DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927."""
if sigma_min <= 0 or sigma_max <= 0:
raise ValueError('sigma_min and sigma_max must not be 0')
with tqdm(total=n, disable=disable) as pbar:
dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
if callback is not None:
dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise, noise_sampler)
@torch.no_grad()
def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False):
"""DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927."""
if sigma_min <= 0 or sigma_max <= 0:
raise ValueError('sigma_min and sigma_max must not be 0')
with tqdm(disable=disable) as pbar:
dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
if callback is not None:
dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise, noise_sampler)
if return_info:
return x, info
return x
@torch.no_grad()
def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
return sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigma_down == 0:
# Euler method
d = to_d(x, sigmas[i], denoised)
dt = sigma_down - sigmas[i]
x = x + d * dt
else:
# DPM-Solver++(2S)
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
r = 1 / 2
h = t_next - t
s = t + r * h
x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2
# Noise addition
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1
lambda_fn = lambda sigma: ((1-sigma)/sigma).log()
# logged_x = x.unsqueeze(0)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
sigma_down = sigmas[i+1] * downstep_ratio
alpha_ip1 = 1 - sigmas[i+1]
alpha_down = 1 - sigma_down
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
# Euler method
d = to_d(x, sigmas[i], denoised)
dt = sigma_down - sigmas[i]
x = x + d * dt
else:
# DPM-Solver++(2S)
if sigmas[i] == 1.0:
sigma_s = 0.9999
else:
t_i, t_down = lambda_fn(sigmas[i]), lambda_fn(sigma_down)
r = 1 / 2
h = t_down - t_i
s = t_i + r * h
sigma_s = sigma_fn(s)
# sigma_s = sigmas[i+1]
sigma_s_i_ratio = sigma_s / sigmas[i]
u = sigma_s_i_ratio * x + (1 - sigma_s_i_ratio) * denoised
D_i = model(u, sigma_s * s_in, **extra_args)
sigma_down_i_ratio = sigma_down / sigmas[i]
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * D_i
# print("sigma_i", sigmas[i], "sigma_ip1", sigmas[i+1],"sigma_down", sigma_down, "sigma_down_i_ratio", sigma_down_i_ratio, "sigma_s_i_ratio", sigma_s_i_ratio, "renoise_coeff", renoise_coeff)
# Noise addition
if sigmas[i + 1] > 0 and eta > 0:
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
# logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
return x
@torch.no_grad()
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
"""DPM-Solver++ (stochastic)."""
if len(sigmas) <= 1:
return x
extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
seed = extra_args.get("seed", None)
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
# DPM-Solver++
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
lambda_s_1 = lambda_s + r * h
fac = 1 / (2 * r)
sigma_s_1 = sigma_fn(lambda_s_1)
alpha_s = sigmas[i] * lambda_s.exp()
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
alpha_t = sigmas[i + 1] * lambda_t.exp()
# Step 1
sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_s_1.neg().exp(), eta)
lambda_s_1_ = sd.log().neg()
h_ = lambda_s_1_ - lambda_s
x_2 = (alpha_s_1 / alpha_s) * (-h_).exp() * x - alpha_s_1 * (-h_).expm1() * denoised
if eta > 0 and s_noise > 0:
x_2 = x_2 + alpha_s_1 * noise_sampler(sigmas[i], sigma_s_1) * s_noise * su
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2
sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_t.neg().exp(), eta)
lambda_t_ = sd.log().neg()
h_ = lambda_t_ - lambda_s
denoised_d = (1 - fac) * denoised + fac * denoised_2
x = (alpha_t / alpha_s) * (-h_).exp() * x - alpha_t * (-h_).expm1() * denoised_d
if eta > 0 and s_noise > 0:
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * su
return x
@torch.no_grad()
def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""DPM-Solver++(2M)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
old_denoised = None
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
h = t_next - t
if old_denoised is None or sigmas[i + 1] == 0:
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
else:
h_last = t - t_fn(sigmas[i - 1])
r = h_last / h
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
old_denoised = denoised
return x
@torch.no_grad()
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
"""DPM-Solver++(2M) SDE."""
if len(sigmas) <= 1:
return x
if solver_type not in {'heun', 'midpoint'}:
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
old_denoised = None
h, h_last = None, None
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
# DPM-Solver++(2M) SDE
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
alpha_t = sigmas[i + 1] * lambda_t.exp()
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised
if old_denoised is not None:
r = h_last / h
if solver_type == 'heun':
x = x + alpha_t * ((-h_eta).expm1().neg() / (-h_eta) + 1) * (1 / r) * (denoised - old_denoised)
elif solver_type == 'midpoint':
x = x + 0.5 * alpha_t * (-h_eta).expm1().neg() * (1 / r) * (denoised - old_denoised)
if eta > 0 and s_noise > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
old_denoised = denoised
h_last = h
return x
@torch.no_grad()
def sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'):
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
@torch.no_grad()
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""DPM-Solver++(3M) SDE."""
if len(sigmas) <= 1:
return x
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
denoised_1, denoised_2 = None, None
h, h_1, h_2 = None, None, None
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
alpha_t = sigmas[i + 1] * lambda_t.exp()
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised
if h_2 is not None:
# DPM-Solver++(3M) SDE
r0 = h_1 / h
r1 = h_2 / h
d1_0 = (denoised - denoised_1) / r0
d1_1 = (denoised_1 - denoised_2) / r1
d1 = d1_0 + (d1_0 - d1_1) * r0 / (r0 + r1)
d2 = (d1_0 - d1_1) / (r0 + r1)
phi_2 = h_eta.neg().expm1() / h_eta + 1
phi_3 = phi_2 / h_eta - 0.5
x = x + (alpha_t * phi_2) * d1 - (alpha_t * phi_3) * d2
elif h_1 is not None:
# DPM-Solver++(2M) SDE
r = h_1 / h
d = (denoised - denoised_1) / r
phi_2 = h_eta.neg().expm1() / h_eta + 1
x = x + (alpha_t * phi_2) * d
if eta > 0 and s_noise > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
denoised_1, denoised_2 = denoised, denoised_1
h_1, h_2 = h, h_1
return x
@torch.no_grad()
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
if len(sigmas) <= 1:
return x
extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
@torch.no_grad()
def sample_dpmpp_2m_sde_heun_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'):
if len(sigmas) <= 1:
return x
extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
@torch.no_grad()
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
if len(sigmas) <= 1:
return x
extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
@torch.no_grad()
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
if len(sigmas) <= 1:
return x
extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
alpha_cumprod = 1 / ((sigma * sigma) + 1)
alpha_cumprod_prev = 1 / ((sigma_prev * sigma_prev) + 1)
alpha = (alpha_cumprod / alpha_cumprod_prev)
mu = (1.0 / alpha).sqrt() * (x - (1 - alpha) * noise / (1 - alpha_cumprod).sqrt())
if sigma_prev > 0:
mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev)
return mu
def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
x = step_function(x / torch.sqrt(1.0 + sigmas[i] ** 2.0), sigmas[i], sigmas[i + 1], (x - denoised) / sigmas[i], noise_sampler)
if sigmas[i + 1] != 0:
x *= torch.sqrt(1.0 + sigmas[i + 1] ** 2.0)
return x
@torch.no_grad()
def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)
@torch.no_grad()
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
x = denoised
if sigmas[i + 1] > 0:
x = model.inner_model.inner_model.model_sampling.noise_scaling(sigmas[i + 1], noise_sampler(sigmas[i], sigmas[i + 1]), x)
return x
@torch.no_grad()
def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
# From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
s_end = sigmas[-1]
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
dt = sigmas[i + 1] - sigma_hat
if sigmas[i + 1] == s_end:
# Euler method
x = x + d * dt
elif sigmas[i + 2] == s_end:
# Heun's method
x_2 = x + d * dt
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
w = 2 * sigmas[0]
w2 = sigmas[i+1]/w
w1 = 1 - w2
d_prime = d * w1 + d_2 * w2
x = x + d_prime * dt
else:
# Heun++
x_2 = x + d * dt
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
dt_2 = sigmas[i + 2] - sigmas[i + 1]
x_3 = x_2 + d_2 * dt_2
denoised_3 = model(x_3, sigmas[i + 2] * s_in, **extra_args)
d_3 = to_d(x_3, sigmas[i + 2], denoised_3)
w = 3 * sigmas[0]
w2 = sigmas[i + 1] / w
w3 = sigmas[i + 2] / w
w1 = 1 - w2 - w3
d_prime = w1 * d + w2 * d_2 + w3 * d_3
x = x + d_prime * dt
return x
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
#under Apache 2 license
def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
x_next = x
buffer_model = []
for i in trange(len(sigmas) - 1, disable=disable):
t_cur = sigmas[i]
t_next = sigmas[i + 1]
x_cur = x_next
denoised = model(x_cur, t_cur * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d_cur = (x_cur - denoised) / t_cur
order = min(max_order, i+1)
if t_next == 0: # Denoising step
x_next = denoised
elif order == 1: # First Euler step.
x_next = x_cur + (t_next - t_cur) * d_cur
elif order == 2: # Use one history point.
x_next = x_cur + (t_next - t_cur) * (3 * d_cur - buffer_model[-1]) / 2
elif order == 3: # Use two history points.
x_next = x_cur + (t_next - t_cur) * (23 * d_cur - 16 * buffer_model[-1] + 5 * buffer_model[-2]) / 12
elif order == 4: # Use three history points.
x_next = x_cur + (t_next - t_cur) * (55 * d_cur - 59 * buffer_model[-1] + 37 * buffer_model[-2] - 9 * buffer_model[-3]) / 24
if len(buffer_model) == max_order - 1:
for k in range(max_order - 2):
buffer_model[k] = buffer_model[k+1]
buffer_model[-1] = d_cur
else:
buffer_model.append(d_cur)
return x_next
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
#under Apache 2 license
def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
x_next = x
t_steps = sigmas
buffer_model = []
for i in trange(len(sigmas) - 1, disable=disable):
t_cur = sigmas[i]
t_next = sigmas[i + 1]
x_cur = x_next
denoised = model(x_cur, t_cur * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d_cur = (x_cur - denoised) / t_cur
order = min(max_order, i+1)
if t_next == 0: # Denoising step
x_next = denoised
elif order == 1: # First Euler step.
x_next = x_cur + (t_next - t_cur) * d_cur
elif order == 2: # Use one history point.
h_n = (t_next - t_cur)
h_n_1 = (t_cur - t_steps[i-1])
coeff1 = (2 + (h_n / h_n_1)) / 2
coeff2 = -(h_n / h_n_1) / 2
x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1])
elif order == 3: # Use two history points.
h_n = (t_next - t_cur)
h_n_1 = (t_cur - t_steps[i-1])
h_n_2 = (t_steps[i-1] - t_steps[i-2])
temp = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2
coeff1 = (2 + (h_n / h_n_1)) / 2 + temp
coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp
coeff3 = temp * h_n_1 / h_n_2
x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1] + coeff3 * buffer_model[-2])
elif order == 4: # Use three history points.
h_n = (t_next - t_cur)
h_n_1 = (t_cur - t_steps[i-1])
h_n_2 = (t_steps[i-1] - t_steps[i-2])
h_n_3 = (t_steps[i-2] - t_steps[i-3])
temp1 = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2
temp2 = ((1 - h_n / (3 * (h_n + h_n_1))) / 2 + (1 - h_n / (2 * (h_n + h_n_1))) * h_n / (6 * (h_n + h_n_1 + h_n_2))) \
* (h_n * (h_n + h_n_1) * (h_n + h_n_1 + h_n_2)) / (h_n_1 * (h_n_1 + h_n_2) * (h_n_1 + h_n_2 + h_n_3))
coeff1 = (2 + (h_n / h_n_1)) / 2 + temp1 + temp2
coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp1 - (1 + (h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3)))) * temp2
coeff3 = temp1 * h_n_1 / h_n_2 + ((h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) * (1 + h_n_2 / h_n_3)) * temp2
coeff4 = -temp2 * (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) * h_n_1 / h_n_2
x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1] + coeff3 * buffer_model[-2] + coeff4 * buffer_model[-3])
if len(buffer_model) == max_order - 1:
for k in range(max_order - 2):
buffer_model[k] = buffer_model[k+1]
buffer_model[-1] = d_cur.detach()
else:
buffer_model.append(d_cur.detach())
return x_next
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
#under Apache 2 license
@torch.no_grad()
def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=3, deis_mode='tab'):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
x_next = x
t_steps = sigmas
coeff_list = deis.get_deis_coeff_list(t_steps, max_order, deis_mode=deis_mode)
buffer_model = []
for i in trange(len(sigmas) - 1, disable=disable):
t_cur = sigmas[i]
t_next = sigmas[i + 1]
x_cur = x_next
denoised = model(x_cur, t_cur * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d_cur = (x_cur - denoised) / t_cur
order = min(max_order, i+1)
if t_next <= 0:
order = 1
if order == 1: # First Euler step.
x_next = x_cur + (t_next - t_cur) * d_cur
elif order == 2: # Use one history point.
coeff_cur, coeff_prev1 = coeff_list[i]
x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1]
elif order == 3: # Use two history points.
coeff_cur, coeff_prev1, coeff_prev2 = coeff_list[i]
x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2]
elif order == 4: # Use three history points.
coeff_cur, coeff_prev1, coeff_prev2, coeff_prev3 = coeff_list[i]
x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2] + coeff_prev3 * buffer_model[-3]
if len(buffer_model) == max_order - 1:
for k in range(max_order - 2):
buffer_model[k] = buffer_model[k+1]
buffer_model[-1] = d_cur.detach()
else:
buffer_model.append(d_cur.detach())
return x_next
@torch.no_grad()
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with Euler method steps (CFG++)."""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
uncond_denoised = None
def post_cfg_function(args):
nonlocal uncond_denoised
uncond_denoised = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
alpha_s = sigmas[i] * lambda_fn(sigmas[i]).exp()
alpha_t = sigmas[i + 1] * lambda_fn(sigmas[i + 1]).exp()
d = to_d(x, sigmas[i], alpha_s * uncond_denoised) # to noise
# DDIM stochastic sampling
sigma_down, sigma_up = get_ancestral_step(sigmas[i] / alpha_s, sigmas[i + 1] / alpha_t, eta=eta)
sigma_down = alpha_t * sigma_down
# Euler method
x = alpha_t * denoised + sigma_down * d
if eta > 0 and s_noise > 0:
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""Euler method steps (CFG++)."""
return sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None)
@torch.no_grad()
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
temp = [0]
def post_cfg_function(args):
temp[0] = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigma_down == 0:
# Euler method
d = to_d(x, sigmas[i], temp[0])
x = denoised + d * sigma_down
else:
# DPM-Solver++(2S)
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
# r = torch.sinh(1 + (2 - eta) * (t_next - t) / (t - t_fn(sigma_up))) works only on non-cfgpp, weird
r = 1 / 2
h = t_next - t
s = t + r * h
x_2 = (sigma_fn(s) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h * r).expm1() * denoised
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
x = (sigma_fn(t_next) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h).expm1() * denoised_2
# Noise addition
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""DPM-Solver++(2M)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
t_fn = lambda sigma: sigma.log().neg()
old_uncond_denoised = None
uncond_denoised = None
def post_cfg_function(args):
nonlocal uncond_denoised
uncond_denoised = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
h = t_next - t
if old_uncond_denoised is None or sigmas[i + 1] == 0:
denoised_mix = -torch.exp(-h) * uncond_denoised
else:
h_last = t - t_fn(sigmas[i - 1])
r = h_last / h
denoised_mix = -torch.exp(-h) * uncond_denoised - torch.expm1(-h) * (1 / (2 * r)) * (denoised - old_uncond_denoised)
x = denoised + denoised_mix + torch.exp(-h) * x
old_uncond_denoised = uncond_denoised
return x
@torch.no_grad()
def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, eta=1., cfg_pp=False):
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
phi1_fn = lambda t: torch.expm1(t) / t
phi2_fn = lambda t: (phi1_fn(t) - 1.0) / t
old_sigma_down = None
old_denoised = None
uncond_denoised = None
def post_cfg_function(args):
nonlocal uncond_denoised
uncond_denoised = args["uncond_denoised"]
return args["denoised"]
if cfg_pp:
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
if sigma_down == 0 or old_denoised is None:
# Euler method
if cfg_pp:
d = to_d(x, sigmas[i], uncond_denoised)
x = denoised + d * sigma_down
else:
d = to_d(x, sigmas[i], denoised)
dt = sigma_down - sigmas[i]
x = x + d * dt
else:
# Second order multistep method in https://arxiv.org/pdf/2308.02157
t, t_old, t_next, t_prev = t_fn(sigmas[i]), t_fn(old_sigma_down), t_fn(sigma_down), t_fn(sigmas[i - 1])
h = t_next - t
c2 = (t_prev - t_old) / h
phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h)
b1 = torch.nan_to_num(phi1_val - phi2_val / c2, nan=0.0)
b2 = torch.nan_to_num(phi2_val / c2, nan=0.0)
if cfg_pp:
x = x + (denoised - uncond_denoised)
x = sigma_fn(h) * x + h * (b1 * uncond_denoised + b2 * old_denoised)
else:
x = sigma_fn(h) * x + h * (b1 * denoised + b2 * old_denoised)
# Noise addition
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
if cfg_pp:
old_denoised = uncond_denoised
else:
old_denoised = denoised
old_sigma_down = sigma_down
return x
@torch.no_grad()
def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None):
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=0., cfg_pp=False)
@torch.no_grad()
def sample_res_multistep_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None):
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=0., cfg_pp=True)
@torch.no_grad()
def sample_res_multistep_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=False)
@torch.no_grad()
def sample_res_multistep_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True)
@torch.no_grad()
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2., cfg_pp=False):
"""Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
old_d = None
uncond_denoised = None
def post_cfg_function(args):
nonlocal uncond_denoised
uncond_denoised = args["uncond_denoised"]
return args["denoised"]
if cfg_pp:
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if cfg_pp:
d = to_d(x, sigmas[i], uncond_denoised)
else:
d = to_d(x, sigmas[i], denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
dt = sigmas[i + 1] - sigmas[i]
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
# Euler method
if cfg_pp:
x = denoised + d * sigmas[i + 1]
else:
x = x + d * dt
if i >= 1:
# Gradient estimation
d_bar = (ge_gamma - 1) * (d - old_d)
x = x + d_bar * dt
old_d = d
return x
@torch.no_grad()
def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True)
@torch.no_grad()
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1.0, noise_sampler=None, noise_scaler=None, max_stage=3):
"""Extended Reverse-Time SDE solver (VP ER-SDE-Solver-3). arXiv: https://arxiv.org/abs/2309.06169.
Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
"""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
def default_er_sde_noise_scaler(x):
return x * ((x ** 0.3).exp() + 10.0)
noise_scaler = default_er_sde_noise_scaler if noise_scaler is None else noise_scaler
num_integration_points = 200.0
point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device)
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
half_log_snrs = sigma_to_half_log_snr(sigmas, model_sampling)
er_lambdas = half_log_snrs.neg().exp() # er_lambda_t = sigma_t / alpha_t
old_denoised = None
old_denoised_d = None
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
stage_used = min(max_stage, i + 1)
if sigmas[i + 1] == 0:
x = denoised
else:
er_lambda_s, er_lambda_t = er_lambdas[i], er_lambdas[i + 1]
alpha_s = sigmas[i] / er_lambda_s
alpha_t = sigmas[i + 1] / er_lambda_t
r_alpha = alpha_t / alpha_s
r = noise_scaler(er_lambda_t) / noise_scaler(er_lambda_s)
# Stage 1 Euler
x = r_alpha * r * x + alpha_t * (1 - r) * denoised
if stage_used >= 2:
dt = er_lambda_t - er_lambda_s
lambda_step_size = -dt / num_integration_points
lambda_pos = er_lambda_t + point_indice * lambda_step_size
scaled_pos = noise_scaler(lambda_pos)
# Stage 2
s = torch.sum(1 / scaled_pos) * lambda_step_size
denoised_d = (denoised - old_denoised) / (er_lambda_s - er_lambdas[i - 1])
x = x + alpha_t * (dt + s * noise_scaler(er_lambda_t)) * denoised_d
if stage_used >= 3:
# Stage 3
s_u = torch.sum((lambda_pos - er_lambda_s) / scaled_pos) * lambda_step_size
denoised_u = (denoised_d - old_denoised_d) / ((er_lambda_s - er_lambdas[i - 2]) / 2)
x = x + alpha_t * ((dt ** 2) / 2 + s_u * noise_scaler(er_lambda_t)) * denoised_u
old_denoised_d = denoised_d
if s_noise > 0:
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (er_lambda_t ** 2 - er_lambda_s ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
old_denoised = denoised
return x
@torch.no_grad()
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
arXiv: https://arxiv.org/abs/2305.14267
"""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
inject_noise = eta > 0 and s_noise > 0
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
x = denoised
else:
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
lambda_s_1 = lambda_s + r * h
fac = 1 / (2 * r)
sigma_s_1 = sigma_fn(lambda_s_1)
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
alpha_t = sigmas[i + 1] * lambda_t.exp()
coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
if inject_noise:
# 0 < r < 1
noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt()
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1])
# Step 1
x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
if inject_noise:
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2
denoised_d = (1 - fac) * denoised + fac * denoised_2
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d
if inject_noise:
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
return x
@torch.no_grad()
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
"""SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3.
arXiv: https://arxiv.org/abs/2305.14267
"""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
inject_noise = eta > 0 and s_noise > 0
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
x = denoised
else:
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
lambda_s_1 = lambda_s + r_1 * h
lambda_s_2 = lambda_s + r_2 * h
sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
alpha_t = sigmas[i + 1] * lambda_t.exp()
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
if inject_noise:
# 0 < r_1 < r_2 < 1
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt()
noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt()
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
# Step 1
x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
if inject_noise:
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2
x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
if inject_noise:
x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
# Step 3
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
if inject_noise:
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
return x
@torch.no_grad()
def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, use_pece=False, simple_order_2=False):
"""Stochastic Adams Solver with predictor-corrector method (NeurIPS 2023)."""
if len(sigmas) <= 1:
return x
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
lambdas = sigma_to_half_log_snr(sigmas, model_sampling=model_sampling)
if tau_func is None:
# Use default interval for stochastic sampling
start_sigma = model_sampling.percent_to_sigma(0.2)
end_sigma = model_sampling.percent_to_sigma(0.8)
tau_func = sa_solver.get_tau_interval_func(start_sigma, end_sigma, eta=1.0)
max_used_order = max(predictor_order, corrector_order)
x_pred = x # x: current state, x_pred: predicted next state
h = 0.0
tau_t = 0.0
noise = 0.0
pred_list = []
# Lower order near the end to improve stability
lower_order_to_end = sigmas[-1].item() == 0
for i in trange(len(sigmas) - 1, disable=disable):
# Evaluation
denoised = model(x_pred, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({"x": x_pred, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
pred_list.append(denoised)
pred_list = pred_list[-max_used_order:]
predictor_order_used = min(predictor_order, len(pred_list))
if i == 0 or (sigmas[i + 1] == 0 and not use_pece):
corrector_order_used = 0
else:
corrector_order_used = min(corrector_order, len(pred_list))
if lower_order_to_end:
predictor_order_used = min(predictor_order_used, len(sigmas) - 2 - i)
corrector_order_used = min(corrector_order_used, len(sigmas) - 1 - i)
# Corrector
if corrector_order_used == 0:
# Update by the predicted state
x = x_pred
else:
curr_lambdas = lambdas[i - corrector_order_used + 1:i + 1]
b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs(
sigmas[i],
curr_lambdas,
lambdas[i - 1],
lambdas[i],
tau_t,
simple_order_2,
is_corrector_step=True,
)
pred_mat = torch.stack(pred_list[-corrector_order_used:], dim=1) # (B, K, ...)
corr_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...)
x = sigmas[i] / sigmas[i - 1] * (-(tau_t ** 2) * h).exp() * x + corr_res
if tau_t > 0 and s_noise > 0:
# The noise from the previous predictor step
x = x + noise
if use_pece:
# Evaluate the corrected state
denoised = model(x, sigmas[i] * s_in, **extra_args)
pred_list[-1] = denoised
# Predictor
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
tau_t = tau_func(sigmas[i + 1])
curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1]
b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs(
sigmas[i + 1],
curr_lambdas,
lambdas[i],
lambdas[i + 1],
tau_t,
simple_order_2,
is_corrector_step=False,
)
pred_mat = torch.stack(pred_list[-predictor_order_used:], dim=1) # (B, K, ...)
pred_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...)
h = lambdas[i + 1] - lambdas[i]
x_pred = sigmas[i + 1] / sigmas[i] * (-(tau_t ** 2) * h).exp() * x + pred_res
if tau_t > 0 and s_noise > 0:
noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise
x_pred = x_pred + noise
return x
@torch.no_grad()
def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False):
"""Stochastic Adams Solver with PECE (Predict–Evaluate–Correct–Evaluate) mode (NeurIPS 2023)."""
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2)
from contextlib import contextmanager
import hashlib
import math
from pathlib import Path
import shutil
import urllib
import warnings
from PIL import Image
import torch
from torch import nn, optim
from torch.utils import data
def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'):
"""Apply passed in transforms for HuggingFace Datasets."""
images = [transform(image.convert(mode)) for image in examples[image_key]]
return {image_key: images}
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
expanded = x[(...,) + (None,) * dims_to_append]
# MPS will get inf values if it tries to index into the new axes, but detaching fixes this.
# https://github.com/pytorch/pytorch/issues/84364
return expanded.detach().clone() if expanded.device.type == 'mps' else expanded
def n_params(module):
"""Returns the number of trainable parameters in a module."""
return sum(p.numel() for p in module.parameters())
def download_file(path, url, digest=None):
"""Downloads a file if it does not exist, optionally checking its SHA-256 hash."""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
if not path.exists():
with urllib.request.urlopen(url) as response, open(path, 'wb') as f:
shutil.copyfileobj(response, f)
if digest is not None:
file_digest = hashlib.sha256(open(path, 'rb').read()).hexdigest()
if digest != file_digest:
raise OSError(f'hash of {path} (url: {url}) failed to validate')
return path
@contextmanager
def train_mode(model, mode=True):
"""A context manager that places a model into training mode and restores
the previous mode on exit."""
modes = [module.training for module in model.modules()]
try:
yield model.train(mode)
finally:
for i, module in enumerate(model.modules()):
module.training = modes[i]
def eval_mode(model):
"""A context manager that places a model into evaluation mode and restores
the previous mode on exit."""
return train_mode(model, False)
@torch.no_grad()
def ema_update(model, averaged_model, decay):
"""Incorporates updated model parameters into an exponential moving averaged
version of a model. It should be called after each optimizer step."""
model_params = dict(model.named_parameters())
averaged_params = dict(averaged_model.named_parameters())
assert model_params.keys() == averaged_params.keys()
for name, param in model_params.items():
averaged_params[name].mul_(decay).add_(param, alpha=1 - decay)
model_buffers = dict(model.named_buffers())
averaged_buffers = dict(averaged_model.named_buffers())
assert model_buffers.keys() == averaged_buffers.keys()
for name, buf in model_buffers.items():
averaged_buffers[name].copy_(buf)
class EMAWarmup:
"""Implements an EMA warmup using an inverse decay schedule.
If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are
good values for models you plan to train for a million or more steps (reaches decay
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
215.4k steps).
Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 1.
min_value (float): The minimum EMA decay rate. Default: 0.
max_value (float): The maximum EMA decay rate. Default: 1.
start_at (int): The epoch to start averaging at. Default: 0.
last_epoch (int): The index of last epoch. Default: 0.
"""
def __init__(self, inv_gamma=1., power=1., min_value=0., max_value=1., start_at=0,
last_epoch=0):
self.inv_gamma = inv_gamma
self.power = power
self.min_value = min_value
self.max_value = max_value
self.start_at = start_at
self.last_epoch = last_epoch
def state_dict(self):
"""Returns the state of the class as a :class:`dict`."""
return dict(self.__dict__.items())
def load_state_dict(self, state_dict):
"""Loads the class's state.
Args:
state_dict (dict): scaler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
def get_value(self):
"""Gets the current EMA decay rate."""
epoch = max(0, self.last_epoch - self.start_at)
value = 1 - (1 + epoch / self.inv_gamma) ** -self.power
return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value))
def step(self):
"""Updates the step count."""
self.last_epoch += 1
class InverseLR(optim.lr_scheduler._LRScheduler):
"""Implements an inverse decay learning rate schedule with an optional exponential
warmup. When last_epoch=-1, sets initial lr as lr.
inv_gamma is the number of steps/epochs required for the learning rate to decay to
(1 / 2)**power of its original value.
Args:
optimizer (Optimizer): Wrapped optimizer.
inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1.
power (float): Exponential factor of learning rate decay. Default: 1.
warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
Default: 0.
min_lr (float): The minimum learning rate. Default: 0.
last_epoch (int): The index of last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
"""
def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., min_lr=0.,
last_epoch=-1, verbose=False):
self.inv_gamma = inv_gamma
self.power = power
if not 0. <= warmup < 1:
raise ValueError('Invalid value for warmup')
self.warmup = warmup
self.min_lr = min_lr
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.")
return self._get_closed_form_lr()
def _get_closed_form_lr(self):
warmup = 1 - self.warmup ** (self.last_epoch + 1)
lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power
return [warmup * max(self.min_lr, base_lr * lr_mult)
for base_lr in self.base_lrs]
class ExponentialLR(optim.lr_scheduler._LRScheduler):
"""Implements an exponential learning rate schedule with an optional exponential
warmup. When last_epoch=-1, sets initial lr as lr. Decays the learning rate
continuously by decay (default 0.5) every num_steps steps.
Args:
optimizer (Optimizer): Wrapped optimizer.
num_steps (float): The number of steps to decay the learning rate by decay in.
decay (float): The factor by which to decay the learning rate every num_steps
steps. Default: 0.5.
warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
Default: 0.
min_lr (float): The minimum learning rate. Default: 0.
last_epoch (int): The index of last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
"""
def __init__(self, optimizer, num_steps, decay=0.5, warmup=0., min_lr=0.,
last_epoch=-1, verbose=False):
self.num_steps = num_steps
self.decay = decay
if not 0. <= warmup < 1:
raise ValueError('Invalid value for warmup')
self.warmup = warmup
self.min_lr = min_lr
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.")
return self._get_closed_form_lr()
def _get_closed_form_lr(self):
warmup = 1 - self.warmup ** (self.last_epoch + 1)
lr_mult = (self.decay ** (1 / self.num_steps)) ** self.last_epoch
return [warmup * max(self.min_lr, base_lr * lr_mult)
for base_lr in self.base_lrs]
def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
"""Draws samples from an lognormal distribution."""
return (torch.randn(shape, device=device, dtype=dtype) * scale + loc).exp()
def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
"""Draws samples from an optionally truncated log-logistic distribution."""
min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64)
max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64)
min_cdf = min_value.log().sub(loc).div(scale).sigmoid()
max_cdf = max_value.log().sub(loc).div(scale).sigmoid()
u = torch.rand(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf) + min_cdf
return u.logit().mul(scale).add(loc).exp().to(dtype)
def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32):
"""Draws samples from an log-uniform distribution."""
min_value = math.log(min_value)
max_value = math.log(max_value)
return (torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp()
def rand_v_diffusion(shape, sigma_data=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
"""Draws samples from a truncated v-diffusion training timestep distribution."""
min_cdf = math.atan(min_value / sigma_data) * 2 / math.pi
max_cdf = math.atan(max_value / sigma_data) * 2 / math.pi
u = torch.rand(shape, device=device, dtype=dtype) * (max_cdf - min_cdf) + min_cdf
return torch.tan(u * math.pi / 2) * sigma_data
def rand_split_log_normal(shape, loc, scale_1, scale_2, device='cpu', dtype=torch.float32):
"""Draws samples from a split lognormal distribution."""
n = torch.randn(shape, device=device, dtype=dtype).abs()
u = torch.rand(shape, device=device, dtype=dtype)
n_left = n * -scale_1 + loc
n_right = n * scale_2 + loc
ratio = scale_1 / (scale_1 + scale_2)
return torch.where(u < ratio, n_left, n_right).exp()
class FolderOfImages(data.Dataset):
"""Recursively finds all images in a directory. It does not support
classes/targets."""
IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'}
def __init__(self, root, transform=None):
super().__init__()
self.root = Path(root)
self.transform = nn.Identity() if transform is None else transform
self.paths = sorted(path for path in self.root.rglob('*') if path.suffix.lower() in self.IMG_EXTENSIONS)
def __repr__(self):
return f'FolderOfImages(root="{self.root}", len: {len(self)})'
def __len__(self):
return len(self.paths)
def __getitem__(self, key):
path = self.paths[key]
with open(path, 'rb') as f:
image = Image.open(f).convert('RGB')
image = self.transform(image)
return image,
class CSVLogger:
def __init__(self, filename, columns):
self.filename = Path(filename)
self.columns = columns
if self.filename.exists():
self.file = open(self.filename, 'a')
else:
self.file = open(self.filename, 'w')
self.write(*self.columns)
def write(self, *args):
print(*args, sep=',', file=self.file, flush=True)
@contextmanager
def tf32_mode(cudnn=None, matmul=None):
"""A context manager that sets whether TF32 is allowed on cuDNN or matmul."""
cudnn_old = torch.backends.cudnn.allow_tf32
matmul_old = torch.backends.cuda.matmul.allow_tf32
try:
if cudnn is not None:
torch.backends.cudnn.allow_tf32 = cudnn
if matmul is not None:
torch.backends.cuda.matmul.allow_tf32 = matmul
yield
finally:
if cudnn is not None:
torch.backends.cudnn.allow_tf32 = cudnn_old
if matmul is not None:
torch.backends.cuda.matmul.allow_tf32 = matmul_old
import torch
class LatentFormat:
scale_factor = 1.0
latent_channels = 4
latent_dimensions = 2
latent_rgb_factors = None
latent_rgb_factors_bias = None
taesd_decoder_name = None
def process_in(self, latent):
return latent * self.scale_factor
def process_out(self, latent):
return latent / self.scale_factor
class SD15(LatentFormat):
def __init__(self, scale_factor=0.18215):
self.scale_factor = scale_factor
self.latent_rgb_factors = [
# R G B
[ 0.3512, 0.2297, 0.3227],
[ 0.3250, 0.4974, 0.2350],
[-0.2829, 0.1762, 0.2721],
[-0.2120, -0.2616, -0.7177]
]
self.taesd_decoder_name = "taesd_decoder"
class SDXL(LatentFormat):
scale_factor = 0.13025
def __init__(self):
self.latent_rgb_factors = [
# R G B
[ 0.3651, 0.4232, 0.4341],
[-0.2533, -0.0042, 0.1068],
[ 0.1076, 0.1111, -0.0362],
[-0.3165, -0.2492, -0.2188]
]
self.latent_rgb_factors_bias = [ 0.1084, -0.0175, -0.0011]
self.taesd_decoder_name = "taesdxl_decoder"
class SDXL_Playground_2_5(LatentFormat):
def __init__(self):
self.scale_factor = 0.5
self.latents_mean = torch.tensor([-1.6574, 1.886, -1.383, 2.5155]).view(1, 4, 1, 1)
self.latents_std = torch.tensor([8.4927, 5.9022, 6.5498, 5.2299]).view(1, 4, 1, 1)
self.latent_rgb_factors = [
# R G B
[ 0.3920, 0.4054, 0.4549],
[-0.2634, -0.0196, 0.0653],
[ 0.0568, 0.1687, -0.0755],
[-0.3112, -0.2359, -0.2076]
]
self.taesd_decoder_name = "taesdxl_decoder"
def process_in(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return (latent - latents_mean) * self.scale_factor / latents_std
def process_out(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return latent * latents_std / self.scale_factor + latents_mean
class SD_X4(LatentFormat):
def __init__(self):
self.scale_factor = 0.08333
self.latent_rgb_factors = [
[-0.2340, -0.3863, -0.3257],
[ 0.0994, 0.0885, -0.0908],
[-0.2833, -0.2349, -0.3741],
[ 0.2523, -0.0055, -0.1651]
]
class SC_Prior(LatentFormat):
latent_channels = 16
def __init__(self):
self.scale_factor = 1.0
self.latent_rgb_factors = [
[-0.0326, -0.0204, -0.0127],
[-0.1592, -0.0427, 0.0216],
[ 0.0873, 0.0638, -0.0020],
[-0.0602, 0.0442, 0.1304],
[ 0.0800, -0.0313, -0.1796],
[-0.0810, -0.0638, -0.1581],
[ 0.1791, 0.1180, 0.0967],
[ 0.0740, 0.1416, 0.0432],
[-0.1745, -0.1888, -0.1373],
[ 0.2412, 0.1577, 0.0928],
[ 0.1908, 0.0998, 0.0682],
[ 0.0209, 0.0365, -0.0092],
[ 0.0448, -0.0650, -0.1728],
[-0.1658, -0.1045, -0.1308],
[ 0.0542, 0.1545, 0.1325],
[-0.0352, -0.1672, -0.2541]
]
class SC_B(LatentFormat):
def __init__(self):
self.scale_factor = 1.0 / 0.43
self.latent_rgb_factors = [
[ 0.1121, 0.2006, 0.1023],
[-0.2093, -0.0222, -0.0195],
[-0.3087, -0.1535, 0.0366],
[ 0.0290, -0.1574, -0.4078]
]
class SD3(LatentFormat):
latent_channels = 16
def __init__(self):
self.scale_factor = 1.5305
self.shift_factor = 0.0609
self.latent_rgb_factors = [
[-0.0922, -0.0175, 0.0749],
[ 0.0311, 0.0633, 0.0954],
[ 0.1994, 0.0927, 0.0458],
[ 0.0856, 0.0339, 0.0902],
[ 0.0587, 0.0272, -0.0496],
[-0.0006, 0.1104, 0.0309],
[ 0.0978, 0.0306, 0.0427],
[-0.0042, 0.1038, 0.1358],
[-0.0194, 0.0020, 0.0669],
[-0.0488, 0.0130, -0.0268],
[ 0.0922, 0.0988, 0.0951],
[-0.0278, 0.0524, -0.0542],
[ 0.0332, 0.0456, 0.0895],
[-0.0069, -0.0030, -0.0810],
[-0.0596, -0.0465, -0.0293],
[-0.1448, -0.1463, -0.1189]
]
self.latent_rgb_factors_bias = [0.2394, 0.2135, 0.1925]
self.taesd_decoder_name = "taesd3_decoder"
def process_in(self, latent):
return (latent - self.shift_factor) * self.scale_factor
def process_out(self, latent):
return (latent / self.scale_factor) + self.shift_factor
class StableAudio1(LatentFormat):
latent_channels = 64
latent_dimensions = 1
class Flux(SD3):
latent_channels = 16
def __init__(self):
self.scale_factor = 0.3611
self.shift_factor = 0.1159
self.latent_rgb_factors =[
[-0.0346, 0.0244, 0.0681],
[ 0.0034, 0.0210, 0.0687],
[ 0.0275, -0.0668, -0.0433],
[-0.0174, 0.0160, 0.0617],
[ 0.0859, 0.0721, 0.0329],
[ 0.0004, 0.0383, 0.0115],
[ 0.0405, 0.0861, 0.0915],
[-0.0236, -0.0185, -0.0259],
[-0.0245, 0.0250, 0.1180],
[ 0.1008, 0.0755, -0.0421],
[-0.0515, 0.0201, 0.0011],
[ 0.0428, -0.0012, -0.0036],
[ 0.0817, 0.0765, 0.0749],
[-0.1264, -0.0522, -0.1103],
[-0.0280, -0.0881, -0.0499],
[-0.1262, -0.0982, -0.0778]
]
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
self.taesd_decoder_name = "taef1_decoder"
def process_in(self, latent):
return (latent - self.shift_factor) * self.scale_factor
def process_out(self, latent):
return (latent / self.scale_factor) + self.shift_factor
class Mochi(LatentFormat):
latent_channels = 12
latent_dimensions = 3
def __init__(self):
self.scale_factor = 1.0
self.latents_mean = torch.tensor([-0.06730895953510081, -0.038011381506090416, -0.07477820912866141,
-0.05565264470995561, 0.012767231469026969, -0.04703542746246419,
0.043896967884726704, -0.09346305707025976, -0.09918314763016893,
-0.008729793427399178, -0.011931556316503654, -0.0321993391887285]).view(1, self.latent_channels, 1, 1, 1)
self.latents_std = torch.tensor([0.9263795028493863, 0.9248894543193766, 0.9393059390890617,
0.959253732819592, 0.8244560132752793, 0.917259975397747,
0.9294154431013696, 1.3720942357788521, 0.881393668867029,
0.9168315692124348, 0.9185249279345552, 0.9274757570805041]).view(1, self.latent_channels, 1, 1, 1)
self.latent_rgb_factors =[
[-0.0069, -0.0045, 0.0018],
[ 0.0154, -0.0692, -0.0274],
[ 0.0333, 0.0019, 0.0206],
[-0.1390, 0.0628, 0.1678],
[-0.0725, 0.0134, -0.1898],
[ 0.0074, -0.0270, -0.0209],
[-0.0176, -0.0277, -0.0221],
[ 0.5294, 0.5204, 0.3852],
[-0.0326, -0.0446, -0.0143],
[-0.0659, 0.0153, -0.0153],
[ 0.0185, -0.0217, 0.0014],
[-0.0396, -0.0495, -0.0281]
]
self.latent_rgb_factors_bias = [-0.0940, -0.1418, -0.1453]
self.taesd_decoder_name = None #TODO
def process_in(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return (latent - latents_mean) * self.scale_factor / latents_std
def process_out(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return latent * latents_std / self.scale_factor + latents_mean
class LTXV(LatentFormat):
latent_channels = 128
latent_dimensions = 3
def __init__(self):
self.latent_rgb_factors = [
[ 1.1202e-02, -6.3815e-04, -1.0021e-02],
[ 8.6031e-02, 6.5813e-02, 9.5409e-04],
[-1.2576e-02, -7.5734e-03, -4.0528e-03],
[ 9.4063e-03, -2.1688e-03, 2.6093e-03],
[ 3.7636e-03, 1.2765e-02, 9.1548e-03],
[ 2.1024e-02, -5.2973e-03, 3.4373e-03],
[-8.8896e-03, -1.9703e-02, -1.8761e-02],
[-1.3160e-02, -1.0523e-02, 1.9709e-03],
[-1.5152e-03, -6.9891e-03, -7.5810e-03],
[-1.7247e-03, 4.6560e-04, -3.3839e-03],
[ 1.3617e-02, 4.7077e-03, -2.0045e-03],
[ 1.0256e-02, 7.7318e-03, 1.3948e-02],
[-1.6108e-02, -6.2151e-03, 1.1561e-03],
[ 7.3407e-03, 1.5628e-02, 4.4865e-04],
[ 9.5357e-04, -2.9518e-03, -1.4760e-02],
[ 1.9143e-02, 1.0868e-02, 1.2264e-02],
[ 4.4575e-03, 3.6682e-05, -6.8508e-03],
[-4.5681e-04, 3.2570e-03, 7.7929e-03],
[ 3.3902e-02, 3.3405e-02, 3.7454e-02],
[-2.3001e-02, -2.4877e-03, -3.1033e-03],
[ 5.0265e-02, 3.8841e-02, 3.3539e-02],
[-4.1018e-03, -1.1095e-03, 1.5859e-03],
[-1.2689e-01, -1.3107e-01, -2.1005e-01],
[ 2.6276e-02, 1.4189e-02, -3.5963e-03],
[-4.8679e-03, 8.8486e-03, 7.8029e-03],
[-1.6610e-03, -4.8597e-03, -5.2060e-03],
[-2.1010e-03, 2.3610e-03, 9.3796e-03],
[-2.2482e-02, -2.1305e-02, -1.5087e-02],
[-1.5753e-02, -1.0646e-02, -6.5083e-03],
[-4.6975e-03, 5.0288e-03, -6.7390e-03],
[ 1.1951e-02, 2.0712e-02, 1.6191e-02],
[-6.3704e-03, -8.4827e-03, -9.5483e-03],
[ 7.2610e-03, -9.9326e-03, -2.2978e-02],
[-9.1904e-04, 6.2882e-03, 9.5720e-03],
[-3.7178e-02, -3.7123e-02, -5.6713e-02],
[-1.3373e-01, -1.0720e-01, -5.3801e-02],
[-5.3702e-03, 8.1256e-03, 8.8397e-03],
[-1.5247e-01, -2.1437e-01, -2.1843e-01],
[ 3.1441e-02, 7.0335e-03, -9.7541e-03],
[ 2.1528e-03, -8.9817e-03, -2.1023e-02],
[ 3.8461e-03, -5.8957e-03, -1.5014e-02],
[-4.3470e-03, -1.2940e-02, -1.5972e-02],
[-5.4781e-03, -1.0842e-02, -3.0204e-03],
[-6.5347e-03, 3.0806e-03, -1.0163e-02],
[-5.0414e-03, -7.1503e-03, -8.9686e-04],
[-8.5851e-03, -2.4351e-03, 1.0674e-03],
[-9.0016e-03, -9.6493e-03, 1.5692e-03],
[ 5.0914e-03, 1.2099e-02, 1.9968e-02],
[ 1.3758e-02, 1.1669e-02, 8.1958e-03],
[-1.0518e-02, -1.1575e-02, -4.1307e-03],
[-2.8410e-02, -3.1266e-02, -2.2149e-02],
[ 2.9336e-03, 3.6511e-02, 1.8717e-02],
[-1.6703e-02, -1.6696e-02, -4.4529e-03],
[ 4.8818e-02, 4.0063e-02, 8.7410e-03],
[-1.5066e-02, -5.7328e-04, 2.9785e-03],
[-1.7613e-02, -8.1034e-03, 1.3086e-02],
[-9.2633e-03, 1.0803e-02, -6.3489e-03],
[ 3.0851e-03, 4.7750e-04, 1.2347e-02],
[-2.2785e-02, -2.3043e-02, -2.6005e-02],
[-2.4787e-02, -1.5389e-02, -2.2104e-02],
[-2.3572e-02, 1.0544e-03, 1.2361e-02],
[-7.8915e-03, -1.2271e-03, -6.0968e-03],
[-1.1478e-02, -1.2543e-03, 6.2679e-03],
[-5.4229e-02, 2.6644e-02, 6.3394e-03],
[ 4.4216e-03, -7.3338e-03, -1.0464e-02],
[-4.5013e-03, 1.6082e-03, 1.4420e-02],
[ 1.3673e-02, 8.8877e-03, 4.1253e-03],
[-1.0145e-02, 9.0072e-03, 1.5695e-02],
[-5.6234e-03, 1.1847e-03, 8.1261e-03],
[-3.7171e-03, -5.3538e-03, 1.2590e-03],
[ 2.9476e-02, 2.1424e-02, 3.0424e-02],
[-3.4925e-02, -2.4340e-02, -2.5316e-02],
[-3.4127e-02, -2.2406e-02, -1.0589e-02],
[-1.7342e-02, -1.3249e-02, -1.0719e-02],
[-2.1478e-03, -8.6051e-03, -2.9878e-03],
[ 1.2089e-03, -4.2391e-03, -6.8569e-03],
[ 9.0411e-04, -6.6886e-03, -6.7547e-05],
[ 1.6048e-02, -1.0057e-02, -2.8929e-02],
[ 1.2290e-03, 1.0163e-02, 1.8861e-02],
[ 1.7264e-02, 2.7257e-04, 1.3785e-02],
[-1.3482e-02, -3.6427e-03, 6.7481e-04],
[ 4.6782e-03, -5.2423e-03, 2.4467e-03],
[-5.9113e-03, -6.2244e-03, -1.8162e-03],
[ 1.5496e-02, 1.4582e-02, 1.9514e-03],
[ 7.4958e-03, 1.5886e-03, -8.2305e-03],
[ 1.9086e-02, 1.6360e-03, -3.9674e-03],
[-5.7021e-03, -2.7307e-03, -4.1066e-03],
[ 1.7450e-03, 1.4602e-02, 2.5794e-02],
[-8.2788e-04, 2.2902e-03, 4.5161e-03],
[ 1.1632e-02, 8.9193e-03, -7.2813e-03],
[ 7.5721e-03, 2.6784e-03, 1.1393e-02],
[ 5.1939e-03, 3.6903e-03, 1.4049e-02],
[-1.8383e-02, -2.2529e-02, -2.4477e-02],
[ 5.8842e-04, -5.7874e-03, -1.4770e-02],
[-1.6125e-02, -8.6101e-03, -1.4533e-02],
[ 2.0540e-02, 2.0729e-02, 6.4338e-03],
[ 3.3587e-03, -1.1226e-02, -1.6444e-02],
[-1.4742e-03, -1.0489e-02, 1.7097e-03],
[ 2.8130e-02, 2.3546e-02, 3.2791e-02],
[-1.8532e-02, -1.2842e-02, -8.7756e-03],
[-8.0533e-03, -1.0771e-02, -1.7536e-02],
[-3.9009e-03, 1.6150e-02, 3.3359e-02],
[-7.4554e-03, -1.4154e-02, -6.1910e-03],
[ 3.4734e-03, -1.1370e-02, -1.0581e-02],
[ 1.1476e-02, 3.9281e-03, 2.8231e-03],
[ 7.1639e-03, -1.4741e-03, -3.8066e-03],
[ 2.2250e-03, -8.7552e-03, -9.5719e-03],
[ 2.4146e-02, 2.1696e-02, 2.8056e-02],
[-5.4365e-03, -2.4291e-02, -1.7802e-02],
[ 7.4263e-03, 1.0510e-02, 1.2705e-02],
[ 6.2669e-03, 6.2658e-03, 1.9211e-02],
[ 1.6378e-02, 9.4933e-03, 6.6971e-03],
[ 1.7173e-02, 2.3601e-02, 2.3296e-02],
[-1.4568e-02, -9.8279e-03, -1.1556e-02],
[ 1.4431e-02, 1.4430e-02, 6.6362e-03],
[-6.8230e-03, 1.8863e-02, 1.4555e-02],
[ 6.1156e-03, 3.4700e-03, -2.6662e-03],
[-2.6983e-03, -5.9402e-03, -9.2276e-03],
[ 1.0235e-02, 7.4173e-03, -7.6243e-03],
[-1.3255e-02, 1.9322e-02, -9.2153e-04],
[ 2.4222e-03, -4.8039e-03, -1.5759e-02],
[ 2.6244e-02, 2.5951e-02, 2.0249e-02],
[ 1.5711e-02, 1.8498e-02, 2.7407e-03],
[-2.1714e-03, 4.7214e-03, -2.2443e-02],
[-7.4747e-03, 7.4166e-03, 1.4430e-02],
[-8.3906e-03, -7.9776e-03, 9.7927e-03],
[ 3.8321e-02, 9.6622e-03, -1.9268e-02],
[-1.4605e-02, -6.7032e-03, 3.9675e-03]
]
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]
class HunyuanVideo(LatentFormat):
latent_channels = 16
latent_dimensions = 3
scale_factor = 0.476986
latent_rgb_factors = [
[-0.0395, -0.0331, 0.0445],
[ 0.0696, 0.0795, 0.0518],
[ 0.0135, -0.0945, -0.0282],
[ 0.0108, -0.0250, -0.0765],
[-0.0209, 0.0032, 0.0224],
[-0.0804, -0.0254, -0.0639],
[-0.0991, 0.0271, -0.0669],
[-0.0646, -0.0422, -0.0400],
[-0.0696, -0.0595, -0.0894],
[-0.0799, -0.0208, -0.0375],
[ 0.1166, 0.1627, 0.0962],
[ 0.1165, 0.0432, 0.0407],
[-0.2315, -0.1920, -0.1355],
[-0.0270, 0.0401, -0.0821],
[-0.0616, -0.0997, -0.0727],
[ 0.0249, -0.0469, -0.1703]
]
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
class Cosmos1CV8x8x8(LatentFormat):
latent_channels = 16
latent_dimensions = 3
latent_rgb_factors = [
[ 0.1817, 0.2284, 0.2423],
[-0.0586, -0.0862, -0.3108],
[-0.4703, -0.4255, -0.3995],
[ 0.0803, 0.1963, 0.1001],
[-0.0820, -0.1050, 0.0400],
[ 0.2511, 0.3098, 0.2787],
[-0.1830, -0.2117, -0.0040],
[-0.0621, -0.2187, -0.0939],
[ 0.3619, 0.1082, 0.1455],
[ 0.3164, 0.3922, 0.2575],
[ 0.1152, 0.0231, -0.0462],
[-0.1434, -0.3609, -0.3665],
[ 0.0635, 0.1471, 0.1680],
[-0.3635, -0.1963, -0.3248],
[-0.1865, 0.0365, 0.2346],
[ 0.0447, 0.0994, 0.0881]
]
latent_rgb_factors_bias = [-0.1223, -0.1889, -0.1976]
class Wan21(LatentFormat):
latent_channels = 16
latent_dimensions = 3
latent_rgb_factors = [
[-0.1299, -0.1692, 0.2932],
[ 0.0671, 0.0406, 0.0442],
[ 0.3568, 0.2548, 0.1747],
[ 0.0372, 0.2344, 0.1420],
[ 0.0313, 0.0189, -0.0328],
[ 0.0296, -0.0956, -0.0665],
[-0.3477, -0.4059, -0.2925],
[ 0.0166, 0.1902, 0.1975],
[-0.0412, 0.0267, -0.1364],
[-0.1293, 0.0740, 0.1636],
[ 0.0680, 0.3019, 0.1128],
[ 0.0032, 0.0581, 0.0639],
[-0.1251, 0.0927, 0.1699],
[ 0.0060, -0.0633, 0.0005],
[ 0.3477, 0.2275, 0.2950],
[ 0.1984, 0.0913, 0.1861]
]
latent_rgb_factors_bias = [-0.1835, -0.0868, -0.3360]
def __init__(self):
self.scale_factor = 1.0
self.latents_mean = torch.tensor([
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
]).view(1, self.latent_channels, 1, 1, 1)
self.latents_std = torch.tensor([
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
]).view(1, self.latent_channels, 1, 1, 1)
self.taesd_decoder_name = None #TODO
def process_in(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return (latent - latents_mean) * self.scale_factor / latents_std
def process_out(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return latent * latents_std / self.scale_factor + latents_mean
class Wan22(Wan21):
latent_channels = 48
latent_dimensions = 3
latent_rgb_factors = [
[ 0.0119, 0.0103, 0.0046],
[-0.1062, -0.0504, 0.0165],
[ 0.0140, 0.0409, 0.0491],
[-0.0813, -0.0677, 0.0607],
[ 0.0656, 0.0851, 0.0808],
[ 0.0264, 0.0463, 0.0912],
[ 0.0295, 0.0326, 0.0590],
[-0.0244, -0.0270, 0.0025],
[ 0.0443, -0.0102, 0.0288],
[-0.0465, -0.0090, -0.0205],
[ 0.0359, 0.0236, 0.0082],
[-0.0776, 0.0854, 0.1048],
[ 0.0564, 0.0264, 0.0561],
[ 0.0006, 0.0594, 0.0418],
[-0.0319, -0.0542, -0.0637],
[-0.0268, 0.0024, 0.0260],
[ 0.0539, 0.0265, 0.0358],
[-0.0359, -0.0312, -0.0287],
[-0.0285, -0.1032, -0.1237],
[ 0.1041, 0.0537, 0.0622],
[-0.0086, -0.0374, -0.0051],
[ 0.0390, 0.0670, 0.2863],
[ 0.0069, 0.0144, 0.0082],
[ 0.0006, -0.0167, 0.0079],
[ 0.0313, -0.0574, -0.0232],
[-0.1454, -0.0902, -0.0481],
[ 0.0714, 0.0827, 0.0447],
[-0.0304, -0.0574, -0.0196],
[ 0.0401, 0.0384, 0.0204],
[-0.0758, -0.0297, -0.0014],
[ 0.0568, 0.1307, 0.1372],
[-0.0055, -0.0310, -0.0380],
[ 0.0239, -0.0305, 0.0325],
[-0.0663, -0.0673, -0.0140],
[-0.0416, -0.0047, -0.0023],
[ 0.0166, 0.0112, -0.0093],
[-0.0211, 0.0011, 0.0331],
[ 0.1833, 0.1466, 0.2250],
[-0.0368, 0.0370, 0.0295],
[-0.3441, -0.3543, -0.2008],
[-0.0479, -0.0489, -0.0420],
[-0.0660, -0.0153, 0.0800],
[-0.0101, 0.0068, 0.0156],
[-0.0690, -0.0452, -0.0927],
[-0.0145, 0.0041, 0.0015],
[ 0.0421, 0.0451, 0.0373],
[ 0.0504, -0.0483, -0.0356],
[-0.0837, 0.0168, 0.0055]
]
latent_rgb_factors_bias = [0.0317, -0.0878, -0.1388]
def __init__(self):
self.scale_factor = 1.0
self.latents_mean = torch.tensor([
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
-0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
-0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
-0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667,
]).view(1, self.latent_channels, 1, 1, 1)
self.latents_std = torch.tensor([
0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
]).view(1, self.latent_channels, 1, 1, 1)
class Hunyuan3Dv2(LatentFormat):
latent_channels = 64
latent_dimensions = 1
scale_factor = 0.9990943042622529
class Hunyuan3Dv2mini(LatentFormat):
latent_channels = 64
latent_dimensions = 1
scale_factor = 1.0188137142395404
class ACEAudio(LatentFormat):
latent_channels = 8
latent_dimensions = 2
# Original from: https://github.com/ace-step/ACE-Step/blob/main/models/attention.py
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, Union, Optional
import torch
import torch.nn.functional as F
from torch import nn
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention
class Attention(nn.Module):
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
kv_heads: Optional[int] = None,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
qk_norm: Optional[str] = None,
added_kv_proj_dim: Optional[int] = None,
added_proj_bias: Optional[bool] = True,
out_bias: bool = True,
scale_qk: bool = True,
only_cross_attention: bool = False,
eps: float = 1e-5,
rescale_output_factor: float = 1.0,
residual_connection: bool = False,
processor=None,
out_dim: int = None,
out_context_dim: int = None,
context_pre_only=None,
pre_only=False,
elementwise_affine: bool = True,
is_causal: bool = False,
dtype=None, device=None, operations=None
):
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
self.query_dim = query_dim
self.use_bias = bias
self.is_cross_attention = cross_attention_dim is not None
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.rescale_output_factor = rescale_output_factor
self.residual_connection = residual_connection
self.dropout = dropout
self.fused_projections = False
self.out_dim = out_dim if out_dim is not None else query_dim
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
self.context_pre_only = context_pre_only
self.pre_only = pre_only
self.is_causal = is_causal
self.scale_qk = scale_qk
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
self.heads = out_dim // dim_head if out_dim is not None else heads
# for slice_size > 0 the attention score computation
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
self.sliceable_head_dim = heads
self.added_kv_proj_dim = added_kv_proj_dim
self.only_cross_attention = only_cross_attention
if self.added_kv_proj_dim is None and self.only_cross_attention:
raise ValueError(
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
)
self.group_norm = None
self.spatial_norm = None
self.norm_q = None
self.norm_k = None
self.norm_cross = None
self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
if not self.only_cross_attention:
# only relevant for the `AddedKVProcessor` classes
self.to_k = operations.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
self.to_v = operations.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
else:
self.to_k = None
self.to_v = None
self.added_proj_bias = added_proj_bias
if self.added_kv_proj_dim is not None:
self.add_k_proj = operations.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias, dtype=dtype, device=device)
self.add_v_proj = operations.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias, dtype=dtype, device=device)
if self.context_pre_only is not None:
self.add_q_proj = operations.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias, dtype=dtype, device=device)
else:
self.add_q_proj = None
self.add_k_proj = None
self.add_v_proj = None
if not self.pre_only:
self.to_out = nn.ModuleList([])
self.to_out.append(operations.Linear(self.inner_dim, self.out_dim, bias=out_bias, dtype=dtype, device=device))
self.to_out.append(nn.Dropout(dropout))
else:
self.to_out = None
if self.context_pre_only is not None and not self.context_pre_only:
self.to_add_out = operations.Linear(self.inner_dim, self.out_context_dim, bias=out_bias, dtype=dtype, device=device)
else:
self.to_add_out = None
self.norm_added_q = None
self.norm_added_k = None
self.processor = processor
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**cross_attention_kwargs,
) -> torch.Tensor:
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
class CustomLiteLAProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections. add rms norm for query and key and apply RoPE"""
def __init__(self):
self.kernel_func = nn.ReLU(inplace=False)
self.eps = 1e-15
self.pad_val = 1.0
def apply_rotary_emb(
self,
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
tensors contain rotary embeddings and are returned as real tensors.
Args:
x (`torch.Tensor`):
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
hidden_states_len = hidden_states.shape[1]
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
if encoder_hidden_states is not None:
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size = hidden_states.shape[0]
# `sample` projections.
dtype = hidden_states.dtype
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# `context` projections.
has_encoder_hidden_state_proj = hasattr(attn, "add_q_proj") and hasattr(attn, "add_k_proj") and hasattr(attn, "add_v_proj")
if encoder_hidden_states is not None and has_encoder_hidden_state_proj:
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
# attention
if not attn.is_cross_attention:
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
else:
query = hidden_states
key = encoder_hidden_states
value = encoder_hidden_states
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
key = key.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1).transpose(-1, -2)
value = value.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
# RoPE需要 [B, H, S, D] 输入
# 此时 query是 [B, H, D, S], 需要转成 [B, H, S, D] 才能应用RoPE
query = query.permute(0, 1, 3, 2) # [B, H, S, D] (从 [B, H, D, S])
# Apply query and key normalization if needed
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if rotary_freqs_cis is not None:
query = self.apply_rotary_emb(query, rotary_freqs_cis)
if not attn.is_cross_attention:
key = self.apply_rotary_emb(key, rotary_freqs_cis)
elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
# 此时 query是 [B, H, S, D],需要还原成 [B, H, D, S]
query = query.permute(0, 1, 3, 2) # [B, H, D, S]
if attention_mask is not None:
# attention_mask: [B, S] -> [B, 1, S, 1]
attention_mask = attention_mask[:, None, :, None].to(key.dtype) # [B, 1, S, 1]
query = query * attention_mask.permute(0, 1, 3, 2) # [B, H, S, D] * [B, 1, S, 1]
if not attn.is_cross_attention:
key = key * attention_mask # key: [B, h, S, D] 与 mask [B, 1, S, 1] 相乘
value = value * attention_mask.permute(0, 1, 3, 2) # 如果 value 是 [B, h, D, S],那么需调整mask以匹配S维度
if attn.is_cross_attention and encoder_attention_mask is not None and has_encoder_hidden_state_proj:
encoder_attention_mask = encoder_attention_mask[:, None, :, None].to(key.dtype) # [B, 1, S_enc, 1]
# 此时 key: [B, h, S_enc, D], value: [B, h, D, S_enc]
key = key * encoder_attention_mask # [B, h, S_enc, D] * [B, 1, S_enc, 1]
value = value * encoder_attention_mask.permute(0, 1, 3, 2) # [B, h, D, S_enc] * [B, 1, 1, S_enc]
query = self.kernel_func(query)
key = self.kernel_func(key)
query, key, value = query.float(), key.float(), value.float()
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=self.pad_val)
vk = torch.matmul(value, key)
hidden_states = torch.matmul(vk, query)
if hidden_states.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.float()
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
hidden_states = hidden_states.view(batch_size, attn.heads * head_dim, -1).permute(0, 2, 1)
hidden_states = hidden_states.to(dtype)
if encoder_hidden_states is not None:
encoder_hidden_states = encoder_hidden_states.to(dtype)
# Split the attention outputs.
if encoder_hidden_states is not None and not attn.is_cross_attention and has_encoder_hidden_state_proj:
hidden_states, encoder_hidden_states = (
hidden_states[:, : hidden_states_len],
hidden_states[:, hidden_states_len:],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if encoder_hidden_states is not None and not attn.context_pre_only and not attn.is_cross_attention and hasattr(attn, "to_add_out"):
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if encoder_hidden_states is not None and context_input_ndim == 4:
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if torch.get_autocast_gpu_dtype() == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
if encoder_hidden_states is not None:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
return hidden_states, encoder_hidden_states
class CustomerAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def apply_rotary_emb(
self,
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
tensors contain rotary embeddings and are returned as real tensors.
Args:
x (`torch.Tensor`):
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
*args,
**kwargs,
) -> torch.Tensor:
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
has_encoder_hidden_state_proj = hasattr(attn, "add_q_proj") and hasattr(attn, "add_k_proj") and hasattr(attn, "add_v_proj")
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if rotary_freqs_cis is not None:
query = self.apply_rotary_emb(query, rotary_freqs_cis)
if not attn.is_cross_attention:
key = self.apply_rotary_emb(key, rotary_freqs_cis)
elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
if attn.is_cross_attention and encoder_attention_mask is not None and has_encoder_hidden_state_proj:
# attention_mask: N x S1
# encoder_attention_mask: N x S2
# cross attention 整合attention_mask和encoder_attention_mask
combined_mask = attention_mask[:, :, None] * encoder_attention_mask[:, None, :]
attention_mask = torch.where(combined_mask == 1, 0.0, -torch.inf)
attention_mask = attention_mask[:, None, :, :].expand(-1, attn.heads, -1, -1).to(query.dtype)
elif not attn.is_cross_attention and attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = optimized_attention(
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
).to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def val2list(x: list or tuple or any, repeat_time=1) -> list: # type: ignore
"""Repeat `val` for `repeat_time` times and return the list or val if list/tuple."""
if isinstance(x, (list, tuple)):
return list(x)
return [x for _ in range(repeat_time)]
def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple: # type: ignore
"""Return tuple with min_len by repeating element at idx_repeat."""
# convert to list first
x = val2list(x)
# repeat elements if necessary
if len(x) > 0:
x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
return tuple(x)
def t2i_modulate(x, shift, scale):
return x * (1 + scale) + shift
def get_same_padding(kernel_size: Union[int, Tuple[int, ...]]) -> Union[int, Tuple[int, ...]]:
if isinstance(kernel_size, tuple):
return tuple([get_same_padding(ks) for ks in kernel_size])
else:
assert kernel_size % 2 > 0, f"kernel size {kernel_size} should be odd number"
return kernel_size // 2
class ConvLayer(nn.Module):
def __init__(
self,
in_dim: int,
out_dim: int,
kernel_size=3,
stride=1,
dilation=1,
groups=1,
padding: Union[int, None] = None,
use_bias=False,
norm=None,
act=None,
dtype=None, device=None, operations=None
):
super().__init__()
if padding is None:
padding = get_same_padding(kernel_size)
padding *= dilation
self.in_dim = in_dim
self.out_dim = out_dim
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.groups = groups
self.padding = padding
self.use_bias = use_bias
self.conv = operations.Conv1d(
in_dim,
out_dim,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=use_bias,
device=device,
dtype=dtype
)
if norm is not None:
self.norm = operations.RMSNorm(out_dim, elementwise_affine=False, dtype=dtype, device=device)
else:
self.norm = None
if act is not None:
self.act = nn.SiLU(inplace=True)
else:
self.act = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x)
if self.norm:
x = self.norm(x)
if self.act:
x = self.act(x)
return x
class GLUMBConv(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: int,
out_feature=None,
kernel_size=3,
stride=1,
padding: Union[int, None] = None,
use_bias=False,
norm=(None, None, None),
act=("silu", "silu", None),
dilation=1,
dtype=None, device=None, operations=None
):
out_feature = out_feature or in_features
super().__init__()
use_bias = val2tuple(use_bias, 3)
norm = val2tuple(norm, 3)
act = val2tuple(act, 3)
self.glu_act = nn.SiLU(inplace=False)
self.inverted_conv = ConvLayer(
in_features,
hidden_features * 2,
1,
use_bias=use_bias[0],
norm=norm[0],
act=act[0],
dtype=dtype,
device=device,
operations=operations,
)
self.depth_conv = ConvLayer(
hidden_features * 2,
hidden_features * 2,
kernel_size,
stride=stride,
groups=hidden_features * 2,
padding=padding,
use_bias=use_bias[1],
norm=norm[1],
act=None,
dilation=dilation,
dtype=dtype,
device=device,
operations=operations,
)
self.point_conv = ConvLayer(
hidden_features,
out_feature,
1,
use_bias=use_bias[2],
norm=norm[2],
act=act[2],
dtype=dtype,
device=device,
operations=operations,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.transpose(1, 2)
x = self.inverted_conv(x)
x = self.depth_conv(x)
x, gate = torch.chunk(x, 2, dim=1)
gate = self.glu_act(gate)
x = x * gate
x = self.point_conv(x)
x = x.transpose(1, 2)
return x
class LinearTransformerBlock(nn.Module):
"""
A Sana block with global shared adaptive layer norm (adaLN-single) conditioning.
"""
def __init__(
self,
dim,
num_attention_heads,
attention_head_dim,
use_adaln_single=True,
cross_attention_dim=None,
added_kv_proj_dim=None,
context_pre_only=False,
mlp_ratio=4.0,
add_cross_attention=False,
add_cross_attention_dim=None,
qk_norm=None,
dtype=None, device=None, operations=None
):
super().__init__()
self.norm1 = operations.RMSNorm(dim, elementwise_affine=False, eps=1e-6)
self.attn = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
added_kv_proj_dim=added_kv_proj_dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
qk_norm=qk_norm,
processor=CustomLiteLAProcessor2_0(),
dtype=dtype,
device=device,
operations=operations,
)
self.add_cross_attention = add_cross_attention
self.context_pre_only = context_pre_only
if add_cross_attention and add_cross_attention_dim is not None:
self.cross_attn = Attention(
query_dim=dim,
cross_attention_dim=add_cross_attention_dim,
added_kv_proj_dim=add_cross_attention_dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=context_pre_only,
bias=True,
qk_norm=qk_norm,
processor=CustomerAttnProcessor2_0(),
dtype=dtype,
device=device,
operations=operations,
)
self.norm2 = operations.RMSNorm(dim, 1e-06, elementwise_affine=False)
self.ff = GLUMBConv(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
use_bias=(True, True, False),
norm=(None, None, None),
act=("silu", "silu", None),
dtype=dtype,
device=device,
operations=operations,
)
self.use_adaln_single = use_adaln_single
if use_adaln_single:
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, dtype=dtype, device=device))
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: torch.FloatTensor = None,
encoder_attention_mask: torch.FloatTensor = None,
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
temb: torch.FloatTensor = None,
):
N = hidden_states.shape[0]
# step 1: AdaLN single
if self.use_adaln_single:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
comfy.model_management.cast_to(self.scale_shift_table[None], dtype=temb.dtype, device=temb.device) + temb.reshape(N, 6, -1)
).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)
if self.use_adaln_single:
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
# step 2: attention
if not self.add_cross_attention:
attn_output, encoder_hidden_states = self.attn(
hidden_states=norm_hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
)
else:
attn_output, _ = self.attn(
hidden_states=norm_hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=None,
encoder_attention_mask=None,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=None,
)
if self.use_adaln_single:
attn_output = gate_msa * attn_output
hidden_states = attn_output + hidden_states
if self.add_cross_attention:
attn_output = self.cross_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
)
hidden_states = attn_output + hidden_states
# step 3: add norm
norm_hidden_states = self.norm2(hidden_states)
if self.use_adaln_single:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
# step 4: feed forward
ff_output = self.ff(norm_hidden_states)
if self.use_adaln_single:
ff_output = gate_mlp * ff_output
hidden_states = hidden_states + ff_output
return hidden_states
# Original from: https://github.com/ace-step/ACE-Step/blob/main/models/lyrics_utils/lyric_encoder.py
from typing import Optional, Tuple, Union
import math
import torch
from torch import nn
import comfy.model_management
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model."""
def __init__(self,
channels: int,
kernel_size: int = 15,
activation: nn.Module = nn.ReLU(),
norm: str = "batch_norm",
causal: bool = False,
bias: bool = True,
dtype=None, device=None, operations=None):
"""Construct an ConvolutionModule object.
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernel size of conv layers.
causal (int): Whether use causal convolution or not
"""
super().__init__()
self.pointwise_conv1 = operations.Conv1d(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
dtype=dtype, device=device
)
# self.lorder is used to distinguish if it's a causal convolution,
# if self.lorder > 0: it's a causal convolution, the input will be
# padded with self.lorder frames on the left in forward.
# else: it's a symmetrical convolution
if causal:
padding = 0
self.lorder = kernel_size - 1
else:
# kernel_size should be an odd number for none causal convolution
assert (kernel_size - 1) % 2 == 0
padding = (kernel_size - 1) // 2
self.lorder = 0
self.depthwise_conv = operations.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=padding,
groups=channels,
bias=bias,
dtype=dtype, device=device
)
assert norm in ['batch_norm', 'layer_norm']
if norm == "batch_norm":
self.use_layer_norm = False
self.norm = nn.BatchNorm1d(channels)
else:
self.use_layer_norm = True
self.norm = operations.LayerNorm(channels, dtype=dtype, device=device)
self.pointwise_conv2 = operations.Conv1d(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
dtype=dtype, device=device
)
self.activation = activation
def forward(
self,
x: torch.Tensor,
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
cache: torch.Tensor = torch.zeros((0, 0, 0)),
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute convolution module.
Args:
x (torch.Tensor): Input tensor (#batch, time, channels).
mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
(0, 0, 0) means fake mask.
cache (torch.Tensor): left context cache, it is only
used in causal convolution (#batch, channels, cache_t),
(0, 0, 0) meas fake cache.
Returns:
torch.Tensor: Output tensor (#batch, time, channels).
"""
# exchange the temporal dimension and the feature dimension
x = x.transpose(1, 2) # (#batch, channels, time)
# mask batch padding
if mask_pad.size(2) > 0: # time > 0
x.masked_fill_(~mask_pad, 0.0)
if self.lorder > 0:
if cache.size(2) == 0: # cache_t == 0
x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
else:
assert cache.size(0) == x.size(0) # equal batch
assert cache.size(1) == x.size(1) # equal channel
x = torch.cat((cache, x), dim=2)
assert (x.size(2) > self.lorder)
new_cache = x[:, :, -self.lorder:]
else:
# It's better we just return None if no cache is required,
# However, for JIT export, here we just fake one tensor instead of
# None.
new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
# 1D Depthwise Conv
x = self.depthwise_conv(x)
if self.use_layer_norm:
x = x.transpose(1, 2)
x = self.activation(self.norm(x))
if self.use_layer_norm:
x = x.transpose(1, 2)
x = self.pointwise_conv2(x)
# mask batch padding
if mask_pad.size(2) > 0: # time > 0
x.masked_fill_(~mask_pad, 0.0)
return x.transpose(1, 2), new_cache
class PositionwiseFeedForward(torch.nn.Module):
"""Positionwise feed forward layer.
FeedForward are appied on each position of the sequence.
The output dim is same with the input dim.
Args:
idim (int): Input dimenstion.
hidden_units (int): The number of hidden units.
dropout_rate (float): Dropout rate.
activation (torch.nn.Module): Activation function
"""
def __init__(
self,
idim: int,
hidden_units: int,
dropout_rate: float,
activation: torch.nn.Module = torch.nn.ReLU(),
dtype=None, device=None, operations=None
):
"""Construct a PositionwiseFeedForward object."""
super(PositionwiseFeedForward, self).__init__()
self.w_1 = operations.Linear(idim, hidden_units, dtype=dtype, device=device)
self.activation = activation
self.dropout = torch.nn.Dropout(dropout_rate)
self.w_2 = operations.Linear(hidden_units, idim, dtype=dtype, device=device)
def forward(self, xs: torch.Tensor) -> torch.Tensor:
"""Forward function.
Args:
xs: input tensor (B, L, D)
Returns:
output tensor, (B, L, D)
"""
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
class Swish(torch.nn.Module):
"""Construct an Swish object."""
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Return Swish activation function."""
return x * torch.sigmoid(x)
class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
def __init__(self,
n_head: int,
n_feat: int,
dropout_rate: float,
key_bias: bool = True,
dtype=None, device=None, operations=None):
"""Construct an MultiHeadedAttention object."""
super().__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
self.linear_q = operations.Linear(n_feat, n_feat, dtype=dtype, device=device)
self.linear_k = operations.Linear(n_feat, n_feat, bias=key_bias, dtype=dtype, device=device)
self.linear_v = operations.Linear(n_feat, n_feat, dtype=dtype, device=device)
self.linear_out = operations.Linear(n_feat, n_feat, dtype=dtype, device=device)
self.dropout = nn.Dropout(p=dropout_rate)
def forward_qkv(
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Transform query, key and value.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
Returns:
torch.Tensor: Transformed query tensor, size
(#batch, n_head, time1, d_k).
torch.Tensor: Transformed key tensor, size
(#batch, n_head, time2, d_k).
torch.Tensor: Transformed value tensor, size
(#batch, n_head, time2, d_k).
"""
n_batch = query.size(0)
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
q = q.transpose(1, 2) # (batch, head, time1, d_k)
k = k.transpose(1, 2) # (batch, head, time2, d_k)
v = v.transpose(1, 2) # (batch, head, time2, d_k)
return q, k, v
def forward_attention(
self,
value: torch.Tensor,
scores: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
) -> torch.Tensor:
"""Compute attention context vector.
Args:
value (torch.Tensor): Transformed value, size
(#batch, n_head, time2, d_k).
scores (torch.Tensor): Attention score, size
(#batch, n_head, time1, time2).
mask (torch.Tensor): Mask, size (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
Returns:
torch.Tensor: Transformed value (#batch, time1, d_model)
weighted by the attention score (#batch, time1, time2).
"""
n_batch = value.size(0)
if mask is not None and mask.size(2) > 0: # time2 > 0
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
# For last chunk, time2 might be larger than scores.size(-1)
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
scores = scores.masked_fill(mask, -float('inf'))
attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0) # (batch, head, time1, time2)
else:
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
self.h * self.d_k)
) # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
pos_emb: torch.Tensor = torch.empty(0),
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute scaled dot product attention.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
1.When applying cross attention between decoder and encoder,
the batch padding mask for input is in (#batch, 1, T) shape.
2.When applying self attention of encoder,
the mask is in (#batch, T, T) shape.
3.When applying self attention of decoder,
the mask is in (#batch, L, L) shape.
4.If the different position in decoder see different block
of the encoder, such as Mocha, the passed in mask could be
in (#batch, L, T) shape. But there is no such case in current
CosyVoice.
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q, k, v = self.forward_qkv(query, key, value)
if cache.size(0) > 0:
key_cache, value_cache = torch.split(cache,
cache.size(-1) // 2,
dim=-1)
k = torch.cat([key_cache, k], dim=2)
v = torch.cat([value_cache, v], dim=2)
new_cache = torch.cat((k, v), dim=-1)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask), new_cache
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
def __init__(self,
n_head: int,
n_feat: int,
dropout_rate: float,
key_bias: bool = True,
dtype=None, device=None, operations=None):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate, key_bias, dtype=dtype, device=device, operations=operations)
# linear transformation for positional encoding
self.linear_pos = operations.Linear(n_feat, n_feat, bias=False, dtype=dtype, device=device)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = nn.Parameter(torch.empty(self.h, self.d_k, dtype=dtype, device=device))
self.pos_bias_v = nn.Parameter(torch.empty(self.h, self.d_k, dtype=dtype, device=device))
# torch.nn.init.xavier_uniform_(self.pos_bias_u)
# torch.nn.init.xavier_uniform_(self.pos_bias_v)
def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
"""Compute relative positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
time1 means the length of query vector.
Returns:
torch.Tensor: Output tensor.
"""
zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
device=x.device,
dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(x.size()[0],
x.size()[1],
x.size(3) + 1, x.size(2))
x = x_padded[:, :, 1:].view_as(x)[
:, :, :, : x.size(-1) // 2 + 1
] # only keep the positions from 0 to time2
return x
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
pos_emb: torch.Tensor = torch.empty(0),
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
pos_emb (torch.Tensor): Positional embedding tensor
(#batch, time2, size).
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2) # (batch, time1, head, d_k)
if cache.size(0) > 0:
key_cache, value_cache = torch.split(cache,
cache.size(-1) // 2,
dim=-1)
k = torch.cat([key_cache, k], dim=2)
v = torch.cat([value_cache, v], dim=2)
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = torch.cat((k, v), dim=-1)
n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, time1, d_k)
# (batch, head, time1, d_k)
q_with_bias_u = (q + comfy.model_management.cast_to(self.pos_bias_u, dtype=q.dtype, device=q.device)).transpose(1, 2)
# (batch, head, time1, d_k)
q_with_bias_v = (q + comfy.model_management.cast_to(self.pos_bias_v, dtype=q.dtype, device=q.device)).transpose(1, 2)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
# compute matrix b and matrix d
# (batch, head, time1, time2)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
# NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
if matrix_ac.shape != matrix_bd.shape:
matrix_bd = self.rel_shift(matrix_bd)
scores = (matrix_ac + matrix_bd) / math.sqrt(
self.d_k) # (batch, head, time1, time2)
return self.forward_attention(v, scores, mask), new_cache
def subsequent_mask(
size: int,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size).
This mask is used only in decoder which works in an auto-regressive mode.
This means the current step could only do attention with its left steps.
In encoder, fully attention is used when streaming is not necessary and
the sequence is not long. In this case, no attention mask is needed.
When streaming is need, chunk-based attention is used in encoder. See
subsequent_chunk_mask for the chunk-based attention mask.
Args:
size (int): size of mask
str device (str): "cpu" or "cuda" or torch.Tensor.device
dtype (torch.device): result dtype
Returns:
torch.Tensor: mask
Examples:
>>> subsequent_mask(3)
[[1, 0, 0],
[1, 1, 0],
[1, 1, 1]]
"""
arange = torch.arange(size, device=device)
mask = arange.expand(size, size)
arange = arange.unsqueeze(-1)
mask = mask <= arange
return mask
def subsequent_chunk_mask(
size: int,
chunk_size: int,
num_left_chunks: int = -1,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size) with chunk size,
this is for streaming encoder
Args:
size (int): size of mask
chunk_size (int): size of chunk
num_left_chunks (int): number of left chunks
<0: use full chunk
>=0: use num_left_chunks
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
Returns:
torch.Tensor: mask
Examples:
>>> subsequent_chunk_mask(4, 2)
[[1, 1, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 1],
[1, 1, 1, 1]]
"""
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
for i in range(size):
if num_left_chunks < 0:
start = 0
else:
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
ending = min((i // chunk_size + 1) * chunk_size, size)
ret[i, start:ending] = True
return ret
def add_optional_chunk_mask(xs: torch.Tensor,
masks: torch.Tensor,
use_dynamic_chunk: bool,
use_dynamic_left_chunk: bool,
decoding_chunk_size: int,
static_chunk_size: int,
num_decoding_left_chunks: int,
enable_full_context: bool = True):
""" Apply optional mask for encoder.
Args:
xs (torch.Tensor): padded input, (B, L, D), L for max length
mask (torch.Tensor): mask for xs, (B, 1, L)
use_dynamic_chunk (bool): whether to use dynamic chunk or not
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
training.
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
0: default for training, use random dynamic chunk.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
static_chunk_size (int): chunk size for static chunk training/decoding
if it's greater than 0, if use_dynamic_chunk is true,
this parameter will be ignored
num_decoding_left_chunks: number of left chunks, this is for decoding,
the chunk size is decoding_chunk_size.
>=0: use num_decoding_left_chunks
<0: use all left chunks
enable_full_context (bool):
True: chunk size is either [1, 25] or full context(max_len)
False: chunk size ~ U[1, 25]
Returns:
torch.Tensor: chunk mask of the input xs.
"""
# Whether to use chunk mask or not
if use_dynamic_chunk:
max_len = xs.size(1)
if decoding_chunk_size < 0:
chunk_size = max_len
num_left_chunks = -1
elif decoding_chunk_size > 0:
chunk_size = decoding_chunk_size
num_left_chunks = num_decoding_left_chunks
else:
# chunk size is either [1, 25] or full context(max_len).
# Since we use 4 times subsampling and allow up to 1s(100 frames)
# delay, the maximum frame is 100 / 4 = 25.
chunk_size = torch.randint(1, max_len, (1, )).item()
num_left_chunks = -1
if chunk_size > max_len // 2 and enable_full_context:
chunk_size = max_len
else:
chunk_size = chunk_size % 25 + 1
if use_dynamic_left_chunk:
max_left_chunks = (max_len - 1) // chunk_size
num_left_chunks = torch.randint(0, max_left_chunks,
(1, )).item()
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
num_left_chunks,
xs.device) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
chunk_masks = masks & chunk_masks # (B, L, L)
elif static_chunk_size > 0:
num_left_chunks = num_decoding_left_chunks
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
num_left_chunks,
xs.device) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
chunk_masks = masks & chunk_masks # (B, L, L)
else:
chunk_masks = masks
return chunk_masks
class ConformerEncoderLayer(nn.Module):
"""Encoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.
feed_forward_macaron (torch.nn.Module): Additional feed-forward module
instance.
`PositionwiseFeedForward` instance can be used as the argument.
conv_module (torch.nn.Module): Convolution module instance.
`ConvlutionModule` instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool):
True: use layer_norm before each sub-block.
False: use layer_norm after each sub-block.
"""
def __init__(
self,
size: int,
self_attn: torch.nn.Module,
feed_forward: Optional[nn.Module] = None,
feed_forward_macaron: Optional[nn.Module] = None,
conv_module: Optional[nn.Module] = None,
dropout_rate: float = 0.1,
normalize_before: bool = True,
dtype=None, device=None, operations=None
):
"""Construct an EncoderLayer object."""
super().__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.feed_forward_macaron = feed_forward_macaron
self.conv_module = conv_module
self.norm_ff = operations.LayerNorm(size, eps=1e-5, dtype=dtype, device=device) # for the FNN module
self.norm_mha = operations.LayerNorm(size, eps=1e-5, dtype=dtype, device=device) # for the MHA module
if feed_forward_macaron is not None:
self.norm_ff_macaron = operations.LayerNorm(size, eps=1e-5, dtype=dtype, device=device)
self.ff_scale = 0.5
else:
self.ff_scale = 1.0
if self.conv_module is not None:
self.norm_conv = operations.LayerNorm(size, eps=1e-5, dtype=dtype, device=device) # for the CNN module
self.norm_final = operations.LayerNorm(
size, eps=1e-5, dtype=dtype, device=device) # for the final output of the block
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor,
pos_emb: torch.Tensor,
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute encoded features.
Args:
x (torch.Tensor): (#batch, time, size)
mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
(0, 0, 0) means fake mask.
pos_emb (torch.Tensor): positional encoding, must not be None
for ConformerEncoderLayer.
mask_pad (torch.Tensor): batch padding mask used for conv module.
(#batch, 1,time), (0, 0, 0) means fake mask.
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
cnn_cache (torch.Tensor): Convolution cache in conformer layer
(#batch=1, size, cache_t2)
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time, time).
torch.Tensor: att_cache tensor,
(#batch=1, head, cache_t1 + time, d_k * 2).
torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
"""
# whether to use macaron style
if self.feed_forward_macaron is not None:
residual = x
if self.normalize_before:
x = self.norm_ff_macaron(x)
x = residual + self.ff_scale * self.dropout(
self.feed_forward_macaron(x))
if not self.normalize_before:
x = self.norm_ff_macaron(x)
# multi-headed self-attention module
residual = x
if self.normalize_before:
x = self.norm_mha(x)
x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
att_cache)
x = residual + self.dropout(x_att)
if not self.normalize_before:
x = self.norm_mha(x)
# convolution module
# Fake new cnn cache here, and then change it in conv_module
new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
if self.conv_module is not None:
residual = x
if self.normalize_before:
x = self.norm_conv(x)
x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
x = residual + self.dropout(x)
if not self.normalize_before:
x = self.norm_conv(x)
# feed forward module
residual = x
if self.normalize_before:
x = self.norm_ff(x)
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm_ff(x)
if self.conv_module is not None:
x = self.norm_final(x)
return x, mask, new_att_cache, new_cnn_cache
class EspnetRelPositionalEncoding(torch.nn.Module):
"""Relative positional encoding module (new implementation).
Details can be found in https://github.com/espnet/espnet/pull/2816.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
"""Construct an PositionalEncoding object."""
super(EspnetRelPositionalEncoding, self).__init__()
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x: torch.Tensor):
"""Reset the positional encodings."""
if self.pe is not None:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1:
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vecotr and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model)
pe_negative = torch.zeros(x.size(1), self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe_positive[:, 0::2] = torch.sin(position * div_term)
pe_positive[:, 1::2] = torch.cos(position * div_term)
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
# Reserve the order of positive indices and concat both positive and
# negative indices. This is used to support the shifting trick
# as in https://arxiv.org/abs/1901.02860
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
-> Tuple[torch.Tensor, torch.Tensor]:
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
"""
self.extend_pe(x)
x = x * self.xscale
pos_emb = self.position_encoding(size=x.size(1), offset=offset)
return self.dropout(x), self.dropout(pos_emb)
def position_encoding(self,
offset: Union[int, torch.Tensor],
size: int) -> torch.Tensor:
""" For getting encoding in a streaming fashion
Attention!!!!!
we apply dropout only once at the whole utterance level in a none
streaming way, but will call this function several times with
increasing input size in a streaming scenario, so the dropout will
be applied several times.
Args:
offset (int or torch.tensor): start offset
size (int): required size of position encoding
Returns:
torch.Tensor: Corresponding encoding
"""
pos_emb = self.pe[
:,
self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
]
return pos_emb
class LinearEmbed(torch.nn.Module):
"""Linear transform the input without subsampling
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
def __init__(self, idim: int, odim: int, dropout_rate: float,
pos_enc_class: torch.nn.Module, dtype=None, device=None, operations=None):
"""Construct an linear object."""
super().__init__()
self.out = torch.nn.Sequential(
operations.Linear(idim, odim, dtype=dtype, device=device),
operations.LayerNorm(odim, eps=1e-5, dtype=dtype, device=device),
torch.nn.Dropout(dropout_rate),
)
self.pos_enc = pos_enc_class #rel_pos_espnet
def position_encoding(self, offset: Union[int, torch.Tensor],
size: int) -> torch.Tensor:
return self.pos_enc.position_encoding(offset, size)
def forward(
self,
x: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Input x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: linear input tensor (#batch, time', odim),
where time' = time .
torch.Tensor: linear input mask (#batch, 1, time'),
where time' = time .
"""
x = self.out(x)
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb
ATTENTION_CLASSES = {
"selfattn": MultiHeadedAttention,
"rel_selfattn": RelPositionMultiHeadedAttention,
}
ACTIVATION_CLASSES = {
"hardtanh": torch.nn.Hardtanh,
"tanh": torch.nn.Tanh,
"relu": torch.nn.ReLU,
"selu": torch.nn.SELU,
"swish": getattr(torch.nn, "SiLU", Swish),
"gelu": torch.nn.GELU,
}
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
"""Make mask tensor containing indices of padded part.
See description of make_non_pad_mask.
Args:
lengths (torch.Tensor): Batch of lengths (B,).
Returns:
torch.Tensor: Mask tensor containing indices of padded part.
Examples:
>>> lengths = [5, 3, 2]
>>> make_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
"""
batch_size = lengths.size(0)
max_len = max_len if max_len > 0 else lengths.max().item()
seq_range = torch.arange(0,
max_len,
dtype=torch.int64,
device=lengths.device)
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
seq_length_expand = lengths.unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
return mask
#https://github.com/FunAudioLLM/CosyVoice/blob/main/examples/magicdata-read/cosyvoice/conf/cosyvoice.yaml
class ConformerEncoder(torch.nn.Module):
"""Conformer encoder module."""
def __init__(
self,
input_size: int,
output_size: int = 1024,
attention_heads: int = 16,
linear_units: int = 4096,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: str = 'linear',
pos_enc_layer_type: str = 'rel_pos_espnet',
normalize_before: bool = True,
static_chunk_size: int = 1, # 1: causal_mask; 0: full_mask
use_dynamic_chunk: bool = False,
use_dynamic_left_chunk: bool = False,
positionwise_conv_kernel_size: int = 1,
macaron_style: bool =False,
selfattention_layer_type: str = "rel_selfattn",
activation_type: str = "swish",
use_cnn_module: bool = False,
cnn_module_kernel: int = 15,
causal: bool = False,
cnn_module_norm: str = "batch_norm",
key_bias: bool = True,
dtype=None, device=None, operations=None
):
"""Construct ConformerEncoder
Args:
input_size to use_dynamic_chunk, see in BaseEncoder
positionwise_conv_kernel_size (int): Kernel size of positionwise
conv1d layer.
macaron_style (bool): Whether to use macaron style for
positionwise layer.
selfattention_layer_type (str): Encoder attention layer type,
the parameter has no effect now, it's just for configure
compatibility. #'rel_selfattn'
activation_type (str): Encoder activation function type.
use_cnn_module (bool): Whether to use convolution module.
cnn_module_kernel (int): Kernel size of convolution module.
causal (bool): whether to use causal convolution or not.
key_bias: whether use bias in attention.linear_k, False for whisper models.
"""
super().__init__()
self.output_size = output_size
self.embed = LinearEmbed(input_size, output_size, dropout_rate,
EspnetRelPositionalEncoding(output_size, positional_dropout_rate), dtype=dtype, device=device, operations=operations)
self.normalize_before = normalize_before
self.after_norm = operations.LayerNorm(output_size, eps=1e-5, dtype=dtype, device=device)
self.use_dynamic_chunk = use_dynamic_chunk
self.static_chunk_size = static_chunk_size
self.use_dynamic_chunk = use_dynamic_chunk
self.use_dynamic_left_chunk = use_dynamic_left_chunk
activation = ACTIVATION_CLASSES[activation_type]()
# self-attention module definition
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
key_bias,
)
# feed-forward module definition
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
activation,
)
# convolution module definition
convolution_layer_args = (output_size, cnn_module_kernel, activation,
cnn_module_norm, causal)
self.encoders = torch.nn.ModuleList([
ConformerEncoderLayer(
output_size,
RelPositionMultiHeadedAttention(
*encoder_selfattn_layer_args, dtype=dtype, device=device, operations=operations),
PositionwiseFeedForward(*positionwise_layer_args, dtype=dtype, device=device, operations=operations),
PositionwiseFeedForward(
*positionwise_layer_args, dtype=dtype, device=device, operations=operations) if macaron_style else None,
ConvolutionModule(
*convolution_layer_args, dtype=dtype, device=device, operations=operations) if use_cnn_module else None,
dropout_rate,
normalize_before, dtype=dtype, device=device, operations=operations
) for _ in range(num_blocks)
])
def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
pos_emb: torch.Tensor,
mask_pad: torch.Tensor) -> torch.Tensor:
for layer in self.encoders:
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
return xs
def forward(
self,
xs: torch.Tensor,
pad_mask: torch.Tensor,
decoding_chunk_size: int = 0,
num_decoding_left_chunks: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Embed positions in tensor.
Args:
xs: padded input tensor (B, T, D)
xs_lens: input length (B)
decoding_chunk_size: decoding chunk size for dynamic chunk
0: default for training, use random dynamic chunk.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
num_decoding_left_chunks: number of left chunks, this is for decoding,
the chunk size is decoding_chunk_size.
>=0: use num_decoding_left_chunks
<0: use all left chunks
Returns:
encoder output tensor xs, and subsampled masks
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
masks: torch.Tensor batch padding mask after subsample
(B, 1, T' ~= T/subsample_rate)
NOTE(xcsong):
We pass the `__call__` method of the modules instead of `forward` to the
checkpointing API because `__call__` attaches all the hooks of the module.
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
"""
masks = None
if pad_mask is not None:
masks = pad_mask.to(torch.bool).unsqueeze(1) # (B, 1, T)
xs, pos_emb = self.embed(xs)
mask_pad = masks # (B, 1, T/subsample_rate)
chunk_masks = add_optional_chunk_mask(xs, masks,
self.use_dynamic_chunk,
self.use_dynamic_left_chunk,
decoding_chunk_size,
self.static_chunk_size,
num_decoding_left_chunks)
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
if self.normalize_before:
xs = self.after_norm(xs)
# Here we assume the mask is not changed in encoder layers, so just
# return the masks before encoder layers, and the masks will be used
# for cross attention with decoder later
return xs, masks
# Original from: https://github.com/ace-step/ACE-Step/blob/main/models/ace_step_transformer.py
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, List, Union
import torch
from torch import nn
import comfy.model_management
import comfy.patcher_extension
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
from .attention import LinearTransformerBlock, t2i_modulate
from .lyric_encoder import ConformerEncoder as LyricEncoder
def cross_norm(hidden_states, controlnet_input):
# input N x T x c
mean_hidden_states, std_hidden_states = hidden_states.mean(dim=(1,2), keepdim=True), hidden_states.std(dim=(1,2), keepdim=True)
mean_controlnet_input, std_controlnet_input = controlnet_input.mean(dim=(1,2), keepdim=True), controlnet_input.std(dim=(1,2), keepdim=True)
controlnet_input = (controlnet_input - mean_controlnet_input) * (std_hidden_states / (std_controlnet_input + 1e-12)) + mean_hidden_states
return controlnet_input
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2
class Qwen2RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, dtype=None, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=device).float() / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
class T2IFinalLayer(nn.Module):
"""
The final layer of Sana.
"""
def __init__(self, hidden_size, patch_size=[16, 1], out_channels=256, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, patch_size[0] * patch_size[1] * out_channels, bias=True, dtype=dtype, device=device)
self.scale_shift_table = nn.Parameter(torch.empty(2, hidden_size, dtype=dtype, device=device))
self.out_channels = out_channels
self.patch_size = patch_size
def unpatchfy(
self,
hidden_states: torch.Tensor,
width: int,
):
# 4 unpatchify
new_height, new_width = 1, hidden_states.size(1)
hidden_states = hidden_states.reshape(
shape=(hidden_states.shape[0], new_height, new_width, self.patch_size[0], self.patch_size[1], self.out_channels)
).contiguous()
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(hidden_states.shape[0], self.out_channels, new_height * self.patch_size[0], new_width * self.patch_size[1])
).contiguous()
if width > new_width:
output = torch.nn.functional.pad(output, (0, width - new_width, 0, 0), 'constant', 0)
elif width < new_width:
output = output[:, :, :, :width]
return output
def forward(self, x, t, output_length):
shift, scale = (comfy.model_management.cast_to(self.scale_shift_table[None], device=t.device, dtype=t.dtype) + t[:, None]).chunk(2, dim=1)
x = t2i_modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
# unpatchify
output = self.unpatchfy(x, output_length)
return output
class PatchEmbed(nn.Module):
"""2D Image to Patch Embedding"""
def __init__(
self,
height=16,
width=4096,
patch_size=(16, 1),
in_channels=8,
embed_dim=1152,
bias=True,
dtype=None, device=None, operations=None
):
super().__init__()
patch_size_h, patch_size_w = patch_size
self.early_conv_layers = nn.Sequential(
operations.Conv2d(in_channels, in_channels*256, kernel_size=patch_size, stride=patch_size, padding=0, bias=bias, dtype=dtype, device=device),
operations.GroupNorm(num_groups=32, num_channels=in_channels*256, eps=1e-6, affine=True, dtype=dtype, device=device),
operations.Conv2d(in_channels*256, embed_dim, kernel_size=1, stride=1, padding=0, bias=bias, dtype=dtype, device=device)
)
self.patch_size = patch_size
self.height, self.width = height // patch_size_h, width // patch_size_w
self.base_size = self.width
def forward(self, latent):
# early convolutions, N x C x H x W -> N x 256 * sqrt(patch_size) x H/patch_size x W/patch_size
latent = self.early_conv_layers(latent)
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
return latent
class ACEStepTransformer2DModel(nn.Module):
# _supports_gradient_checkpointing = True
def __init__(
self,
in_channels: Optional[int] = 8,
num_layers: int = 28,
inner_dim: int = 1536,
attention_head_dim: int = 64,
num_attention_heads: int = 24,
mlp_ratio: float = 4.0,
out_channels: int = 8,
max_position: int = 32768,
rope_theta: float = 1000000.0,
speaker_embedding_dim: int = 512,
text_embedding_dim: int = 768,
ssl_encoder_depths: List[int] = [9, 9],
ssl_names: List[str] = ["mert", "m-hubert"],
ssl_latent_dims: List[int] = [1024, 768],
lyric_encoder_vocab_size: int = 6681,
lyric_hidden_size: int = 1024,
patch_size: List[int] = [16, 1],
max_height: int = 16,
max_width: int = 4096,
audio_model=None,
dtype=None, device=None, operations=None
):
super().__init__()
self.dtype = dtype
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
self.inner_dim = inner_dim
self.out_channels = out_channels
self.max_position = max_position
self.patch_size = patch_size
self.rope_theta = rope_theta
self.rotary_emb = Qwen2RotaryEmbedding(
dim=self.attention_head_dim,
max_position_embeddings=self.max_position,
base=self.rope_theta,
dtype=dtype,
device=device,
)
# 2. Define input layers
self.in_channels = in_channels
self.num_layers = num_layers
# 3. Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
LinearTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.num_attention_heads,
attention_head_dim=attention_head_dim,
mlp_ratio=mlp_ratio,
add_cross_attention=True,
add_cross_attention_dim=self.inner_dim,
dtype=dtype,
device=device,
operations=operations,
)
for i in range(self.num_layers)
]
)
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim, dtype=dtype, device=device, operations=operations)
self.t_block = nn.Sequential(nn.SiLU(), operations.Linear(self.inner_dim, 6 * self.inner_dim, bias=True, dtype=dtype, device=device))
# speaker
self.speaker_embedder = operations.Linear(speaker_embedding_dim, self.inner_dim, dtype=dtype, device=device)
# genre
self.genre_embedder = operations.Linear(text_embedding_dim, self.inner_dim, dtype=dtype, device=device)
# lyric
self.lyric_embs = operations.Embedding(lyric_encoder_vocab_size, lyric_hidden_size, dtype=dtype, device=device)
self.lyric_encoder = LyricEncoder(input_size=lyric_hidden_size, static_chunk_size=0, dtype=dtype, device=device, operations=operations)
self.lyric_proj = operations.Linear(lyric_hidden_size, self.inner_dim, dtype=dtype, device=device)
projector_dim = 2 * self.inner_dim
self.projectors = nn.ModuleList([
nn.Sequential(
operations.Linear(self.inner_dim, projector_dim, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(projector_dim, projector_dim, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(projector_dim, ssl_dim, dtype=dtype, device=device),
) for ssl_dim in ssl_latent_dims
])
self.proj_in = PatchEmbed(
height=max_height,
width=max_width,
patch_size=patch_size,
embed_dim=self.inner_dim,
bias=True,
dtype=dtype,
device=device,
operations=operations,
)
self.final_layer = T2IFinalLayer(self.inner_dim, patch_size=patch_size, out_channels=out_channels, dtype=dtype, device=device, operations=operations)
def forward_lyric_encoder(
self,
lyric_token_idx: Optional[torch.LongTensor] = None,
lyric_mask: Optional[torch.LongTensor] = None,
out_dtype=None,
):
# N x T x D
lyric_embs = self.lyric_embs(lyric_token_idx, out_dtype=out_dtype)
prompt_prenet_out, _mask = self.lyric_encoder(lyric_embs, lyric_mask, decoding_chunk_size=1, num_decoding_left_chunks=-1)
prompt_prenet_out = self.lyric_proj(prompt_prenet_out)
return prompt_prenet_out
def encode(
self,
encoder_text_hidden_states: Optional[torch.Tensor] = None,
text_attention_mask: Optional[torch.LongTensor] = None,
speaker_embeds: Optional[torch.FloatTensor] = None,
lyric_token_idx: Optional[torch.LongTensor] = None,
lyric_mask: Optional[torch.LongTensor] = None,
lyrics_strength=1.0,
):
bs = encoder_text_hidden_states.shape[0]
device = encoder_text_hidden_states.device
# speaker embedding
encoder_spk_hidden_states = self.speaker_embedder(speaker_embeds).unsqueeze(1)
# genre embedding
encoder_text_hidden_states = self.genre_embedder(encoder_text_hidden_states)
# lyric
encoder_lyric_hidden_states = self.forward_lyric_encoder(
lyric_token_idx=lyric_token_idx,
lyric_mask=lyric_mask,
out_dtype=encoder_text_hidden_states.dtype,
)
encoder_lyric_hidden_states *= lyrics_strength
encoder_hidden_states = torch.cat([encoder_spk_hidden_states, encoder_text_hidden_states, encoder_lyric_hidden_states], dim=1)
encoder_hidden_mask = None
if text_attention_mask is not None:
speaker_mask = torch.ones(bs, 1, device=device)
encoder_hidden_mask = torch.cat([speaker_mask, text_attention_mask, lyric_mask], dim=1)
return encoder_hidden_states, encoder_hidden_mask
def decode(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_mask: torch.Tensor,
timestep: Optional[torch.Tensor],
output_length: int = 0,
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
controlnet_scale: Union[float, torch.Tensor] = 1.0,
):
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
temb = self.t_block(embedded_timestep)
hidden_states = self.proj_in(hidden_states)
# controlnet logic
if block_controlnet_hidden_states is not None:
control_condi = cross_norm(hidden_states, block_controlnet_hidden_states)
hidden_states = hidden_states + control_condi * controlnet_scale
# inner_hidden_states = []
rotary_freqs_cis = self.rotary_emb(hidden_states, seq_len=hidden_states.shape[1])
encoder_rotary_freqs_cis = self.rotary_emb(encoder_hidden_states, seq_len=encoder_hidden_states.shape[1])
for index_block, block in enumerate(self.transformer_blocks):
hidden_states = block(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_hidden_mask,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
temb=temb,
)
output = self.final_layer(hidden_states, embedded_timestep, output_length)
return output
def forward(self,
x,
timestep,
attention_mask=None,
context: Optional[torch.Tensor] = None,
text_attention_mask: Optional[torch.LongTensor] = None,
speaker_embeds: Optional[torch.FloatTensor] = None,
lyric_token_idx: Optional[torch.LongTensor] = None,
lyric_mask: Optional[torch.LongTensor] = None,
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
controlnet_scale: Union[float, torch.Tensor] = 1.0,
lyrics_strength=1.0,
**kwargs
):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
).execute(x, timestep, attention_mask, context, text_attention_mask, speaker_embeds, lyric_token_idx, lyric_mask, block_controlnet_hidden_states,
controlnet_scale, lyrics_strength, **kwargs)
def _forward(
self,
x,
timestep,
attention_mask=None,
context: Optional[torch.Tensor] = None,
text_attention_mask: Optional[torch.LongTensor] = None,
speaker_embeds: Optional[torch.FloatTensor] = None,
lyric_token_idx: Optional[torch.LongTensor] = None,
lyric_mask: Optional[torch.LongTensor] = None,
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
controlnet_scale: Union[float, torch.Tensor] = 1.0,
lyrics_strength=1.0,
**kwargs
):
hidden_states = x
encoder_text_hidden_states = context
encoder_hidden_states, encoder_hidden_mask = self.encode(
encoder_text_hidden_states=encoder_text_hidden_states,
text_attention_mask=text_attention_mask,
speaker_embeds=speaker_embeds,
lyric_token_idx=lyric_token_idx,
lyric_mask=lyric_mask,
lyrics_strength=lyrics_strength,
)
output_length = hidden_states.shape[-1]
output = self.decode(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_mask=encoder_hidden_mask,
timestep=timestep,
output_length=output_length,
block_controlnet_hidden_states=block_controlnet_hidden_states,
controlnet_scale=controlnet_scale,
)
return output
# Rewritten from diffusers
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Union
import comfy.model_management
import comfy.ops
ops = comfy.ops.disable_weight_init
class RMSNorm(ops.RMSNorm):
def __init__(self, dim, eps=1e-5, elementwise_affine=True, bias=False):
super().__init__(dim, eps=eps, elementwise_affine=elementwise_affine)
if elementwise_affine:
self.bias = nn.Parameter(torch.empty(dim)) if bias else None
def forward(self, x):
x = super().forward(x)
if self.elementwise_affine:
if self.bias is not None:
x = x + comfy.model_management.cast_to(self.bias, dtype=x.dtype, device=x.device)
return x
def get_normalization(norm_type, num_features, num_groups=32, eps=1e-5):
if norm_type == "batch_norm":
return nn.BatchNorm2d(num_features)
elif norm_type == "group_norm":
return ops.GroupNorm(num_groups, num_features)
elif norm_type == "layer_norm":
return ops.LayerNorm(num_features)
elif norm_type == "rms_norm":
return RMSNorm(num_features, eps=eps, elementwise_affine=True, bias=True)
else:
raise ValueError(f"Unknown normalization type: {norm_type}")
def get_activation(activation_type):
if activation_type == "relu":
return nn.ReLU()
elif activation_type == "relu6":
return nn.ReLU6()
elif activation_type == "silu":
return nn.SiLU()
elif activation_type == "leaky_relu":
return nn.LeakyReLU(0.2)
else:
raise ValueError(f"Unknown activation type: {activation_type}")
class ResBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
norm_type: str = "batch_norm",
act_fn: str = "relu6",
) -> None:
super().__init__()
self.norm_type = norm_type
self.nonlinearity = get_activation(act_fn) if act_fn is not None else nn.Identity()
self.conv1 = ops.Conv2d(in_channels, in_channels, 3, 1, 1)
self.conv2 = ops.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False)
self.norm = get_normalization(norm_type, out_channels)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
hidden_states = self.conv1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.norm_type == "rms_norm":
# move channel to the last dimension so we apply RMSnorm across channel dimension
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
else:
hidden_states = self.norm(hidden_states)
return hidden_states + residual
class SanaMultiscaleAttentionProjection(nn.Module):
def __init__(
self,
in_channels: int,
num_attention_heads: int,
kernel_size: int,
) -> None:
super().__init__()
channels = 3 * in_channels
self.proj_in = ops.Conv2d(
channels,
channels,
kernel_size,
padding=kernel_size // 2,
groups=channels,
bias=False,
)
self.proj_out = ops.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.proj_in(hidden_states)
hidden_states = self.proj_out(hidden_states)
return hidden_states
class SanaMultiscaleLinearAttention(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_attention_heads: int = None,
attention_head_dim: int = 8,
mult: float = 1.0,
norm_type: str = "batch_norm",
kernel_sizes: tuple = (5,),
eps: float = 1e-15,
residual_connection: bool = False,
):
super().__init__()
self.eps = eps
self.attention_head_dim = attention_head_dim
self.norm_type = norm_type
self.residual_connection = residual_connection
num_attention_heads = (
int(in_channels // attention_head_dim * mult)
if num_attention_heads is None
else num_attention_heads
)
inner_dim = num_attention_heads * attention_head_dim
self.to_q = ops.Linear(in_channels, inner_dim, bias=False)
self.to_k = ops.Linear(in_channels, inner_dim, bias=False)
self.to_v = ops.Linear(in_channels, inner_dim, bias=False)
self.to_qkv_multiscale = nn.ModuleList()
for kernel_size in kernel_sizes:
self.to_qkv_multiscale.append(
SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size)
)
self.nonlinearity = nn.ReLU()
self.to_out = ops.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False)
self.norm_out = get_normalization(norm_type, out_channels)
def apply_linear_attention(self, query, key, value):
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1)
scores = torch.matmul(value, key.transpose(-1, -2))
hidden_states = torch.matmul(scores, query)
hidden_states = hidden_states.to(dtype=torch.float32)
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
return hidden_states
def apply_quadratic_attention(self, query, key, value):
scores = torch.matmul(key.transpose(-1, -2), query)
scores = scores.to(dtype=torch.float32)
scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
hidden_states = torch.matmul(value, scores.to(value.dtype))
return hidden_states
def forward(self, hidden_states):
height, width = hidden_states.shape[-2:]
if height * width > self.attention_head_dim:
use_linear_attention = True
else:
use_linear_attention = False
residual = hidden_states
batch_size, _, height, width = list(hidden_states.size())
original_dtype = hidden_states.dtype
hidden_states = hidden_states.movedim(1, -1)
query = self.to_q(hidden_states)
key = self.to_k(hidden_states)
value = self.to_v(hidden_states)
hidden_states = torch.cat([query, key, value], dim=3)
hidden_states = hidden_states.movedim(-1, 1)
multi_scale_qkv = [hidden_states]
for block in self.to_qkv_multiscale:
multi_scale_qkv.append(block(hidden_states))
hidden_states = torch.cat(multi_scale_qkv, dim=1)
if use_linear_attention:
# for linear attention upcast hidden_states to float32
hidden_states = hidden_states.to(dtype=torch.float32)
hidden_states = hidden_states.reshape(batch_size, -1, 3 * self.attention_head_dim, height * width)
query, key, value = hidden_states.chunk(3, dim=2)
query = self.nonlinearity(query)
key = self.nonlinearity(key)
if use_linear_attention:
hidden_states = self.apply_linear_attention(query, key, value)
hidden_states = hidden_states.to(dtype=original_dtype)
else:
hidden_states = self.apply_quadratic_attention(query, key, value)
hidden_states = torch.reshape(hidden_states, (batch_size, -1, height, width))
hidden_states = self.to_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
if self.norm_type == "rms_norm":
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
else:
hidden_states = self.norm_out(hidden_states)
if self.residual_connection:
hidden_states = hidden_states + residual
return hidden_states
class EfficientViTBlock(nn.Module):
def __init__(
self,
in_channels: int,
mult: float = 1.0,
attention_head_dim: int = 32,
qkv_multiscales: tuple = (5,),
norm_type: str = "batch_norm",
) -> None:
super().__init__()
self.attn = SanaMultiscaleLinearAttention(
in_channels=in_channels,
out_channels=in_channels,
mult=mult,
attention_head_dim=attention_head_dim,
norm_type=norm_type,
kernel_sizes=qkv_multiscales,
residual_connection=True,
)
self.conv_out = GLUMBConv(
in_channels=in_channels,
out_channels=in_channels,
norm_type="rms_norm",
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.attn(x)
x = self.conv_out(x)
return x
class GLUMBConv(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
expand_ratio: float = 4,
norm_type: str = None,
residual_connection: bool = True,
) -> None:
super().__init__()
hidden_channels = int(expand_ratio * in_channels)
self.norm_type = norm_type
self.residual_connection = residual_connection
self.nonlinearity = nn.SiLU()
self.conv_inverted = ops.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)
self.conv_depth = ops.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2)
self.conv_point = ops.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)
self.norm = None
if norm_type == "rms_norm":
self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.residual_connection:
residual = hidden_states
hidden_states = self.conv_inverted(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv_depth(hidden_states)
hidden_states, gate = torch.chunk(hidden_states, 2, dim=1)
hidden_states = hidden_states * self.nonlinearity(gate)
hidden_states = self.conv_point(hidden_states)
if self.norm_type == "rms_norm":
# move channel to the last dimension so we apply RMSnorm across channel dimension
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
if self.residual_connection:
hidden_states = hidden_states + residual
return hidden_states
def get_block(
block_type: str,
in_channels: int,
out_channels: int,
attention_head_dim: int,
norm_type: str,
act_fn: str,
qkv_mutliscales: tuple = (),
):
if block_type == "ResBlock":
block = ResBlock(in_channels, out_channels, norm_type, act_fn)
elif block_type == "EfficientViTBlock":
block = EfficientViTBlock(
in_channels,
attention_head_dim=attention_head_dim,
norm_type=norm_type,
qkv_multiscales=qkv_mutliscales
)
else:
raise ValueError(f"Block with {block_type=} is not supported.")
return block
class DCDownBlock2d(nn.Module):
def __init__(self, in_channels: int, out_channels: int, downsample: bool = False, shortcut: bool = True) -> None:
super().__init__()
self.downsample = downsample
self.factor = 2
self.stride = 1 if downsample else 2
self.group_size = in_channels * self.factor**2 // out_channels
self.shortcut = shortcut
out_ratio = self.factor**2
if downsample:
assert out_channels % out_ratio == 0
out_channels = out_channels // out_ratio
self.conv = ops.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=self.stride,
padding=1,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
x = self.conv(hidden_states)
if self.downsample:
x = F.pixel_unshuffle(x, self.factor)
if self.shortcut:
y = F.pixel_unshuffle(hidden_states, self.factor)
y = y.unflatten(1, (-1, self.group_size))
y = y.mean(dim=2)
hidden_states = x + y
else:
hidden_states = x
return hidden_states
class DCUpBlock2d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
interpolate: bool = False,
shortcut: bool = True,
interpolation_mode: str = "nearest",
) -> None:
super().__init__()
self.interpolate = interpolate
self.interpolation_mode = interpolation_mode
self.shortcut = shortcut
self.factor = 2
self.repeats = out_channels * self.factor**2 // in_channels
out_ratio = self.factor**2
if not interpolate:
out_channels = out_channels * out_ratio
self.conv = ops.Conv2d(in_channels, out_channels, 3, 1, 1)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.interpolate:
x = F.interpolate(hidden_states, scale_factor=self.factor, mode=self.interpolation_mode)
x = self.conv(x)
else:
x = self.conv(hidden_states)
x = F.pixel_shuffle(x, self.factor)
if self.shortcut:
y = hidden_states.repeat_interleave(self.repeats, dim=1, output_size=hidden_states.shape[1] * self.repeats)
y = F.pixel_shuffle(y, self.factor)
hidden_states = x + y
else:
hidden_states = x
return hidden_states
class Encoder(nn.Module):
def __init__(
self,
in_channels: int,
latent_channels: int,
attention_head_dim: int = 32,
block_type: str or tuple = "ResBlock",
block_out_channels: tuple = (128, 256, 512, 512, 1024, 1024),
layers_per_block: tuple = (2, 2, 2, 2, 2, 2),
qkv_multiscales: tuple = ((), (), (), (5,), (5,), (5,)),
downsample_block_type: str = "pixel_unshuffle",
out_shortcut: bool = True,
):
super().__init__()
num_blocks = len(block_out_channels)
if isinstance(block_type, str):
block_type = (block_type,) * num_blocks
if layers_per_block[0] > 0:
self.conv_in = ops.Conv2d(
in_channels,
block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1],
kernel_size=3,
stride=1,
padding=1,
)
else:
self.conv_in = DCDownBlock2d(
in_channels=in_channels,
out_channels=block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1],
downsample=downsample_block_type == "pixel_unshuffle",
shortcut=False,
)
down_blocks = []
for i, (out_channel, num_layers) in enumerate(zip(block_out_channels, layers_per_block)):
down_block_list = []
for _ in range(num_layers):
block = get_block(
block_type[i],
out_channel,
out_channel,
attention_head_dim=attention_head_dim,
norm_type="rms_norm",
act_fn="silu",
qkv_mutliscales=qkv_multiscales[i],
)
down_block_list.append(block)
if i < num_blocks - 1 and num_layers > 0:
downsample_block = DCDownBlock2d(
in_channels=out_channel,
out_channels=block_out_channels[i + 1],
downsample=downsample_block_type == "pixel_unshuffle",
shortcut=True,
)
down_block_list.append(downsample_block)
down_blocks.append(nn.Sequential(*down_block_list))
self.down_blocks = nn.ModuleList(down_blocks)
self.conv_out = ops.Conv2d(block_out_channels[-1], latent_channels, 3, 1, 1)
self.out_shortcut = out_shortcut
if out_shortcut:
self.out_shortcut_average_group_size = block_out_channels[-1] // latent_channels
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states)
for down_block in self.down_blocks:
hidden_states = down_block(hidden_states)
if self.out_shortcut:
x = hidden_states.unflatten(1, (-1, self.out_shortcut_average_group_size))
x = x.mean(dim=2)
hidden_states = self.conv_out(hidden_states) + x
else:
hidden_states = self.conv_out(hidden_states)
return hidden_states
class Decoder(nn.Module):
def __init__(
self,
in_channels: int,
latent_channels: int,
attention_head_dim: int = 32,
block_type: str or tuple = "ResBlock",
block_out_channels: tuple = (128, 256, 512, 512, 1024, 1024),
layers_per_block: tuple = (2, 2, 2, 2, 2, 2),
qkv_multiscales: tuple = ((), (), (), (5,), (5,), (5,)),
norm_type: str or tuple = "rms_norm",
act_fn: str or tuple = "silu",
upsample_block_type: str = "pixel_shuffle",
in_shortcut: bool = True,
):
super().__init__()
num_blocks = len(block_out_channels)
if isinstance(block_type, str):
block_type = (block_type,) * num_blocks
if isinstance(norm_type, str):
norm_type = (norm_type,) * num_blocks
if isinstance(act_fn, str):
act_fn = (act_fn,) * num_blocks
self.conv_in = ops.Conv2d(latent_channels, block_out_channels[-1], 3, 1, 1)
self.in_shortcut = in_shortcut
if in_shortcut:
self.in_shortcut_repeats = block_out_channels[-1] // latent_channels
up_blocks = []
for i, (out_channel, num_layers) in reversed(list(enumerate(zip(block_out_channels, layers_per_block)))):
up_block_list = []
if i < num_blocks - 1 and num_layers > 0:
upsample_block = DCUpBlock2d(
block_out_channels[i + 1],
out_channel,
interpolate=upsample_block_type == "interpolate",
shortcut=True,
)
up_block_list.append(upsample_block)
for _ in range(num_layers):
block = get_block(
block_type[i],
out_channel,
out_channel,
attention_head_dim=attention_head_dim,
norm_type=norm_type[i],
act_fn=act_fn[i],
qkv_mutliscales=qkv_multiscales[i],
)
up_block_list.append(block)
up_blocks.insert(0, nn.Sequential(*up_block_list))
self.up_blocks = nn.ModuleList(up_blocks)
channels = block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1]
self.norm_out = RMSNorm(channels, 1e-5, elementwise_affine=True, bias=True)
self.conv_act = nn.ReLU()
self.conv_out = None
if layers_per_block[0] > 0:
self.conv_out = ops.Conv2d(channels, in_channels, 3, 1, 1)
else:
self.conv_out = DCUpBlock2d(
channels, in_channels, interpolate=upsample_block_type == "interpolate", shortcut=False
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.in_shortcut:
x = hidden_states.repeat_interleave(
self.in_shortcut_repeats, dim=1, output_size=hidden_states.shape[1] * self.in_shortcut_repeats
)
hidden_states = self.conv_in(hidden_states) + x
else:
hidden_states = self.conv_in(hidden_states)
for up_block in reversed(self.up_blocks):
hidden_states = up_block(hidden_states)
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class AutoencoderDC(nn.Module):
def __init__(
self,
in_channels: int = 2,
latent_channels: int = 8,
attention_head_dim: int = 32,
encoder_block_types: Union[str, Tuple[str]] = ["ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock"],
decoder_block_types: Union[str, Tuple[str]] = ["ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock"],
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024),
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024),
encoder_layers_per_block: Tuple[int] = (2, 2, 3, 3),
decoder_layers_per_block: Tuple[int] = (3, 3, 3, 3),
encoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (5,), (5,)),
decoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (5,), (5,)),
upsample_block_type: str = "interpolate",
downsample_block_type: str = "Conv",
decoder_norm_types: Union[str, Tuple[str]] = "rms_norm",
decoder_act_fns: Union[str, Tuple[str]] = "silu",
scaling_factor: float = 0.41407,
) -> None:
super().__init__()
self.encoder = Encoder(
in_channels=in_channels,
latent_channels=latent_channels,
attention_head_dim=attention_head_dim,
block_type=encoder_block_types,
block_out_channels=encoder_block_out_channels,
layers_per_block=encoder_layers_per_block,
qkv_multiscales=encoder_qkv_multiscales,
downsample_block_type=downsample_block_type,
)
self.decoder = Decoder(
in_channels=in_channels,
latent_channels=latent_channels,
attention_head_dim=attention_head_dim,
block_type=decoder_block_types,
block_out_channels=decoder_block_out_channels,
layers_per_block=decoder_layers_per_block,
qkv_multiscales=decoder_qkv_multiscales,
norm_type=decoder_norm_types,
act_fn=decoder_act_fns,
upsample_block_type=upsample_block_type,
)
self.scaling_factor = scaling_factor
self.spatial_compression_ratio = 2 ** (len(encoder_block_out_channels) - 1)
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Internal encoding function."""
encoded = self.encoder(x)
return encoded * self.scaling_factor
def decode(self, z: torch.Tensor) -> torch.Tensor:
# Scale the latents back
z = z / self.scaling_factor
decoded = self.decoder(z)
return decoded
def forward(self, x: torch.Tensor) -> torch.Tensor:
z = self.encode(x)
return self.decode(z)
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_dcae_pipeline.py
import torch
from .autoencoder_dc import AutoencoderDC
import logging
try:
import torchaudio
except:
logging.warning("torchaudio missing, ACE model will be broken")
import torchvision.transforms as transforms
from .music_vocoder import ADaMoSHiFiGANV1
class MusicDCAE(torch.nn.Module):
def __init__(self, source_sample_rate=None, dcae_config={}, vocoder_config={}):
super(MusicDCAE, self).__init__()
self.dcae = AutoencoderDC(**dcae_config)
self.vocoder = ADaMoSHiFiGANV1(**vocoder_config)
if source_sample_rate is None:
self.source_sample_rate = 48000
else:
self.source_sample_rate = source_sample_rate
# self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
self.transform = transforms.Compose([
transforms.Normalize(0.5, 0.5),
])
self.min_mel_value = -11.0
self.max_mel_value = 3.0
self.audio_chunk_size = int(round((1024 * 512 / 44100 * 48000)))
self.mel_chunk_size = 1024
self.time_dimention_multiple = 8
self.latent_chunk_size = self.mel_chunk_size // self.time_dimention_multiple
self.scale_factor = 0.1786
self.shift_factor = -1.9091
def load_audio(self, audio_path):
audio, sr = torchaudio.load(audio_path)
return audio, sr
def forward_mel(self, audios):
mels = []
for i in range(len(audios)):
image = self.vocoder.mel_transform(audios[i])
mels.append(image)
mels = torch.stack(mels)
return mels
@torch.no_grad()
def encode(self, audios, audio_lengths=None, sr=None):
if audio_lengths is None:
audio_lengths = torch.tensor([audios.shape[2]] * audios.shape[0])
audio_lengths = audio_lengths.to(audios.device)
if sr is None:
sr = self.source_sample_rate
if sr != 44100:
audios = torchaudio.functional.resample(audios, sr, 44100)
max_audio_len = audios.shape[-1]
if max_audio_len % (8 * 512) != 0:
audios = torch.nn.functional.pad(audios, (0, 8 * 512 - max_audio_len % (8 * 512)))
mels = self.forward_mel(audios)
mels = (mels - self.min_mel_value) / (self.max_mel_value - self.min_mel_value)
mels = self.transform(mels)
latents = []
for mel in mels:
latent = self.dcae.encoder(mel.unsqueeze(0))
latents.append(latent)
latents = torch.cat(latents, dim=0)
# latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long()
latents = (latents - self.shift_factor) * self.scale_factor
return latents
# return latents, latent_lengths
@torch.no_grad()
def decode(self, latents, audio_lengths=None, sr=None):
latents = latents / self.scale_factor + self.shift_factor
pred_wavs = []
for latent in latents:
mels = self.dcae.decoder(latent.unsqueeze(0))
mels = mels * 0.5 + 0.5
mels = mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value
wav = self.vocoder.decode(mels[0]).squeeze(1)
if sr is not None:
# resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype)
wav = torchaudio.functional.resample(wav, 44100, sr)
# wav = resampler(wav)
else:
sr = 44100
pred_wavs.append(wav)
if audio_lengths is not None:
pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
return torch.stack(pred_wavs)
# return sr, pred_wavs
def forward(self, audios, audio_lengths=None, sr=None):
latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)
sr, pred_wavs = self.decode(latents=latents, audio_lengths=audio_lengths, sr=sr)
return sr, pred_wavs, latents, latent_lengths
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment