Commit e2364931 authored by mashun1's avatar mashun1
Browse files

pixart-alpha

parents
Pipeline #861 canceled with stages
_base_ = ['../PixArt_xl2_internal.py']
data_root = 'data'
image_list_json = ['data_info.json',]
data = dict(type='InternalDataHed', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True)
image_size = 1024
# model setting
model = 'PixArtMS_XL_2'
fp32_attention = False # Set to True if you got NaN loss
load_from = 'path-to-pixart-checkpoints'
vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
window_block_indexes = []
window_size=0
use_rel_pos=False
lewei_scale = 2.0
# training setting
num_workers=10
train_batch_size = 4 # set the batch size according to your VRAM
num_epochs = 10 # 3
gradient_accumulation_steps = 4
grad_checkpointing = True
gradient_clip = 0.01
optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
lr_schedule_args = dict(num_warmup_steps=0)
save_model_epochs=5
save_model_steps=1000
log_interval = 20
eval_sampling_steps = 200
work_dir = 'output_debug/debug'
# controlnet related params
copy_blocks_num = 13
class_dropout_prob = 0.5
train_ratio = 1
_base_ = ['../PixArt_xl2_internal.py']
data_root = 'data/dreambooth/dataset'
data = dict(type='DreamBooth', root='dog6', prompt=['a photo of sks dog'], transform='default_train', load_vae_feat=True)
image_size = 1024
# model setting
model = 'PixArtMS_XL_2' # model for multi-scale training
fp32_attention = True
load_from = 'Path/to/PixArt-XL-2-1024-MS.pth'
vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
window_block_indexes = []
window_size=0
use_rel_pos=False
aspect_ratio_type = 'ASPECT_RATIO_1024' # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256]
multi_scale = True # if use multiscale dataset model training
lewei_scale = 2.0
# training setting
num_workers=1
train_batch_size = 1
num_epochs = 200
gradient_accumulation_steps = 1
grad_checkpointing = True
gradient_clip = 0.01
optimizer = dict(type='AdamW', lr=5e-6, weight_decay=3e-2, eps=1e-10)
lr_schedule_args = dict(num_warmup_steps=0)
auto_lr = None
log_interval = 1
save_model_epochs=10000
save_model_steps=100
work_dir = 'output/debug'
_base_ = ['../PixArt_xl2_internal.py']
data_root = 'data'
image_list_json = ['data_info.json',]
data = dict(type='InternalDataHed', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True)
image_size = 512
# model setting
model = 'PixArt_XL_2'
fp32_attention = False # Set to True if you got NaN loss
load_from = 'path-to-pixart-checkpoints'
vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
window_block_indexes = []
window_size=0
use_rel_pos=False
lewei_scale = 1.0
# training setting
num_workers=10
train_batch_size = 12 # 32 # max 96 for DiT-L/4 when grad_checkpoint
num_epochs = 1000 # 3
gradient_accumulation_steps = 4
grad_checkpointing = True
gradient_clip = 0.01
optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
lr_schedule_args = dict(num_warmup_steps=0)
save_model_epochs=5
save_model_steps=1000
log_interval = 20
eval_sampling_steps = 200
work_dir = 'output_debug/debug'
# controlnet related params
copy_blocks_num = 13
class_dropout_prob = 0.5
train_ratio = 0.1
_base_ = ['../PixArt_xl2_internal.py']
data_root = 'data'
image_list_json = ['data_info.json',]
data = dict(type='InternalData', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True)
image_size = 1024
# model setting
window_block_indexes = []
window_size=0
use_rel_pos=False
model = 'PixArt_XL_2'
fp32_attention = True
load_from = None
vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
lewei_scale = 2.0
# training setting
num_workers=10
train_batch_size = 2 # 32
num_epochs = 200 # 3
gradient_accumulation_steps = 1
grad_checkpointing = True
gradient_clip = 0.01
optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
lr_schedule_args = dict(num_warmup_steps=1000)
eval_sampling_steps = 200
log_interval = 20
save_model_epochs=1
save_model_steps=2000
work_dir = 'output/debug'
_base_ = ['../PixArt_xl2_internal.py']
data_root = 'data'
image_list_json = ['data_info.json',]
data = dict(type='InternalDataMS', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True)
image_size = 1024
# model setting
model = 'PixArtMS_XL_2' # model for multi-scale training
fp32_attention = True
load_from = None
vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
window_block_indexes = []
window_size=0
use_rel_pos=False
aspect_ratio_type = 'ASPECT_RATIO_1024' # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256]
multi_scale = True # if use multiscale dataset model training
lewei_scale = 2.0
# training setting
num_workers=10
train_batch_size = 12 # max 14 for PixArt-xL/2 when grad_checkpoint
num_epochs = 10 # 3
gradient_accumulation_steps = 1
grad_checkpointing = True
gradient_clip = 0.01
optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
lr_schedule_args = dict(num_warmup_steps=1000)
save_model_epochs=1
save_model_steps=2000
log_interval = 20
eval_sampling_steps = 200
work_dir = 'output/debug'
_base_ = ['../PixArt_xl2_internal.py']
data_root = 'data'
image_list_json = ['data_info.json',]
data = dict(type='InternalDataMS', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True)
image_size = 1024
# model setting
model = 'PixArtMS_XL_2' # model for multi-scale training
fp32_attention = False # Set to True if you got NaN loss
load_from = None
vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
window_block_indexes = []
window_size=0
use_rel_pos=False
aspect_ratio_type = 'ASPECT_RATIO_1024' # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256]
multi_scale = True # if use multiscale dataset model training
lewei_scale = 2.0
# training setting
num_workers=4
train_batch_size = 16 # max 12 for PixArt-xL/2 when grad_checkpoint 16 for LCM-LoRA
num_epochs = 10 # 3
gradient_accumulation_steps = 1
grad_checkpointing = True
gradient_clip = 0.01
optimizer = dict(type='AdamW', lr=2e-5, weight_decay=0.0, eps=1e-10)
# optimizer = dict(type='CAMEWrapper', lr=1e-7, weight_decay=0.0, betas=(0.9, 0.999, 0.9999), eps=(1e-30, 1e-16))
lr_schedule_args = dict(num_warmup_steps=100)
save_model_epochs=1
save_model_steps=200
valid_num=0 # take as valid aspect-ratio when sample number >= valid_num
log_interval = 10
eval_sampling_steps = 200
work_dir = 'output/debug'
# LCM
loss_type = 'huber'
huber_c = 0.001
num_ddim_timesteps=50
w_max = 15.0
w_min = 3.0
ema_decay = 0.95
cfg_scale = 4.5
class_dropout_prob = 0.
lora_rank = 32
\ No newline at end of file
_base_ = ['../PixArt_xl2_sam.py']
data_root = 'data'
# image_list_txt = ['part0.txt', 'part1.txt', 'part2.txt', 'part3.txt', 'part4.txt', 'part5.txt', 'part6.txt', 'part7.txt', 'part8.txt',
# 'part9.txt', 'part10.txt', 'part11.txt', 'part12.txt', 'part13.txt', 'part14.txt','part15.txt','part16.txt',
# 'part17.txt','part18.txt','part19.txt','part20.txt','part21.txt', 'part22.txt', 'part23.txt', 'part24.txt',
# 'part25.txt', 'part26.txt', 'part27.txt', 'part28.txt', 'part29.txt', 'part30.txt', 'part31.txt']
# data = dict(type='SAM', root='SA1B', image_list_txt=image_list_txt, transform='default_train', load_vae_feat=True)
image_list_txt = ['part0.txt']
data = dict(type='SAM', root='data_toy', image_list_txt=image_list_txt, transform='default_train', load_vae_feat=True)
image_size = 256
# model setting
window_block_indexes=[]
window_size=0
use_rel_pos=False
model = 'PixArt_XL_2'
fp32_attention = False
load_from = None
# vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
vae_pretrained = "pretrained_models/hub/pixart_alpha/sd-vae-ft-ema"
# training setting
use_fsdp=False # if use FSDP mode
num_workers=1
train_batch_size = 2 # 32
num_epochs = 2 # 3
gradient_accumulation_steps = 1
grad_checkpointing = True
gradient_clip = 0.01
optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
lr_schedule_args = dict(num_warmup_steps=1000)
eval_sampling_steps = 1
log_interval = 1
save_model_epochs=1
save_model_steps=1
work_dir = 'output/debug'
_base_ = ['../PixArt_xl2_internal.py']
data_root = 'data'
image_list_json = ['data_info.json',]
data = dict(type='InternalData', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True)
image_size = 256
# model setting
window_block_indexes=[]
window_size=0
use_rel_pos=False
model = 'PixArt_XL_2'
fp32_attention = True
load_from = None
vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
# training setting
eval_sampling_steps = 200
num_workers=10
train_batch_size = 176 # 32 # max 96 for PixArt-L/4 when grad_checkpoint
num_epochs = 200 # 3
gradient_accumulation_steps = 1
grad_checkpointing = True
gradient_clip = 0.01
optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
lr_schedule_args = dict(num_warmup_steps=1000)
log_interval = 20
save_model_epochs=5
work_dir = 'output/debug'
_base_ = ['../PixArt_xl2_internal.py']
data_root = 'data'
image_list_json = ['data_info.json',]
data = dict(type='InternalData', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True)
image_size = 512
# model setting
window_block_indexes = []
window_size=0
use_rel_pos=False
model = 'PixArt_XL_2'
fp32_attention = True
load_from = None
vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
lewei_scale = 1.0
# training setting
use_fsdp=False # if use FSDP mode
num_workers=10
train_batch_size = 38 # 32
num_epochs = 200 # 3
gradient_accumulation_steps = 1
grad_checkpointing = True
gradient_clip = 0.01
optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
lr_schedule_args = dict(num_warmup_steps=1000)
eval_sampling_steps = 200
log_interval = 20
save_model_epochs=1
work_dir = 'output/debug'
_base_ = ['../PixArt_xl2_internal.py']
data_root = 'data'
image_list_json = ['data_info.json',]
data = dict(type='InternalDataMS', root='InternData', image_list_json=image_list_json, transform='default_train', load_vae_feat=True)
image_size = 512
# model setting
model = 'PixArtMS_XL_2' # model for multi-scale training
fp32_attention = True
load_from = None
vae_pretrained = "output/pretrained_models/sd-vae-ft-ema"
window_block_indexes = []
window_size=0
use_rel_pos=False
aspect_ratio_type = 'ASPECT_RATIO_512' # base aspect ratio [ASPECT_RATIO_512 or ASPECT_RATIO_256]
multi_scale = True # if use multiscale dataset model training
lewei_scale = 1.0
# training setting
num_workers=10
train_batch_size = 40 # max 40 for PixArt-xL/2 when grad_checkpoint
num_epochs = 20 # 3
gradient_accumulation_steps = 1
grad_checkpointing = True
gradient_clip = 0.01
optimizer = dict(type='AdamW', lr=2e-5, weight_decay=3e-2, eps=1e-10)
lr_schedule_args = dict(num_warmup_steps=1000)
save_model_epochs=1
save_model_steps=2000
log_interval = 20
eval_sampling_steps = 200
work_dir = 'output/debug'
# Modified from OpenAI's diffusion repos
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
from .iddpm import IDDPM
from .dpm_solver import DPMS
from .sa_sampler import SASolverSampler
import torch
from .model import gaussian_diffusion as gd
from .model.dpm_solver import model_wrapper, DPM_Solver, NoiseScheduleVP
def DPMS(model, condition, uncondition, cfg_scale, model_type='noise', noise_schedule="linear", guidance_type='classifier-free', model_kwargs=None, diffusion_steps=1000):
if model_kwargs is None:
model_kwargs = {}
betas = torch.tensor(gd.get_named_beta_schedule(noise_schedule, diffusion_steps))
## 1. Define the noise schedule.
noise_schedule = NoiseScheduleVP(schedule='discrete', betas=betas)
## 2. Convert your discrete-time `model` to the continuous-time
## noise prediction model. Here is an example for a diffusion model
## `model` with the noise prediction type ("noise") .
model_fn = model_wrapper(
model,
noise_schedule,
model_type=model_type,
model_kwargs=model_kwargs,
guidance_type=guidance_type,
condition=condition,
unconditional_condition=uncondition,
guidance_scale=cfg_scale,
)
## 3. Define dpm-solver and sample by multistep DPM-Solver.
return DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
\ No newline at end of file
# Modified from OpenAI's diffusion repos
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
from diffusion.model.respace import SpacedDiffusion, space_timesteps
from .model import gaussian_diffusion as gd
def IDDPM(
timestep_respacing,
noise_schedule="linear",
use_kl=False,
sigma_small=False,
predict_xstart=False,
learn_sigma=True,
pred_sigma=True,
rescale_learned_sigmas=False,
diffusion_steps=1000,
snr=False,
return_startx=False,
):
betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
if use_kl:
loss_type = gd.LossType.RESCALED_KL
elif rescale_learned_sigmas:
loss_type = gd.LossType.RESCALED_MSE
else:
loss_type = gd.LossType.MSE
if timestep_respacing is None or timestep_respacing == "":
timestep_respacing = [diffusion_steps]
return SpacedDiffusion(
use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
betas=betas,
model_mean_type=(
gd.ModelMeanType.START_X if predict_xstart else gd.ModelMeanType.EPSILON
),
model_var_type=(
(gd.ModelVarType.LEARNED_RANGE if learn_sigma else (
gd.ModelVarType.FIXED_LARGE
if not sigma_small
else gd.ModelVarType.FIXED_SMALL
)
)
if pred_sigma
else None
),
loss_type=loss_type,
snr=snr,
return_startx=return_startx,
# rescale_timesteps=rescale_timesteps,
)
\ No newline at end of file
This diff is collapsed.
from mmcv import Registry
from diffusion.model.utils import set_grad_checkpoint
MODELS = Registry('models')
def build_model(cfg, use_grad_checkpoint=False, use_fp32_attention=False, gc_step=1, **kwargs):
if isinstance(cfg, str):
cfg = dict(type=cfg)
model = MODELS.build(cfg, default_args=kwargs)
if use_grad_checkpoint:
set_grad_checkpoint(model, use_fp32_attention=use_fp32_attention, gc_step=gc_step)
return model
# Modified from OpenAI's diffusion repos
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
import numpy as np
import torch as th
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = next(
(
obj
for obj in (mean1, logvar1, mean2, logvar2)
if isinstance(obj, th.Tensor)
),
None,
)
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for th.exp().
logvar1, logvar2 = [
x if isinstance(x, th.Tensor) else th.tensor(x, device=tensor.device)
for x in (logvar1, logvar2)
]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ th.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * th.exp(-logvar2)
)
def approx_standard_normal_cdf(x):
"""
A fast approximation of the cumulative distribution function of the
standard normal.
"""
return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
def continuous_gaussian_log_likelihood(x, *, means, log_scales):
"""
Compute the log-likelihood of a continuous Gaussian distribution.
:param x: the targets
:param means: the Gaussian mean Tensor.
:param log_scales: the Gaussian log stddev Tensor.
:return: a tensor like x of log probabilities (in nats).
"""
centered_x = x - means
inv_stdv = th.exp(-log_scales)
normalized_x = centered_x * inv_stdv
return th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(
normalized_x
)
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
"""
Compute the log-likelihood of a Gaussian distribution discretizing to a
given image.
:param x: the target images. It is assumed that this was uint8 values,
rescaled to the range [-1, 1].
:param means: the Gaussian mean Tensor.
:param log_scales: the Gaussian log stddev Tensor.
:return: a tensor like x of log probabilities (in nats).
"""
assert x.shape == means.shape == log_scales.shape
centered_x = x - means
inv_stdv = th.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
cdf_plus = approx_standard_normal_cdf(plus_in)
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
cdf_min = approx_standard_normal_cdf(min_in)
log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = th.where(
x < -0.999,
log_cdf_plus,
th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
)
assert log_probs.shape == x.shape
return log_probs
This diff is collapsed.
import random
import numpy as np
from tqdm import tqdm
from diffusion.model.utils import *
# ----------------------------------------------------------------------------
# Proposed EDM sampler (Algorithm 2).
def edm_sampler(
net, latents, class_labels=None, cfg_scale=None, randn_like=torch.randn_like,
num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, **kwargs
):
# Adjust noise levels based on what's supported by the network.
sigma_min = max(sigma_min, net.sigma_min)
sigma_max = min(sigma_max, net.sigma_max)
# Time step discretization.
step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (
sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
# Main sampling loop.
x_next = latents.to(torch.float64) * t_steps[0]
for i, (t_cur, t_next) in tqdm(list(enumerate(zip(t_steps[:-1], t_steps[1:])))): # 0, ..., N-1
x_cur = x_next
# Increase noise temporarily.
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
t_hat = net.round_sigma(t_cur + gamma * t_cur)
x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
# Euler step.
denoised = net(x_hat.float(), t_hat, class_labels, cfg_scale, **kwargs)['x'].to(torch.float64)
d_cur = (x_hat - denoised) / t_hat
x_next = x_hat + (t_next - t_hat) * d_cur
# Apply 2nd order correction.
if i < num_steps - 1:
denoised = net(x_next.float(), t_next, class_labels, cfg_scale, **kwargs)['x'].to(torch.float64)
d_prime = (x_next - denoised) / t_next
x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
return x_next
# ----------------------------------------------------------------------------
# Generalized ablation sampler, representing the superset of all sampling
# methods discussed in the paper.
def ablation_sampler(
net, latents, class_labels=None, cfg_scale=None, feat=None, randn_like=torch.randn_like,
num_steps=18, sigma_min=None, sigma_max=None, rho=7,
solver='heun', discretization='edm', schedule='linear', scaling='none',
epsilon_s=1e-3, C_1=0.001, C_2=0.008, M=1000, alpha=1,
S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
):
assert solver in ['euler', 'heun']
assert discretization in ['vp', 've', 'iddpm', 'edm']
assert schedule in ['vp', 've', 'linear']
assert scaling in ['vp', 'none']
# Helper functions for VP & VE noise level schedules.
vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
vp_sigma_deriv = lambda beta_d, beta_min: lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t))
vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (
sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
ve_sigma = lambda t: t.sqrt()
ve_sigma_deriv = lambda t: 0.5 / t.sqrt()
ve_sigma_inv = lambda sigma: sigma ** 2
# Select default noise level range based on the specified time step discretization.
if sigma_min is None:
vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=epsilon_s)
sigma_min = {'vp': vp_def, 've': 0.02, 'iddpm': 0.002, 'edm': 0.002}[discretization]
if sigma_max is None:
vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=1)
sigma_max = {'vp': vp_def, 've': 100, 'iddpm': 81, 'edm': 80}[discretization]
# Adjust noise levels based on what's supported by the network.
sigma_min = max(sigma_min, net.sigma_min)
sigma_max = min(sigma_max, net.sigma_max)
# Compute corresponding betas for VP.
vp_beta_d = 2 * (np.log(sigma_min ** 2 + 1) / epsilon_s - np.log(sigma_max ** 2 + 1)) / (epsilon_s - 1)
vp_beta_min = np.log(sigma_max ** 2 + 1) - 0.5 * vp_beta_d
# Define time steps in terms of noise level.
step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
if discretization == 'vp':
orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1)
sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps)
elif discretization == 've':
orig_t_steps = (sigma_max ** 2) * ((sigma_min ** 2 / sigma_max ** 2) ** (step_indices / (num_steps - 1)))
sigma_steps = ve_sigma(orig_t_steps)
elif discretization == 'iddpm':
u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device)
alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2
for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1
u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt()
u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)]
sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)]
else:
assert discretization == 'edm'
sigma_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (
sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
# Define noise level schedule.
if schedule == 'vp':
sigma = vp_sigma(vp_beta_d, vp_beta_min)
sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min)
sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min)
elif schedule == 've':
sigma = ve_sigma
sigma_deriv = ve_sigma_deriv
sigma_inv = ve_sigma_inv
else:
assert schedule == 'linear'
sigma = lambda t: t
sigma_deriv = lambda t: 1
sigma_inv = lambda sigma: sigma
# Define scaling schedule.
if scaling == 'vp':
s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt()
s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3)
else:
assert scaling == 'none'
s = lambda t: 1
s_deriv = lambda t: 0
# Compute final time steps based on the corresponding noise levels.
t_steps = sigma_inv(net.round_sigma(sigma_steps))
t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0
# Main sampling loop.
t_next = t_steps[0]
x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next))
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
x_cur = x_next
# Increase noise temporarily.
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0
t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur)))
x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s(
t_hat) * S_noise * randn_like(x_cur)
# Euler step.
h = t_next - t_hat
denoised = net(x_hat.float() / s(t_hat), sigma(t_hat), class_labels, cfg_scale, feat=feat)['x'].to(
torch.float64)
d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s(
t_hat) / sigma(t_hat) * denoised
x_prime = x_hat + alpha * h * d_cur
t_prime = t_hat + alpha * h
# Apply 2nd order correction.
if solver == 'euler' or i == num_steps - 1:
x_next = x_hat + h * d_cur
else:
assert solver == 'heun'
denoised = net(x_prime.float() / s(t_prime), sigma(t_prime), class_labels, cfg_scale, feat=feat)['x'].to(
torch.float64)
d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv(
t_prime) * s(t_prime) / sigma(t_prime) * denoised
x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime)
return x_next
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment