Commit 1f5da520 authored by yangzhong's avatar yangzhong
Browse files

git init

parents
Pipeline #3144 failed with stages
in 0 seconds
from typing import Dict, Union
import torch
import torch.nn as nn
from ...util import append_dims, instantiate_from_config
class Denoiser(nn.Module):
def __init__(self, weighting_config, scaling_config):
super().__init__()
self.weighting = instantiate_from_config(weighting_config)
self.scaling = instantiate_from_config(scaling_config)
def possibly_quantize_sigma(self, sigma):
return sigma
def possibly_quantize_c_noise(self, c_noise):
return c_noise
def w(self, sigma):
return self.weighting(sigma)
def forward(
self,
network: nn.Module,
input: torch.Tensor,
sigma: torch.Tensor,
cond: Dict,
**additional_model_inputs,
) -> torch.Tensor:
sigma = self.possibly_quantize_sigma(sigma)
sigma_shape = sigma.shape
sigma = append_dims(sigma, input.ndim)
c_skip, c_out, c_in, c_noise = self.scaling(sigma, **additional_model_inputs)
c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
if input.shape[2] == 32:
input_noise_state, _ = input.chunk(2, dim=2)
res = input_noise_state * c_skip
else:
res = input * c_skip
# print('input shape:', input.shape) # torch.Size([2, 8, 32, 60, 90])
return network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out + res
class DiscreteDenoiser(Denoiser):
def __init__(
self,
weighting_config,
scaling_config,
num_idx,
discretization_config,
do_append_zero=False,
quantize_c_noise=True,
flip=True,
):
super().__init__(weighting_config, scaling_config)
sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip)
self.sigmas = sigmas
# self.register_buffer("sigmas", sigmas)
self.quantize_c_noise = quantize_c_noise
def sigma_to_idx(self, sigma):
dists = sigma - self.sigmas.to(sigma.device)[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape)
def idx_to_sigma(self, idx):
return self.sigmas.to(idx.device)[idx]
def possibly_quantize_sigma(self, sigma):
return self.idx_to_sigma(self.sigma_to_idx(sigma))
def possibly_quantize_c_noise(self, c_noise):
if self.quantize_c_noise:
return self.sigma_to_idx(c_noise)
else:
return c_noise
from abc import ABC, abstractmethod
from typing import Any, Tuple
import torch
class DenoiserScaling(ABC):
@abstractmethod
def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
pass
class EDMScaling:
def __init__(self, sigma_data: float = 0.5):
self.sigma_data = sigma_data
def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
c_noise = 0.25 * sigma.log()
return c_skip, c_out, c_in, c_noise
class EpsScaling:
def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
c_skip = torch.ones_like(sigma, device=sigma.device)
c_out = -sigma
c_in = 1 / (sigma**2 + 1.0) ** 0.5
c_noise = sigma.clone()
return c_skip, c_out, c_in, c_noise
class VScaling:
def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
c_skip = 1.0 / (sigma**2 + 1.0)
c_out = -sigma / (sigma**2 + 1.0) ** 0.5
c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
c_noise = sigma.clone()
return c_skip, c_out, c_in, c_noise
class VScalingWithEDMcNoise(DenoiserScaling):
def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
c_skip = 1.0 / (sigma**2 + 1.0)
c_out = -sigma / (sigma**2 + 1.0) ** 0.5
c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
c_noise = 0.25 * sigma.log()
return c_skip, c_out, c_in, c_noise
class VideoScaling: # similar to VScaling
def __call__(
self, alphas_cumprod_sqrt: torch.Tensor, **additional_model_inputs
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
c_skip = alphas_cumprod_sqrt
c_out = -((1 - alphas_cumprod_sqrt**2) ** 0.5)
c_in = torch.ones_like(alphas_cumprod_sqrt, device=alphas_cumprod_sqrt.device)
c_noise = additional_model_inputs["idx"].clone()
return c_skip, c_out, c_in, c_noise
import torch
class UnitWeighting:
def __call__(self, sigma):
return torch.ones_like(sigma, device=sigma.device)
class EDMWeighting:
def __init__(self, sigma_data=0.5):
self.sigma_data = sigma_data
def __call__(self, sigma):
return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
class VWeighting(EDMWeighting):
def __init__(self):
super().__init__(sigma_data=1.0)
class EpsWeighting:
def __call__(self, sigma):
return sigma**-2.0
from abc import abstractmethod
from functools import partial
import numpy as np
import torch
from ...modules.diffusionmodules.util import make_beta_schedule
from ...util import append_zero
def generate_roughly_equally_spaced_steps(num_substeps: int, max_step: int) -> np.ndarray:
return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1]
class Discretization:
def __call__(self, n, do_append_zero=True, device="cpu", flip=False, return_idx=False):
if return_idx:
sigmas, idx = self.get_sigmas(n, device=device, return_idx=return_idx)
else:
sigmas = self.get_sigmas(n, device=device, return_idx=return_idx)
sigmas = append_zero(sigmas) if do_append_zero else sigmas
if return_idx:
return sigmas if not flip else torch.flip(sigmas, (0,)), idx
else:
return sigmas if not flip else torch.flip(sigmas, (0,))
@abstractmethod
def get_sigmas(self, n, device):
pass
class EDMDiscretization(Discretization):
def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0):
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.rho = rho
def get_sigmas(self, n, device="cpu"):
ramp = torch.linspace(0, 1, n, device=device)
min_inv_rho = self.sigma_min ** (1 / self.rho)
max_inv_rho = self.sigma_max ** (1 / self.rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho
return sigmas
class LegacyDDPMDiscretization(Discretization):
def __init__(
self,
linear_start=0.00085,
linear_end=0.0120,
num_timesteps=1000,
):
super().__init__()
self.num_timesteps = num_timesteps
betas = make_beta_schedule("linear", num_timesteps, linear_start=linear_start, linear_end=linear_end)
alphas = 1.0 - betas
self.alphas_cumprod = np.cumprod(alphas, axis=0)
self.to_torch = partial(torch.tensor, dtype=torch.float32)
def get_sigmas(self, n, device="cpu"):
if n < self.num_timesteps:
timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
alphas_cumprod = self.alphas_cumprod[timesteps]
elif n == self.num_timesteps:
alphas_cumprod = self.alphas_cumprod
else:
raise ValueError
to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
return torch.flip(sigmas, (0,)) # sigma_t: 14.4 -> 0.029
class ZeroSNRDDPMDiscretization(Discretization):
def __init__(
self,
linear_start=0.00085,
linear_end=0.0120,
num_timesteps=1000,
shift_scale=1.0, # noise schedule t_n -> t_m: logSNR(t_m) = logSNR(t_n) - log(shift_scale)
keep_start=False,
post_shift=False,
):
super().__init__()
if keep_start and not post_shift:
linear_start = linear_start / (shift_scale + (1 - shift_scale) * linear_start)
self.num_timesteps = num_timesteps
betas = make_beta_schedule("linear", num_timesteps, linear_start=linear_start, linear_end=linear_end)
alphas = 1.0 - betas
self.alphas_cumprod = np.cumprod(alphas, axis=0)
self.to_torch = partial(torch.tensor, dtype=torch.float32)
# SNR shift
if not post_shift:
self.alphas_cumprod = self.alphas_cumprod / (shift_scale + (1 - shift_scale) * self.alphas_cumprod)
self.post_shift = post_shift
self.shift_scale = shift_scale
def get_sigmas(self, n, device="cpu", return_idx=False):
if n < self.num_timesteps:
timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
alphas_cumprod = self.alphas_cumprod[timesteps]
elif n == self.num_timesteps:
alphas_cumprod = self.alphas_cumprod
else:
raise ValueError
to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
alphas_cumprod = to_torch(alphas_cumprod)
alphas_cumprod_sqrt = alphas_cumprod.sqrt()
alphas_cumprod_sqrt_0 = alphas_cumprod_sqrt[0].clone()
alphas_cumprod_sqrt_T = alphas_cumprod_sqrt[-1].clone()
alphas_cumprod_sqrt -= alphas_cumprod_sqrt_T
alphas_cumprod_sqrt *= alphas_cumprod_sqrt_0 / (alphas_cumprod_sqrt_0 - alphas_cumprod_sqrt_T)
if self.post_shift:
alphas_cumprod_sqrt = (
alphas_cumprod_sqrt**2 / (self.shift_scale + (1 - self.shift_scale) * alphas_cumprod_sqrt**2)
) ** 0.5
if return_idx:
return torch.flip(alphas_cumprod_sqrt, (0,)), timesteps
else:
return torch.flip(alphas_cumprod_sqrt, (0,)) # sqrt(alpha_t): 0 -> 0.99
import logging
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, Union
from functools import partial
import math
import torch
from einops import rearrange, repeat
from ...util import append_dims, default, instantiate_from_config
class Guider(ABC):
@abstractmethod
def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
pass
def prepare_inputs(self, x: torch.Tensor, s: float, c: Dict, uc: Dict) -> Tuple[torch.Tensor, float, Dict]:
pass
class VanillaCFG:
"""
implements parallelized CFG
"""
def __init__(self, scale, dyn_thresh_config=None):
self.scale = scale
scale_schedule = lambda scale, sigma: scale # independent of step
self.scale_schedule = partial(scale_schedule, scale)
self.dyn_thresh = instantiate_from_config(
default(
dyn_thresh_config,
{"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"},
)
)
def __call__(self, x, sigma, scale=None):
x_u, x_c = x.chunk(2)
scale_value = default(scale, self.scale_schedule(sigma))
x_pred = self.dyn_thresh(x_u, x_c, scale_value)
return x_pred
def prepare_inputs(self, x, s, c, uc, lq=None):
c_out = dict()
for k in c:
if k in ["vector", "crossattn", "concat"]:
c_out[k] = torch.cat((uc[k], c[k]), 0)
else:
assert c[k] == uc[k]
c_out[k] = c[k]
x = torch.cat([x] * 2)
if lq is not None:
# print("lq shape:", lq.shape)
# print("x shape:", x.shape)
x = torch.cat((x, lq), dim=2)
return x, torch.cat([s] * 2), c_out
class DynamicCFG(VanillaCFG):
def __init__(self, scale, exp, num_steps, dyn_thresh_config=None):
super().__init__(scale, dyn_thresh_config)
scale_schedule = (
lambda scale, sigma, step_index: 1 + scale * (1 - math.cos(math.pi * (step_index / num_steps) ** exp)) / 2
)
self.scale_schedule = partial(scale_schedule, scale)
self.dyn_thresh = instantiate_from_config(
default(
dyn_thresh_config,
{"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"},
)
)
def __call__(self, x, sigma, step_index, scale=None):
x_u, x_c = x.chunk(2)
scale_value = self.scale_schedule(sigma, step_index.item())
x_pred = self.dyn_thresh(x_u, x_c, scale_value)
return x_pred
class IdentityGuider:
def __call__(self, x, sigma):
return x
def prepare_inputs(self, x, s, c, uc):
c_out = dict()
for k in c:
c_out[k] = c[k]
return x, s, c_out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment