Commit 1f5da520 authored by yangzhong's avatar yangzhong
Browse files

git init

parents
Pipeline #3144 failed with stages
in 0 seconds
import random
import torch
from .schedules_sdedit import karras_schedule
from .solvers_sdedit import sample_dpmpp_2m_sde, sample_heun
from video_to_video.utils.logger import get_logger
logger = get_logger()
__all__ = ['GaussianDiffusion']
def _i(tensor, t, x):
shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
return tensor[t.to(tensor.device)].view(shape).to(x.device)
class GaussianDiffusion(object):
def __init__(self, sigmas):
self.sigmas = sigmas
self.alphas = torch.sqrt(1 - sigmas**2)
self.num_timesteps = len(sigmas)
def diffuse(self, x0, t, noise=None):
noise = torch.randn_like(x0) if noise is None else noise
xt = _i(self.alphas, t, x0) * x0 + _i(self.sigmas, t, x0) * noise
return xt
def get_velocity(self, x0, xt, t):
sigmas = _i(self.sigmas, t, xt)
alphas = _i(self.alphas, t, xt)
velocity = (alphas * xt - x0) / sigmas
return velocity
def get_x0(self, v, xt, t):
sigmas = _i(self.sigmas, t, xt)
alphas = _i(self.alphas, t, xt)
x0 = alphas * xt - sigmas * v
return x0
def denoise(self,
xt,
t,
s,
model,
model_kwargs={},
guide_scale=None,
guide_rescale=None,
clamp=None,
percentile=None,
variant_info=None,):
s = t - 1 if s is None else s
# hyperparams
sigmas = _i(self.sigmas, t, xt)
alphas = _i(self.alphas, t, xt)
alphas_s = _i(self.alphas, s.clamp(0), xt)
alphas_s[s < 0] = 1.
sigmas_s = torch.sqrt(1 - alphas_s**2)
# precompute variables
betas = 1 - (alphas / alphas_s)**2
coef1 = betas * alphas_s / sigmas**2
coef2 = (alphas * sigmas_s**2) / (alphas_s * sigmas**2)
var = betas * (sigmas_s / sigmas)**2
log_var = torch.log(var).clamp_(-20, 20)
# prediction
if guide_scale is None:
assert isinstance(model_kwargs, dict)
out = model(xt, t=t, **model_kwargs)
else:
# classifier-free guidance
assert isinstance(model_kwargs, list)
if len(model_kwargs) > 3:
y_out = model(xt, t=t, **model_kwargs[0], **model_kwargs[2], **model_kwargs[3], **model_kwargs[4], **model_kwargs[5])
else:
y_out = model(xt, t=t, **model_kwargs[0], **model_kwargs[2], variant_info=variant_info)
if guide_scale == 1.:
out = y_out
else:
if len(model_kwargs) > 3:
u_out = model(xt, t=t, **model_kwargs[1], **model_kwargs[2], **model_kwargs[3], **model_kwargs[4], **model_kwargs[5])
else:
u_out = model(xt, t=t, **model_kwargs[1], **model_kwargs[2], variant_info=variant_info)
out = u_out + guide_scale * (y_out - u_out)
if guide_rescale is not None:
assert guide_rescale >= 0 and guide_rescale <= 1
ratio = (
y_out.flatten(1).std(dim=1) / # noqa
(out.flatten(1).std(dim=1) + 1e-12)
).view((-1, ) + (1, ) * (y_out.ndim - 1))
out *= guide_rescale * ratio + (1 - guide_rescale) * 1.0
x0 = alphas * xt - sigmas * out
# restrict the range of x0
if percentile is not None:
assert percentile > 0 and percentile <= 1
s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1)
s = s.clamp_(1.0).view((-1, ) + (1, ) * (xt.ndim - 1))
x0 = torch.min(s, torch.max(-s, x0)) / s
elif clamp is not None:
x0 = x0.clamp(-clamp, clamp)
# recompute eps using the restricted x0
eps = (xt - alphas * x0) / sigmas
# compute mu (mean of posterior distribution) using the restricted x0
mu = coef1 * x0 + coef2 * xt
return mu, var, log_var, x0, eps
@torch.no_grad()
def sample(self,
noise,
model,
model_kwargs={},
condition_fn=None,
guide_scale=None,
guide_rescale=None,
clamp=None,
percentile=None,
solver='euler_a',
solver_mode='fast',
steps=20,
t_max=None,
t_min=None,
discretization=None,
discard_penultimate_step=None,
return_intermediate=None,
show_progress=False,
seed=-1,
chunk_inds=None,
**kwargs):
# sanity check
assert isinstance(steps, (int, torch.LongTensor))
assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1)
assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1)
assert discretization in (None, 'leading', 'linspace', 'trailing')
assert discard_penultimate_step in (None, True, False)
assert return_intermediate in (None, 'x0', 'xt')
# function of diffusion solver
solver_fn = {
'heun': sample_heun,
'dpmpp_2m_sde': sample_dpmpp_2m_sde
}[solver]
# options
schedule = 'karras' if 'karras' in solver else None
discretization = discretization or 'linspace'
seed = seed if seed >= 0 else random.randint(0, 2**31)
if isinstance(steps, torch.LongTensor):
discard_penultimate_step = False
if discard_penultimate_step is None:
discard_penultimate_step = True if solver in (
'dpm2', 'dpm2_ancestral', 'dpmpp_2m_sde', 'dpm2_karras',
'dpm2_ancestral_karras', 'dpmpp_2m_sde_karras') else False
# function for denoising xt to get x0
intermediates = []
def model_fn(xt, sigma):
# denoising
t = self._sigma_to_t(sigma).repeat(len(xt)).round().long()
x0 = self.denoise(xt, t, None, model, model_kwargs, guide_scale,
guide_rescale, clamp, percentile)[-2]
# collect intermediate outputs
if return_intermediate == 'xt':
intermediates.append(xt)
elif return_intermediate == 'x0':
intermediates.append(x0)
return x0
mask_cond = model_kwargs[3]['mask_cond']
def model_chunk_fn(xt, sigma):
# denoising
t = self._sigma_to_t(sigma).repeat(len(xt)).round().long()
O_LEN = chunk_inds[0][-1]-chunk_inds[1][0]
cut_f_ind = O_LEN//2
results_list = []
for i in range(len(chunk_inds)):
ind_start, ind_end = chunk_inds[i]
xt_chunk = xt[:,:,ind_start:ind_end].clone()
cur_f = xt_chunk.size(2)
model_kwargs[3]['mask_cond'] = mask_cond[:,ind_start:ind_end].clone()
x0_chunk = self.denoise(xt_chunk, t, None, model, model_kwargs, guide_scale,
guide_rescale, clamp, percentile)[-2]
if i == 0:
results_list.append(x0_chunk[:,:,:cur_f+cut_f_ind-O_LEN])
elif i == len(chunk_inds)-1:
results_list.append(x0_chunk[:,:,cut_f_ind:])
else:
results_list.append(x0_chunk[:,:,cut_f_ind:cur_f+cut_f_ind-O_LEN])
x0 = torch.concat(results_list, dim=2)
torch.cuda.empty_cache()
return x0
# get timesteps
if isinstance(steps, int):
steps += 1 if discard_penultimate_step else 0
t_max = self.num_timesteps - 1 if t_max is None else t_max
t_min = 0 if t_min is None else t_min
# discretize timesteps
if discretization == 'leading':
steps = torch.arange(t_min, t_max + 1,
(t_max - t_min + 1) / steps).flip(0)
elif discretization == 'linspace':
steps = torch.linspace(t_max, t_min, steps)
elif discretization == 'trailing':
steps = torch.arange(t_max, t_min - 1,
-((t_max - t_min + 1) / steps))
if solver_mode == 'fast':
t_mid = 500
steps1 = torch.arange(t_max, t_mid - 1,
-((t_max - t_mid + 1) / 4))
steps2 = torch.arange(t_mid, t_min - 1,
-((t_mid - t_min + 1) / 11))
steps = torch.concat([steps1, steps2])
else:
raise NotImplementedError(
f'{discretization} discretization not implemented')
steps = steps.clamp_(t_min, t_max)
steps = torch.as_tensor(
steps, dtype=torch.float32, device=noise.device)
# get sigmas
sigmas = self._t_to_sigma(steps)
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
if schedule == 'karras':
if sigmas[0] == float('inf'):
sigmas = karras_schedule(
n=len(steps) - 1,
sigma_min=sigmas[sigmas > 0].min().item(),
sigma_max=sigmas[sigmas < float('inf')].max().item(),
rho=7.).to(sigmas)
sigmas = torch.cat([
sigmas.new_tensor([float('inf')]), sigmas,
sigmas.new_zeros([1])
])
else:
sigmas = karras_schedule(
n=len(steps),
sigma_min=sigmas[sigmas > 0].min().item(),
sigma_max=sigmas.max().item(),
rho=7.).to(sigmas)
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
if discard_penultimate_step:
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
fn = model_chunk_fn if chunk_inds is not None else model_fn
x0 = solver_fn(
noise, fn, sigmas, show_progress=show_progress, **kwargs)
return (x0, intermediates) if return_intermediate is not None else x0
@torch.no_grad()
def sample_sr(self,
noise,
model,
model_kwargs={},
condition_fn=None,
guide_scale=None,
guide_rescale=None,
clamp=None,
percentile=None,
solver='euler_a',
solver_mode='fast',
steps=20,
t_max=None,
t_min=None,
discretization=None,
discard_penultimate_step=None,
return_intermediate=None,
show_progress=False,
seed=-1,
chunk_inds=None,
variant_info=None,
**kwargs):
# sanity check
assert isinstance(steps, (int, torch.LongTensor))
assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1)
assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1)
assert discretization in (None, 'leading', 'linspace', 'trailing')
assert discard_penultimate_step in (None, True, False)
assert return_intermediate in (None, 'x0', 'xt')
# function of diffusion solver
solver_fn = {
'heun': sample_heun,
'dpmpp_2m_sde': sample_dpmpp_2m_sde
}[solver]
# options
schedule = 'karras' if 'karras' in solver else None
discretization = discretization or 'linspace'
seed = seed if seed >= 0 else random.randint(0, 2**31)
if isinstance(steps, torch.LongTensor):
discard_penultimate_step = False
if discard_penultimate_step is None:
discard_penultimate_step = True if solver in (
'dpm2', 'dpm2_ancestral', 'dpmpp_2m_sde', 'dpm2_karras',
'dpm2_ancestral_karras', 'dpmpp_2m_sde_karras') else False
# function for denoising xt to get x0
intermediates = []
def model_fn(xt, sigma, variant_info=None):
# denoising
t = self._sigma_to_t(sigma).repeat(len(xt)).round().long()
x0 = self.denoise(xt, t, None, model, model_kwargs, guide_scale,
guide_rescale, clamp, percentile, variant_info=variant_info)[-2]
# collect intermediate outputs
if return_intermediate == 'xt':
intermediates.append(xt)
elif return_intermediate == 'x0':
print('add intermediate outputs x0')
intermediates.append(x0)
return x0
# mask_cond = model_kwargs[3]['mask_cond']
def model_chunk_fn(xt, sigma, variant_info=None):
# denoising
t = self._sigma_to_t(sigma).repeat(len(xt)).round().long()
O_LEN = chunk_inds[0][-1]-chunk_inds[1][0]
cut_f_ind = O_LEN//2
results_list = []
for i in range(len(chunk_inds)):
ind_start, ind_end = chunk_inds[i]
xt_chunk = xt[:,:,ind_start:ind_end].clone()
model_kwargs[2]['hint_chunk'] = model_kwargs[2]['hint'][:,:,ind_start:ind_end].clone() # new added
cur_f = xt_chunk.size(2)
# model_kwargs[3]['mask_cond'] = mask_cond[:,ind_start:ind_end].clone()
x0_chunk = self.denoise(xt_chunk, t, None, model, model_kwargs, guide_scale,
guide_rescale, clamp, percentile, variant_info=variant_info)[-2]
if i == 0:
results_list.append(x0_chunk[:,:,:cur_f+cut_f_ind-O_LEN])
elif i == len(chunk_inds)-1:
results_list.append(x0_chunk[:,:,cut_f_ind:])
else:
results_list.append(x0_chunk[:,:,cut_f_ind:cur_f+cut_f_ind-O_LEN])
x0 = torch.concat(results_list, dim=2)
torch.cuda.empty_cache()
return x0
# get timesteps
if isinstance(steps, int):
steps += 1 if discard_penultimate_step else 0
t_max = self.num_timesteps - 1 if t_max is None else t_max
t_min = 0 if t_min is None else t_min
# discretize timesteps
if discretization == 'leading':
steps = torch.arange(t_min, t_max + 1,
(t_max - t_min + 1) / steps).flip(0)
elif discretization == 'linspace':
steps = torch.linspace(t_max, t_min, steps)
elif discretization == 'trailing':
steps = torch.arange(t_max, t_min - 1,
-((t_max - t_min + 1) / steps))
if solver_mode == 'fast':
t_mid = 500
steps1 = torch.arange(t_max, t_mid - 1,
-((t_max - t_mid + 1) / 4))
steps2 = torch.arange(t_mid, t_min - 1,
-((t_mid - t_min + 1) / 11))
steps = torch.concat([steps1, steps2])
else:
raise NotImplementedError(
f'{discretization} discretization not implemented')
steps = steps.clamp_(t_min, t_max)
steps = torch.as_tensor(
steps, dtype=torch.float32, device=noise.device)
# get sigmas
sigmas = self._t_to_sigma(steps)
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
if schedule == 'karras':
if sigmas[0] == float('inf'):
sigmas = karras_schedule(
n=len(steps) - 1,
sigma_min=sigmas[sigmas > 0].min().item(),
sigma_max=sigmas[sigmas < float('inf')].max().item(),
rho=7.).to(sigmas)
sigmas = torch.cat([
sigmas.new_tensor([float('inf')]), sigmas,
sigmas.new_zeros([1])
])
else:
sigmas = karras_schedule(
n=len(steps),
sigma_min=sigmas[sigmas > 0].min().item(),
sigma_max=sigmas.max().item(),
rho=7.).to(sigmas)
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
if discard_penultimate_step:
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
fn = model_chunk_fn if chunk_inds is not None else model_fn
x0 = solver_fn(
noise, fn, sigmas, variant_info=variant_info, show_progress=show_progress, **kwargs)
return (x0, intermediates) if return_intermediate is not None else x0
def _sigma_to_t(self, sigma):
if sigma == float('inf'):
t = torch.full_like(sigma, len(self.sigmas) - 1)
else:
log_sigmas = torch.sqrt(self.sigmas**2 / # noqa
(1 - self.sigmas**2)).log().to(sigma)
log_sigma = sigma.log()
dists = log_sigma - log_sigmas[:, None]
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(
max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low, high = log_sigmas[low_idx], log_sigmas[high_idx]
w = (low - log_sigma) / (low - high)
w = w.clamp(0, 1)
t = (1 - w) * low_idx + w * high_idx
t = t.view(sigma.shape)
if t.ndim == 0:
t = t.unsqueeze(0)
return t
def _t_to_sigma(self, t):
t = t.float()
low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
log_sigmas = torch.sqrt(self.sigmas**2 / # noqa
(1 - self.sigmas**2)).log().to(t)
log_sigma = (1 - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx]
log_sigma[torch.isnan(log_sigma)
| torch.isinf(log_sigma)] = float('inf')
return log_sigma.exp()
\ No newline at end of file
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import torch
def betas_to_sigmas(betas):
return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0))
def sigmas_to_betas(sigmas):
square_alphas = 1 - sigmas**2
betas = 1 - torch.cat(
[square_alphas[:1], square_alphas[1:] / square_alphas[:-1]])
return betas
def logsnrs_to_sigmas(logsnrs):
return torch.sqrt(torch.sigmoid(-logsnrs))
def sigmas_to_logsnrs(sigmas):
square_sigmas = sigmas**2
return torch.log(square_sigmas / (1 - square_sigmas))
def _logsnr_cosine(n, logsnr_min=-15, logsnr_max=15):
t_min = math.atan(math.exp(-0.5 * logsnr_min))
t_max = math.atan(math.exp(-0.5 * logsnr_max))
t = torch.linspace(1, 0, n)
logsnrs = -2 * torch.log(torch.tan(t_min + t * (t_max - t_min)))
return logsnrs
def _logsnr_cosine_shifted(n, logsnr_min=-15, logsnr_max=15, scale=2):
logsnrs = _logsnr_cosine(n, logsnr_min, logsnr_max)
logsnrs += 2 * math.log(1 / scale)
return logsnrs
def _logsnr_cosine_interp(n,
logsnr_min=-15,
logsnr_max=15,
scale_min=2,
scale_max=4):
t = torch.linspace(1, 0, n)
logsnrs_min = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_min)
logsnrs_max = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_max)
logsnrs = t * logsnrs_min + (1 - t) * logsnrs_max
return logsnrs
def karras_schedule(n, sigma_min=0.002, sigma_max=80.0, rho=7.0):
ramp = torch.linspace(1, 0, n)
min_inv_rho = sigma_min**(1 / rho)
max_inv_rho = sigma_max**(1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho))**rho
sigmas = torch.sqrt(sigmas**2 / (1 + sigmas**2))
return sigmas
def logsnr_cosine_interp_schedule(n,
logsnr_min=-15,
logsnr_max=15,
scale_min=2,
scale_max=4):
return logsnrs_to_sigmas(
_logsnr_cosine_interp(n, logsnr_min, logsnr_max, scale_min, scale_max))
def noise_schedule(schedule='logsnr_cosine_interp',
n=1000,
zero_terminal_snr=False,
**kwargs):
# compute sigmas
sigmas = {
'logsnr_cosine_interp': logsnr_cosine_interp_schedule
}[schedule](n, **kwargs)
# post-processing
if zero_terminal_snr and sigmas.max() != 1.0:
scale = (1.0 - sigmas.min()) / (sigmas.max() - sigmas.min())
sigmas = sigmas.min() + scale * (sigmas - sigmas.min())
return sigmas
\ No newline at end of file
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torchsde
from tqdm.auto import trange
from video_to_video.utils.logger import get_logger
logger = get_logger()
def get_ancestral_step(sigma_from, sigma_to, eta=1.):
"""
Calculates the noise level (sigma_down) to step down to and the amount
of noise to add (sigma_up) when doing an ancestral sampling step.
"""
if not eta:
return sigma_to, 0.
sigma_up = min(
sigma_to,
eta * (
sigma_to**2 * # noqa
(sigma_from**2 - sigma_to**2) / sigma_from**2)**0.5)
sigma_down = (sigma_to**2 - sigma_up**2)**0.5
return sigma_down, sigma_up
def get_scalings(sigma):
c_out = -sigma
c_in = 1 / (sigma**2 + 1.**2)**0.5
return c_out, c_in
@torch.no_grad()
def sample_heun(noise,
model,
sigmas,
s_churn=0.,
s_tmin=0.,
s_tmax=float('inf'),
s_noise=1.,
show_progress=True):
"""
Implements Algorithm 2 (Heun steps) from Karras et al. (2022).
"""
x = noise * sigmas[0]
for i in trange(len(sigmas) - 1, disable=not show_progress):
gamma = 0.
if s_tmin <= sigmas[i] <= s_tmax and sigmas[i] < float('inf'):
gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat**2 - sigmas[i]**2)**0.5
if sigmas[i] == float('inf'):
# Euler method
denoised = model(noise, sigma_hat)
x = denoised + sigmas[i + 1] * (gamma + 1) * noise
else:
_, c_in = get_scalings(sigma_hat)
denoised = model(x * c_in, sigma_hat)
d = (x - denoised) / sigma_hat
dt = sigmas[i + 1] - sigma_hat
if sigmas[i + 1] == 0:
# Euler method
x = x + d * dt
else:
# Heun's method
x_2 = x + d * dt
_, c_in = get_scalings(sigmas[i + 1])
denoised_2 = model(x_2 * c_in, sigmas[i + 1])
d_2 = (x_2 - denoised_2) / sigmas[i + 1]
d_prime = (d + d_2) / 2
x = x + d_prime * dt
return x
class BatchedBrownianTree:
"""
A wrapper around torchsde.BrownianTree that enables batches of entropy.
"""
def __init__(self, x, t0, t1, seed=None, **kwargs):
t0, t1, self.sign = self.sort(t0, t1)
w0 = kwargs.get('w0', torch.zeros_like(x))
if seed is None:
seed = torch.randint(0, 2**63 - 1, []).item()
self.batched = True
try:
assert len(seed) == x.shape[0]
w0 = w0[0]
except TypeError:
seed = [seed]
self.batched = False
self.trees = [
torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs)
for s in seed
]
@staticmethod
def sort(a, b):
return (a, b, 1) if a < b else (b, a, -1)
def __call__(self, t0, t1):
t0, t1, sign = self.sort(t0, t1)
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (
self.sign * sign)
return w if self.batched else w[0]
class BrownianTreeNoiseSampler:
"""
A noise sampler backed by a torchsde.BrownianTree.
Args:
x (Tensor): The tensor whose shape, device and dtype to use to generate
random samples.
sigma_min (float): The low end of the valid interval.
sigma_max (float): The high end of the valid interval.
seed (int or List[int]): The random seed. If a list of seeds is
supplied instead of a single integer, then the noise sampler will
use one BrownianTree per batch item, each with its own seed.
transform (callable): A function that maps sigma to the sampler's
internal timestep.
"""
def __init__(self,
x,
sigma_min,
sigma_max,
seed=None,
transform=lambda x: x):
self.transform = transform
t0 = self.transform(torch.as_tensor(sigma_min))
t1 = self.transform(torch.as_tensor(sigma_max))
self.tree = BatchedBrownianTree(x, t0, t1, seed)
def __call__(self, sigma, sigma_next):
t0 = self.transform(torch.as_tensor(sigma))
t1 = self.transform(torch.as_tensor(sigma_next))
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
@torch.no_grad()
def sample_dpmpp_2m_sde(noise,
model,
sigmas,
eta=1.,
s_noise=1.,
solver_type='midpoint',
show_progress=True,
variant_info=None):
"""
DPM-Solver++ (2M) SDE.
"""
assert solver_type in {'heun', 'midpoint'}
x = noise * sigmas[0]
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas[
sigmas < float('inf')].max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max)
old_denoised = None
h_last = None
for i in trange(len(sigmas) - 1, disable=not show_progress):
logger.info(f'step: {i}')
if sigmas[i] == float('inf'):
# Euler method
denoised = model(noise, sigmas[i], variant_info=variant_info)
x = denoised + sigmas[i + 1] * noise
else:
_, c_in = get_scalings(sigmas[i])
denoised = model(x * c_in, sigmas[i], variant_info=variant_info)
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
# DPM-Solver++(2M) SDE
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
h = s - t
eta_h = eta * h
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + \
(-h - eta_h).expm1().neg() * denoised
if old_denoised is not None:
r = h_last / h
if solver_type == 'heun':
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * \
(1 / r) * (denoised - old_denoised)
elif solver_type == 'midpoint':
x = x + 0.5 * (-h - eta_h).expm1().neg() * \
(1 / r) * (denoised - old_denoised)
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[
i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
old_denoised = denoised
h_last = h
if variant_info is not None and variant_info.get('type') == 'variant1':
x_long, x_short = x.chunk(2, dim=0)
x = x_long * (1-variant_info['alpha']) + x_short * variant_info['alpha']
return x
\ No newline at end of file
from .embedder import *
from .unet_v2v import *
# from .unet_v2v_deform import *
\ No newline at end of file
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import numpy as np
import open_clip
import torch
import torch.nn as nn
import torchvision.transforms as T
class FrozenOpenCLIPEmbedder(nn.Module):
"""
Uses the OpenCLIP transformer encoder for text
"""
LAYERS = ['last', 'penultimate']
def __init__(self,
pretrained='laion2b_s32b_b79k',
arch='ViT-H-14',
device='cuda',
max_length=77,
freeze=True,
layer='penultimate'):
super().__init__()
assert layer in self.LAYERS
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=pretrained)
del model.visual
self.model = model
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
if self.layer == 'last':
self.layer_idx = 0
elif self.layer == 'penultimate':
self.layer_idx = 1
else:
raise NotImplementedError()
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
tokens = open_clip.tokenize(text)
z = self.encode_with_transformer(tokens.to(self.device))
return z
def encode_with_transformer(self, text):
x = self.model.token_embedding(text)
x = x + self.model.positional_embedding
x = x.permute(1, 0, 2)
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
x = x.permute(1, 0, 2)
x = self.model.ln_final(x)
return x
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - self.layer_idx:
break
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(
):
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
return x
def encode(self, text):
return self(text)
\ No newline at end of file
# Adapted from PixArt
#
# Copyright (C) 2023 PixArt-alpha/PixArt-alpha
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# PixArt: https://github.com/PixArt-alpha/PixArt-alpha
# T5: https://github.com/google-research/text-to-text-transfer-transformer
# --------------------------------------------------------
import html
import re
import ftfy
import torch
from transformers import AutoTokenizer, T5EncoderModel
# from opensora.registry import MODELS
class T5Embedder:
def __init__(
self,
device,
from_pretrained=None,
*,
cache_dir=None,
hf_token=None,
use_text_preprocessing=True,
t5_model_kwargs=None,
torch_dtype=None,
use_offload_folder=None,
model_max_length=120,
local_files_only=False,
):
self.device = torch.device(device)
self.torch_dtype = torch_dtype or torch.bfloat16
self.cache_dir = cache_dir
if t5_model_kwargs is None:
t5_model_kwargs = {
"low_cpu_mem_usage": True,
"torch_dtype": self.torch_dtype,
}
if use_offload_folder is not None:
t5_model_kwargs["offload_folder"] = use_offload_folder
t5_model_kwargs["device_map"] = {
"shared": self.device,
"encoder.embed_tokens": self.device,
"encoder.block.0": self.device,
"encoder.block.1": self.device,
"encoder.block.2": self.device,
"encoder.block.3": self.device,
"encoder.block.4": self.device,
"encoder.block.5": self.device,
"encoder.block.6": self.device,
"encoder.block.7": self.device,
"encoder.block.8": self.device,
"encoder.block.9": self.device,
"encoder.block.10": self.device,
"encoder.block.11": self.device,
"encoder.block.12": "disk",
"encoder.block.13": "disk",
"encoder.block.14": "disk",
"encoder.block.15": "disk",
"encoder.block.16": "disk",
"encoder.block.17": "disk",
"encoder.block.18": "disk",
"encoder.block.19": "disk",
"encoder.block.20": "disk",
"encoder.block.21": "disk",
"encoder.block.22": "disk",
"encoder.block.23": "disk",
"encoder.final_layer_norm": "disk",
"encoder.dropout": "disk",
}
else:
t5_model_kwargs["device_map"] = {
"shared": self.device,
"encoder": self.device,
}
self.use_text_preprocessing = use_text_preprocessing
self.hf_token = hf_token
self.tokenizer = AutoTokenizer.from_pretrained(
from_pretrained,
cache_dir=cache_dir,
local_files_only=local_files_only,
)
self.model = T5EncoderModel.from_pretrained(
from_pretrained,
cache_dir=cache_dir,
local_files_only=local_files_only,
**t5_model_kwargs,
).eval()
self.model_max_length = model_max_length
def get_text_embeddings(self, texts):
text_tokens_and_mask = self.tokenizer(
texts,
max_length=self.model_max_length,
padding="max_length",
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
input_ids = text_tokens_and_mask["input_ids"].to(self.device)
attention_mask = text_tokens_and_mask["attention_mask"].to(self.device)
with torch.no_grad():
text_encoder_embs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
)["last_hidden_state"].detach()
return text_encoder_embs, attention_mask
# @MODELS.register_module("t5")
class T5Encoder:
def __init__(
self,
from_pretrained=None,
model_max_length=120,
device="cuda",
dtype=torch.float,
cache_dir=None,
shardformer=False,
local_files_only=False,
):
assert from_pretrained is not None, "Please specify the path to the T5 model"
self.t5 = T5Embedder(
device=device,
torch_dtype=dtype,
from_pretrained=from_pretrained,
cache_dir=cache_dir,
model_max_length=model_max_length,
local_files_only=local_files_only,
)
self.t5.model.to(dtype=dtype)
self.y_embedder = None
self.model_max_length = model_max_length
self.output_dim = self.t5.model.config.d_model
self.dtype = dtype
if shardformer:
self.shardformer_t5()
def shardformer_t5(self):
from colossalai.shardformer import ShardConfig, ShardFormer
from opensora.acceleration.shardformer.policy.t5_encoder import T5EncoderPolicy
from opensora.utils.misc import requires_grad
shard_config = ShardConfig(
tensor_parallel_process_group=None,
pipeline_stage_manager=None,
enable_tensor_parallelism=False,
enable_fused_normalization=False,
enable_flash_attention=False,
enable_jit_fused=True,
enable_sequence_parallelism=False,
enable_sequence_overlap=False,
)
shard_former = ShardFormer(shard_config=shard_config)
optim_model, _ = shard_former.optimize(self.t5.model, policy=T5EncoderPolicy())
self.t5.model = optim_model.to(self.dtype)
# ensure the weights are frozen
requires_grad(self.t5.model, False)
def encode(self, text):
caption_embs, emb_masks = self.t5.get_text_embeddings(text)
caption_embs = caption_embs[:, None]
return dict(y=caption_embs, mask=emb_masks)
def null(self, n):
null_y = self.y_embedder.y_embedding[None].repeat(n, 1, 1)[:, None]
return null_y
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
BAD_PUNCT_REGEX = re.compile(
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
) # noqa
def clean_caption(caption):
import urllib.parse as ul
from bs4 import BeautifulSoup
caption = str(caption)
caption = ul.unquote_plus(caption)
caption = caption.strip().lower()
caption = re.sub("<person>", "person", caption)
# urls:
caption = re.sub(
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
"",
caption,
) # regex for urls
caption = re.sub(
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
"",
caption,
) # regex for urls
# html:
caption = BeautifulSoup(caption, features="html.parser").text
# @<nickname>
caption = re.sub(r"@[\w\d]+\b", "", caption)
# 31C0—31EF CJK Strokes
# 31F0—31FF Katakana Phonetic Extensions
# 3200—32FF Enclosed CJK Letters and Months
# 3300—33FF CJK Compatibility
# 3400—4DBF CJK Unified Ideographs Extension A
# 4DC0—4DFF Yijing Hexagram Symbols
# 4E00—9FFF CJK Unified Ideographs
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
#######################################################
# все виды тире / all types of dash --> "-"
caption = re.sub(
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
"-",
caption,
)
# кавычки к одному стандарту
caption = re.sub(r"[`´«»“”¨]", '"', caption)
caption = re.sub(r"[‘’]", "'", caption)
# &quot;
caption = re.sub(r"&quot;?", "", caption)
# &amp
caption = re.sub(r"&amp", "", caption)
# ip adresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
caption = re.sub(r"\d:\d\d\s+$", "", caption)
# \n
caption = re.sub(r"\\n", " ", caption)
# "#123"
caption = re.sub(r"#\d{1,3}\b", "", caption)
# "#12345.."
caption = re.sub(r"#\d{5,}\b", "", caption)
# "123456.."
caption = re.sub(r"\b\d{6,}\b", "", caption)
# filenames:
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
#
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
caption = re.sub(BAD_PUNCT_REGEX, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
# this-is-my-cute-cat / this_is_my_cute_cat
regex2 = re.compile(r"(?:\-|\_)")
if len(re.findall(regex2, caption)) > 3:
caption = re.sub(regex2, " ", caption)
caption = basic_clean(caption)
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
caption = re.sub(r"\s+", " ", caption)
caption.strip()
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
caption = re.sub(r"^\.\S+$", "", caption)
return caption.strip()
def text_preprocessing(text, use_text_preprocessing: bool = True):
if use_text_preprocessing:
# The exact text cleaning as was in the training stage:
text = clean_caption(text)
text = clean_caption(text)
return text
else:
return text.lower().strip()
\ No newline at end of file
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import os
from abc import abstractmethod
import torch
import torch.nn as nn
import torch.nn.functional as F
import xformers
import xformers.ops
from einops import rearrange
from fairscale.nn.checkpoint import checkpoint_wrapper
from timm.models.vision_transformer import Mlp
USE_TEMPORAL_TRANSFORMER = True
class CaptionEmbedder(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate="tanh"), token_num=120):
super().__init__()
self.y_proj = Mlp(
in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0
)
self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels**0.5))
self.uncond_prob = uncond_prob
def token_drop(self, caption, force_drop_ids=None):
"""
Drops labels to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
else:
drop_ids = force_drop_ids == 1
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
return caption
def forward(self, caption, train, force_drop_ids=None):
if train:
assert caption.shape[2:] == self.y_embedding.shape
use_dropout = self.uncond_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
caption = self.token_drop(caption, force_drop_ids)
caption = self.y_proj(caption)
return caption
class DropPath(nn.Module):
r"""DropPath but without rescaling and supports optional all-zero and/or all-keep.
"""
def __init__(self, p):
super(DropPath, self).__init__()
self.p = p
def forward(self, *args, zero=None, keep=None):
if not self.training:
return args[0] if len(args) == 1 else args
# params
x = args[0]
b = x.size(0)
n = (torch.rand(b) < self.p).sum()
# non-zero and non-keep mask
mask = x.new_ones(b, dtype=torch.bool)
if keep is not None:
mask[keep] = False
if zero is not None:
mask[zero] = False
# drop-path index
index = torch.where(mask)[0]
index = index[torch.randperm(len(index))[:n]]
if zero is not None:
index = torch.cat([index, torch.where(zero)[0]], dim=0)
# drop-path multiplier
multiplier = x.new_ones(b)
multiplier[index] = 0.0
output = tuple(u * self.broadcast(multiplier, u) for u in args)
return output[0] if len(args) == 1 else output
def broadcast(self, src, dst):
assert src.size(0) == dst.size(0)
shape = (dst.size(0), ) + (1, ) * (dst.ndim - 1)
return src.view(shape)
def sinusoidal_embedding(timesteps, dim):
# check input
half = dim // 2
timesteps = timesteps.float()
# compute sinusoidal embedding
sinusoid = torch.outer(
timesteps, torch.pow(10000,
-torch.arange(half).to(timesteps).div(half)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
if dim % 2 != 0:
x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
return x
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
def prob_mask_like(shape, prob, device):
if prob == 1:
return torch.ones(shape, device=device, dtype=torch.bool)
elif prob == 0:
return torch.zeros(shape, device=device, dtype=torch.bool)
else:
mask = torch.zeros(shape, device=device).float().uniform_(0, 1) < prob
# aviod mask all, which will cause find_unused_parameters error
if mask.all():
mask[0] = False
return mask
class MemoryEfficientCrossAttention(nn.Module):
def __init__(self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
max_bs=16384,
dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.max_bs = max_bs
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, mask=None):
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
b * self.heads, t.shape[1], self.dim_head).contiguous(),
(q, k, v),
)
# actually compute the attention, what we cannot get enough of.
if q.shape[0] > self.max_bs:
q_list = torch.chunk(q, q.shape[0] // self.max_bs, dim=0)
k_list = torch.chunk(k, k.shape[0] // self.max_bs, dim=0)
v_list = torch.chunk(v, v.shape[0] // self.max_bs, dim=0)
out_list = []
for q_1, k_1, v_1 in zip(q_list, k_list, v_list):
out = xformers.ops.memory_efficient_attention(
q_1, k_1, v_1, attn_bias=None, op=self.attention_op)
out_list.append(out)
out = torch.cat(out_list, dim=0)
else:
out = xformers.ops.memory_efficient_attention(
q, k, v, attn_bias=None, op=self.attention_op)
if exists(mask):
raise NotImplementedError
out = (
out.unsqueeze(0).reshape(
b, self.heads, out.shape[1],
self.dim_head).permute(0, 2, 1,
3).reshape(b, out.shape[1],
self.heads * self.dim_head))
return self.to_out(out)
class RelativePositionBias(nn.Module):
def __init__(self, heads=8, num_buckets=32, max_distance=128):
super().__init__()
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
@staticmethod
def _relative_position_bucket(relative_position,
num_buckets=32,
max_distance=128):
ret = 0
n = -relative_position
num_buckets //= 2
ret += (n < 0).long() * num_buckets
n = torch.abs(n)
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (
torch.log(n.float() / max_exact)
/ math.log(max_distance / max_exact) * # noqa
(num_buckets - max_exact)).long()
val_if_large = torch.min(
val_if_large, torch.full_like(val_if_large, num_buckets - 1))
ret += torch.where(is_small, n, val_if_large)
return ret
def forward(self, n, device):
q_pos = torch.arange(n, dtype=torch.long, device=device)
k_pos = torch.arange(n, dtype=torch.long, device=device)
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
rp_bucket = self._relative_position_bucket(
rel_pos,
num_buckets=self.num_buckets,
max_distance=self.max_distance)
values = self.relative_attention_bias(rp_bucket)
return rearrange(values, 'i j h -> h i j')
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
NEW: use_linear for more efficiency instead of the 1x1 convs
"""
def __init__(self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.,
context_dim=None,
disable_self_attn=False,
use_linear=False,
use_checkpoint=True,
is_ctrl=False):
super().__init__()
if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim]
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
if not use_linear:
self.proj_in = nn.Conv2d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList([
BasicTransformerBlock(
inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim[d],
disable_self_attn=disable_self_attn,
checkpoint=use_checkpoint,
local_type='space',
is_ctrl=is_ctrl) for d in range(depth)
])
if not use_linear:
self.proj_out = zero_module(
nn.Conv2d(
inner_dim, in_channels, kernel_size=1, stride=1,
padding=0))
else:
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self.use_linear = use_linear
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
if not isinstance(context, list):
context = [context]
_, _, h, w = x.shape
# print('x shape:', x.shape) # [64, 320, 90, 160]
x_in = x
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
if self.use_linear:
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
x = block(x, context=context[i], h=h, w=w)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
if not self.use_linear:
x = self.proj_out(x)
return x + x_in
_ATTN_PRECISION = os.environ.get('ATTN_PRECISION', 'fp32')
class CrossAttention(nn.Module):
def __init__(self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
def forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(q, k, v))
# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION == 'fp32':
with torch.autocast(enabled=False, device_type='cuda'):
q, k = q.float(), k.float()
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
else:
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
del q, k
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
out = torch.einsum('b i j, b j d -> b i d', sim, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
class SpatialAttention(nn.Module):
def __init__(self):
super(SpatialAttention, self).__init__()
self.conv1 = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, padding=7 // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
max_out, _ = torch.max(x, dim=1, keepdim=True)
avg_out = torch.mean(x, dim=1, keepdim=True)
weight = torch.cat([max_out, avg_out], dim=1)
weight = self.conv1(weight)
out = self.sigmoid(weight) * x
return out
class TemporalLocalAttention(nn.Module):
def __init__(self):
super(TemporalLocalAttention, self).__init__()
self.conv1 = nn.Linear(in_features=2, out_features=1, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
max_out, _ = torch.max(x, dim=-1, keepdim=True)
avg_out = torch.mean(x, dim=-1, keepdim=True)
weight = torch.cat([max_out, avg_out], dim=-1)
weight = self.conv1(weight)
out = self.sigmoid(weight) * x
return out
class BasicTransformerBlock(nn.Module):
def __init__(self,
dim,
n_heads,
d_head,
dropout=0.,
context_dim=None,
gated_ff=True,
checkpoint=True,
disable_self_attn=False,
local_type=None,
is_ctrl=False):
super().__init__()
self.local_type = local_type
self.is_ctrl = is_ctrl
attn_cls = MemoryEfficientCrossAttention
self.disable_self_attn = disable_self_attn
self.attn1 = attn_cls( # self-attn
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None)
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
attn_cls2 = MemoryEfficientCrossAttention
self.attn2 = attn_cls2(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
if self.local_type == 'space' and self.is_ctrl:
self.local1 = SpatialAttention()
if self.local_type == 'temp' and self.is_ctrl:
self.local1 = TemporalLocalAttention()
self.local2 = TemporalLocalAttention()
def forward_(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(),
self.checkpoint)
def forward(self, x, context=None, h=None, w=None):
if self.local_type == 'space' and self.is_ctrl: # [b*t,(hw), c]
x_local = rearrange(x, 'b (h w) c -> b c h w', h=h)
x_local = self.local1(x_local)
x_local = rearrange(x_local, 'b c h w -> b (h w) c')
x = self.attn1(
self.norm1(x_local),
context=context if self.disable_self_attn else None) + x
x = self.attn2(self.norm2(x), context=context) + x # cross attention or self-attention
x = self.ff(self.norm3(x)) + x
if self.local_type == 'temp' and self.is_ctrl:
x_local = self.local1(x)
x = self.attn1(
self.norm1(x_local),
context=context if self.disable_self_attn else None) + x
x_local = self.local2(x)
x = self.attn2(self.norm2(x_local), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(nn.Linear(
dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out))
def forward(self, x):
return self.net(x)
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self,
channels,
use_conv,
dims=2,
out_channels=None,
padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = nn.Conv2d(
self.channels, self.out_channels, 3, padding=padding)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
mode='nearest')
else:
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = x[..., 1:-1, :]
if self.use_conv:
x = self.conv(x)
return x
class ResBlock(nn.Module):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param use_checkpoint: if True, use gradient checkpointing on this module.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
"""
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
up=False,
down=False,
use_temporal_conv=True,
use_image_dataset=False,
):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_scale_shift_norm = use_scale_shift_norm
self.use_temporal_conv = use_temporal_conv
self.in_layers = nn.Sequential(
nn.GroupNorm(32, channels),
nn.SiLU(),
nn.Conv2d(channels, self.out_channels, 3, padding=1),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
else:
self.h_upd = self.x_upd = nn.Identity()
self.emb_layers = nn.Sequential(
nn.SiLU(),
nn.Linear(
emb_channels,
2 * self.out_channels
if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
nn.GroupNorm(32, self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 3, padding=1)
else:
self.skip_connection = nn.Conv2d(channels, self.out_channels, 1)
if self.use_temporal_conv:
self.temopral_conv = TemporalConvBlock_v2(
self.out_channels,
self.out_channels,
dropout=0.1,
use_image_dataset=use_image_dataset)
def forward(self, x, emb, batch_size, variant_info=None):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
return self._forward(x, emb, batch_size, variant_info)
def _forward(self, x, emb, batch_size, variant_info):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = th.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
h = self.skip_connection(x) + h
if self.use_temporal_conv:
h = rearrange(h, '(b f) c h w -> b c f h w', b=batch_size)
h = self.temopral_conv(h, variant_info=variant_info)
h = rearrange(h, 'b c f h w -> (b f) c h w')
return h
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self,
channels,
use_conv,
dims=2,
out_channels=None,
padding=(2, 1)):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = nn.Conv2d(
self.channels,
self.out_channels,
3,
stride=stride,
padding=padding)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class Resample(nn.Module):
def __init__(self, in_dim, out_dim, mode):
assert mode in ['none', 'upsample', 'downsample']
super(Resample, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.mode = mode
def forward(self, x, reference=None):
if self.mode == 'upsample':
assert reference is not None
x = F.interpolate(x, size=reference.shape[-2:], mode='nearest')
elif self.mode == 'downsample':
x = F.adaptive_avg_pool2d(
x, output_size=tuple(u // 2 for u in x.shape[-2:]))
return x
class ResidualBlock(nn.Module):
def __init__(self,
in_dim,
embed_dim,
out_dim,
use_scale_shift_norm=True,
mode='none',
dropout=0.0):
super(ResidualBlock, self).__init__()
self.in_dim = in_dim
self.embed_dim = embed_dim
self.out_dim = out_dim
self.use_scale_shift_norm = use_scale_shift_norm
self.mode = mode
# layers
self.layer1 = nn.Sequential(
nn.GroupNorm(32, in_dim), nn.SiLU(),
nn.Conv2d(in_dim, out_dim, 3, padding=1))
self.resample = Resample(in_dim, in_dim, mode)
self.embedding = nn.Sequential(
nn.SiLU(),
nn.Linear(embed_dim,
out_dim * 2 if use_scale_shift_norm else out_dim))
self.layer2 = nn.Sequential(
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
nn.Conv2d(out_dim, out_dim, 3, padding=1))
self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d(
in_dim, out_dim, 1)
# zero out the last layer params
nn.init.zeros_(self.layer2[-1].weight)
def forward(self, x, e, reference=None):
identity = self.resample(x, reference)
x = self.layer1[-1](self.resample(self.layer1[:-1](x), reference))
e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype)
if self.use_scale_shift_norm:
scale, shift = e.chunk(2, dim=1)
x = self.layer2[0](x) * (1 + scale) + shift
x = self.layer2[1:](x)
else:
x = x + e
x = self.layer2(x)
x = x + self.shortcut(identity)
return x
class AttentionBlock(nn.Module):
def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None):
# consider head_dim first, then num_heads
num_heads = dim // head_dim if head_dim else num_heads
head_dim = dim // num_heads
assert num_heads * head_dim == dim
super(AttentionBlock, self).__init__()
self.dim = dim
self.context_dim = context_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.scale = math.pow(head_dim, -0.25)
# layers
self.norm = nn.GroupNorm(32, dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
if context_dim is not None:
self.context_kv = nn.Linear(context_dim, dim * 2)
self.proj = nn.Conv2d(dim, dim, 1)
# zero out the last layer params
nn.init.zeros_(self.proj.weight)
def forward(self, x, context=None):
r"""x: [B, C, H, W].
context: [B, L, C] or None.
"""
identity = x
b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
x = self.norm(x)
q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1)
if context is not None:
ck, cv = self.context_kv(context).reshape(b, -1, n * 2,
d).permute(0, 2, 3,
1).chunk(
2, dim=1)
k = torch.cat([ck, k], dim=-1)
v = torch.cat([cv, v], dim=-1)
# compute attention
attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale)
attn = F.softmax(attn, dim=-1)
# gather context
x = torch.matmul(v, attn.transpose(-1, -2))
x = x.reshape(b, c, h, w)
# output
x = self.proj(x)
return x + identity
class TemporalAttentionBlock(nn.Module):
def __init__(self,
dim,
heads=4,
dim_head=32,
rotary_emb=None,
use_image_dataset=False,
use_sim_mask=False):
super().__init__()
# consider num_heads first, as pos_bias needs fixed num_heads
dim_head = dim // heads
assert heads * dim_head == dim
self.use_image_dataset = use_image_dataset
self.use_sim_mask = use_sim_mask
self.scale = dim_head**-0.5
self.heads = heads
hidden_dim = dim_head * heads
self.norm = nn.GroupNorm(32, dim)
self.rotary_emb = rotary_emb
self.to_qkv = nn.Linear(dim, hidden_dim * 3)
self.to_out = nn.Linear(hidden_dim, dim)
def forward(self,
x,
pos_bias=None,
focus_present_mask=None,
video_mask=None):
identity = x
n, height, device = x.shape[2], x.shape[-2], x.device
x = self.norm(x)
x = rearrange(x, 'b c f h w -> b (h w) f c')
qkv = self.to_qkv(x).chunk(3, dim=-1)
if exists(focus_present_mask) and focus_present_mask.all():
# if all batch samples are focusing on present
# it would be equivalent to passing that token's values (v=qkv[-1]) through to the output
values = qkv[-1]
out = self.to_out(values)
out = rearrange(out, 'b (h w) f c -> b c f h w', h=height)
return out + identity
# split out heads
q = rearrange(qkv[0], '... n (h d) -> ... h n d', h=self.heads)
k = rearrange(qkv[1], '... n (h d) -> ... h n d', h=self.heads)
v = rearrange(qkv[2], '... n (h d) -> ... h n d', h=self.heads)
# scale
q = q * self.scale
# rotate positions into queries and keys for time attention
if exists(self.rotary_emb):
q = self.rotary_emb.rotate_queries_or_keys(q)
k = self.rotary_emb.rotate_queries_or_keys(k)
# similarity
# shape [b (hw) h n n], n=f
sim = torch.einsum('... h i d, ... h j d -> ... h i j', q, k)
# relative positional bias
if exists(pos_bias):
sim = sim + pos_bias
if (focus_present_mask is None and video_mask is not None):
# video_mask: [B, n]
mask = video_mask[:, None, :] * video_mask[:, :, None]
mask = mask.unsqueeze(1).unsqueeze(1)
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
elif exists(focus_present_mask) and not (~focus_present_mask).all():
attend_all_mask = torch.ones((n, n),
device=device,
dtype=torch.bool)
attend_self_mask = torch.eye(n, device=device, dtype=torch.bool)
mask = torch.where(
rearrange(focus_present_mask, 'b -> b 1 1 1 1'),
rearrange(attend_self_mask, 'i j -> 1 1 1 i j'),
rearrange(attend_all_mask, 'i j -> 1 1 1 i j'),
)
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
if self.use_sim_mask:
sim_mask = torch.tril(
torch.ones((n, n), device=device, dtype=torch.bool),
diagonal=0)
sim = sim.masked_fill(~sim_mask, -torch.finfo(sim.dtype).max)
# numerical stability
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
# aggregate values
out = torch.einsum('... h i j, ... h j d -> ... h i d', attn, v)
out = rearrange(out, '... h n d -> ... n (h d)')
out = self.to_out(out)
out = rearrange(out, 'b (h w) f c -> b c f h w', h=height)
if self.use_image_dataset:
out = identity + 0 * out
else:
out = identity + out
return out
class TemporalTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
"""
def __init__(self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.,
context_dim=None,
disable_self_attn=False,
use_linear=False,
use_checkpoint=True,
only_self_att=True,
multiply_zero=False,
is_ctrl=False):
super().__init__()
self.multiply_zero = multiply_zero
self.only_self_att = only_self_att
self.use_adaptor = False
if self.only_self_att:
context_dim = None
if not isinstance(context_dim, list):
context_dim = [context_dim]
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
if not use_linear:
self.proj_in = nn.Conv1d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
if self.use_adaptor:
self.adaptor_in = nn.Linear(frames, frames)
self.transformer_blocks = nn.ModuleList([
BasicTransformerBlock(
inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim[d],
checkpoint=use_checkpoint,
local_type='temp',
is_ctrl=is_ctrl) for d in range(depth)
])
if not use_linear:
self.proj_out = zero_module(
nn.Conv1d(
inner_dim, in_channels, kernel_size=1, stride=1,
padding=0))
else:
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
if self.use_adaptor:
self.adaptor_out = nn.Linear(frames, frames)
self.use_linear = use_linear
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
if self.only_self_att:
context = None
if not isinstance(context, list):
context = [context]
b, _, _, h, w = x.shape
x_in = x
x = self.norm(x)
if not self.use_linear:
x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous()
x = self.proj_in(x)
if self.use_linear:
x = rearrange(
x, 'b c f h w -> (b h w) f c').contiguous()
x = self.proj_in(x)
x = rearrange(
x, 'bhw f c -> bhw c f').contiguous()
# print('x shape:', x.shape) # [28800, 512, 32]
if self.only_self_att: # no cross-attention
x = rearrange(x, 'bhw c f -> bhw f c').contiguous()
for i, block in enumerate(self.transformer_blocks):
x = block(x, h=h, w=w)
# print('x shape:', x.shape) # [43200, 32, 512]
x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous()
else:
x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous()
for i, block in enumerate(self.transformer_blocks):
context[i] = rearrange(
context[i], '(b f) l con -> b f l con',
f=self.frames).contiguous()
# calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
for j in range(b):
context_i_j = repeat(
context[i][j],
'f l con -> (f r) l con',
r=(h * w) // self.frames,
f=self.frames).contiguous()
x[j] = block(x[j], context=context_i_j)
if self.use_linear:
x = rearrange(x, 'b hw f c -> (b hw) f c').contiguous()
x = self.proj_out(x)
x = rearrange(
x, '(b h w) f c -> b c f h w', b=b, h=h, w=w).contiguous()
if not self.use_linear:
# print('x shape:', x.shape) # [2, 21600, 32, 512]
x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous()
x = self.proj_out(x)
x = rearrange(
x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous()
if self.multiply_zero:
x = 0.0 * x + x_in
else:
x = x + x_in
return x
class TemporalAttentionMultiBlock(nn.Module):
def __init__(
self,
dim,
heads=4,
dim_head=32,
rotary_emb=None,
use_image_dataset=False,
use_sim_mask=False,
temporal_attn_times=1,
):
super().__init__()
self.att_layers = nn.ModuleList([
TemporalAttentionBlock(dim, heads, dim_head, rotary_emb,
use_image_dataset, use_sim_mask)
for _ in range(temporal_attn_times)
])
def forward(self,
x,
pos_bias=None,
focus_present_mask=None,
video_mask=None):
for layer in self.att_layers:
x = layer(x, pos_bias, focus_present_mask, video_mask)
return x
class InitTemporalConvBlock(nn.Module):
def __init__(self,
in_dim,
out_dim=None,
dropout=0.0,
use_image_dataset=False):
super(InitTemporalConvBlock, self).__init__()
if out_dim is None:
out_dim = in_dim
self.in_dim = in_dim
self.out_dim = out_dim
self.use_image_dataset = use_image_dataset
# conv layers
self.conv = nn.Sequential(
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
# zero out the last layer params,so the conv block is identity
nn.init.zeros_(self.conv[-1].weight)
nn.init.zeros_(self.conv[-1].bias)
def forward(self, x):
identity = x
x = self.conv(x)
if self.use_image_dataset:
x = identity + 0 * x
else:
x = identity + x
return x
class TemporalConvBlock(nn.Module):
def __init__(self,
in_dim,
out_dim=None,
dropout=0.0,
use_image_dataset=False):
super(TemporalConvBlock, self).__init__()
if out_dim is None:
out_dim = in_dim
self.in_dim = in_dim
self.out_dim = out_dim
self.use_image_dataset = use_image_dataset
# conv layers
self.conv1 = nn.Sequential(
nn.GroupNorm(32, in_dim), nn.SiLU(),
nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)))
self.conv2 = nn.Sequential(
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
# zero out the last layer params,so the conv block is identity
nn.init.zeros_(self.conv2[-1].weight)
nn.init.zeros_(self.conv2[-1].bias)
def forward(self, x):
identity = x
x = self.conv1(x)
x = self.conv2(x)
if self.use_image_dataset:
x = identity + 0 * x
else:
x = identity + x
return x
class TemporalConvBlock_v2(nn.Module):
def __init__(self,
in_dim,
out_dim=None,
dropout=0.0,
use_image_dataset=False):
super(TemporalConvBlock_v2, self).__init__()
if out_dim is None:
out_dim = in_dim
self.in_dim = in_dim
self.out_dim = out_dim
self.use_image_dataset = use_image_dataset
# conv layers
self.conv1 = nn.Sequential(
nn.GroupNorm(32, in_dim), nn.SiLU(),
nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)))
self.conv2 = nn.Sequential(
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
self.conv3 = nn.Sequential(
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
self.conv4 = nn.Sequential(
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
# zero out the last layer params,so the conv block is identity
nn.init.zeros_(self.conv4[-1].weight)
nn.init.zeros_(self.conv4[-1].bias)
def forward(self, x, variant_info=None):
if variant_info is not None and variant_info.get('type') == 'variant2':
# print(x.shape) # torch.Size([1, 320, 32, 90, 160])
_, _, f, _, _ = x.shape
assert f % 4 == 0, "f must be divisible by 4"
x_short = rearrange(x, "b c (n s) h w -> (n b) c s h w", n=4)
x_short = self.conv1(x_short)
x_short = self.conv2(x_short)
x_short = self.conv3(x_short)
x_short = self.conv4(x_short)
x_short = rearrange(x_short, "(n b) c s h w -> b c (n s) h w", n=4)
identity = x
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = x * (1-variant_info['alpha']) + x_short * variant_info['alpha']
elif variant_info is not None and variant_info.get('type') == 'variant1':
identity = x
x_long, x_short = x.chunk(2, dim=0)
x_short = rearrange(x_short, "b c (n s) h w -> (n b) c s h w", n=4)
x_short = self.conv1(x_short)
x_short = self.conv2(x_short)
x_short = self.conv3(x_short)
x_short = self.conv4(x_short)
x_short = rearrange(x_short, "(n b) c s h w -> b c (n s) h w", n=4)
x_long = self.conv1(x_long)
x_long = self.conv2(x_long)
x_long = self.conv3(x_long)
x_long = self.conv4(x_long)
x = torch.cat([x_long, x_short], dim=0)
elif variant_info is None:
identity = x
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
if self.use_image_dataset:
x = identity + 0.0 * x
else:
x = identity + x
return x
class Vid2VidSDUNet(nn.Module):
def __init__(self,
in_dim=4,
dim=320,
y_dim=1024,
context_dim=1024,
out_dim=4,
dim_mult=[1, 2, 4, 4],
num_heads=8,
head_dim=64,
num_res_blocks=2,
attn_scales=[1 / 1, 1 / 2, 1 / 4],
use_scale_shift_norm=True,
dropout=0.1,
temporal_attn_times=1,
temporal_attention=True,
use_checkpoint=True,
use_image_dataset=False,
use_fps_condition=False,
use_sim_mask=False,
training=False,
inpainting=True):
embed_dim = dim * 4
num_heads = num_heads if num_heads else dim // 32
super(Vid2VidSDUNet, self).__init__()
self.in_dim = in_dim
self.dim = dim
self.y_dim = y_dim
self.context_dim = context_dim
self.embed_dim = embed_dim
self.out_dim = out_dim
self.dim_mult = dim_mult
# for temporal attention
self.num_heads = num_heads
# for spatial attention
self.head_dim = head_dim
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.use_scale_shift_norm = use_scale_shift_norm
self.temporal_attn_times = temporal_attn_times
self.temporal_attention = temporal_attention
self.use_checkpoint = use_checkpoint
self.use_image_dataset = use_image_dataset
self.use_fps_condition = use_fps_condition
self.use_sim_mask = use_sim_mask
self.training = training
self.inpainting = inpainting
use_linear_in_temporal = False
transformer_depth = 1
disabled_sa = False
# params
enc_dims = [dim * u for u in [1] + dim_mult]
dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
shortcut_dims = []
scale = 1.0
# embeddings
self.time_embed = nn.Sequential(
nn.Linear(dim, embed_dim), nn.SiLU(),
nn.Linear(embed_dim, embed_dim))
if self.use_fps_condition:
self.fps_embedding = nn.Sequential(
nn.Linear(dim, embed_dim), nn.SiLU(),
nn.Linear(embed_dim, embed_dim))
nn.init.zeros_(self.fps_embedding[-1].weight)
nn.init.zeros_(self.fps_embedding[-1].bias)
# encoder
self.input_blocks = nn.ModuleList()
init_block = nn.ModuleList([nn.Conv2d(self.in_dim, dim, 3, padding=1)])
# need an initial temporal attention?
if temporal_attention:
if USE_TEMPORAL_TRANSFORMER:
init_block.append(
TemporalTransformer(
dim,
num_heads,
head_dim,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disabled_sa,
use_linear=use_linear_in_temporal,
multiply_zero=use_image_dataset,
is_ctrl=True
))
else:
init_block.append(
TemporalAttentionMultiBlock(
dim,
num_heads,
head_dim,
rotary_emb=self.rotary_emb,
temporal_attn_times=temporal_attn_times,
use_image_dataset=use_image_dataset))
self.input_blocks.append(init_block)
shortcut_dims.append(dim)
for i, (in_dim,
out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
for j in range(num_res_blocks):
block = nn.ModuleList([
ResBlock(
in_dim,
embed_dim,
dropout,
out_channels=out_dim,
use_scale_shift_norm=False,
use_image_dataset=use_image_dataset,
)
])
if scale in attn_scales:
block.append(
SpatialTransformer(
out_dim,
out_dim // head_dim,
head_dim,
depth=1,
context_dim=self.context_dim,
disable_self_attn=False,
use_linear=True,
is_ctrl=True
))
if self.temporal_attention:
if USE_TEMPORAL_TRANSFORMER:
block.append(
TemporalTransformer(
out_dim,
out_dim // head_dim,
head_dim,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disabled_sa,
use_linear=use_linear_in_temporal,
multiply_zero=use_image_dataset,
is_ctrl=True
))
else:
block.append(
TemporalAttentionMultiBlock(
out_dim,
num_heads,
head_dim,
rotary_emb=self.rotary_emb,
use_image_dataset=use_image_dataset,
use_sim_mask=use_sim_mask,
temporal_attn_times=temporal_attn_times))
in_dim = out_dim
self.input_blocks.append(block)
shortcut_dims.append(out_dim)
# downsample
if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
downsample = Downsample(
out_dim, True, dims=2, out_channels=out_dim)
shortcut_dims.append(out_dim)
scale /= 2.0
self.input_blocks.append(downsample)
self.middle_block = nn.ModuleList([
ResBlock(
out_dim,
embed_dim,
dropout,
use_scale_shift_norm=False,
use_image_dataset=use_image_dataset,
),
SpatialTransformer(
out_dim,
out_dim // head_dim,
head_dim,
depth=1,
context_dim=self.context_dim,
disable_self_attn=False,
use_linear=True,
is_ctrl=True
)
])
if self.temporal_attention:
if USE_TEMPORAL_TRANSFORMER:
self.middle_block.append(
TemporalTransformer(
out_dim,
out_dim // head_dim,
head_dim,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disabled_sa,
use_linear=use_linear_in_temporal,
multiply_zero=use_image_dataset,
is_ctrl=True
))
else:
self.middle_block.append(
TemporalAttentionMultiBlock(
out_dim,
num_heads,
head_dim,
rotary_emb=self.rotary_emb,
use_image_dataset=use_image_dataset,
use_sim_mask=use_sim_mask,
temporal_attn_times=temporal_attn_times))
self.middle_block.append(
ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False))
# decoder
self.output_blocks = nn.ModuleList()
for i, (in_dim,
out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
for j in range(num_res_blocks + 1):
block = nn.ModuleList([
ResBlock(
in_dim + shortcut_dims.pop(),
embed_dim,
dropout,
out_dim,
use_scale_shift_norm=False,
use_image_dataset=use_image_dataset,
)
])
if scale in attn_scales:
block.append(
SpatialTransformer(
out_dim,
out_dim // head_dim,
head_dim,
depth=1,
context_dim=1024,
disable_self_attn=False,
use_linear=True,
is_ctrl=True))
if self.temporal_attention:
if USE_TEMPORAL_TRANSFORMER:
block.append(
TemporalTransformer(
out_dim,
out_dim // head_dim,
head_dim,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disabled_sa,
use_linear=use_linear_in_temporal,
multiply_zero=use_image_dataset,
is_ctrl=True))
else:
block.append(
TemporalAttentionMultiBlock(
out_dim,
num_heads,
head_dim,
rotary_emb=self.rotary_emb,
use_image_dataset=use_image_dataset,
use_sim_mask=use_sim_mask,
temporal_attn_times=temporal_attn_times))
in_dim = out_dim
# upsample
if i != len(dim_mult) - 1 and j == num_res_blocks:
upsample = Upsample(
out_dim, True, dims=2.0, out_channels=out_dim)
scale *= 2.0
block.append(upsample)
self.output_blocks.append(block)
# head
self.out = nn.Sequential(
nn.GroupNorm(32, out_dim), nn.SiLU(),
nn.Conv2d(out_dim, self.out_dim, 3, padding=1))
# zero out the last layer params
nn.init.zeros_(self.out[-1].weight)
def forward(self,
x,
t,
y,
x_lr=None,
fps=None,
video_mask=None,
focus_present_mask=None,
prob_focus_present=0.,
mask_last_frame_num=0):
batch, c, f, h, w = x.shape
device = x.device
self.batch = batch
# image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored
if mask_last_frame_num > 0:
focus_present_mask = None
video_mask[-mask_last_frame_num:] = False
else:
focus_present_mask = default(
focus_present_mask, lambda: prob_mask_like(
(batch, ), prob_focus_present, device=device))
if self.temporal_attention and not USE_TEMPORAL_TRANSFORMER:
time_rel_pos_bias = self.time_rel_pos_bias(
x.shape[2], device=x.device)
else:
time_rel_pos_bias = None
# embeddings
e = self.time_embed(sinusoidal_embedding(t, self.dim))
context = y
# repeat f times for spatial e and context
e = e.repeat_interleave(repeats=f, dim=0)
context = context.repeat_interleave(repeats=f, dim=0)
# always in shape (b f) c h w, except for temporal layer
x = rearrange(x, 'b c f h w -> (b f) c h w')
# encoder
xs = []
for ind, block in enumerate(self.input_blocks):
x = self._forward_single(block, x, e, context, time_rel_pos_bias,
focus_present_mask, video_mask)
xs.append(x)
# middle
for block in self.middle_block:
x = self._forward_single(block, x, e, context, time_rel_pos_bias,
focus_present_mask, video_mask)
# decoder
for block in self.output_blocks:
x = torch.cat([x, xs.pop()], dim=1)
x = self._forward_single(
block,
x,
e,
context,
time_rel_pos_bias,
focus_present_mask,
video_mask,
reference=xs[-1] if len(xs) > 0 else None)
# head
x = self.out(x)
# reshape back to (b c f h w)
x = rearrange(x, '(b f) c h w -> b c f h w', b=batch)
return x
def _forward_single(self,
module,
x,
e,
context,
time_rel_pos_bias,
focus_present_mask,
video_mask,
reference=None):
if isinstance(module, ResidualBlock):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = x.contiguous()
x = module(x, e, reference)
elif isinstance(module, ResBlock):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = x.contiguous()
x = module(x, e, self.batch)
elif isinstance(module, SpatialTransformer):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, TemporalTransformer):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
x = module(x, context)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, CrossAttention):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, MemoryEfficientCrossAttention):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, BasicTransformerBlock):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, FeedForward):
x = module(x, context)
elif isinstance(module, Upsample):
x = module(x)
elif isinstance(module, Downsample):
x = module(x)
elif isinstance(module, Resample):
x = module(x, reference)
elif isinstance(module, TemporalAttentionBlock):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, TemporalAttentionMultiBlock):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, InitTemporalConvBlock):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
x = module(x)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, TemporalConvBlock):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
x = module(x)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, nn.ModuleList):
for block in module:
x = self._forward_single(block, x, e, context,
time_rel_pos_bias, focus_present_mask,
video_mask, reference)
else:
x = module(x)
return x
class ControlledV2VUNet(Vid2VidSDUNet):
def __init__(self):
super(ControlledV2VUNet, self).__init__()
self.VideoControlNet = VideoControlNet()
def forward(self,
x,
t,
y,
hint=None,
variant_info=None,
hint_chunk=None,
t_hint=None,
s_cond=None,
mask_cond=None,
x_lr=None,
fps=None,
mask=None,
video_mask=None,
focus_present_mask=None,
prob_focus_present=0.,
mask_last_frame_num=0,
):
batch, _, f, _, _= x.shape
device = x.device
self.batch = batch
# Process text (new added for t5 encoder)
# y = self.VideoControlNet.y_embedder(y, self.training).squeeze(1) # [1, 1, 120, 4096] -> [B, 1, 120, 1024].squeeze(1) -> [B, 120, 1024]
if hint_chunk is not None:
hint = hint_chunk
control = self.VideoControlNet(x, t, y, hint=hint, t_hint=t_hint, \
mask_cond=mask_cond, s_cond=s_cond, \
variant_info=variant_info)
# image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored
if mask_last_frame_num > 0:
focus_present_mask = None
video_mask[-mask_last_frame_num:] = False
else:
focus_present_mask = default(
focus_present_mask, lambda: prob_mask_like(
(batch, ), prob_focus_present, device=device))
if self.temporal_attention and not USE_TEMPORAL_TRANSFORMER:
time_rel_pos_bias = self.time_rel_pos_bias(
x.shape[2], device=x.device)
else:
time_rel_pos_bias = None
e = self.time_embed(sinusoidal_embedding(t, self.dim))
e = e.repeat_interleave(repeats=f, dim=0)
# context = y
context = y.repeat_interleave(repeats=f, dim=0)
# always in shape (b f) c h w, except for temporal layer
x = rearrange(x, 'b c f h w -> (b f) c h w')
# encoder
xs = []
for block in self.input_blocks:
x = self._forward_single(block, x, e, context, time_rel_pos_bias,
focus_present_mask, video_mask, variant_info=variant_info)
xs.append(x)
# middle
for block in self.middle_block:
x = self._forward_single(block, x, e, context, time_rel_pos_bias,
focus_present_mask, video_mask, variant_info=variant_info)
if control is not None:
x = control.pop() + x
# decoder
for block in self.output_blocks:
if control is None:
x = torch.cat([x, xs.pop()], dim=1)
else:
x = torch.cat([x, xs.pop() + control.pop()], dim=1)
x = self._forward_single(
block,
x,
e,
context,
time_rel_pos_bias,
focus_present_mask,
video_mask,
reference=xs[-1] if len(xs) > 0 else None,
variant_info=variant_info)
# head
x = self.out(x)
# reshape back to (b c f h w)
x = rearrange(x, '(b f) c h w -> b c f h w', b=batch)
return x
def _forward_single(self,
module,
x,
e,
context,
time_rel_pos_bias,
focus_present_mask,
video_mask,
reference=None,
variant_info=None):
variant_info = None # For Debug
if isinstance(module, ResidualBlock):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = x.contiguous()
x = module(x, e, reference)
elif isinstance(module, ResBlock):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = x.contiguous()
x = module(x, e, self.batch, variant_info)
elif isinstance(module, SpatialTransformer):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, TemporalTransformer):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
x = module(x, context)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, CrossAttention):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, MemoryEfficientCrossAttention):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, BasicTransformerBlock):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, FeedForward):
x = module(x, context)
elif isinstance(module, Upsample):
x = module(x)
elif isinstance(module, Downsample):
x = module(x)
elif isinstance(module, Resample):
x = module(x, reference)
elif isinstance(module, TemporalAttentionBlock):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, TemporalAttentionMultiBlock):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, InitTemporalConvBlock):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
x = module(x)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, TemporalConvBlock):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
x = module(x)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, nn.ModuleList):
for block in module:
x = self._forward_single(block, x, e, context,
time_rel_pos_bias, focus_present_mask,
video_mask, reference, variant_info)
else:
x = module(x)
return x
class VideoControlNet(nn.Module):
def __init__(self,
in_dim=4,
dim=320,
y_dim=1024,
context_dim=1024,
out_dim=4,
dim_mult=[1, 2, 4, 4],
num_heads=8,
head_dim=64,
num_res_blocks=2,
attn_scales=[1 / 1, 1 / 2, 1 / 4],
use_scale_shift_norm=True,
dropout=0.1,
temporal_attn_times=1,
temporal_attention=True,
use_checkpoint=True,
use_image_dataset=False,
use_fps_condition=False,
use_sim_mask=False,
training=False,
inpainting=True):
embed_dim = dim * 4
num_heads = num_heads if num_heads else dim // 32
super(VideoControlNet, self).__init__()
self.in_dim = in_dim
self.dim = dim
self.y_dim = y_dim
self.context_dim = context_dim
self.embed_dim = embed_dim
self.out_dim = out_dim
self.dim_mult = dim_mult
# for temporal attention
self.num_heads = num_heads
# for spatial attention
self.head_dim = head_dim
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.use_scale_shift_norm = use_scale_shift_norm
self.temporal_attn_times = temporal_attn_times
self.temporal_attention = temporal_attention
self.use_checkpoint = use_checkpoint
self.use_image_dataset = use_image_dataset
self.use_fps_condition = use_fps_condition
self.use_sim_mask = use_sim_mask
self.training = training
self.inpainting = inpainting
use_linear_in_temporal = False
transformer_depth = 1
disabled_sa = False
# params
enc_dims = [dim * u for u in [1] + dim_mult]
dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
shortcut_dims = []
scale = 1.0
# CaptionEmbedder (new add)
# approx_gelu = lambda: nn.GELU(approximate="tanh")
# self.y_embedder = CaptionEmbedder(
# in_channels=4096,
# hidden_size=1024,
# uncond_prob=0.1,
# act_layer=approx_gelu,
# token_num=120,
# )
# embeddings
self.time_embed = nn.Sequential(
nn.Linear(dim, embed_dim), nn.SiLU(),
nn.Linear(embed_dim, embed_dim))
# self.hint_time_zero_linear = zero_module(nn.Linear(embed_dim, embed_dim))
# scale prompt
# self.scale_cond = nn.Sequential(
# nn.Linear(dim, embed_dim), nn.SiLU(),
# zero_module(nn.Linear(embed_dim, embed_dim)))
if self.use_fps_condition:
self.fps_embedding = nn.Sequential(
nn.Linear(dim, embed_dim), nn.SiLU(),
nn.Linear(embed_dim, embed_dim))
nn.init.zeros_(self.fps_embedding[-1].weight)
nn.init.zeros_(self.fps_embedding[-1].bias)
# encoder
self.input_blocks = nn.ModuleList()
init_block = nn.ModuleList([nn.Conv2d(self.in_dim, dim, 3, padding=1)])
# need an initial temporal attention?
if temporal_attention:
if USE_TEMPORAL_TRANSFORMER:
init_block.append(
TemporalTransformer(
dim,
num_heads,
head_dim,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disabled_sa,
use_linear=use_linear_in_temporal,
multiply_zero=use_image_dataset,
is_ctrl=True,))
else:
init_block.append(
TemporalAttentionMultiBlock(
dim,
num_heads,
head_dim,
rotary_emb=self.rotary_emb,
temporal_attn_times=temporal_attn_times,
use_image_dataset=use_image_dataset))
self.input_blocks.append(init_block)
self.zero_convs = nn.ModuleList([self.make_zero_conv(dim)])
shortcut_dims.append(dim)
for i, (in_dim,
out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
for j in range(num_res_blocks):
block = nn.ModuleList([
ResBlock(
in_dim,
embed_dim,
dropout,
out_channels=out_dim,
use_scale_shift_norm=False,
use_image_dataset=use_image_dataset,
)
])
if scale in attn_scales:
block.append(
SpatialTransformer(
out_dim,
out_dim // head_dim,
head_dim,
depth=1,
context_dim=self.context_dim,
disable_self_attn=False,
use_linear=True,
is_ctrl=True))
if self.temporal_attention:
if USE_TEMPORAL_TRANSFORMER:
block.append(
TemporalTransformer(
out_dim,
out_dim // head_dim,
head_dim,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disabled_sa,
use_linear=use_linear_in_temporal,
multiply_zero=use_image_dataset,
is_ctrl=True,))
else:
block.append(
TemporalAttentionMultiBlock(
out_dim,
num_heads,
head_dim,
rotary_emb=self.rotary_emb,
use_image_dataset=use_image_dataset,
use_sim_mask=use_sim_mask,
temporal_attn_times=temporal_attn_times))
in_dim = out_dim
self.input_blocks.append(block)
self.zero_convs.append(self.make_zero_conv(out_dim))
shortcut_dims.append(out_dim)
# downsample
if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
downsample = Downsample(
out_dim, True, dims=2, out_channels=out_dim)
shortcut_dims.append(out_dim)
scale /= 2.0
self.input_blocks.append(downsample)
self.zero_convs.append(self.make_zero_conv(out_dim))
self.middle_block = nn.ModuleList([
ResBlock(
out_dim,
embed_dim,
dropout,
use_scale_shift_norm=False,
use_image_dataset=use_image_dataset,
),
SpatialTransformer(
out_dim,
out_dim // head_dim,
head_dim,
depth=1,
context_dim=self.context_dim,
disable_self_attn=False,
use_linear=True,
is_ctrl=True)
])
if self.temporal_attention:
if USE_TEMPORAL_TRANSFORMER:
self.middle_block.append(
TemporalTransformer(
out_dim,
out_dim // head_dim,
head_dim,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disabled_sa,
use_linear=use_linear_in_temporal,
multiply_zero=use_image_dataset,
is_ctrl=True,
))
else:
self.middle_block.append(
TemporalAttentionMultiBlock(
out_dim,
num_heads,
head_dim,
rotary_emb=self.rotary_emb,
use_image_dataset=use_image_dataset,
use_sim_mask=use_sim_mask,
temporal_attn_times=temporal_attn_times))
self.middle_block.append(
ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False))
self.middle_block_out = self.make_zero_conv(embed_dim)
'''
add prompt
'''
add_dim = 320
self.add_dim = add_dim
self.input_hint_block = zero_module(nn.Conv2d(4, add_dim, 3, padding=1))
def make_zero_conv(self, in_channels, out_channels=None):
out_channels = in_channels if out_channels is None else out_channels
return TimestepEmbedSequential(zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0)))
def forward(self,
x,
t,
y,
s_cond=None,
hint=None,
variant_info=None,
t_hint=None,
mask_cond=None,
fps=None,
video_mask=None,
focus_present_mask=None,
prob_focus_present=0.,
mask_last_frame_num=0):
batch, _, f, _, _ = x.shape
device = x.device
self.batch = batch
# image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored
if mask_last_frame_num > 0:
focus_present_mask = None
video_mask[-mask_last_frame_num:] = False
else:
focus_present_mask = default(
focus_present_mask, lambda: prob_mask_like(
(batch, ), prob_focus_present, device=device))
if self.temporal_attention and not USE_TEMPORAL_TRANSFORMER:
time_rel_pos_bias = self.time_rel_pos_bias(
x.shape[2], device=x.device)
else:
time_rel_pos_bias = None
if hint is not None:
# add = x.new_zeros(batch, self.add_dim, f, h, w)
hint = rearrange(hint, 'b c f h w -> (b f) c h w')
hint = self.input_hint_block(hint)
# hint = rearrange(hint, '(b f) c h w -> b c f h w', b = batch)
e = self.time_embed(sinusoidal_embedding(t, self.dim))
e = e.repeat_interleave(repeats=f, dim=0)
context = y.repeat_interleave(repeats=f, dim=0)
# always in shape (b f) c h w, except for temporal layer
x = rearrange(x, 'b c f h w -> (b f) c h w')
# print('before x shape:', x.shape) [64, 320, 90, 160]
# print('hint shape:', hint.shape) [32, 320, 90, 160]
# encoder
xs = []
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
if hint is not None:
for block in module:
x = self._forward_single(block, x, e, context, time_rel_pos_bias,
focus_present_mask, video_mask, variant_info=variant_info)
if not isinstance(block, TemporalTransformer):
if hint is not None:
x += hint
hint = None
else:
x = self._forward_single(module, x, e, context, time_rel_pos_bias,
focus_present_mask, video_mask, variant_info=variant_info)
xs.append(zero_conv(x, e, context))
# middle
for block in self.middle_block:
x = self._forward_single(block, x, e, context, time_rel_pos_bias,
focus_present_mask, video_mask, variant_info=variant_info)
xs.append(self.middle_block_out(x, e, context))
return xs
def _forward_single(self,
module,
x,
e,
context,
time_rel_pos_bias,
focus_present_mask,
video_mask,
reference=None,
variant_info=None,):
# variant_info = None # For Debug
if isinstance(module, ResidualBlock):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = x.contiguous()
x = module(x, e, reference)
elif isinstance(module, ResBlock):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = x.contiguous()
x = module(x, e, self.batch, variant_info)
elif isinstance(module, SpatialTransformer):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, TemporalTransformer):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
# print("x shape:", x.shape) # [2, 320, 32, 90, 160]
x = module(x, context)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, CrossAttention):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, MemoryEfficientCrossAttention):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, BasicTransformerBlock):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = module(x, context)
elif isinstance(module, FeedForward):
x = module(x, context)
elif isinstance(module, Upsample):
x = module(x)
elif isinstance(module, Downsample):
x = module(x)
elif isinstance(module, Resample):
x = module(x, reference)
elif isinstance(module, TemporalAttentionBlock):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, TemporalAttentionMultiBlock):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, InitTemporalConvBlock):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
x = module(x)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, TemporalConvBlock):
module = checkpoint_wrapper(
module) if self.use_checkpoint else module
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
x = module(x)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, nn.ModuleList):
for block in module:
x = self._forward_single(block, x, e, context,
time_rel_pos_bias, focus_present_mask,
video_mask, reference, variant_info)
else:
x = module(x)
return x
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def forward(self, x, emb, context=None):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context)
else:
x = layer(x)
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment